#pragma once #include "op.h" namespace ctranslate2 { namespace ops { class LayerNorm : public TernaryOp { public: LayerNorm(const dim_t axis = -1, const float epsilon = 1e-5); using TernaryOp::operator(); void operator()(const StorageView& beta, const StorageView& gamma, const StorageView& input, StorageView& output) const; void operator()(StorageView& input) const; void operator()(const StorageView& input, StorageView& output) const; private: void operator()(const StorageView* beta, const StorageView* gamma, const StorageView& input, StorageView& output) const; template void compute(const StorageView* beta, const StorageView* gamma, const StorageView& input, const dim_t axis, const dim_t outer_size, const dim_t axis_size, const dim_t inner_size, StorageView& output) const; const dim_t _axis; const float _epsilon; }; } }