#pragma once #include "ctranslate2/layers/decoder.h" #include "ctranslate2/layers/encoder.h" #include "ctranslate2/models/model.h" #include "ctranslate2/scoring.h" #include "ctranslate2/translation.h" #include "ctranslate2/vocabulary_map.h" namespace ctranslate2 { namespace models { class SequenceToSequenceModel : public Model { public: size_t num_source_vocabularies() const; const Vocabulary& get_source_vocabulary(size_t index = 0) const; const Vocabulary& get_target_vocabulary() const; const VocabularyMap* get_vocabulary_map() const; bool with_source_bos() const { return config["add_source_bos"]; } bool with_source_eos() const { return config["add_source_eos"]; } const std::string* decoder_start_token() const { auto& start_token = config["decoder_start_token"]; return start_token.is_null() ? nullptr : start_token.get_ptr(); } protected: virtual void initialize(ModelReader& model_reader) override; private: std::vector> _source_vocabularies; std::shared_ptr _target_vocabulary; std::shared_ptr _vocabulary_map; void load_vocabularies(ModelReader& model_reader); }; class SequenceToSequenceReplica : public ModelReplica { public: SequenceToSequenceReplica(const std::shared_ptr& model) : ModelReplica(model) { } static std::unique_ptr create_from_model(const Model& model) { return model.as_sequence_to_sequence(); } std::vector score(const std::vector>& source, const std::vector>& target, const ScoringOptions& options = ScoringOptions()); std::vector translate(const std::vector>& source, const std::vector>& target_prefix = {}, const TranslationOptions& options = TranslationOptions()); protected: virtual bool skip_scoring(const std::vector& source, const std::vector& target, const ScoringOptions& options, ScoringResult& result) { (void)source; (void)target; (void)options; (void)result; return false; } virtual bool skip_translation(const std::vector& source, const std::vector& target_prefix, const TranslationOptions& options, TranslationResult& result) { (void)source; (void)target_prefix; (void)options; (void)result; return false; } virtual std::vector run_scoring(const std::vector>& source, const std::vector>& target, const ScoringOptions& options) = 0; virtual std::vector run_translation(const std::vector>& source, const std::vector>& target_prefix, const TranslationOptions& options) = 0; }; class EncoderDecoderReplica : public SequenceToSequenceReplica { public: EncoderDecoderReplica(const std::shared_ptr& model, std::unique_ptr encoder, std::unique_ptr decoder); layers::Encoder& encoder() { return *_encoder; } layers::Decoder& decoder() { return *_decoder; } protected: bool skip_scoring(const std::vector& source, const std::vector& target, const ScoringOptions& options, ScoringResult& result) override; bool skip_translation(const std::vector& source, const std::vector& target_prefix, const TranslationOptions& options, TranslationResult& result) override; std::vector run_scoring(const std::vector>& source, const std::vector>& target, const ScoringOptions& options) override; std::vector run_translation(const std::vector>& source, const std::vector>& target_prefix, const TranslationOptions& options) override; private: std::vector>> make_source_ids(const std::vector>>& source_features, size_t max_length = 0) const; std::vector> make_target_ids(const std::vector>& target, size_t max_length = 0, bool is_prefix = false) const; size_t get_source_length(const std::vector& source, bool include_special_tokens) const; void encode(const std::vector>>& ids, StorageView& memory, StorageView& memory_lengths); const std::shared_ptr _model; const std::unique_ptr _encoder; const std::unique_ptr _decoder; }; } }