import asyncio import collections import itertools import queue import threading from typing import AsyncIterable, Callable, Iterable, List, Optional, Union from ctranslate2._ext import ( GenerationResult, GenerationStepResult, Generator, ScoringResult, TranslationResult, Translator, ) def register_extensions(): """Registers additional attributes to compiled modules.""" setattr(Translator, "translate_iterable", translator_translate_iterable) setattr(Translator, "score_iterable", translator_score_iterable) setattr(Translator, "generate_tokens", translator_generate_tokens) setattr(Generator, "generate_iterable", generator_generate_iterable) setattr(Generator, "score_iterable", generator_score_iterable) setattr(Generator, "generate_tokens", generator_generate_tokens) setattr(Generator, "async_generate_tokens", generator_async_generate_tokens) def translator_translate_iterable( translator: Translator, source: Iterable[List[str]], target_prefix: Optional[Iterable[List[str]]] = None, max_batch_size: int = 32, batch_type: str = "examples", **kwargs, ) -> Iterable[TranslationResult]: """Translates an iterable of tokenized examples. This method is built on top of :meth:`ctranslate2.Translator.translate_batch` to efficiently translate an arbitrarily large stream of data. It enables the following optimizations: * stream processing (the iterable is not fully materialized in memory) * parallel translations (if the translator has multiple workers) * asynchronous batch prefetching * local sorting by length Arguments: source: An iterable of tokenized source examples. target_prefix: An optional iterable of tokenized target prefixes. max_batch_size: The maximum batch size. batch_type: Whether :obj:`max_batch_size` is the number of "examples" or "tokens". **kwargs: Any translation options accepted by :meth:`ctranslate2.Translator.translate_batch`. Returns: A generator iterator over :class:`ctranslate2.TranslationResult` instances. Example: This method can be used to efficiently translate text files: .. code-block:: python # Replace by your own tokenization and detokenization functions. tokenize_fn = lambda line: line.strip().split() detokenize_fn = lambda tokens: " ".join(tokens) with open("input.txt") as input_file: source = map(tokenize_fn, input_file) results = translator.translate_iterable(source, max_batch_size=64) for result in results: tokens = result.hypotheses[0] target = detokenize_fn(tokens) print(target) """ iterables = [source] if target_prefix is not None: iterables.append(target_prefix) yield from _process_iterable( translator.translate_batch, iterables, max_batch_size, batch_type, **kwargs, ) def translator_score_iterable( translator: Translator, source: Iterable[List[str]], target: Iterable[List[str]], max_batch_size: int = 64, batch_type: str = "examples", **kwargs, ) -> Iterable[ScoringResult]: """Scores an iterable of tokenized examples. This method is built on top of :meth:`ctranslate2.Translator.score_batch` to efficiently score an arbitrarily large stream of data. It enables the following optimizations: * stream processing (the iterable is not fully materialized in memory) * parallel scoring (if the translator has multiple workers) * asynchronous batch prefetching * local sorting by length Arguments: source: An iterable of tokenized source examples. target: An iterable of tokenized target examples. max_batch_size: The maximum batch size. batch_type: Whether :obj:`max_batch_size` is the number of "examples" or "tokens". **kwargs: Any scoring options accepted by :meth:`ctranslate2.Translator.score_batch`. Returns: A generator iterator over :class:`ctranslate2.ScoringResult` instances. """ yield from _process_iterable( translator.score_batch, [source, target], max_batch_size, batch_type, **kwargs, ) def generator_generate_iterable( generator: Generator, start_tokens: Iterable[List[str]], max_batch_size: int = 32, batch_type: str = "examples", **kwargs, ) -> Iterable[GenerationResult]: """Generates from an iterable of tokenized prompts. This method is built on top of :meth:`ctranslate2.Generator.generate_batch` to efficiently run generation on an arbitrarily large stream of data. It enables the following optimizations: * stream processing (the iterable is not fully materialized in memory) * parallel generations (if the generator has multiple workers) * asynchronous batch prefetching * local sorting by length Arguments: start_tokens: An iterable of tokenized prompts. max_batch_size: The maximum batch size. batch_type: Whether :obj:`max_batch_size` is the number of "examples" or "tokens". **kwargs: Any generation options accepted by :meth:`ctranslate2.Generator.generate_batch`. Returns: A generator iterator over :class:`ctranslate2.GenerationResult` instances. """ yield from _process_iterable( generator.generate_batch, [start_tokens], max_batch_size, batch_type, **kwargs, ) def generator_score_iterable( generator: Generator, tokens: Iterable[List[str]], max_batch_size: int = 64, batch_type: str = "examples", **kwargs, ) -> Iterable[ScoringResult]: """Scores an iterable of tokenized examples. This method is built on top of :meth:`ctranslate2.Generator.score_batch` to efficiently score an arbitrarily large stream of data. It enables the following optimizations: * stream processing (the iterable is not fully materialized in memory) * parallel scoring (if the generator has multiple workers) * asynchronous batch prefetching * local sorting by length Arguments: tokens: An iterable of tokenized examples. max_batch_size: The maximum batch size. batch_type: Whether :obj:`max_batch_size` is the number of "examples" or "tokens". **kwargs: Any score options accepted by :meth:`ctranslate2.Generator.score_batch`. Returns: A generator iterator over :class:`ctranslate2.ScoringResult` instances. """ yield from _process_iterable( generator.score_batch, [tokens], max_batch_size, batch_type, **kwargs, ) def translator_generate_tokens( translator: Translator, source: List[str], target_prefix: Optional[List[str]] = None, *, max_decoding_length: int = 256, min_decoding_length: int = 1, sampling_topk: int = 1, sampling_topp: float = 1, sampling_temperature: float = 1, return_log_prob: bool = False, repetition_penalty: float = 1, no_repeat_ngram_size: int = 0, disable_unk: bool = False, suppress_sequences: Optional[List[List[str]]] = None, end_token: Optional[Union[str, List[str], List[int]]] = None, max_input_length: int = 1024, use_vmap: bool = False, ) -> Iterable[GenerationStepResult]: """Yields tokens as they are generated by the model. Arguments: source: Source tokens. target_prefix: Optional target prefix tokens. max_decoding_length: Maximum prediction length. min_decoding_length: Minimum prediction length. sampling_topk: Randomly sample predictions from the top K candidates. sampling_topp: Keep the most probable tokens whose cumulative probability exceeds this value. sampling_temperature: Sampling temperature to generate more random samples. return_log_prob: Include the token log probability in the result. repetition_penalty: Penalty applied to the score of previously generated tokens (set > 1 to penalize). no_repeat_ngram_size: Prevent repetitions of ngrams with this size (set 0 to disable). disable_unk: Disable the generation of the unknown token. suppress_sequences: Disable the generation of some sequences of tokens. end_token: Stop the decoding on one of these tokens (defaults to the model EOS token). max_input_length: Truncate inputs after this many tokens (set 0 to disable). use_vmap: Use the vocabulary mapping file saved in this model Returns: A generator iterator over :class:`ctranslate2.GenerationStepResult` instances. Note: This generation method is not compatible with beam search which requires a complete decoding. """ yield from _generate_tokens( translator.translate_batch, [source], [target_prefix] if target_prefix is not None else None, repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size, disable_unk=disable_unk, suppress_sequences=suppress_sequences, end_token=end_token, max_decoding_length=max_decoding_length, min_decoding_length=min_decoding_length, sampling_topk=sampling_topk, sampling_topp=sampling_topp, sampling_temperature=sampling_temperature, return_scores=return_log_prob, max_input_length=max_input_length, use_vmap=use_vmap, ) def generator_generate_tokens( generator: Generator, prompt: Union[List[str], List[List[str]]], max_batch_size: int = 0, batch_type: str = "examples", *, max_length: int = 512, min_length: int = 0, sampling_topk: int = 1, sampling_topp: float = 1, sampling_temperature: float = 1, return_log_prob: bool = False, repetition_penalty: float = 1, no_repeat_ngram_size: int = 0, disable_unk: bool = False, suppress_sequences: Optional[List[List[str]]] = None, end_token: Optional[Union[str, List[str], List[int]]] = None, static_prompt: Optional[List[str]] = None, cache_static_prompt: bool = True, callback: Callable[[GenerationStepResult], bool] = None, ) -> Iterable[GenerationStepResult]: """Yields tokens as they are generated by the model. Arguments: prompt: Batch of start tokens. If the decoder starts from a special start token like , this token should be added to this input. max_batch_size: The maximum batch size. batch_type: Whether :obj:`max_batch_size` is the number of "examples" or "tokens". max_length: Maximum generation length. min_length: Minimum generation length. sampling_topk: Randomly sample predictions from the top K candidates. sampling_topp: Keep the most probable tokens whose cumulative probability exceeds this value. sampling_temperature: Sampling temperature to generate more random samples. return_log_prob: Include the token log probability in the result. repetition_penalty: Penalty applied to the score of previously generated tokens (set > 1 to penalize). no_repeat_ngram_size: Prevent repetitions of ngrams with this size (set 0 to disable). disable_unk: Disable the generation of the unknown token. suppress_sequences: Disable the generation of some sequences of tokens. end_token: Stop the decoding on one these tokens (defaults to the model EOS token). static_prompt: If the model expects a static prompt (a.k.a. system prompt) it can be set here to simplify the inputs and optionally cache the model state for this prompt to accelerate future generations. cache_static_prompt: Cache the model state after the static prompt and reuse it for future generations using the same static prompt. callback: Optional function that is called for each generated token when obj:`beam_size` is 1. If the callback function returns ``True``, the decoding will stop for this batch index. Returns: A generator iterator over :class:`ctranslate2.GenerationStepResult` instances. Note: This generation method is not compatible with beam search which requires a complete decoding. """ if len(prompt) > 0 and isinstance(prompt[0], str): prompt = [prompt] yield from _generate_tokens( generator.generate_batch, prompt, max_batch_size=max_batch_size, batch_type=batch_type, repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size, disable_unk=disable_unk, suppress_sequences=suppress_sequences, end_token=end_token, max_length=max_length, min_length=min_length, sampling_topk=sampling_topk, sampling_topp=sampling_topp, sampling_temperature=sampling_temperature, return_scores=return_log_prob, static_prompt=static_prompt, cache_static_prompt=cache_static_prompt, include_prompt_in_result=False, callback=callback, ) async def generator_async_generate_tokens( generator: Generator, prompt: Union[List[str], List[List[str]]], max_batch_size: int = 0, batch_type: str = "examples", *, max_length: int = 512, min_length: int = 0, sampling_topk: int = 1, sampling_topp: float = 1, sampling_temperature: float = 1, return_log_prob: bool = False, repetition_penalty: float = 1, no_repeat_ngram_size: int = 0, disable_unk: bool = False, suppress_sequences: Optional[List[List[str]]] = None, end_token: Optional[Union[str, List[str], List[int]]] = None, static_prompt: Optional[List[str]] = None, cache_static_prompt: bool = True, callback: Callable[[GenerationStepResult], bool] = None, ) -> AsyncIterable[GenerationStepResult]: """Yields tokens asynchronously as they are generated by the model. Arguments: prompt: Batch of start tokens. If the decoder starts from a special start token like , this token should be added to this input. max_batch_size: The maximum batch size. batch_type: Whether :obj:`max_batch_size` is the number of "examples" or "tokens". max_length: Maximum generation length. min_length: Minimum generation length. sampling_topk: Randomly sample predictions from the top K candidates. sampling_topp: Keep the most probable tokens whose cumulative probability exceeds this value. sampling_temperature: Sampling temperature to generate more random samples. return_log_prob: Include the token log probability in the result. repetition_penalty: Penalty applied to the score of previously generated tokens (set > 1 to penalize). no_repeat_ngram_size: Prevent repetitions of ngrams with this size (set 0 to disable). disable_unk: Disable the generation of the unknown token. suppress_sequences: Disable the generation of some sequences of tokens. end_token: Stop the decoding on one of these tokens (defaults to the model EOS token). static_prompt: If the model expects a static prompt (a.k.a. system prompt) it can be set here to simplify the inputs and optionally cache the model state for this prompt to accelerate future generations. cache_static_prompt: Cache the model state after the static prompt and reuse it for future generations using the same static prompt. callback: Optional function that is called for each generated token when obj:`beam_size` is 1. If the callback function returns ``True``, the decoding will stop for this batch index. Returns: An async generator iterator over :class:`ctranslate2.GenerationStepResult` instances. Note: This generation method is not compatible with beam search which requires a complete decoding. """ if len(prompt) > 0 and isinstance(prompt[0], str): prompt = [prompt] async for step_result in AsyncGenerator( generator.generate_batch, prompt, max_batch_size=max_batch_size, batch_type=batch_type, repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size, disable_unk=disable_unk, suppress_sequences=suppress_sequences, end_token=end_token, max_length=max_length, min_length=min_length, sampling_topk=sampling_topk, sampling_topp=sampling_topp, sampling_temperature=sampling_temperature, return_scores=return_log_prob, static_prompt=static_prompt, cache_static_prompt=cache_static_prompt, include_prompt_in_result=False, callback=callback, ): yield step_result class AsyncGenerator: def __init__(self, process_func, *args, **kwargs): self.queue = asyncio.Queue() self.shutdown_event = threading.Event() self.iterator_task = None self.process_func = process_func self.args = args self.kwargs = kwargs async def producer(self): # Data generation logic here for step_result in _generate_tokens( self.process_func, *self.args, **self.kwargs ): await self.queue.put(step_result) await asyncio.sleep(0.0001) # asyc sleep otherwise this doesn't yield any result if self.shutdown_event.is_set(): break await self.queue.put(None) def __aiter__(self): self.iterator_task = asyncio.create_task(self.producer()) return self async def __anext__(self): if self.shutdown_event.is_set(): raise StopAsyncIteration try: item = await self.queue.get() if item is None: self.shutdown_event.set() raise StopAsyncIteration return item except asyncio.CancelledError: self.shutdown_event.set() raise StopAsyncIteration def _generate_tokens(process_func, *args, **kwargs): step_results = queue.Queue() generator_closed = threading.Event() user_callback = kwargs.get("callback", None) if user_callback is None: user_callback = lambda step_result: False def _callback(step_result): user_callback_result = user_callback(step_result) step_results.put(step_result) return generator_closed.is_set() or user_callback_result kwargs.update( { "asynchronous": True, "beam_size": 1, "callback": _callback, } ) async_results = process_func(*args, **kwargs) def _catch_exception(): try: for result in async_results: result.result() except Exception as e: step_results.put(e) step_results.put(None) thread = threading.Thread(target=_catch_exception, daemon=True) thread.start() while True: step_result = step_results.get() if step_result is None: break if isinstance(step_result, Exception): raise step_result try: yield step_result except GeneratorExit: generator_closed.set() break # Wait for the job to terminate before exiting. thread.join() def _process_iterable(process_func, iterables, max_batch_size, batch_type, **kwargs): if max_batch_size < 1: raise ValueError("max_batch_size must be >= 1") if len(iterables) == 1: iterable = iterables[0] else: iterable = itertools.zip_longest(*iterables) kwargs.update( { "max_batch_size": max_batch_size, "batch_type": batch_type, "asynchronous": True, } ) read_batch_size = max_batch_size * 16 if max_batch_size > 1 else max_batch_size queue = collections.deque() for streams in _batch_iterator(iterable, read_batch_size, batch_type): queue.extend(process_func(*streams, **kwargs)) while queue and queue[0].done(): yield queue.popleft().result() while queue: yield queue.popleft().result() def _batch_iterator(iterable, batch_size, batch_type): streams = None cur_batch_size = 0 for example in iterable: if not isinstance(example, tuple): example = (example,) if streams is None: streams = tuple([] for _ in example) for batch, element in zip(streams, example): if element is None and len(streams) > 1: raise ValueError("Input iterables do not have the same length") batch.append(element) if batch_type == "examples": cur_batch_size += 1 elif batch_type == "tokens": cur_batch_size += len(example[0]) else: raise ValueError("Invalid batch type %s" % batch_type) if cur_batch_size >= batch_size: yield streams streams = None cur_batch_size = 0 if streams is not None: yield streams