#pragma once #include "op.h" namespace ctranslate2 { namespace ops { class GELU : public UnaryOp { public: enum class Approximation { None, Tanh, Sigmoid, }; GELU(const Approximation approximation = Approximation::None); void operator()(const StorageView& x, StorageView& y) const override; private: template void compute(const StorageView& x, StorageView& y) const { switch (_approximation) { case Approximation::None: primitives::gelu(x.data(), y.data(), x.size()); break; case Approximation::Tanh: primitives::gelu_tanh(x.data(), y.data(), x.size()); break; case Approximation::Sigmoid: primitives::gelu_sigmoid(x.data(), y.data(), x.size()); break; } } const Approximation _approximation; }; } }