/** * \file dnn/test/x86/matrix_mul.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ #include "test/x86/fixture.h" #include "src/x86/utils.h" #include "test/common/benchmarker.h" #include "test/common/checker.h" #include "test/common/matrix_mul.h" #include "test/common/rng.h" #include "test/common/task_record_check.h" using namespace megdnn; using namespace test; using namespace megdnn::x86; TEST_F(X86, MATRIX_MUL_RECORD) { TaskRecordChecker checker(0); using Param = MatrixMul::Param; auto args = matrix_mul::get_matmul_args(); auto arg = args[0]; auto m = arg.m, n = arg.n, k = arg.k; auto mask = arg.mask; Param param; param.transposeA = mask & 1; param.transposeB = mask & 2; TensorShape AS, BS, CS; if (param.transposeA) AS = TensorShape{k, m}; else AS = TensorShape{m, k}; if (param.transposeB) BS = TensorShape{n, k}; else BS = TensorShape{k, n}; CS = TensorShape{m, n}; TensorLayout AL, BL, CL; AL = TensorLayout(AS, dtype::Float32()); BL = TensorLayout(BS, dtype::Float32()); CL = TensorLayout(CS, dtype::Float32()); checker.set_param(param); checker.execl({AL, BL, CL}); } #if MEGDNN_X86_WITH_VNNI TEST_F(X86, MATRIX_MUL_VNNI_8X8X32) { matrix_mul::check_matrix_mul( dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, handle(), "X86_INT8X8X32_VNNI"); } #endif #if MEGDNN_X86_WITH_MKL_DNN TEST_F(X86, MATRIX_MUL_MKLDNN_8X8X32) { if (is_supported(SIMDType::VNNI)) { matrix_mul::check_matrix_mul( dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, handle(), "X86_INT8X8X32_MKLDNN"); } else { std::cout << "can not do mkldnn matmul check for no vnni support" << std::endl; matrix_mul::check_matrix_mul( dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, handle()); } } #endif //! FIXME: need to add tests of GEMV and QUINT8 TEST_F(X86, MATRIX_MUL_AVX2_8X8X32) { matrix_mul::check_matrix_mul( dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, handle(), "X86_INT8X8X32_AVX2_2X4X16", param::MatrixMul::Format::DEFAULT, 8, 1e-3, false); matrix_mul::check_matrix_mul( dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, handle(), "X86_INT8X8X32_AVX2_4X16X2", param::MatrixMul::Format::DEFAULT, 8, 1e-3, false); } TEST_F(X86, MATRIX_MUL_AVX2_8X8X16) { matrix_mul::check_matrix_mul( dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, handle(), "X86_INT8X8X16_AVX2", param::MatrixMul::Format::DEFAULT, 8, 1e-3, false); } TEST_F(X86, MATRIX_MUL_SSE_8X8X16) { matrix_mul::check_matrix_mul( dtype::Int8{}, dtype::Int8{}, dtype::Int16{}, handle(), "X86_INT8X8X16_SSE", param::MatrixMul::Format::DEFAULT, 8, 1e-3, false); } TEST_F(X86, MATRIX_MUL_SSE_8X8X32) { matrix_mul::check_matrix_mul( dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, handle(), "X86_INT8X8X32_SSE_4X8X2", param::MatrixMul::Format::DEFAULT, 8, 1e-3, false); } #if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM TEST_F(X86, MATRIX_MUL_MKL_PACKA) { matrix_mul::check_matrix_mul( dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), "X86_F32_MKL_PACKA"); } #endif TEST_F(X86, MATRIX_MUL_AVX2_MK8_8X8) { matrix_mul::check_matrix_mul( dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), "X86_F32MK8_8X8", param::MatrixMul::Format::MK8, 1, 1e-3, false); } TEST_F(X86, MATRIX_MUL_AVX2_6x16) { matrix_mul::check_matrix_mul( dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, handle(), "X86_F32_6x16", param::MatrixMul::Format::DEFAULT, 1, 1e-3, false); } #if MEGDNN_WITH_BENCHMARK TEST_F(X86, BENCHMARK_MATRIX_MUL_AVX2_MK8_8X8) { auto args = matrix_mul::get_benchmark_matmul_mk_packed_args(8); matrix_mul::benchmark_with_contrast( handle(), args, dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, "X86_F32MK8_8X8", param::MatrixMul::Format::MK8, dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, "X86_F32_BLAS"); } TEST_F(X86, BENCHMARK_MATRIX_MUL_AVX2_6x16) { auto args = matrix_mul::get_benchmark_matmul_mk_packed_args(8); matrix_mul::benchmark_with_contrast( handle(), args, dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, "X86_F32_6x16", param::MatrixMul::Format::DEFAULT, dtype::Float32{}, dtype::Float32{}, dtype::Float32{}, "X86_F32_BLAS"); } TEST_F(X86, BENCHMARK_MATRIX_MUL_8X8X32) { constexpr size_t RUNS = 50; auto rng = std::make_unique(-127, 127); #if MEGDNN_X86_WITH_VNNI Benchmarker benchmarker_vnni(handle()); benchmarker_vnni.set_times(RUNS) .set_dtype(0, dtype::Int8{}) .set_dtype(1, dtype::Int8{}) .set_dtype(2, dtype::Int32{}) .set_display(false) .set_rng(0, rng.get()) .set_rng(1, rng.get()); benchmarker_vnni.set_before_exec_callback( AlgoChecker("X86_INT8X8X32_VNNI")); #endif #if MEGDNN_X86_WITH_MKL_DNN Benchmarker benchmarker_mkldnn(handle()); benchmarker_mkldnn.set_times(RUNS) .set_dtype(0, dtype::Int8{}) .set_dtype(1, dtype::Int8{}) .set_dtype(2, dtype::Int32{}) .set_display(false) .set_rng(0, rng.get()) .set_rng(1, rng.get()); benchmarker_mkldnn.set_before_exec_callback( AlgoChecker("X86_INT8X8X32_MKLDNN")); #endif Benchmarker benchmarker_avx2_4x16x2(handle()); benchmarker_avx2_4x16x2.set_display(false) .set_times(RUNS) .set_dtype(0, dtype::Int8{}) .set_dtype(1, dtype::Int8{}) .set_dtype(2, dtype::Int32{}) .set_rng(0, rng.get()) .set_rng(1, rng.get()); benchmarker_avx2_4x16x2.set_before_exec_callback( AlgoChecker("X86_INT8X8X32_AVX2_4X16X2")); Benchmarker benchmarker_avx2_4x16x2_8816(handle()); benchmarker_avx2_4x16x2_8816.set_display(false) .set_times(RUNS) .set_dtype(0, dtype::Int8{}) .set_dtype(1, dtype::Int8{}) .set_dtype(2, dtype::Int16{}) .set_rng(0, rng.get()) .set_rng(1, rng.get()); benchmarker_avx2_4x16x2_8816.set_before_exec_callback( AlgoChecker("X86_INT8X8X16_AVX2")); Benchmarker benchmarker_sse_4x8x2_8816(handle()); benchmarker_sse_4x8x2_8816.set_display(false) .set_times(RUNS) .set_dtype(0, dtype::Int8{}) .set_dtype(1, dtype::Int8{}) .set_dtype(2, dtype::Int16{}) .set_rng(0, rng.get()) .set_rng(1, rng.get()); benchmarker_sse_4x8x2_8816.set_before_exec_callback( AlgoChecker("X86_INT8X8X16_SSE")); Benchmarker benchmarker_avx2_2x4x16(handle()); benchmarker_avx2_2x4x16.set_display(false) .set_times(RUNS) .set_dtype(0, dtype::Int8{}) .set_dtype(1, dtype::Int8{}) .set_dtype(2, dtype::Int32{}) .set_rng(0, rng.get()) .set_rng(1, rng.get()); benchmarker_avx2_2x4x16.set_before_exec_callback( AlgoChecker("X86_INT8X8X32_AVX2_2X4X16")); Benchmarker benchmarker_sse_4x8x2(handle()); benchmarker_sse_4x8x2.set_display(false) .set_times(RUNS) .set_dtype(0, dtype::Int8{}) .set_dtype(1, dtype::Int8{}) .set_dtype(2, dtype::Int32{}) .set_rng(0, rng.get()) .set_rng(1, rng.get()); benchmarker_sse_4x8x2.set_before_exec_callback( AlgoChecker("X86_INT8X8X32_SSE_4X8X2")); Benchmarker benchmarker_float(handle()); benchmarker_float.set_display(false) .set_times(RUNS) .set_rng(0, rng.get()) .set_rng(1, rng.get()); benchmarker_float.set_before_exec_callback(AlgoChecker("X86_F32_BLAS")); auto run = [&](size_t M, size_t N, size_t K) { const float computations = 2.f * M * K * N * 1e-6; std::cout << "run : {" << M << "," << N << "," << K << "} "; auto float_used = benchmarker_float.exec({{M, K}, {K, N}, {}}) / RUNS; std::cout << "float: " << float_used << " ms, " << computations / float_used << " Gflops, "; #if MEGDNN_X86_WITH_VNNI if (is_supported(SIMDType::VNNI)) { auto vnni_used = benchmarker_vnni.exec({{M, K}, {K, N}, {}}) / RUNS; std::cout << "vnni: " << vnni_used << " ms, " << computations / vnni_used << " Gflops, " << "speed_up " << float_used / vnni_used << ", "; } #endif #if MEGDNN_X86_WITH_MKL_DNN if (is_supported(SIMDType::VNNI)) { auto mkldnn_used = benchmarker_mkldnn.exec({{M, K}, {K, N}, {}}) / RUNS; std::cout << "mkldnn: " << mkldnn_used << " ms, " << computations / mkldnn_used << " Gflops, " << "speed_up " << float_used / mkldnn_used << ", "; } #endif if (is_supported(SIMDType::AVX2)) { auto avx2_used_4x16x2 = benchmarker_avx2_4x16x2.exec({{M, K}, {K, N}, {}}) / RUNS; auto avx2_used_2x4x16 = benchmarker_avx2_2x4x16.exec({{M, K}, {K, N}, {}}) / RUNS; std::cout << "avx2_k2: " << avx2_used_4x16x2 << " ms, k2 throughput " << computations / avx2_used_4x16x2 << " Gflops, " << "k2_speed_up " << float_used / avx2_used_4x16x2 << ", k16_speed_up " << float_used / avx2_used_2x4x16 << ","; auto avx2_used_4x16x2_8816 = benchmarker_avx2_4x16x2_8816.exec({{M, K}, {K, N}, {}}) / RUNS; std::cout << "avx2_8816: " << avx2_used_4x16x2_8816 << " ms, 8816 throughput " << computations / avx2_used_4x16x2_8816 << " Gflops,"; } if (is_supported(SIMDType::SSE4_1)) { auto sse_used = benchmarker_sse_4x8x2.exec({{M, K}, {K, N}, {}}) / RUNS; std::cout << "sse: " << sse_used << " ms, " << computations / sse_used << " Gflops, " << "speed_up " << float_used / sse_used << ", "; auto sse_used_8816 = benchmarker_sse_4x8x2_8816.exec({{M, K}, {K, N}, {}}) / RUNS; std::cout << "sse_8816: " << sse_used_8816 << " ms, " << computations / sse_used_8816 << " Gflops, "; } std::cout << std::endl; }; run(256, 256, 256); for (size_t M : {8, 64, 112, 256, 512}) { for (size_t K : {8, 16, 32, 64, 112, 256, 512}) { for (size_t N : {8, 64, 112, 256, 512}) { run(M, N, K); } } } } #endif // MEGDNN_WITH_BENCHMARK // vim: syntax=cpp.doxygen