#pragma once #include "op.h" namespace ctranslate2 { namespace ops { class Transpose : public UnaryOp { public: Transpose() = default; Transpose(const std::vector& perm); void operator()(const StorageView& x, StorageView& y) const override; private: std::vector _perm; template void compute(const StorageView& x, const std::vector& perm, StorageView& y) const { if (x.rank() == 2) { y.resize({x.dim(1), x.dim(0)}); primitives::transpose_2d(x.data(), x.shape().data(), y.data()); } else if (x.rank() == 3) { y.resize({x.dim(perm[0]), x.dim(perm[1]), x.dim(perm[2])}); primitives::transpose_3d(x.data(), x.shape().data(), perm.data(), y.data()); } else if (x.rank() == 4) { y.resize({x.dim(perm[0]), x.dim(perm[1]), x.dim(perm[2]), x.dim(perm[3])}); primitives::transpose_4d(x.data(), x.shape().data(), perm.data(), y.data()); } else { throw std::invalid_argument("Transpose: rank " + std::to_string(x.rank()) + " is not supported, supported ranks are: 2, 3, 4"); } } }; } }