/*************************************************************************************************** * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** */ #include #include #include "cutlass/cutlass.h" #include "cutlass/gemm/device/gemm.h" #include "cutlass/epilogue/thread/linear_combination_relu.h" #include "cutlass/util/host_tensor.h" #include "cutlass/util/reference/device/gemm.h" #include "cutlass/util/reference/host/tensor_compare.h" #include "cutlass/util/reference/host/tensor_copy.h" #include "cutlass/util/reference/host/tensor_fill.h" #include "cutlass/util/tensor_view_io.h" #include "helper.h" // The code section below describes datatype for input, output matrices and computation between // elements in input matrices. using ElementAccumulator = float; // <- data type of accumulator using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations using ElementInputA = cutlass::half_t; // <- data type of elements in input matrix A using ElementInputB = cutlass::half_t; // <- data type of elements in input matrix B using ElementOutput = float; // <- data type of elements in output matrix D // Note that if the output is column major, the bias has to be per row. i.e. every row has different bias. // If the output is row major, the bias has to be per column, i.e. every column has different bias. // Below list some other notices: // // Note this example only works for ColumnMajor output because // 1) we only have row major epilogue. // 2) we swap A and B if the output is column major then we can still use the // row major epilogue. // 3) Mx1 bias vector becomes 1xM after the swapping/transposing. // 4) we can use the existing OutputIterator to load 1xM bias vector. using LayoutInputA = cutlass::layout::ColumnMajor; using LayoutInputB = cutlass::layout::ColumnMajor; using LayoutOutput = cutlass::layout::ColumnMajor; // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM using MMAOp = cutlass::arch::OpClassTensorOp; // This code section describes CUDA SM architecture number using SmArch = cutlass::arch::Sm75; // This code section describes the tile size a thread block will compute using ShapeMMAThreadBlock = cutlass::gemm::GemmShape<128, 128, 32>; // <- threadblock tile M = 128, N = 128, K = 32 // This code section describes tile size a warp will compute using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>; // <- warp tile M = 64, N = 64, K = 32 // This code section describes the size of MMA op using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8 // This code section describes how threadblocks are scheduled on GPU using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? // Define the epilogue operation as LinearCombinationRelu. This is approximately equal to // // d_ij = max(0, alpha * sum_k(a_ik * b_kj) + c_ij ) // using EpilogueOp = cutlass::epilogue::thread::LinearCombinationRelu< ElementOutput, // <- data type of output matrix 128 / cutlass::sizeof_bits::value, // <- this is the number of elements per // vectorized memory access. For half // precision, it's 8 elements. This becomes // the vector width of math instructions in // epilogue too ElementAccumulator, // <- data type of accumulator ElementComputeEpilogue, // <- data type for alpha in linear combination function cutlass::epilogue::thread::ScaleType::NoBetaScaling>; // <- alpha x C + bias // Number of pipelines you want to use constexpr int NumStages = 2; using Gemm = cutlass::gemm::device::Gemm; int run() { const int length_m = 5120; const int length_n = 4096; const int length_k = 4096; // Create a tuple of problem size for matrix multiplication cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k); // Initialize tensors using CUTLASS helper functions cutlass::HostTensor tensor_a( problem_size.mk()); // <- Create matrix A with dimensions M x K cutlass::HostTensor tensor_b( problem_size.kn()); // <- Create matrix B with dimensions K x N cutlass::HostTensor tensor_c_bias( {problem_size.m(), 1}); // <- Create matrix C with dimensions M x 1 cutlass::HostTensor tensor_d( problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from // CUTLASS kernel cutlass::HostTensor tensor_ref_d( problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from // reference kernel // Fill input and output matrices on host using CUTLASS helper functions cutlass::reference::host::TensorFillRandomUniform( tensor_a.host_view(), 1, ElementInputA(4), ElementInputA(-4), 0); // <- Fill matrix A on host with uniform-distribution random data cutlass::reference::host::TensorFillRandomUniform( tensor_b.host_view(), 1, ElementInputB(4), ElementInputB(-4), 0); // <- Fill matrix B on host with uniform-distribution random data cutlass::reference::host::TensorFillRandomUniform( tensor_c_bias.host_view(), 1, ElementOutput(4), ElementOutput(-4), 0); // <- Fill matrix C on host with uniform-distribution random data cutlass::reference::host::TensorFill( tensor_d.host_view()); // <- fill matrix D on host with zeros cutlass::reference::host::TensorFill( tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros // Copy data from host to GPU tensor_a.sync_device(); tensor_b.sync_device(); tensor_c_bias.sync_device(); tensor_d.sync_device(); tensor_ref_d.sync_device(); // Initialize alpha for dot product computation ElementComputeEpilogue alpha = ElementComputeEpilogue(1); // Split K dimension into 1 partitions int split_k_slices = 1; // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch // instantiated CUTLASS kernel typename Gemm::Arguments arguments{ problem_size, // <- problem size of matrix multiplication tensor_a.device_ref(), // <- reference to matrix A on device tensor_b.device_ref(), // <- reference to matrix B on device {tensor_c_bias.device_data(), 0}, // <- the C matrix is treated as the bias vector. We can enable the GEMM // to project away the N dimension by setting the stride to zero. tensor_d.device_ref(), // <- reference to matrix D on device {alpha}, // <- alpha split_k_slices}; // <- k-dimension split factor // Using the arguments, query for extra workspace required for matrix multiplication computation size_t workspace_size = Gemm::get_workspace_size(arguments); // Allocate workspace memory cutlass::device_memory::allocation workspace(workspace_size); // Instantiate CUTLASS kernel depending on templates Gemm gemm_op; // Check the problem size is supported or not cutlass::Status status = gemm_op.can_implement(arguments); CUTLASS_CHECK(status); // Initialize CUTLASS kernel with arguments and workspace pointer status = gemm_op.initialize(arguments, workspace.get()); CUTLASS_CHECK(status); // Launch initialized CUTLASS kernel status = gemm_op(); CUTLASS_CHECK(status); // // Create instantiation for device reference gemm kernel // cutlass::reference::device::Gemm gemm_device_reference; // Launch device reference to compute strictly the product A * B gemm_device_reference( problem_size, alpha, tensor_a.device_ref(), tensor_b.device_ref(), 0, tensor_ref_d.device_ref()); // Wait for kernels to finish cudaDeviceSynchronize(); // Copy output data from CUTLASS and reference kernel to host for comparison tensor_d.sync_host(); tensor_ref_d.sync_host(); // Compute bias + relu in host code for (int i = 0; i < problem_size.m(); ++i) { for (int j = 0; j < problem_size.n(); ++j) { tensor_ref_d.at({i, j}) = std::max( ElementOutput(0), ElementOutput(tensor_ref_d.at({i, j}) + tensor_c_bias.at({i, 0})) ); } } // Check if output from CUTLASS kernel and reference kernel are equal or not std::cout << (cutlass::reference::host::TensorEquals(tensor_d.host_view(), tensor_ref_d.host_view()) ? "Passed" : "Failed") << std::endl; CUTLASS_CHECK(status); return 0; } int main() { bool notSupported = false; // Turing Tensor Core operations exposed with mma.sync are first available in CUDA 10.2. // // CUTLASS must be compiled with CUDA 10.1 Toolkit to run these examples. if (!(__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2))) { std::cerr << "Turing Tensor Core operations must be compiled with CUDA 10.2 Toolkit or later." << std::endl; notSupported = true; } cudaDeviceProp props; cudaError_t error = cudaGetDeviceProperties(&props, 0); if (error != cudaSuccess) { std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; return -1; } if (!(props.major * 10 + props.minor >= 75)) { std::cerr << "Turing Tensor Ops must be run on a machine with compute capability at least 75." << std::endl; notSupported = true; } if (notSupported) { // Returning zero so this test passes on older Toolkits. Its actions are no-op. return 0; } return run(); }