/** * \file dnn/test/common/matrix_mul.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ #include "test/common/matrix_mul.h" #include "src/common/utils.h" #include "test/common/benchmarker.h" #include "test/common/checker.h" using namespace megdnn; using namespace test; constexpr size_t matrix_mul::TestArg::UNSET_STRIDE_VAL; std::vector matrix_mul::get_matmul_args_no_mask() { std::vector args; for (size_t m : {1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 15, 16, 32}) for (size_t n : {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 32}) for (size_t k : {1, 2, 4, 8, 11, 12, 15, 16, 31, 32, 37}) args.emplace_back(m, n, k, 0); for (size_t m : {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}) args.emplace_back(m, m + 1, m + 2, 0); for (size_t mbase : {11}) for (size_t test_case_offset : {64, 256, 512}) { size_t mnk = mbase + test_case_offset; args.emplace_back(mnk, mnk, mnk, 0); return args; } return args; } std::vector matrix_mul::get_matmul_mk_packed_args(size_t nbase) { std::vector args; for (size_t m : {1, 2, 3, 4, 5, 6, 7, 8, 11}) for (size_t n : {1, 2, 3, 4, 5, 8, 12, 16, 24}) for (size_t k : {1, 2, 3, 4, 5, 9, 10, 11}) args.emplace_back(m, n * nbase, k, 0); return args; } std::vector matrix_mul::get_batched_matmul_args_cublaslt() { std::vector args; for (size_t m : {4, 6, 8, 16}) { for (size_t n : {4, 6, 8, 16}) { //[TODO]: the following test case are disabled due to the // cublasLt(version: 10020) produce wrong result when k in [65, 97], // so please uncomment it if the bug is fixed for (size_t k : {32, 64}) { args.emplace_back( m, n, k, 0, TestArg::UNSET_STRIDE_VAL, TestArg::UNSET_STRIDE_VAL, TestArg::UNSET_STRIDE_VAL, 2); } } } return args; } std::vector matrix_mul::get_batched_matmul_args_int8x8x32() { std::vector args; for (size_t m : {1, 2, 3, 4, 5, 8, 64}) { for (size_t n : {1, 2, 3, 4, 5, 8, 64}) { for (size_t k : {1, 2, 3, 4, 5, 8, 64}) { args.emplace_back( m, n, k, 0, TestArg::UNSET_STRIDE_VAL, TestArg::UNSET_STRIDE_VAL, TestArg::UNSET_STRIDE_VAL, 2); } } } return args; } std::vector matrix_mul::get_matmul_args_mask(uint8_t mask) { std::vector args; std::vector args_temp = matrix_mul::get_matmul_args_no_mask(); for (auto arg : args_temp) { arg.mask = mask; args.emplace_back(arg); } // non-contiguous case for (size_t m : {110}) for (size_t n : {119}) for (size_t k : {120}) { // A: (m, k) size_t Astride = mask & 1 ? m + 2 : k + 2; // B: (k, n) size_t Bstride = mask & 2 ? k + 2 : n + 2; size_t Cstride = n * 2 + 2; args.emplace_back(m, n, k, mask, Astride, Bstride, Cstride); } return args; } std::vector matrix_mul::get_matmul_args() { std::vector args; for (size_t mask = 0; mask < 4; ++mask) { std::vector args_temp = matrix_mul::get_matmul_args_mask(mask); for (auto arg : args_temp) args.emplace_back(arg); } return args; } std::vector matrix_mul::get_matmul_args_split_k() { std::vector args = get_matmul_args(); for (auto iter = args.begin(); iter < args.end();) { if (iter->k <= iter->n) { iter = args.erase(iter); } else { iter++; } } return args; } std::vector matrix_mul::get_batched_matmul_args_mask( uint8_t mask) { std::vector args; for (size_t b : {1, 2, 3}) { std::vector args_temp = megdnn::test::matrix_mul::get_matmul_args_mask(mask); for (auto arg : args_temp) { arg.b = b; args.emplace_back(arg); } } return args; } std::vector matrix_mul::get_batched_matmul_args() { std::vector args; for (size_t mask = 0; mask < 4; ++mask) { std::vector args_temp = matrix_mul::get_batched_matmul_args_mask(mask); for (auto arg : args_temp) args.emplace_back(arg); } return args; } std::vector matrix_mul::get_batched_matmul_broadcast_args() { std::vector args; for (size_t mask = 0; mask < 4; ++mask) { std::vector args_temp = matrix_mul::get_batched_matmul_broadcast_args_mask(mask); for (auto arg : args_temp) args.emplace_back(arg); } return args; } std::vector matrix_mul::get_batched_matmul_broadcast_args_mask( uint8_t mask) { std::vector args; std::vector args_temp = matrix_mul::get_batched_matmul_args_mask(mask); for (auto arg : args_temp) { args.emplace_back(arg); args.back().A_batch_stride = 0; } return args; } template void matrix_mul::check_matrix_mul( DType A_dtype, DType B_dtype, DType C_dtype, Handle* handle, const ExecutionPolicyAlgoName& algo, param::MatrixMul::Format format, size_t nbase, float eps, std::vector&& user_args, bool force_deduce_dst, param::MatrixMul::ComputeMode compute_mode) { megdnn_assert(A_dtype.enumv() == B_dtype.enumv()); Checker checker(handle); checker.set_force_deduce_dst(force_deduce_dst); if (!algo.name.empty()) { checker.set_before_exec_callback(AlgoChecker(algo)); } std::unique_ptr rng; checker.set_epsilon(eps); if (A_dtype.enumv() == DTypeEnum::Int8 || A_dtype.enumv() == DTypeEnum::QuantizedS8) { //! use larger rng to check the overflow rng = std::make_unique(-127, 127); } else if ( A_dtype.enumv() == DTypeEnum::Uint8 || A_dtype.enumv() == DTypeEnum::Quantized8Asymm) { rng = std::make_unique(128.f); } else if (A_dtype.enumv() == DTypeEnum::Int16) { rng = std::make_unique(-32767, 32767); } else if (A_dtype.enumv() == DTypeEnum::Float16) { rng = std::make_unique(2.f); //! if fp16 not set eps, default 1e-3, we just set it to 1e-2 if (eps < 1e-2) { checker.set_epsilon(1e-2); } } if (rng) { checker.set_rng(0, rng.get()).set_rng(1, rng.get()); } //! return expect if stride == -1, stride otherwise auto stride_val = [](size_t stride, size_t expect) -> size_t { if (stride == TestArg::UNSET_STRIDE_VAL) { return expect; } else { return stride; } }; constexpr static bool batched = std::is_same::value; using Param = MatrixMul::Param; std::vector args; if (user_args.empty()) { if (format == param::MatrixMul::Format::DEFAULT) { if (batched) { args = matrix_mul::get_batched_matmul_args(); } else { args = matrix_mul::get_matmul_args(); } } else { megdnn_assert(!batched, "BatchedMatrixMul does not support MK4/MK8"); args = matrix_mul::get_matmul_mk_packed_args(nbase); } } else { args = user_args; } size_t pack_size = MatrixMulForward::pack_size(format); for (auto& arg : args) { size_t m = arg.m, n = arg.n, k = arg.k; if (handle->type() == Handle::HandleType::CUDA) { //! NOTE: cublas can only process 4B aligned 8-bit input matrix; bool is_dt_8bit = A_dtype.enumv() == DTypeEnum::Int8 || A_dtype.enumv() == DTypeEnum::QuantizedS8 || A_dtype.enumv() == DTypeEnum::Uint8 || A_dtype.enumv() == DTypeEnum::Quantized8Asymm; if (is_dt_8bit && ((m % 4 != 0) || (n % 4 != 0))) { continue; } } Param param; param.transposeA = arg.mask & 0x1; param.transposeB = arg.mask & 0x2; param.compute_mode = compute_mode; param.format = format; checker.set_dtype(0, A_dtype).set_dtype(1, B_dtype).set_dtype(2, C_dtype); size_t A0 = m, A1 = k, B0 = k, B1 = n; TensorShape A, B; if (param.transposeA) { std::swap(A0, A1); } if (param.transposeB) { std::swap(B0, B1); } ptrdiff_t A_stride = arg.A_stride, B_stride = arg.B_stride, C_stride = arg.C_stride, A_batch_stride = arg.A_batch_stride, B_batch_stride = arg.B_batch_stride, C_batch_stride = arg.C_batch_stride; A_stride = stride_val(A_stride, A1); B_stride = stride_val(B_stride, B1); C_stride = stride_val(C_stride, n); A_batch_stride = stride_val(A_batch_stride, A0 * A_stride); B_batch_stride = stride_val(B_batch_stride, B0 * B_stride); C_batch_stride = stride_val(C_batch_stride, m * C_stride); checker.set_param(param); if (format == param::MatrixMul::Format::DEFAULT) { if (batched) { checker.execl( {TensorLayout{ {arg.b, A0, A1}, {A_batch_stride, A_stride, 1}, A_dtype}, TensorLayout{ {arg.b, B0, B1}, {B_batch_stride, B_stride, 1}, B_dtype}, TensorLayout{ {arg.b, m, n}, {C_batch_stride, C_stride, 1}, C_dtype}}); } else { checker.execl( {TensorLayout{{A0, A1}, {A_stride, 1}, A_dtype}, TensorLayout{{B0, B1}, {B_stride, 1}, B_dtype}, TensorLayout{{m, n}, {C_stride, 1}, C_dtype}}); } } else { //! ignore non-contiguous, only DEFAULT format support //! non-contiguous input checker.execs({{A0, A1, pack_size, pack_size}, {B0, B1, pack_size}, {}}); } } } void matrix_mul::check_batched_matrix_mul( DType A_dtype, DType B_dtype, DType C_dtype, Handle* handle, const ExecutionPolicyAlgoName& algo, float eps, std::vector&& args, bool force_deduce_dst) { check_matrix_mul( A_dtype, B_dtype, C_dtype, handle, algo, param::MatrixMul::Format::DEFAULT, 8, eps, std::forward(args), force_deduce_dst); } void matrix_mul::check_matrix_mul( DType A_dtype, DType B_dtype, DType C_dtype, Handle* handle, const ExecutionPolicyAlgoName& algo, param::MatrixMul::Format format, size_t nbase, float eps, bool force_deduce_dst) { check_matrix_mul( A_dtype, B_dtype, C_dtype, handle, algo, format, nbase, eps, {}, force_deduce_dst); } #if MEGDNN_WITH_BENCHMARK std::vector matrix_mul::get_benchmark_matmul_args() { std::vector args; args.emplace_back(256, 12 * 24, 256, 0); //////////////////////// gemv ////////////////////////// for (size_t M : {8, 64, 112, 256}) { for (size_t K : {8, 64, 112, 256}) { args.emplace_back(M, 1, K, 0); } } //////////////////////// gemm ////////////////////////// for (size_t M : {8, 64, 112, 256}) { for (size_t K : {8, 16, 32, 64, 112, 256}) { for (size_t N : {8, 64, 112, 256}) { args.emplace_back(M, N, K, 0); } } } return args; } std::vector matrix_mul::get_benchmark_matmul_mk_packed_args( size_t nbase) { std::vector args; for (size_t m : {2, 4, 8, 16, 24, 32, 64}) for (size_t n : {1, 2, 3, 4, 8, 16, 32, 64}) for (size_t k : {2, 4, 8, 16, 24, 32, 64}) args.emplace_back(m, n * nbase, k, 0); return args; } void matrix_mul::benchmark_with_contrast( Handle* handle, const std::vector& args, DType A_dtype, DType B_dtype, DType C_dtype, const char* algo, param::MatrixMul::Format format, DType contrast_A_dtype, DType contrast_B_dtype, DType contrast_C_dtype, const char* contrast_algo, param::MatrixMul::Format contrast_format) { using Param = MatrixMul::Param; megdnn_assert(A_dtype.enumv() == B_dtype.enumv()); megdnn_assert(contrast_A_dtype.enumv() == contrast_B_dtype.enumv()); Benchmarker benchmark_contrast(handle); Benchmarker benchmark(handle); constexpr size_t RUNS = 50; if (algo) { benchmark.set_before_exec_callback(AlgoChecker(algo)); } if (contrast_algo) { benchmark_contrast.set_before_exec_callback( AlgoChecker(contrast_algo)); } benchmark.set_dtype(0, A_dtype).set_dtype(1, B_dtype).set_dtype(2, C_dtype); benchmark.set_times(RUNS); benchmark_contrast.set_dtype(0, contrast_A_dtype) .set_dtype(1, contrast_B_dtype) .set_dtype(2, contrast_C_dtype); benchmark_contrast.set_times(RUNS); auto bench = [](Benchmarker& benchmark, Param param, param::MatrixMul::Format format, size_t m, size_t n, size_t k, size_t pack_size) -> float { param.format = format; benchmark.set_param(param); float used_algo = 1.0; if (format == param::MatrixMul::Format::DEFAULT) { size_t A0 = m * pack_size, A1 = k * pack_size, B0 = k * pack_size, B1 = n; TensorShape A, B; if (param.transposeA) { std::swap(A0, A1); } if (param.transposeB) { std::swap(B0, B1); } used_algo = benchmark.execs({{A0, A1}, {B0, B1}, {}}) / RUNS; } else { size_t A0 = m, A1 = k, B0 = k, B1 = n; if (param.transposeA) { std::swap(A0, A1); } if (param.transposeB) { std::swap(B0, B1); } used_algo = benchmark.execs( {{A0, A1, pack_size, pack_size}, {B0, B1, pack_size}, {}}) / RUNS; } return used_algo; }; size_t mk_size = MatrixMulForward::pack_size(format); size_t mk_size_contrast = MatrixMulForward::pack_size(contrast_format); size_t pack_size = std::max(mk_size, mk_size_contrast); for (auto& arg : args) { Param param; param.transposeA = arg.mask & 0x1; param.transposeB = arg.mask & 0x2; auto used_contrast = bench(benchmark_contrast, param, contrast_format, arg.m, arg.n, arg.k, pack_size); auto used_algo = bench(benchmark, param, format, arg.m, arg.n, arg.k, pack_size); float computations = 2.f * arg.m * pack_size * arg.k * pack_size * arg.n * 1e-6; printf("run: {(%zu, %zu) x (%zu, %zu)} contrast: %f ms %f Gflops %s: " "%f " "ms " "%f Gflops " "speedup: %f \n", arg.m * pack_size, arg.k * pack_size, arg.k * pack_size, arg.n, used_contrast, computations / used_contrast, algo, used_algo, computations / used_algo, used_contrast / used_algo); } } #endif // vim: syntax=cpp.doxygen