// generator.h // // Copyright (c) 2023-2024 Junpei Kawamoto // // This software is released under the MIT License. // // http://opensource.org/licenses/mit-license.php #pragma once #include #include #include "rust/cxx.h" #include "config.h" struct VecStr; struct GenerationOptions; struct GenerationResult; struct GenerationStepResult; struct GenerationCallbackBox; class Generator { private: std::unique_ptr impl; public: Generator(std::unique_ptr impl) : impl(std::move(impl)) { } rust::Vec generate_batch( const rust::Vec& start_tokens, const GenerationOptions& options, bool has_callback, GenerationCallbackBox& callback ) const; inline size_t num_queued_batches() const { return this->impl->num_queued_batches(); } inline size_t num_active_batches() const { return this->impl->num_active_batches(); } inline size_t num_replicas() const { return this->impl->num_replicas(); } }; inline std::unique_ptr generator( rust::Str model_path, std::unique_ptr config ) { return std::make_unique(std::make_unique( static_cast(model_path), config->device, config->compute_type, std::vector(config->device_indices.begin(), config->device_indices.end()), config->tensor_parallel, *config->replica_pool_config )); }