/** * \file dnn/test/rocm/matrix_mul.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "hcc_detail/hcc_defs_prologue.h" #include "test/rocm/fixture.h" #include "test/common/checker.h" #include "test/common/matrix_mul.h" #include "src/rocm/utils.h" namespace megdnn { namespace test { TEST_F(ROCM, MATRIX_MUL) { Checker checker(handle_rocm()); using Param = MatrixMul::Param; size_t m = 12, n = 16, k = 20; //! result error for Int8x8x32, not test correctness std::vector dtypes{DNN_INC_FLOAT16(dtype::Float16() MEGDNN_COMMA) dtype::Float32() /*, dtype::Int32()*/}; for (auto dtype : dtypes) { for (unsigned mask = 0; mask < 4; ++mask) { Param param; param.transposeA = mask & 1; param.transposeB = mask & 2; DType stype = dtype == dtype::Int32() ? dtype::Int8() : dtype; TensorShape A, B; if (param.transposeA) A = TensorShape{k, m}; else A = TensorShape{m, k}; if (param.transposeB) B = TensorShape{n, k}; else B = TensorShape{k, n}; checker.set_param(param) .set_dtype(0, stype) .set_dtype(1, stype) .set_dtype(2, dtype) .set_epsilon( DNN_FLOAT16_SELECT(dtype == dtype::Float16(), false) ? 5e-2 : 5e-3) .execs({A, B, {}}); } } // general tests auto args = matrix_mul::get_matmul_args(); for (auto arg : args) { auto m = arg.m, n = arg.n, k = arg.k; auto mask = arg.mask; Param param; param.transposeA = mask & 1; param.transposeB = mask & 2; TensorShape AS, BS, CS; if (param.transposeA) AS = TensorShape{k, m}; else AS = TensorShape{m, k}; if (param.transposeB) BS = TensorShape{n, k}; else BS = TensorShape{k, n}; CS = TensorShape{m, n}; TensorLayout AL, BL, CL; if (arg.A_stride == 0) { AL = TensorLayout(AS, dtype::Float32()); } else { AL = TensorLayout(AS, {ptrdiff_t(arg.A_stride), 1}, dtype::Float32()); } if (arg.B_stride == 0) { BL = TensorLayout(BS, dtype::Float32()); } else { BL = TensorLayout(BS, {ptrdiff_t(arg.B_stride), 1}, dtype::Float32()); } if (arg.C_stride == 0) { CL = TensorLayout(CS, dtype::Float32()); } else { CL = TensorLayout(CS, {ptrdiff_t(arg.C_stride), 1}, dtype::Float32()); } checker.set_param(param).execl({AL, BL, CL}); } } } // namespace test } // namespace megdnn // vim: syntax=cpp.doxygen