#pragma once #include "ctranslate2/layers/encoder.h" #include "ctranslate2/layers/decoder.h" #include "ctranslate2/models/model.h" #include "ctranslate2/encoding.h" #include "ctranslate2/generation.h" #include "ctranslate2/scoring.h" namespace ctranslate2 { namespace models { // Base class for language models. class LanguageModel : public Model { public: LanguageModel(); const Vocabulary& get_vocabulary() const; // The returned cache is thread safe. layers::DecoderStateCache& get_state_cache() const; protected: void initialize(ModelReader& model_reader) override; private: std::shared_ptr _vocabulary; std::shared_ptr _state_cache; }; // Base class for generative language models. class SequenceGeneratorReplica : public ModelReplica { public: SequenceGeneratorReplica(const std::shared_ptr& model) : ModelReplica(model) , _model(model) { } static std::unique_ptr create_from_model(const Model& model) { return model.as_sequence_generator(); } std::vector score(const std::vector>& tokens, const ScoringOptions& options = ScoringOptions()); std::vector generate(const std::vector>& start_tokens, const GenerationOptions& options = GenerationOptions()); StorageView forward(const std::vector>& tokens, const bool return_log_probs); StorageView forward(const std::vector>& ids, const bool return_log_probs); StorageView forward(const StorageView& ids, const StorageView& lengths, const bool return_log_probs); protected: virtual bool skip_scoring(const std::vector& tokens, const ScoringOptions& options, ScoringResult& result) { (void)tokens; (void)options; (void)result; return false; } virtual std::vector run_scoring(const std::vector>& tokens, const ScoringOptions& options) = 0; virtual std::vector run_generation(const std::vector>& start_tokens, const GenerationOptions& options) = 0; virtual StorageView forward(const StorageView& ids, const StorageView& lengths) = 0; private: const std::shared_ptr _model; }; // A model generating sequences with a decoder. class DecoderReplica : public SequenceGeneratorReplica { public: DecoderReplica(const std::shared_ptr& model, std::unique_ptr decoder); protected: bool skip_scoring(const std::vector& tokens, const ScoringOptions& options, ScoringResult& result) override; std::vector run_scoring(const std::vector>& tokens, const ScoringOptions& options) override; std::vector run_generation(const std::vector>& start_tokens, const GenerationOptions& options) override; StorageView forward(const StorageView& ids, const StorageView& lengths) override; private: const std::shared_ptr _model; const std::unique_ptr _decoder; }; // Base class for sequence encoders. class SequenceEncoderReplica : public ModelReplica { public: SequenceEncoderReplica(const std::shared_ptr& model) : ModelReplica(model) , _model(model) { } static std::unique_ptr create_from_model(const Model& model) { return model.as_sequence_encoder(); } EncoderForwardOutput forward(const std::vector>& tokens, const std::vector>& token_type_ids = {}); EncoderForwardOutput forward(const std::vector>& ids, const std::vector>& token_type_ids = {}); EncoderForwardOutput forward(const StorageView& ids, const StorageView& lengths, const std::vector>& token_type_ids = {}); protected: virtual EncoderForwardOutput forward_impl(const StorageView& ids, const StorageView& lengths, const StorageView& token_type_ids) = 0; private: const std::shared_ptr _model; }; // A model encoding sequences with an encoder layer. class EncoderReplica : public SequenceEncoderReplica { public: EncoderReplica(const std::shared_ptr& model, std::unique_ptr encoder); protected: EncoderForwardOutput forward_impl(const StorageView& ids, const StorageView& lengths, const StorageView& token_type_ids) override; private: const std::shared_ptr _model; const std::unique_ptr _encoder; const ops::ActivationType _pooler_activation; const std::unique_ptr _pooler_dense; }; } }