#pragma once #include #include #include "ops/tile.h" #include "storage_view.h" namespace ctranslate2 { inline void split_batch_beam(StorageView& input, dim_t beam_size) { Shape shape = input.shape(); shape.insert(shape.begin() + 1, beam_size); shape[0] /= beam_size; input.reshape(std::move(shape)); } inline void merge_batch_beam(StorageView& input) { Shape shape = input.shape(); shape[0] *= shape[1]; shape.erase(shape.begin() + 1); input.reshape(std::move(shape)); } inline void repeat_batch(StorageView& input, dim_t repeats) { input.expand_dims(1); ops::Tile(/*axis=*/1, repeats)(input); merge_batch_beam(input); } inline bool is_eos(const size_t id, const std::vector& end_ids) { return std::find(end_ids.begin(), end_ids.end(), id) != end_ids.end(); } // Helper class to disable tokens in the model output. class DisableTokens { public: DisableTokens(StorageView& logits, const float disable_value = std::numeric_limits::lowest()); void add(dim_t batch_id, dim_t token_id) { const auto flat_index = batch_id * _vocabulary_size + token_id; if (_logits_data) { // On CPU we directly assign the value. _logits_data[flat_index] = _disable_value; } else { // On GPU we prepare a list of unique index to disable. const auto it = std::lower_bound(_flat_indices.begin(), _flat_indices.end(), flat_index); if (it == _flat_indices.end() || *it != flat_index) _flat_indices.insert(it, flat_index); } } // Disable a token for all batches. void add(dim_t token_id) { for (dim_t batch_id = 0; batch_id < _batch_size; ++batch_id) add(batch_id, token_id); } void apply(); private: StorageView& _logits; float* _logits_data; const float _disable_value; const dim_t _batch_size; const dim_t _vocabulary_size; std::vector _flat_indices; }; // Base class for processing the output logits. class LogitsProcessor { public: virtual ~LogitsProcessor() = default; virtual bool apply_first() const { return false; } virtual void apply(dim_t step, StorageView& logits, DisableTokens& disable_tokens, const StorageView& sequences, const std::vector& batch_offset, const std::vector>* prefix) = 0; protected: dim_t get_batch_index(const dim_t batch_size, const dim_t batch_id, const std::vector& batch_offset) const { const auto beam_size = batch_size / batch_offset.size(); return batch_offset[batch_id / beam_size]; } dim_t get_sample_begin(const dim_t batch_size, const dim_t batch_id, const std::vector& batch_offset, const std::vector>* prefix) const { return prefix ? prefix->at(get_batch_index(batch_size, batch_id, batch_offset)).size() : 0; } }; // Apply a penalty to the score of previously generated tokens. class RepetitionPenalty : public LogitsProcessor { public: RepetitionPenalty(const float penalty); void apply(dim_t step, StorageView& logits, DisableTokens& disable_tokens, const StorageView& sequences, const std::vector& batch_offset, const std::vector>* prefix) override; private: const float _penalty; }; // Prevent repetitions of ngrans with a specific size. class NoRepeatNgram : public LogitsProcessor { public: NoRepeatNgram(const size_t ngram_size); void apply(dim_t step, StorageView& logits, DisableTokens& disable_tokens, const StorageView& sequences, const std::vector& batch_offset, const std::vector>* prefix) override; private: const dim_t _ngram_size; }; // Disable the generation of some sequences of tokens. class SuppressSequences : public LogitsProcessor { public: SuppressSequences(std::vector> sequences); void apply(dim_t step, StorageView& logits, DisableTokens& disable_tokens, const StorageView& sequences, const std::vector& batch_offset, const std::vector>* prefix) override; private: std::vector _ids; std::vector> _sequences; }; // Disable the generation of some tokens. class SuppressTokens : public LogitsProcessor { public: SuppressTokens(std::vector ids); void apply(dim_t step, StorageView& logits, DisableTokens& disable_tokens, const StorageView& sequences, const std::vector& batch_offset, const std::vector>* prefix) override; private: const std::vector _ids; }; // Disable the generation of some tokens at the first unconstrained decoding step. class SuppressTokensBegin : public LogitsProcessor { public: SuppressTokensBegin(std::vector ids); void apply(dim_t step, StorageView& logits, DisableTokens& disable_tokens, const StorageView& sequences, const std::vector& batch_offset, const std::vector>* prefix) override; private: const std::vector _ids; }; }