/*************************************************************************************************** * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once #include #include #include #include #include "cutlass/util/host_tensor.h" #include "cutlass/util/tensor_view_io.h" #include "cutlass/util/distribution.h" #include "cutlass/util/reference/host/tensor_fill.h" #include "cutlass/util/reference/host/tensor_copy.h" #include "cutlass/util/reference/host/tensor_compare.h" #include "cutlass/util/reference/host/tensor_norm.h" #include "cutlass/util/reference/device/gemm.h" #include "cutlass/util/reference/device/tensor_relu.h" #include "cutlass/platform/platform.h" #include "cutlass/gemm/gemm.h" #include "cutlass/gemm/device/gemm_universal.h" #include "dual_gemm_common.h" #include "helper.h" #define CHECK_GT(val1, val2) \ if((val1) <= (val2)) \ std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n"; #define CHECK_TRUE(val) \ if(!(val)) \ std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n"; template < typename OutputOp, typename Element, typename Layout> struct TensorEpilogueForEachFunc { /// View type using TensorView = cutlass::TensorView; /// Coordinate in tensor's index space using TensorCoord = typename TensorView::TensorCoord; /// Parameters structure struct Params { // // Data members // TensorView view_x0; TensorView view_x1; TensorView view_y; OutputOp output_op; // // Methods // Params( TensorView view_x0_ = TensorView(), TensorView view_x1_ = TensorView(), TensorView view_y_ = TensorView(), OutputOp output_op_ = OutputOp(typename OutputOp::Params{}) ): view_x0(view_x0_), view_x1(view_x1_), view_y(view_y_), output_op(output_op_) { } }; Params params; CUTLASS_DEVICE TensorEpilogueForEachFunc(Params const ¶ms): params(params) { } CUTLASS_DEVICE void operator()(TensorCoord const &coord) { Element const & x0 = params.view_x0.at(coord); Element const & x1 = params.view_x1.at(coord); Element& y = params.view_y.at(coord); y = params.output_op(x0, x1); } }; template < typename OutputOp, typename Element, typename Layout> void TensorEpilogueForEach( cutlass::TensorView x0, cutlass::TensorView x1, cutlass::TensorView y) { using Func = TensorEpilogueForEachFunc; using Params = typename Func::Params; cutlass::reference::device::TensorForEach( y.extent(), Params(x0, x1, y) ); } //////////////////////////////////////////////////////////////////////////////// template struct NonFusedDualGemmRun { using Gemm0 = Gemm0_; using Gemm1 = Gemm1_; using ElementAccumulator = typename Gemm0::ElementAccumulator; using ElementCompute = typename Gemm0::GemmKernel::Epilogue::OutputOp::ElementCompute; /// Initialization cutlass::Distribution::Kind init_A; cutlass::Distribution::Kind init_B; cutlass::Distribution::Kind init_C; cutlass::Distribution::Kind init_Bias; uint64_t seed; // // Methods // NonFusedDualGemmRun( cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, uint64_t seed_ = 2080 ): init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Bias(init_Bias_), seed(seed_) { } /// Helper to initialize a tensor view template bool initialize_tensor( cutlass::TensorView view, cutlass::Distribution::Kind dist_kind, uint64_t seed) { if (dist_kind == cutlass::Distribution::Uniform) { cutlass::reference::host::TensorFillRandomUniform( view, seed, 2, -2, 0); } else if (dist_kind == cutlass::Distribution::Identity) { cutlass::reference::host::TensorFillIdentity(view); } else if (dist_kind == cutlass::Distribution::Gaussian) { cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); } else if (dist_kind == cutlass::Distribution::Sequential) { cutlass::reference::host::BlockFillSequential( view.data(), view.capacity()); } else if (dist_kind == cutlass::Distribution::AllZeros) { cutlass::reference::host::TensorFill(view, Element(0)); } else if (dist_kind == cutlass::Distribution::AllOnes) { cutlass::reference::host::TensorFill(view, Element(1)); } else { std::cerr << "Not implemented\n"; return false; } return true; } /// Executes one test bool run( cutlass::gemm::GemmCoord problem_size, ElementCompute alpha0 = ElementCompute(1), ElementCompute beta0 = ElementCompute(0), ElementCompute alpha1 = ElementCompute(1), ElementCompute beta1 = ElementCompute(0), bool is_profiling = true, bool relu = false, int warm_ups = 1, int runs = 100) { // // Allocate the GEMM workspace // cutlass::HostTensor< typename Gemm0::ElementA, typename Gemm0::LayoutA> tensor_A0(problem_size.mk()); cutlass::HostTensor< typename Gemm0::ElementB, typename Gemm0::LayoutB> tensor_B0(problem_size.kn()); cutlass::HostTensor< typename Gemm0::ElementC, typename Gemm0::LayoutC> tensor_C0(problem_size.mn()); cutlass::HostTensor< typename Gemm1::ElementC, typename Gemm0::LayoutC> tensor_Bias0({1, problem_size.n()}); cutlass::HostTensor< typename Gemm0::ElementC, typename Gemm0::LayoutC> tensor_D0(problem_size.mn()); cutlass::HostTensor< typename Gemm0::ElementC, typename Gemm0::LayoutC> reference_D0(problem_size.mn()); cutlass::HostTensor< typename Gemm1::ElementB, typename Gemm1::LayoutB> tensor_B1(problem_size.kn()); cutlass::HostTensor< typename Gemm1::ElementC, typename Gemm1::LayoutC> tensor_C1(problem_size.mn()); cutlass::HostTensor< typename Gemm1::ElementC, typename Gemm1::LayoutC> tensor_Bias1({1, problem_size.n()}); cutlass::HostTensor< typename Gemm1::ElementC, typename Gemm1::LayoutC> tensor_D1(problem_size.mn()); cutlass::HostTensor< typename Gemm1::ElementC, typename Gemm1::LayoutC> reference_D1(problem_size.mn()); CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018)); CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017)); CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2014)); CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016)); CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015)); CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2013)); cutlass::reference::host::TensorFill( tensor_D0.host_view()); cutlass::reference::host::TensorFill( tensor_D1.host_view()); cutlass::reference::host::TensorFill( reference_D0.host_view()); cutlass::reference::host::TensorFill( reference_D1.host_view()); tensor_A0.sync_device(); tensor_B0.sync_device(); tensor_C0.sync_device(); tensor_Bias0.sync_device(); tensor_D0.sync_device(); reference_D0.sync_device(); tensor_B1.sync_device(); tensor_C1.sync_device(); tensor_Bias1.sync_device(); tensor_D1.sync_device(); reference_D1.sync_device(); // // Initialize the GEMM operator // int split_k_slices = Gemm0::kSplitKSerial ? 2 : 1; typename Gemm0::Arguments arguments_0{ problem_size, tensor_A0.device_ref(), tensor_B0.device_ref(), {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)}, tensor_D0.device_ref(), {alpha0, beta0}, split_k_slices }; split_k_slices = Gemm1::kSplitKSerial ? 2 : 1; typename Gemm1::Arguments arguments_1{ problem_size, tensor_A0.device_ref(), tensor_B1.device_ref(), {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)}, tensor_D1.device_ref(), {alpha1, beta1}, split_k_slices }; Gemm0 gemm_op_0; Gemm1 gemm_op_1; // Allocate workspace memory cutlass::device_memory::allocation workspace0(gemm_op_0.get_workspace_size(arguments_0)); cutlass::device_memory::allocation workspace1(gemm_op_1.get_workspace_size(arguments_1)); cutlass::Status status = gemm_op_0.initialize(arguments_0, workspace0.get()); CUTLASS_CHECK(status); status = gemm_op_1.initialize(arguments_1, workspace1.get()); CUTLASS_CHECK(status); for(int i = 0; i < warm_ups; i++) { status = gemm_op_0(); CUTLASS_CHECK(status); status = gemm_op_1(); CUTLASS_CHECK(status); } if (is_profiling) { // // Profile the GEMM // cudaEvent_t start, stop1, stop2; cudaEventCreate(&start); cudaEventCreate(&stop1); cudaEventCreate(&stop2); cudaEventRecord(start); for(int i = 0; i < runs; i++) { status = gemm_op_0(); CUTLASS_CHECK(status); } cudaEventRecord(stop1); for(int i = 0; i < runs; i++) { status = gemm_op_1(); CUTLASS_CHECK(status); } cudaEventRecord(stop2); cudaDeviceSynchronize(); float gemm0Time, gemm1Time, totalTime; cudaEventElapsedTime(&gemm0Time, start, stop1); cudaEventElapsedTime(&gemm1Time, stop1, stop2); cudaEventElapsedTime(&totalTime, start, stop2); std::cout << "gemm 0 time " << gemm0Time / (float)runs << " ms\n"; std::cout << "gemm 1 time " << gemm1Time / (float)runs << " ms\n"; std::cout << "Non-fusion GEMM only time " << totalTime / (float)runs << " ms\n"; } tensor_D0.sync_host(); tensor_D1.sync_host(); // // Verify // cutlass::reference::device::Gemm< typename Gemm0::ElementA, typename Gemm0::LayoutA, typename Gemm0::ElementB, typename Gemm0::LayoutB, typename Gemm0::ElementC, typename Gemm0::LayoutC, ElementCompute, ElementAccumulator, typename Gemm0::Operator> reference_gemm_0; cutlass::reference::device::Gemm< typename Gemm1::ElementA, typename Gemm1::LayoutA, typename Gemm1::ElementB, typename Gemm1::LayoutB, typename Gemm1::ElementC, typename Gemm1::LayoutC, ElementCompute, ElementAccumulator, typename Gemm1::Operator> reference_gemm_1; reference_gemm_0( problem_size, alpha0, tensor_A0.device_ref(), tensor_B0.device_ref(), beta0, {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)}, reference_D0.device_ref() ); if(relu) { cutlass::reference::device::TensorReLu(reference_D0.device_view()); } reference_gemm_1( problem_size, alpha1, tensor_A0.device_ref(), tensor_B1.device_ref(), beta1, {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)}, reference_D1.device_ref() ); if(relu) { cutlass::reference::device::TensorReLu(reference_D1.device_view()); } // Wait for kernels to finish cudaDeviceSynchronize(); reference_D0.sync_host(); reference_D1.sync_host(); CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0); CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0); CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0); CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0); bool passed0 = cutlass::reference::host::TensorEquals( reference_D1.host_view(), tensor_D1.host_view()); CHECK_TRUE(passed0); bool passed1 = cutlass::reference::host::TensorEquals( reference_D1.host_view(), tensor_D1.host_view()); CHECK_TRUE(passed1); if (!passed0 || !passed1) { std::stringstream fname; fname << "error_DualGemm_device_nonfused.txt"; std::cerr << "Dumping results in " << fname.str() << "\n"; std::ofstream file(fname.str()); file << "A0 =\n" << tensor_A0.host_view() << "\nB0 =\n" << tensor_B0.host_view() << "\nC0 =\n" << tensor_C0.host_view() << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" << "\nD0 =\n" << tensor_D0.host_view() << "\nB1 =\n" << tensor_B1.host_view() << "\nC1 =\n" << tensor_C1.host_view() << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" << "\n\nReference =\n" << reference_D1.host_view() << "\nComputed =\n" << tensor_D1.host_view(); } return passed0 && passed1; } }; template struct DualFusedGemmRun { using DualGemm = DualGemm_; using ElementAccumulator = typename DualGemm::ElementAccumulator; using ElementCompute = typename DualGemm::DualGemmKernel::Epilogue0::OutputOp::ElementCompute; using EpilogueOutputOp2 = typename DualGemm::EpilogueOutputOp2; /// Initialization cutlass::Distribution::Kind init_A; cutlass::Distribution::Kind init_B; cutlass::Distribution::Kind init_C; cutlass::Distribution::Kind init_Scale; cutlass::Distribution::Kind init_Bias; uint64_t seed; // // Methods // DualFusedGemmRun( cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, uint64_t seed_ = 2080 ): init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { } /// Helper to initialize a tensor view template bool initialize_tensor( cutlass::TensorView view, cutlass::Distribution::Kind dist_kind, uint64_t seed) { if (dist_kind == cutlass::Distribution::Uniform) { cutlass::reference::host::TensorFillRandomUniform( view, seed, 2, -2, 0); } else if (dist_kind == cutlass::Distribution::Identity) { cutlass::reference::host::TensorFillIdentity(view); } else if (dist_kind == cutlass::Distribution::Gaussian) { cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); } else if (dist_kind == cutlass::Distribution::Sequential) { cutlass::reference::host::BlockFillSequential( view.data(), view.capacity()); } else if (dist_kind == cutlass::Distribution::AllZeros) { cutlass::reference::host::TensorFill(view, Element(0)); } else if (dist_kind == cutlass::Distribution::AllOnes) { cutlass::reference::host::TensorFill(view, Element(1)); } else { std::cerr << "Not implemented\n"; return false; } return true; } /// Executes one test bool run( cutlass::gemm::GemmCoord problem_size, ElementCompute alpha0 = ElementCompute(1), ElementCompute beta0 = ElementCompute(1), ElementCompute alpha1 = ElementCompute(1), ElementCompute beta1 = ElementCompute(1), int batch_count = 1, bool broadcast_b1 = false, bool is_profiling = true, bool relu = false, int warm_ups = 1, int runs = 100) { // // Allocate the GEMM workspace // cutlass::HostTensor< typename DualGemm::ElementA, typename DualGemm::LayoutA> tensor_A0( cutlass::platform::is_same::value ? cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.k()) : cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.k())); cutlass::HostTensor< typename DualGemm::ElementB, typename DualGemm::LayoutB0> tensor_B0( cutlass::platform::is_same::value ? cutlass::MatrixCoord(batch_count * problem_size.k(), problem_size.n()) : cutlass::MatrixCoord(problem_size.k(), batch_count * problem_size.n())); cutlass::HostTensor< typename DualGemm::ElementC, typename DualGemm::LayoutC> tensor_C0( cutlass::platform::is_same::value ? cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); cutlass::HostTensor< typename DualGemm::ElementC, typename DualGemm::LayoutScaleBias> tensor_Bias0({batch_count, problem_size.n()}); cutlass::HostTensor< typename DualGemm::ElementC, typename DualGemm::LayoutC> tensor_D0( cutlass::platform::is_same::value ? cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); cutlass::HostTensor< typename DualGemm::ElementC, typename DualGemm::LayoutC> reference_D0( cutlass::platform::is_same::value ? cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); cutlass::HostTensor< typename DualGemm::ElementB, typename DualGemm::LayoutB1> tensor_B1( cutlass::platform::is_same::value ? cutlass::MatrixCoord(batch_count * problem_size.k(), problem_size.n()) : cutlass::MatrixCoord(problem_size.k(), batch_count * problem_size.n())); if (broadcast_b1) { tensor_B1.resize({problem_size.k(), batch_count}); } cutlass::HostTensor< typename DualGemm::ElementC, typename DualGemm::LayoutC> tensor_C1( cutlass::platform::is_same::value ? cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); cutlass::HostTensor< typename DualGemm::ElementC, typename DualGemm::LayoutScaleBias> tensor_Bias1({batch_count, problem_size.n()}); cutlass::HostTensor< typename DualGemm::ElementC, typename DualGemm::LayoutC> tensor_D1( cutlass::platform::is_same::value ? cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); cutlass::HostTensor< typename DualGemm::ElementC, typename DualGemm::LayoutC> tensor_D2( cutlass::platform::is_same::value ? cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); cutlass::HostTensor< typename DualGemm::ElementC, typename DualGemm::LayoutC> reference_D1( cutlass::platform::is_same::value ? cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); cutlass::HostTensor< typename DualGemm::ElementC, typename DualGemm::LayoutC> reference_D2( cutlass::platform::is_same::value ? cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2118)); CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017)); CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2011)); CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2113)); CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015)); CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2012)); cutlass::reference::host::TensorFill( tensor_D0.host_view()); cutlass::reference::host::TensorFill( tensor_D1.host_view()); cutlass::reference::host::TensorFill( tensor_D2.host_view()); cutlass::reference::host::TensorFill( reference_D0.host_view()); cutlass::reference::host::TensorFill( reference_D1.host_view()); cutlass::reference::host::TensorFill( reference_D2.host_view()); tensor_A0.sync_device(); tensor_B0.sync_device(); tensor_C0.sync_device(); tensor_Bias0.sync_device(); tensor_B1.sync_device(); tensor_C1.sync_device(); tensor_Bias1.sync_device(); tensor_D0.sync_device(); tensor_D1.sync_device(); tensor_D2.sync_device(); reference_D0.sync_device(); reference_D1.sync_device(); reference_D2.sync_device(); // // Batch strides (irrelevant when batch_count == 1) // int64_t batch_stride_A = problem_size.m() * problem_size.k(); int64_t batch_stride_B0 = problem_size.k() * problem_size.n(); int64_t batch_stride_B1 = problem_size.k() * problem_size.n(); if (broadcast_b1) { // B1 is a (column) vector batch_stride_B1 = problem_size.k(); } int64_t batch_stride_Bias = problem_size.n(); int64_t batch_stride_D = problem_size.m() * problem_size.n(); // // Initialize the GEMM operator // int split_k_slices = DualGemm::kSplitKSerial ? 2 : 1; typename cutlass::TensorRef nullptr_ref{}; decltype(nullptr_ref) ref_B0, ref_B1; if (beta0 != ElementCompute(0)) { ref_B0 = {tensor_Bias0.device_data(), typename DualGemm::LayoutC::Stride(0)}; } if (beta1 != ElementCompute(0)) { ref_B1 = {tensor_Bias1.device_data(), typename DualGemm::LayoutC::Stride(0)}; } typename DualGemm::Arguments arguments{ (batch_count > 1 ? cutlass::gemm::DualGemmMode::kBatched : cutlass::gemm::DualGemmMode::kGemm), problem_size, tensor_A0.device_ref(), tensor_B0.device_ref(), ref_B0, DualGemm::kStoreD0 ? tensor_D0.device_ref() : nullptr_ref, (broadcast_b1 ? typename DualGemm::TensorRefB1(tensor_B1.device_data(), 0) : tensor_B1.device_ref()), ref_B1, DualGemm::kStoreD1 ? tensor_D1.device_ref() : nullptr_ref, tensor_D2.device_ref(), {alpha0, beta0}, {alpha1, beta1}, {}, split_k_slices, batch_count, batch_stride_A, batch_stride_B0, batch_stride_B1, batch_stride_Bias, batch_stride_D, }; // // Run the GEMM // DualGemm b2b_gemm_op; cutlass::device_memory::allocation workspace(b2b_gemm_op.get_workspace_size(arguments)); cutlass::Status status = b2b_gemm_op.can_implement(arguments); CUTLASS_CHECK(status); status = b2b_gemm_op.initialize(arguments, workspace.get()); CUTLASS_CHECK(status); for(int i = 0; i < warm_ups; i++) { status = b2b_gemm_op(); CUTLASS_CHECK(status); } if (is_profiling) { // // Profile the GEMM // cudaEvent_t start, stop; cudaEventCreate(&start); cudaEventCreate(&stop); cudaEventRecord(start); for(int i = 0; i < runs; i++) { status = b2b_gemm_op(); CUTLASS_CHECK(status); } cudaEventRecord(stop); cudaDeviceSynchronize(); float gemmTime; cudaEventElapsedTime(&gemmTime, start, stop); std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n"; } tensor_D0.sync_host(); tensor_D1.sync_host(); tensor_D2.sync_host(); // // Verify // using GemmUniversal0 = cutlass::gemm::device::GemmUniversal< typename DualGemm::ElementA, typename DualGemm::LayoutA, typename DualGemm::ElementB, typename DualGemm::LayoutB0, typename DualGemm::ElementC, typename DualGemm::LayoutC, ElementAccumulator >; GemmUniversal0 reference_gemm0; typename GemmUniversal0::Arguments args0 { (batch_count > 1 ? cutlass::gemm::GemmUniversalMode::kBatched : cutlass::gemm::GemmUniversalMode::kGemm), problem_size, batch_count, {alpha0, beta0}, tensor_A0.device_data(), tensor_B0.device_data(), tensor_Bias0.device_data(), reference_D0.device_data(), batch_stride_A, batch_stride_B0, batch_stride_Bias, batch_stride_D, tensor_A0.stride(0), tensor_B0.stride(0), 0, // zero stride for the bias vector reference_D0.stride(0), }; status = reference_gemm0.can_implement(args0); CUTLASS_CHECK(status); status = reference_gemm0(args0); CUTLASS_CHECK(status); using GemmUniversal1 = cutlass::gemm::device::GemmUniversal< typename DualGemm::ElementA, typename DualGemm::LayoutA, typename DualGemm::ElementB, typename DualGemm::LayoutB1, typename DualGemm::ElementC, typename DualGemm::LayoutC, ElementAccumulator >; GemmUniversal1 reference_gemm1; typename GemmUniversal1::Arguments args1 { (batch_count > 1 ? cutlass::gemm::GemmUniversalMode::kBatched : cutlass::gemm::GemmUniversalMode::kGemm), problem_size, batch_count, {alpha1, beta1}, tensor_A0.device_data(), tensor_B1.device_data(), tensor_Bias1.device_data(), reference_D1.device_data(), batch_stride_A, batch_stride_B1, batch_stride_Bias, batch_stride_D, tensor_A0.stride(0), (broadcast_b1 ? 0 : tensor_B1.stride(0)), 0, // zero stride for the bias vector reference_D1.stride(0), }; status = reference_gemm1.can_implement(args1); CUTLASS_CHECK(status); status = reference_gemm1(args1); CUTLASS_CHECK(status); if(relu) { cutlass::reference::device::TensorReLu(reference_D0.device_view()); cutlass::reference::device::TensorReLu(reference_D1.device_view()); } TensorEpilogueForEach(reference_D0.device_view(), reference_D1.device_view(), reference_D2.device_view()); cudaDeviceSynchronize(); reference_D0.sync_host(); reference_D1.sync_host(); reference_D2.sync_host(); CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0); CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0); CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D2.host_view()), 0); CHECK_GT(cutlass::reference::host::TensorNorm(reference_D2.host_view()), 0); bool passed_out0 = true; if (DualGemm::kStoreD0) { CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0); passed_out0 = cutlass::reference::host::TensorEquals( reference_D0.host_view(), tensor_D0.host_view()); } CHECK_TRUE(passed_out0); bool passed_out1 = true; if (DualGemm::kStoreD1) { CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0); passed_out1 = cutlass::reference::host::TensorEquals( reference_D1.host_view(), tensor_D1.host_view()); } CHECK_TRUE(passed_out1); bool passed_out2 = cutlass::reference::host::TensorEquals( reference_D2.host_view(), tensor_D2.host_view()); CHECK_TRUE(passed_out2); bool passed = passed_out0 && passed_out1 && passed_out2; if (!passed) { std::stringstream fname; fname << "error_DualGemm_device_fused.txt"; std::cerr << "Dumping results in " << fname.str() << "\n"; std::ofstream file(fname.str()); file << "A0 =\n" << tensor_A0.host_view() << "\nB0 =\n" << tensor_B0.host_view() << "\nC0 =\n" << tensor_C0.host_view() << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" << "\nB1 =\n" << tensor_B1.host_view() << "\nC1 =\n" << tensor_C1.host_view() << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" << "\n\nReference0 =\n" << reference_D0.host_view() << "\nComputed0 =\n" << tensor_D0.host_view() << "\n\nReference1 =\n" << reference_D1.host_view() << "\nComputed1 =\n" << tensor_D1.host_view() << "\n\nReference2 =\n" << reference_D2.host_view() << "\nComputed2 =\n" << tensor_D2.host_view(); } //std::cout << "A0 " << tensor_A0.host_view() << std::endl; // std::cout << "reference_D0 " << reference_D0.host_view() << std::endl; // std::cout << "reference_D1 " << reference_D1.host_view() << std::endl; // std::cout << "reference_D2 " << reference_D2.host_view() << std::endl; //std::cout << "reference_D0 " << reference_D0.host_view() << std::endl; return passed; } }; ////////////////////////////////////////////////////////////////////////////////