#pragma once #include "op.h" namespace ctranslate2 { namespace ops { class Mul : public BinaryOp { public: void operator()(const StorageView& a, const StorageView& b, StorageView& c) const override; private: template void compute(const StorageView& a, const StorageView& b, StorageView& c) const { c.resize_as(a); if (b.is_scalar()) { primitives::mul(b.data()[0], a.data(), c.data(), c.size()); } else { primitives::mul(a.data(), b.data(), c.data(), c.size()); } } }; } }