#pragma once #include "ctranslate2/layers/attention.h" #include "ctranslate2/layers/flash_attention.h" #include "ctranslate2/layers/common.h" #include "ctranslate2/layers/decoder.h" #include "ctranslate2/layers/encoder.h" #include "ctranslate2/padder.h" namespace ctranslate2 { namespace layers { class FeedForwardNetwork : public Layer { public: FeedForwardNetwork(const models::Model& model, const std::string& scope, const bool pre_norm = true, const ops::ActivationType activation_type = ops::ActivationType::ReLU); void operator()(const StorageView& input, StorageView& output) const; DataType output_type() const override { return _ff2.output_type(); } dim_t output_size() const override { return _ff2.output_size(); } private: const std::unique_ptr _layer_norm; const bool _pre_norm; const ops::ActivationType _activation_type; const Dense _ff1; const std::unique_ptr _ff1_noact; const Dense _ff2; const bool _tensor_parallel; }; class TransformerEncoderLayer : public Layer { public: TransformerEncoderLayer(const models::Model& model, const std::string& scope, const dim_t num_heads, const bool pre_norm = true, const ops::ActivationType activation_type = ops::ActivationType::ReLU, bool use_flash_attention = false); void operator()(const StorageView& input, const StorageView* lengths, StorageView& output, const Padder* padder = nullptr, StorageView* position_bias = nullptr) const; DataType output_type() const override { return _ff.output_type(); } dim_t output_size() const override { return _ff.output_size(); } const AttentionLayer& get_self_attention() const { return *_self_attention; } private: std::unique_ptr _self_attention; const FeedForwardNetwork _ff; }; class TransformerDecoderLayer : public Layer { public: TransformerDecoderLayer(const models::Model& model, const std::string& scope, const dim_t num_heads, const bool pre_norm = true, const ops::ActivationType activation_type = ops::ActivationType::ReLU, const bool use_flash_attention = true, Alibi* alibi = nullptr); void operator()(const StorageView& input, const StorageView* input_lengths, const StorageView* memory, const StorageView* memory_lengths, StorageView* cached_self_attn_keys, StorageView* cached_self_attn_values, StorageView* cached_attn_keys, StorageView* cached_attn_values, StorageView& output, StorageView* attention = nullptr, const Padder* input_padder = nullptr, const Padder* memory_padder = nullptr, bool return_normalized_attention = true, StorageView* position_bias = nullptr, dim_t offset = 0) const; DataType output_type() const override { return _ff.output_type(); } dim_t output_size() const override { return _ff.output_size(); } bool has_cross_attention() const { return bool(_encoder_attention); } const AttentionLayer& get_self_attention() const { return *_self_attention; } private: const std::unique_ptr _self_attention; const std::unique_ptr _shared_layer_norm; const std::unique_ptr _input_layer_norm; const std::unique_ptr _post_attention_layer_norm; const std::unique_ptr _pre_feedforward_layer_norm; const std::unique_ptr _post_feedforward_layer_norm; const std::unique_ptr _encoder_attention; const FeedForwardNetwork _ff; }; class TransformerEncoder : public Encoder { public: TransformerEncoder(const models::Model& model, const std::string& scope); void operator()(const std::vector& ids, const StorageView* lengths, StorageView& output) override; size_t num_input_features() const override { return _embeddings.num_inputs(); } DataType output_type() const override { return _layers.back()->output_type(); } dim_t output_size() const override { return _layers.back()->output_size(); } private: const ParallelEmbeddings _embeddings; const std::unique_ptr _embeddings_scale; const dim_t _num_heads; const ComputeType _compute_type; const std::unique_ptr _layernorm_embedding; const std::unique_ptr _output_norm; const bool _use_flash_attention; const std::vector> _layers; const std::unique_ptr _position_encoder; const bool _tensor_parallel; }; class TransformerDecoder : public Decoder { public: TransformerDecoder(const models::Model& model, const std::string& scope); DecoderState initial_state(bool iterative_decoding = true) const override; bool replicate_state(const std::string& name) const override; void operator()(dim_t step, const StorageView& ids, DecoderState& state, StorageView* logits = nullptr, StorageView* attention = nullptr) override; void operator()(const StorageView& ids, const StorageView& lengths, DecoderState& state, StorageView& logits, StorageView* attention = nullptr) override; void set_alignment_heads(const dim_t layer, const dim_t num_heads_to_average); void set_alignment_heads(const std::vector>& alignment_heads); std::unique_ptr get_layer_alignment_heads(const dim_t layer, const dim_t batch_size) const; virtual bool return_normalized_attention() const { return true; } protected: Dense& output_layer() override { return _proj; } void decode(const StorageView& ids, const StorageView* lengths, dim_t step, DecoderState& state, StorageView* outputs = nullptr, StorageView* attention = nullptr, bool return_logits = true); const dim_t _num_heads; const ComputeType _compute_type; const Embeddings _embeddings; const bool _start_from_zero_embedding; const std::unique_ptr _embeddings_scale; std::unique_ptr _outputs_scale; const std::unique_ptr _layernorm_embedding; const std::unique_ptr _output_norm; const std::unique_ptr _project_in; const std::unique_ptr _project_out; const std::unique_ptr _alibi; const bool _use_flash_attention; const std::vector> _layers; const std::unique_ptr _position_encoder; const bool _with_encoder_attention; std::vector> _alignment_heads; bool _average_alignment_heads; Dense _proj; const dim_t _sliding_window; const bool _tensor_parallel; }; } }