// translator.h // // Copyright (c) 2023-2024 Junpei Kawamoto // // This software is released under the MIT License. // // http://opensource.org/licenses/mit-license.php #pragma once #include #include #include "rust/cxx.h" #include "config.h" struct VecStr; struct TranslationOptions; struct TranslationResult; struct GenerationStepResult; struct TranslationCallbackBox; class Translator { private: std::unique_ptr impl; public: Translator(std::unique_ptr impl) : impl(std::move(impl)) { } rust::Vec translate_batch( const rust::Vec& source, const TranslationOptions& options, bool has_callback, TranslationCallbackBox& callback ) const; rust::Vec translate_batch_with_target_prefix( const rust::Vec& source, const rust::Vec& target_prefix, const TranslationOptions& options, bool has_callback, TranslationCallbackBox& callback ) const; inline size_t num_queued_batches() const { return this->impl->num_queued_batches(); } inline size_t num_active_batches() const { return this->impl->num_active_batches(); } inline size_t num_replicas() const { return this->impl->num_replicas(); } }; inline std::unique_ptr translator( rust::Str model_path, std::unique_ptr config ) { return std::make_unique(std::make_unique( static_cast(model_path), config->device, config->compute_type, std::vector(config->device_indices.begin(), config->device_indices.end()), config->tensor_parallel, *config->replica_pool_config )); }