/*************************************************************************************************** * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** */ #include #include #include #include #include "cutlass/cutlass.h" #include "cutlass/arch/memory.h" #include "cutlass/arch/memory_sm75.h" #include "cutlass/gemm/device/gemm_complex.h" #include "cutlass/util/command_line.h" #include "cutlass/util/host_tensor.h" #include "cutlass/util/reference/host/gemm_complex.h" #include "cutlass/util/reference/device/gemm_complex.h" #include "cutlass/util/reference/host/tensor_reduce.h" #include "cutlass/util/reference/host/tensor_compare.h" #include "cutlass/util/reference/host/tensor_norm.h" #include "cutlass/util/reference/host/tensor_copy.h" #include "cutlass/util/reference/device/tensor_fill.h" #include "cutlass/util/reference/host/tensor_fill.h" #include "cutlass/util/reference/host/error_metrics.h" #include "cutlass/util/tensor_view_io.h" #include "cutlass/layout/matrix.h" #include "cutlass/epilogue/thread/linear_combination.h" ///////////////////////////////////////////////////////////////////////////////////////////////// #include "gemm_with_softmax.h" ///////////////////////////////////////////////////////////////////////////////////////////////// #define TRACE(x) { std::cout << "gemm_softmax.cu:" << __LINE__ << " " << x << std::endl; } ///////////////////////////////////////////////////////////////////////////////////////////////// enum class Disposition { kPassed, kIncorrect, kNotVerified }; ///////////////////////////////////////////////////////////////////////////////////////////////// // Command line options parsing struct Options { bool help; cutlass::gemm::GemmCoord problem_size; int batch_count; int iterations; unsigned seed; float alpha; float beta; bool verification_enabled; float tolerance; Options(): help(false), problem_size({16, 24, 64}), batch_count(16), iterations(20), seed(2022), alpha(1), beta(0), verification_enabled(true), tolerance(1e-5f) { } bool valid() { return true; } // Parses the command line void parse(int argc, char const **args) { cutlass::CommandLine cmd(argc, args); if (cmd.check_cmd_line_flag("help")) { help = true; } cmd.get_cmd_line_argument("m", problem_size.m()); cmd.get_cmd_line_argument("n", problem_size.n()); cmd.get_cmd_line_argument("k", problem_size.k()); cmd.get_cmd_line_argument("batch_count", batch_count); cmd.get_cmd_line_argument("alpha", alpha); cmd.get_cmd_line_argument("beta", beta); cmd.get_cmd_line_argument("iterations", iterations); cmd.get_cmd_line_argument("verify", verification_enabled); cmd.get_cmd_line_argument("seed", seed); cmd.get_cmd_line_argument("tolerance", tolerance); } /// Prints the usage statement. std::ostream & print_usage(std::ostream &out) const { out << "35_gemm_softmax example\n\n" << " This example uses the CUTLASS Library to compute GEMM + Softmax for arbitrary problem sizes.\n\n" << "Options:\n\n" << " --help If specified, displays this usage statement.\n\n" << " --m= GEMM M dimension\n" << " --n= GEMM N dimension\n" << " --k= GEMM K dimension\n" << " --batch_count= Batch number\n" << " --alpha= Epilogue scalar alpha\n" << " --beta= Epilogue scalar beta\n\n" << " --seed= Random number seed (1*)\n\n" << " --iterations= Number of profiling iterations to perform (0 to disable profiling).\n\n" << " --verify= If true, performs reference calculation.\n\n" << " --tolerance Error tolerance\n" ; out << "\n\nExamples:\n\n" << "$ ./examples/35_gemm_softmax/35_gemm_softmax --m=1024 --n=512 \\\n" << " --alpha=2 --beta=0.707 \n\n"; return out; } /// Returns true if the environment and Toolkit support this bool supported(bool verbose = true) const { // Ampere Tensor Core operations exposed with mma.sync and ldmatrix are first available // in CUDA 11.0. // // CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples. if (!(__CUDACC_VER_MAJOR__ >= 11)) { if (verbose) { std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; } return false; } cudaDeviceProp props; cudaError_t error = cudaGetDeviceProperties(&props, 0); if (error != cudaSuccess) { if (verbose) { std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; } return false; } if (!((props.major * 10 + props.minor) >= 80)) { if (verbose) { std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80." << std::endl; } return false; } return true; } }; ///////////////////////////////////////////////////////////////////////////////////////////////// struct Testbed { // // Type definitions // using ElementA = cutlass::half_t; using ElementB = cutlass::half_t; using ElementC = cutlass::half_t; using ElementCompute = float; using ElementD = ElementC; using ElementSoftmax = ElementC; using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; using OperatorClass = cutlass::arch::OpClassTensorOp; using ArchTag = cutlass::arch::Sm80; // ApplyShape impacts the final Softmax performance a lot. // Set ApplyShape::kColumn to be the next multiple of 32 number that is after // (gemm_N / alignment). // Set ApplyShape::kRow to max(1, 128 / ApplyShape::kColumn). using ApplyShape = cutlass::MatrixShape<1, 1024>; static int const kStages = 3; /// Linear scaling operator using EpilogueFunctorOp = cutlass::epilogue::thread::LinearCombination< ElementC, 128 / cutlass::sizeof_bits::value, ElementCompute, ElementCompute >; using GemmSoftmax = cutlass::GemmSoftmax< ElementA, LayoutA, ElementB, LayoutB, ElementC, ElementCompute, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueFunctorOp, kStages, ApplyShape >; using ElementNorm = typename GemmSoftmax::ElementNorm; using ElementSum = typename GemmSoftmax::ElementSum; using LayoutC = typename GemmSoftmax::LayoutC; using LayoutN = typename GemmSoftmax::LayoutN; using LayoutS = typename GemmSoftmax::LayoutS; using MatrixCoord = typename LayoutC::TensorCoord; // // Data members // Options const &options; cutlass::HostTensor reference_N; cutlass::DeviceAllocation block_A; cutlass::DeviceAllocation block_B; cutlass::DeviceAllocation block_C; cutlass::DeviceAllocation block_D; cutlass::DeviceAllocation block_Ref; cutlass::DeviceAllocation block_Softmax; cutlass::DeviceAllocation block_Norm; cutlass::DeviceAllocation block_Sum; int block_num = (options.problem_size.n() + GemmSoftmax::ThreadblockShape::kN - 1) / GemmSoftmax::ThreadblockShape::kN; cutlass::gemm::GemmCoord problem = options.problem_size; int64_t lda = LayoutA::packed({problem.m(), problem.k()}).stride(0); int64_t ldb = LayoutB::packed({problem.k(), problem.n()}).stride(0); int64_t ldc = LayoutC::packed({problem.m(), problem.n()}).stride(0); // fixed rowmajor for norm and sum int64_t ldn = problem.m(); int64_t lds = ldn; int64_t total_elements_A_per_batch = problem.m() * problem.k(); int64_t total_elements_B_per_batch = problem.k() * problem.n(); int64_t total_elements_C_per_batch = problem.m() * problem.n(); int64_t total_elements_D_per_batch = problem.m() * problem.n(); int64_t total_elements_partial_norm_per_batch = block_num * problem.m(); int64_t total_elements_A = total_elements_A_per_batch * options.batch_count; int64_t total_elements_B = total_elements_B_per_batch * options.batch_count; int64_t total_elements_C = total_elements_C_per_batch * options.batch_count; int64_t total_elements_D = total_elements_D_per_batch * options.batch_count; int64_t total_elements_partial_norm = total_elements_partial_norm_per_batch * options.batch_count; // // Methods // Testbed( Options const &options_ ): options(options_) { reference_N.reset({options.problem_size.m(), 1}, false); } /// Run Disposition run() { Disposition disposition = Disposition::kNotVerified; // // Initialize the workspace // initialize(); // // Launch device kernel // cutlass::Status status = cutlass::Status::kSuccess; status = execute_device_kernel(); if (status != cutlass::Status::kSuccess) { std::cerr << "Device execution failed." << std::endl; return disposition; } cudaError_t result = cudaDeviceSynchronize(); if (result != cudaSuccess) { std::cerr << "Device synchronize failed with error " << cudaGetErrorString(result) << std::endl; return disposition; } // // Verify // if (options.verification_enabled) { bool passed = verify(); if (passed) { disposition = Disposition::kPassed; } else { disposition = Disposition::kIncorrect; } } // // Profiling // if (options.iterations) { profile(); } return disposition; } /// Random initialization void initialize() { block_A.reset(total_elements_A); block_B.reset(total_elements_B); block_C.reset(total_elements_C); block_D.reset(total_elements_D); block_Softmax.reset(total_elements_D); block_Ref.reset(total_elements_D_per_batch); block_Norm.reset(total_elements_partial_norm); block_Sum.reset(total_elements_partial_norm); cutlass::reference::device::BlockFillRandomUniform( block_A.get(), total_elements_A, options.seed, ElementA(5), ElementA(-5), 0); cutlass::reference::device::BlockFillRandomUniform( block_B.get(), total_elements_B, options.seed + 1, ElementB(5), ElementB(-5), 0); cutlass::reference::device::BlockFillRandomUniform( block_C.get(), total_elements_C, options.seed + 2, ElementC(5), ElementC(-5), 0); cutlass::reference::device::BlockFillRandomUniform( block_D.get(), total_elements_D, options.seed + 3, ElementD(5), ElementD(-5), 0); cutlass::reference::device::BlockFillRandomUniform( block_Ref.get(), total_elements_D_per_batch, options.seed + 3, ElementD(5), ElementD(-5), 0); cutlass::reference::device::BlockFillRandomUniform( block_Softmax.get(), total_elements_D, options.seed + 3, ElementSoftmax(5), ElementSoftmax(-5), 0); cutlass::reference::host::TensorFill( reference_N.host_view(), ElementNorm() ); } cutlass::Status execute_device_kernel() { cutlass::Status status = cutlass::Status::kSuccess; // // Setup arguments // GemmSoftmax::Arguments args( options.problem_size, options.batch_count, {block_A.get(), lda}, {block_B.get(), ldb}, {block_C.get(), ldc}, {block_D.get(), ldc}, { ElementCompute(options.alpha), ElementCompute(options.beta) }, {block_Norm.get(), ldn}, {block_Sum.get(), lds}, {block_Softmax.get(), ldc}, total_elements_A_per_batch, total_elements_B_per_batch, total_elements_C_per_batch, total_elements_D_per_batch, total_elements_partial_norm_per_batch, total_elements_partial_norm_per_batch, total_elements_D_per_batch ); // // Launch // GemmSoftmax gemm_softmax; // Initialize status = gemm_softmax.initialize(args); if (status != cutlass::Status::kSuccess) { return status; } // Run status = gemm_softmax(); return status; } template bool verify_tensor(std::vector vector_Input, \ std::vector vector_Input_Ref) { int64_t size = (vector_Input.size() < vector_Input_Ref.size()) ? vector_Input.size() : vector_Input_Ref.size(); float abs_tol = options.tolerance; float rel_tol = options.tolerance; for (int64_t i = 0; i < size; ++i) { float diff = (float)(vector_Input.at(i) - vector_Input_Ref.at(i)); float abs_diff = fabs(diff); float abs_ref = fabs((float)vector_Input_Ref.at(i)); float relative_diff = abs_ref > abs_tol ? abs_diff / abs_ref : 0; if ( (isnan(abs_diff) || isinf(abs_diff)) || (abs_diff > rel_tol && relative_diff > rel_tol)) { printf("diff = %f, {%f, %f}.\n", abs_diff, (float)(vector_Input.at(i)), (float)(vector_Input_Ref.at(i))); return false; } } return true; } /// Verifies the reference matches bool verify() { LayoutA layout_A(lda); LayoutB layout_B(ldb); LayoutC layout_C(ldc); LayoutN Layout_N(ldn); LayoutS Layout_S(lds); MatrixCoord extent_A{problem.m(), problem.k()}; MatrixCoord extent_B{problem.k(), problem.n()}; MatrixCoord extent_C{problem.m(), problem.n()}; for (int batch_idx = 0; batch_idx < options.batch_count; batch_idx++) { cutlass::TensorView view_A(block_A.get() + total_elements_A_per_batch * batch_idx, layout_A, extent_A); cutlass::TensorView view_B(block_B.get() + total_elements_B_per_batch * batch_idx, layout_B, extent_B); cutlass::TensorView view_C(block_C.get() + total_elements_C_per_batch * batch_idx, layout_C, extent_C); cutlass::TensorView view_Ref_device(block_Ref.get(), layout_C, extent_C); cutlass::reference::device::GemmComplex< ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ElementCompute, ElementCompute >( problem, options.alpha, view_A, cutlass::ComplexTransform::kNone, view_B, cutlass::ComplexTransform::kNone, options.beta, view_C, view_Ref_device, ElementCompute(0) ); // Copy reference results to host memory for verification std::vector matrix_D_Ref(layout_C.capacity(extent_C)); cutlass::device_memory::copy_to_host(matrix_D_Ref.data(), block_Ref.get(), matrix_D_Ref.size()); cutlass::TensorView view_Ref(matrix_D_Ref.data(), layout_C, extent_C); std::vector matrix_Softmax_Ref(layout_C.capacity(extent_C)); cutlass::TensorView view_Softmax_Ref(matrix_Softmax_Ref.data(), layout_C, extent_C); // Copy computed results to host memory std::vector matrix_D(layout_C.capacity(extent_C)); cutlass::device_memory::copy_to_host(matrix_D.data(), block_D.get() + total_elements_D_per_batch * batch_idx, matrix_D.size()); std::vector matrix_Softmax(layout_C.capacity(extent_C)); cutlass::device_memory::copy_to_host(matrix_Softmax.data(), block_Softmax.get() + total_elements_D_per_batch * batch_idx, matrix_Softmax.size()); // Compute the norm for (int m = 0; m < options.problem_size.m(); ++m) { reference_N.at({m, 0}) = view_Ref.ref().at({m, 0}); for (int n = 1; n < options.problem_size.n(); ++n) { reference_N.at({m, 0}) = std::max(reference_N.at({m, 0}), ElementNorm(view_Ref.ref().at({m, n}))); } } // Compute softmax for (int m = 0; m < options.problem_size.m(); ++m) { float sum = float(); for (int n = 0; n < options.problem_size.n(); ++n) { sum += std::exp( float(view_Ref.ref().at({m, n})) - float(reference_N.at({m, 0})) ); } float inv_sum = float(1.0f / sum); for (int n = 0; n < options.problem_size.n(); ++n) { view_Softmax_Ref.ref().at({m, n}) = ElementSoftmax( std::exp( float(view_Ref.ref().at({m, n})) - float(reference_N.at({m, 0})) ) * inv_sum ); } } // Verification checks - set any of these to 'true' to override the verification checks. bool verified_D = false; bool verified_Softmax = false; // Verify softmax output if (!verified_D) { verified_D = verify_tensor(matrix_D, matrix_D_Ref); } if (!verified_Softmax) { verified_Softmax = verify_tensor(matrix_Softmax, matrix_Softmax_Ref); } if (!verified_D || !verified_Softmax) { std::cerr << "Verification check failed for tensor Softmax at batch " << batch_idx << "\n"; // Summarize which checks failed if (!verified_D) { std::cerr << "Verification of D tensor failed\n"; } if (!verified_Softmax) { std::cerr << "Verification of Softmax tensor failed\n"; } return false; } } return true; } /// Profiles bool profile() { // // Profile // cutlass::Status status = cutlass::Status::kSuccess; cudaError_t result; cudaEvent_t events[2]; int const kIterations = options.iterations; for (cudaEvent_t &evt : events) { result = cudaEventCreate(&evt); if (result != cudaSuccess) { std::cerr << "cudaEventCreate failed with error " << cudaGetErrorString(result) << std::endl; return false; } } result = cudaEventRecord(events[0]); if (result != cudaSuccess) { std::cerr << "cudaEventRecord() failed with error " << cudaGetErrorString(result) << std::endl; return false; } for (int iter = 0; iter < kIterations; ++iter) { status = execute_device_kernel(); if (status != cutlass::Status::kSuccess) { std::cerr << "Device execution failed." << std::endl; return false; } } result = cudaEventRecord(events[1]); if (result != cudaSuccess) { std::cerr << "cudaEventRecord() failed with error " << cudaGetErrorString(result) << std::endl; return false; } result = cudaDeviceSynchronize(); if (result != cudaSuccess) { std::cerr << "cudaDeviceSynchronize() failed with error " << cudaGetErrorString(result) << std::endl; return false; } float elapsed_ms = 0; result = cudaEventElapsedTime(&elapsed_ms, events[0], events[1]); if (result != cudaSuccess) { std::cerr << "cudaEventElapsedTime() failed with error " << cudaGetErrorString(result) << std::endl; return false; } for (cudaEvent_t &evt : events) { result = cudaEventDestroy(evt); if (result != cudaSuccess) { std::cerr << "cudaEventDestroy() failed with error " << cudaGetErrorString(result) << std::endl; return false; } } int64_t flops = int64_t(options.problem_size.m()) * options.problem_size.n() * options.problem_size.k() * 2; int64_t bytes = (sizeof(ElementD) * 2 + sizeof(ElementSoftmax)) * options.problem_size.m() * options.problem_size.n(); double gflops_per_second = double(flops) * kIterations * options.batch_count / double(elapsed_ms / 1000.0f) / double(1.0e9); double gbytes_per_second = double(bytes) * kIterations * options.batch_count / double(elapsed_ms / 1000.0f) / double(1 << 30); double elapsed_ms_per_iter = double(elapsed_ms) / kIterations; std::cout << " Problem: " << options.problem_size.m() << "-by-" << options.problem_size.n() << "-by-" << options.problem_size.k() << ", batch size: " << options.batch_count << std::endl; std::cout << " Runtime: " << elapsed_ms_per_iter << " ms\n" << std::endl; std::cout << " GFLOPs: " << gflops_per_second << " GFLOPs" << std::endl; std::cout << "Memory bandwidth: " << gbytes_per_second << " GiB/s" << std::endl; return true; } }; ///////////////////////////////////////////////////////////////////////////////////////////////// int main(int argc, const char **argv) { // Options parsing Options options; options.parse(argc, argv); if (options.help) { options.print_usage(std::cout) << std::endl; return 0; } if (!options.supported()) { return 0; } // Run Testbed testbed(options); Disposition disposition = testbed.run(); std::cout << std::endl; switch (disposition) { case Disposition::kPassed: std::cout << "Passed" << std::endl; break; case Disposition::kIncorrect: std::cout << "Incorrect" << std::endl; break; case Disposition::kNotVerified: std::cout << "Not verified" << std::endl; break; } return (disposition == Disposition::kPassed ? 0 : -1); } /////////////////////////////////////////////////////////////////////////////////////////////////