// config.h // // Copyright (c) 2023-2024 Junpei Kawamoto // // This software is released under the MIT License. // // http://opensource.org/licenses/mit-license.php #pragma once #include #include #include #include #include "rust/cxx.h" using ctranslate2::BatchType; using ctranslate2::ComputeType; using ctranslate2::Device; using ctranslate2::get_device_count; using ctranslate2::get_log_level; using ctranslate2::get_random_seed; using ctranslate2::LogLevel; using ctranslate2::ReplicaPoolConfig; using ctranslate2::set_log_level; using ctranslate2::set_random_seed; inline std::unique_ptr replica_pool_config( size_t num_threads_per_replica, int32_t max_queued_batches, int cpu_core_offset ) { return std::make_unique(ReplicaPoolConfig { num_threads_per_replica, static_cast(max_queued_batches), cpu_core_offset, }); } struct Config { Device device; ComputeType compute_type; rust::Slice device_indices; bool tensor_parallel; std::unique_ptr replica_pool_config; }; inline std::unique_ptr config( Device device, ComputeType compute_type, rust::Slice device_indices, bool tensor_parallel, std::unique_ptr replica_pool_config ) { return std::make_unique(Config { device, compute_type, device_indices, tensor_parallel, std::move(replica_pool_config), }); }