#include "module.h" #include #include "replica_pool.h" namespace ctranslate2 { namespace python { class GeneratorWrapper : public ReplicaPoolHelper { public: using ReplicaPoolHelper::ReplicaPoolHelper; std::variant, std::vector>> generate_batch(const BatchTokens& tokens, size_t max_batch_size, const std::string& batch_type_str, bool asynchronous, size_t beam_size, float patience, size_t num_hypotheses, float length_penalty, float repetition_penalty, size_t no_repeat_ngram_size, bool disable_unk, const std::optional>>& suppress_sequences, const std::optional& end_token, bool return_end_token, size_t max_length, size_t min_length, const std::optional>& static_prompt, bool cache_static_prompt, bool include_prompt_in_result, bool return_scores, bool return_logits_vocab, bool return_alternatives, float min_alternative_expansion_prob, size_t sampling_topk, float sampling_topp, float sampling_temperature, std::function callback) { if (tokens.empty()) return {}; BatchType batch_type = str_to_batch_type(batch_type_str); GenerationOptions options; options.beam_size = beam_size; options.patience = patience; options.length_penalty = length_penalty; options.repetition_penalty = repetition_penalty; options.no_repeat_ngram_size = no_repeat_ngram_size; options.disable_unk = disable_unk; options.sampling_topk = sampling_topk; options.sampling_topp = sampling_topp; options.sampling_temperature = sampling_temperature; options.max_length = max_length; options.min_length = min_length; options.num_hypotheses = num_hypotheses; options.return_end_token = return_end_token; options.return_scores = return_scores; options.return_logits_vocab = return_logits_vocab; options.return_alternatives = return_alternatives; options.cache_static_prompt = cache_static_prompt; options.include_prompt_in_result = include_prompt_in_result; options.min_alternative_expansion_prob = min_alternative_expansion_prob; options.callback = std::move(callback); if (suppress_sequences) options.suppress_sequences = suppress_sequences.value(); if (end_token) options.end_token = end_token.value(); if (static_prompt) options.static_prompt = static_prompt.value(); std::shared_lock lock(_mutex); assert_model_is_ready(); auto futures = _pool->generate_batch_async(tokens, options, max_batch_size, batch_type); return maybe_wait_on_futures(std::move(futures), asynchronous); } std::variant, std::vector>> score_batch(const BatchTokens& tokens, size_t max_batch_size, const std::string& batch_type_str, size_t max_input_length, bool asynchronous) { const auto batch_type = str_to_batch_type(batch_type_str); ScoringOptions options; options.max_input_length = max_input_length; std::shared_lock lock(_mutex); assert_model_is_ready(); auto futures = _pool->score_batch_async(tokens, options, max_batch_size, batch_type); return maybe_wait_on_futures(std::move(futures), asynchronous); } StorageView forward_batch(const std::variant& inputs, const std::optional& lengths, const bool return_log_probs) { std::future future; switch (inputs.index()) { case 0: future = _pool->forward_batch_async(std::get(inputs), return_log_probs); break; case 1: future = _pool->forward_batch_async(std::get(inputs), return_log_probs); break; case 2: if (!lengths) throw std::invalid_argument("lengths vector is required when passing a dense input"); future = _pool->forward_batch_async(std::get(inputs), lengths.value(), return_log_probs); break; } return future.get(); } }; void register_generator(py::module& m) { py::class_( m, "Generator", R"pbdoc( A text generator. Example: >>> generator = ctranslate2.Generator("model/", device="cpu") >>> generator.generate_batch([[""]], max_length=50, sampling_topk=20) )pbdoc") .def(py::init>&, const StringOrMap&, size_t, size_t, long, bool, bool, py::object>(), py::arg("model_path"), py::arg("device")="cpu", py::kw_only(), py::arg("device_index")=0, py::arg("compute_type")="default", py::arg("inter_threads")=1, py::arg("intra_threads")=0, py::arg("max_queued_batches")=0, py::arg("flash_attention")=false, py::arg("tensor_parallel")=false, py::arg("files")=py::none(), R"pbdoc( Initializes the generator. Arguments: model_path: Path to the CTranslate2 model directory. device: Device to use (possible values are: cpu, cuda, auto). device_index: Device IDs where to place this generator on. compute_type: Model computation type or a dictionary mapping a device name to the computation type (possible values are: default, auto, int8, int8_float32, int8_float16, int8_bfloat16, int16, float16, bfloat16, float32). inter_threads: Maximum number of parallel generations. intra_threads: Number of OpenMP threads per generator (0 to use a default value). max_queued_batches: Maximum numbers of batches in the queue (-1 for unlimited, 0 for an automatic value). When the queue is full, future requests will block until a free slot is available. flash_attention: run model with flash attention 2 for self-attention layer tensor_parallel: run model with tensor parallel mode. files: Load model files from the memory. This argument is a dictionary mapping file names to file contents as file-like or bytes objects. If this is set, :obj:`model_path` acts as an identifier for this model. )pbdoc") .def_property_readonly("device", &GeneratorWrapper::device, "Device this generator is running on.") .def_property_readonly("device_index", &GeneratorWrapper::device_index, "List of device IDs where this generator is running on.") .def_property_readonly("compute_type", &GeneratorWrapper::compute_type, "Computation type used by the model.") .def_property_readonly("num_generators", &GeneratorWrapper::num_replicas, "Number of generators backing this instance.") .def_property_readonly("num_queued_batches", &GeneratorWrapper::num_queued_batches, "Number of batches waiting to be processed.") .def_property_readonly("tensor_parallel", &GeneratorWrapper::tensor_parallel, "Run model with tensor parallel mode.") .def_property_readonly("num_active_batches", &GeneratorWrapper::num_active_batches, "Number of batches waiting to be processed or currently processed.") .def("generate_batch", &GeneratorWrapper::generate_batch, py::arg("start_tokens"), py::kw_only(), py::arg("max_batch_size")=0, py::arg("batch_type")="examples", py::arg("asynchronous")=false, py::arg("beam_size")=1, py::arg("patience")=1, py::arg("num_hypotheses")=1, py::arg("length_penalty")=1, py::arg("repetition_penalty")=1, py::arg("no_repeat_ngram_size")=0, py::arg("disable_unk")=false, py::arg("suppress_sequences")=py::none(), py::arg("end_token")=py::none(), py::arg("return_end_token")=false, py::arg("max_length")=512, py::arg("min_length")=0, py::arg("static_prompt")=py::none(), py::arg("cache_static_prompt")=true, py::arg("include_prompt_in_result")=true, py::arg("return_scores")=false, py::arg("return_logits_vocab")=false, py::arg("return_alternatives")=false, py::arg("min_alternative_expansion_prob")=0, py::arg("sampling_topk")=1, py::arg("sampling_topp")=1, py::arg("sampling_temperature")=1, py::arg("callback")=nullptr, py::call_guard(), R"pbdoc( Generates from a batch of start tokens. Note: The way the start tokens are forwarded in the decoder depends on the argument :obj:`include_prompt_in_result`: * If :obj:`include_prompt_in_result` is ``True`` (the default), the decoding loop is constrained to generate the start tokens that are then included in the result. * If :obj:`include_prompt_in_result` is ``False``, the start tokens are forwarded in the decoder at once to initialize its state (i.e. the KV cache for Transformer models). For variable-length inputs, only the tokens up to the minimum length in the batch are forwarded at once. The remaining tokens are generated in the decoding loop with constrained decoding. Consider setting ``include_prompt_in_result=False`` to increase the performance for long inputs. Arguments: start_tokens: Batch of start tokens. If the decoder starts from a special start token like ````, this token should be added to this input. max_batch_size: The maximum batch size. If the number of inputs is greater than :obj:`max_batch_size`, the inputs are sorted by length and split by chunks of :obj:`max_batch_size` examples so that the number of padding positions is minimized. batch_type: Whether :obj:`max_batch_size` is the number of "examples" or "tokens". asynchronous: Run the generation asynchronously. beam_size: Beam size (1 for greedy search). patience: Beam search patience factor, as described in https://arxiv.org/abs/2204.05424. The decoding will continue until beam_size*patience hypotheses are finished. num_hypotheses: Number of hypotheses to return. length_penalty: Exponential penalty applied to the length during beam search. repetition_penalty: Penalty applied to the score of previously generated tokens (set > 1 to penalize). no_repeat_ngram_size: Prevent repetitions of ngrams with this size (set 0 to disable). disable_unk: Disable the generation of the unknown token. suppress_sequences: Disable the generation of some sequences of tokens. end_token: Stop the decoding on one of these tokens (defaults to the model EOS token). return_end_token: Include the end token in the results. max_length: Maximum generation length. min_length: Minimum generation length. static_prompt: If the model expects a static prompt (a.k.a. system prompt) it can be set here to simplify the inputs and optionally cache the model state for this prompt to accelerate future generations. cache_static_prompt: Cache the model state after the static prompt and reuse it for future generations using the same static prompt. include_prompt_in_result: Include the :obj:`start_tokens` in the result. return_scores: Include the scores in the output. return_logits_vocab: Include log probs for each token in the output return_alternatives: Return alternatives at the first unconstrained decoding position. min_alternative_expansion_prob: Minimum initial probability to expand an alternative. sampling_topk: Randomly sample predictions from the top K candidates. sampling_topp: Keep the most probable tokens whose cumulative probability exceeds this value. sampling_temperature: Sampling temperature to generate more random samples. callback: Optional function that is called for each generated token when :obj:`beam_size` is 1. If the callback function returns ``True``, the decoding will stop for this batch index. Returns: A list of generation results. See Also: `GenerationOptions `_ structure in the C++ library. )pbdoc") .def("score_batch", &GeneratorWrapper::score_batch, py::arg("tokens"), py::kw_only(), py::arg("max_batch_size")=0, py::arg("batch_type")="examples", py::arg("max_input_length")=1024, py::arg("asynchronous")=false, py::call_guard(), R"pbdoc( Scores a batch of tokens. Arguments: tokens: Batch of tokens to score. If the model expects special start or end tokens, they should also be added to this input. max_batch_size: The maximum batch size. If the number of inputs is greater than :obj:`max_batch_size`, the inputs are sorted by length and split by chunks of :obj:`max_batch_size` examples so that the number of padding positions is minimized. batch_type: Whether :obj:`max_batch_size` is the number of "examples" or "tokens". max_input_length: Truncate inputs after this many tokens (0 to disable). asynchronous: Run the scoring asynchronously. Returns: A list of scoring results. )pbdoc") .def("forward_batch", &GeneratorWrapper::forward_batch, py::arg("inputs"), py::arg("lengths")=py::none(), py::kw_only(), py::arg("return_log_probs")=false, py::call_guard(), R"pbdoc( Forwards a batch of sequences in the generator. Arguments: inputs: A batch of sequences either as string tokens or token IDs. This argument can also be a dense int32 array with shape ``[batch_size, max_length]`` (e.g. created from a Numpy array or PyTorch tensor). lengths: The length of each sequence as a int32 array with shape ``[batch_size]``. Required when :obj:`inputs` is a dense array. return_log_probs: If ``True``, the method returns the log probabilties instead of the unscaled logits. Returns: The output logits, or the output log probabilities if :obj:`return_log_probs` is enabled. )pbdoc") .def("unload_model", &GeneratorWrapper::unload_model, py::arg("to_cpu")=false, py::call_guard(), R"pbdoc( Unloads the model attached to this generator but keep enough runtime context to quickly resume generator on the initial device. The model is not guaranteed to be unloaded if generations are running concurrently. Arguments: to_cpu: If ``True``, the model is moved to the CPU memory and not fully unloaded. )pbdoc") .def("load_model", &GeneratorWrapper::load_model, py::arg("keep_cache")=false, py::call_guard(), R"pbdoc( Loads the model back to the initial device. Arguments: keep_cache: If ``True``, the model cache in the CPU memory is not deleted if it exists. )pbdoc") .def_property_readonly("model_is_loaded", &GeneratorWrapper::model_is_loaded, "Whether the model is loaded on the initial device and ready to be used.") ; } } }