#pragma once #include "op.h" namespace ctranslate2 { namespace ops { class SoftMax : public UnaryOp { public: SoftMax(bool log = false); using UnaryOp::operator(); void operator()(StorageView& x) const; void operator()(const StorageView& x, StorageView& y) const override; void operator()(const StorageView& x, const StorageView& lengths, StorageView& y) const; void operator()(const StorageView& x, const StorageView* lengths, StorageView& y) const; private: template void compute(const StorageView& input, const StorageView* lengths, StorageView& output) const; bool _log; }; class LogSoftMax : public SoftMax { public: LogSoftMax(); }; } }