/** * \file dnn/test/rocm/argmxx.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "test/rocm/fixture.h" #include "megdnn/oprs.h" #include "test/common/checker.h" #include "test/common/rng.h" namespace { using namespace megdnn; using namespace test; class ArgmxxRNG final : public RNG { public: void gen(const TensorND& tensor) override { auto offset = tensor.layout.span().low_elem; auto nr_elems = tensor.layout.span().dist_elem(); #define cb(DType) \ if (tensor.layout.dtype == DType()) { \ using ctype = typename DTypeTrait::ctype; \ auto ptr = tensor.ptr(); \ for (size_t i = 0; i < nr_elems; ++i) { \ ptr[offset + i] = i; \ } \ COMPAT_RANDOM(ptr + offset, ptr + offset + nr_elems); \ return; \ } MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb); #undef cb megdnn_throw(ssprintf("Unsupported DType: %s", tensor.layout.dtype.name())); } }; template void test_argmxx(Handle* handle) { Checker checker(handle); checker.set_dtype(1, dtype::Int32()); using Param = typename Argmxx::Param; ArgmxxRNG rng; checker.set_rng(0, &rng); for (size_t axis = 0; axis < 4; ++axis) { Param param; param.axis = axis; checker.set_param(param) .set_dtype(0, dtype::Float32()) .execs({{2, 3, 4, 5}, {}}); checker.set_param(param) .set_dtype(0, dtype::Float16()) .execs({{2, 3, 4, 5}, {}}); } checker.set_dtype(0, dtype::Float32()); Param param; param.axis = 1; checker.set_param(param); // 1-step checker.execs({{2, 64, 32}, {}}); // 2-step checker.execs({{2, 192, 32}, {}}); // 3-step checker.execs({{2, 4333, 32}, {}}); // single reduce checker.execs({{2, 1, 1}, {}}); checker.execs({{2, 1 + 1, 1}, {}}); checker.execs({{2, 2048 + 1, 1}, {}}); checker.execs({{2, 2048 * 2048 + 1, 1}, {}}); checker.execs({{2, 1 + 1, 31}, {}}); checker.execs({{2, 16 + 1, 31}, {}}); checker.execs({{2, 16 * 16 + 1, 31}, {}}); checker.execs({{2, 16 * 16 * 16 + 1, 31}, {}}); checker.execs({{2, 16 * 16 * 16 * 16 + 1, 31}, {}}); checker.execs({{3, 256 * 256 + 1, 2}, {}}); checker.execs({{3, 128 * 128 + 1, 3}, {}}); checker.execs({{3, 64 * 64 + 1, 7}, {}}); checker.execs({{3, 32 * 32 + 1, 15}, {}}); checker.execs({{3, 512, 500}, {}}); // very large reduce checker.execs({{1, 4194304, 1}, {}}); } } // anonymous namespace namespace megdnn { namespace test { TEST_F(ROCM, ARGMAX) { test_argmxx(handle_rocm()); } TEST_F(ROCM, ARGMIN) { test_argmxx(handle_rocm()); } } // namespace test } // namespace megdnn // vim: syntax=cpp.doxygen