/** * \file dnn/test/common/tile_repeat.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #pragma once #include "megdnn/oprs.h" namespace megdnn { namespace test { namespace tile_repeat { struct Arg { TensorShape times, src, dst; Arg(TensorShape times, TensorShape src) : times(times), src(src) { dst = src; for (size_t i = 0; i < src.ndim; ++i) { dst[i] *= times[i]; } } TileRepeatBase::Param param() { TileRepeatBase::Param param; param.times = times; return param; } }; inline std::vector get_args() { std::vector args; args.emplace_back(TensorShape{3}, TensorShape{10000}); args.emplace_back(TensorShape{1, 1}, TensorShape{200, 300}); args.emplace_back(TensorShape{1, 3}, TensorShape{200, 300}); args.emplace_back(TensorShape{2, 1}, TensorShape{200, 300}); args.emplace_back(TensorShape{2, 3}, TensorShape{200, 300}); for (unsigned mask = 0; mask < 32; ++mask) { auto b = [mask](unsigned bit) { return (mask >> bit) & 1; }; args.emplace_back( TensorShape{b(0) + 1, b(1) + 1, b(2) + 1, b(3) + 1, b(4) + 1}, TensorShape{3, 4, 5, 6, 7}); } for (size_t i = 1; i < 10; ++i) for (size_t j = 1; j < 10; ++j) { args.emplace_back(TensorShape{i, j}, TensorShape{3, 4}); } return args; } } // namespace tile_repeat } // namespace test } // namespace megdnn // vim: syntax=cpp.doxygen