/** * \file dnn/test/common/indexing_multi_axis_vec.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #pragma once #include "test/common/opr_proxy.h" namespace megdnn { namespace test { struct OprProxyIndexingMultiAxisVecHelper { size_t axes[TensorLayout::MAX_NDIM]; /*! * \brief OprProxy for indexing multi-vec family oprs * * \param init_axes axes that are indexed */ OprProxyIndexingMultiAxisVecHelper(std::initializer_list init_axes = {}) { size_t i = 0; for (auto ax : init_axes) axes[i++] = ax; } OprProxyIndexingMultiAxisVecHelper(SmallVector init_axes) { size_t i = 0; for (auto ax : init_axes) axes[i++] = ax; } IndexingMultiAxisVec::IndexDesc make_index_desc( const TensorNDArray& tensors) const { megdnn_assert(tensors.size() >= 3); IndexingMultiAxisVec::IndexDesc ret; ret.resize(tensors.size() - 2); for (size_t i = 2; i < tensors.size(); ++i) { ret[i - 2] = {axes[i - 2], tensors[i]}; } return ret; } size_t get_index_ndim(const TensorNDArray& tensors) const { megdnn_assert(tensors.size() >= 3); size_t ndim = 0; for (size_t i = 2; i < tensors.size(); ++i) { ndim = std::max(tensors[i].layout.ndim, ndim); } return ndim; } IndexingMultiAxisVec::IndexDescLayoutOnly make_index_layout( const TensorLayoutArray& layouts) const { megdnn_assert(layouts.size() >= 3); IndexingMultiAxisVec::IndexDescLayoutOnly ret; ret.resize(layouts.size() - 2); for (size_t i = 2; i < layouts.size(); ++i) { ret[i - 2] = {axes[i - 2], layouts[i]}; } return ret; } }; template <> struct OprProxy : public OprProxyIndexingMultiAxisVecHelper { using OprProxyIndexingMultiAxisVecHelper::OprProxyIndexingMultiAxisVecHelper; void exec(IndexingMultiAxisVec* opr, const TensorNDArray& tensors) const { WorkspaceWrapper W( opr->handle(), opr->get_workspace_in_bytes( tensors[1].layout, axes, tensors.size() - 2, get_index_ndim(tensors))); opr->exec(tensors[0], make_index_desc(tensors), tensors[1], W.workspace()); } void deduce_layout(IndexingMultiAxisVec* opr, TensorLayoutArray& layouts) { opr->deduce_layout(layouts[0], make_index_layout(layouts), layouts[1]); } }; template <> struct OprProxy : public OprProxyIndexingMultiAxisVecHelper { using OprProxyIndexingMultiAxisVecHelper::OprProxyIndexingMultiAxisVecHelper; void exec(IndexingIncrMultiAxisVec* opr, const TensorNDArray& tensors) const { WorkspaceWrapper W( opr->handle(), opr->get_workspace_in_bytes( tensors[1].layout, axes, tensors.size() - 2, get_index_ndim(tensors))); opr->exec(tensors[0], tensors[1], make_index_desc(tensors), W.workspace()); } void deduce_layout(IndexingIncrMultiAxisVec*, TensorLayoutArray&) {} }; template <> struct OprProxy : public OprProxyIndexingMultiAxisVecHelper { using OprProxyIndexingMultiAxisVecHelper::OprProxyIndexingMultiAxisVecHelper; void exec(IndexingSetMultiAxisVec* opr, const TensorNDArray& tensors) const { WorkspaceWrapper W( opr->handle(), opr->get_workspace_in_bytes( tensors[1].layout, axes, tensors.size() - 2, get_index_ndim(tensors))); opr->exec(tensors[0], tensors[1], make_index_desc(tensors), W.workspace()); } void deduce_layout(IndexingSetMultiAxisVec*, TensorLayoutArray&) {} }; } // namespace test } // namespace megdnn // vim: syntax=cpp.doxygen