#pragma once #include "op.h" namespace ctranslate2 { namespace ops { class Sub : public BinaryOp { public: void operator()(const StorageView& a, const StorageView& b, StorageView& c) const override; private: template void compute(const StorageView& a, const StorageView& b, StorageView& c) const { c.resize_as(a); if (b.is_scalar()) { primitives::sub(b.data()[0], a.data(), c.data(), c.size()); } else { primitives::sub(a.data(), b.data(), c.data(), c.size()); } } }; } }