/*************************************************************************************************** * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file \brief GEMM Permute Example. This example computes batched GEMM operations with output results permuted as reshaped tensors. We provide layout plugin as a flexible tool for users to add any customized input/output tensor permute operation, or any other generalized global memory writeout address computation. To add a customized layout, add new class in include/cutlass/layout/permute.h In this example we use several permute operations (permute([0, 2, 1, 3])) In this example, we used Tensor4DPermuteBMM0213 layout to perform Batched GEMM with permute([0, 2, 1, 3]) on BMM whole output tensor, and used Tensor5DPermute20314 layout to perform Normal GEMM with permute([2, 0, 3, 1, 4]) on output matrix. The address computations are performed in compute(col_init, row_init, stride_init, BMM_batch_idx) with {col_permute, row_permute and stride_permute} as new addresses after permute op. (check include/cutlass/layout/permute.h) Tips: 1) Make sure to set batch_stride to zero for BMM permute; also the BMM GEMM should be in mode cutlass::gemm::GemmUniversalMode::kBatched instead of kArray. 2) When the contiguous dimension is touched in permute op (for example [0, 2, 3, 1] for row-major matrix or [1, 0, 2, 3] for column-major), Alignment should be set to 1 for the corresponding matrix. If the last dimension is untouched, one can set Alignment to be larger like 8 in our example. As a result, permute op without touching the unit stride dimension is recommended to obtain the best performance. Examples: # Runs a batched GEMM with 96 batches $ ./examples/39_gemm_permute/39_gemm_permute --problem-count=96 # Runs a batched GEMM with 96 batches (with GEMM-K dimension equal to 1024) $ ./examples/39_gemm_permute/39_gemm_permute --problem-count=96 --k=1024 --verbose=true # Execute batched GEMM and profile with NSight $ nv-nsight-cu-cli ./examples/39_gemm_permute/39_gemm_permute --m=256 --n=192 --k=256 --verbose=true --iterations=1 --reference-check=false */ ///////////////////////////////////////////////////////////////////////////////////////////////// #include #include #include #include #include #include #include "cutlass/cutlass.h" #include "cutlass/gemm/gemm.h" #include "cutlass/gemm/device/gemm_universal.h" #include "cutlass/util/command_line.h" #include "cutlass/util/distribution.h" #include "cutlass/util/device_memory.h" #include "cutlass/util/tensor_view_io.h" #include "cutlass/util/host_tensor.h" #include "cutlass/util/reference/host/gemm_complex.h" #include "cutlass/util/reference/device/gemm_complex.h" #include "cutlass/util/reference/host/tensor_compare.h" #include "cutlass/util/reference/device/tensor_compare.h" #include "cutlass/util/reference/host/tensor_copy.h" #include "cutlass/util/reference/device/tensor_fill.h" #include "cutlass/util/reference/host/tensor_norm.h" #include "cutlass/layout/permute.h" #include "layouts.h" #include "permute_info.h" /// Tensor4DPermuteBMM0213 ---> /// Permute layout function for 4-D permuted tensors for BMM with BMM tensor (dimension as [B, M, N]) reshaped /// as [B/D1, D1, M, N]. Then perform permute([0, 2, 1, 3]) on the corresponding whole BMM tensor. int constexpr D1 = 12; /// Tensor5DPermute20314 ---> /// Permute layout function for 5-D permuted tensors with matrix (dimension as [M, N]) reshaped /// as [M/T1, T1, T2, T3, N/T2/T3]. Then perform permute([2, 0, 3, 1, 4]) on the corresponding tensor. int constexpr T1 = 16; int constexpr T2 = 3; int constexpr T3 = 8; /// Tensor4DPermute0213 ---> /// Permute layout function for 4-D permuted tensors with matrix (dimension as [M, N]) reshaped /// as [M/S1, S1, S2, N/S2]. Then perform permute([0, 2, 1, 3]) on the corresponding tensor. int constexpr S1 = 8; int constexpr S2 = 4; // // // Alignments int constexpr AlignmentA = 8; int constexpr AlignmentB = 8; int constexpr AlignmentC = 8; /// GEMM element types using ElementInput = cutlass::half_t; using ElementOutput = cutlass::half_t; using ElementAccumulator = float; ///////////////////////////////////////////////////////////////////////////////////////////////// /// Useful macros #define CHECK_CUDA_CALL(call, handler) \ do { \ cudaError_t __err = (call); \ if (__err != cudaSuccess) { \ std::cerr << #call " failed: " << cudaGetErrorString(__err) << std::endl; \ handler; \ } \ } while(0) #define CHECK_CUTLASS_CALL(call, handler) \ do { \ cutlass::Status __status = (call); \ if (__status != cutlass::Status::kSuccess) { \ std::cerr << #call " failed: " << cutlass::cutlassGetStatusString(__status) << std::endl; \ handler; \ } \ } while(0) ///////////////////////////////////////////////////////////////////////////////////////////////// // Command line options parsing struct Options { bool help; bool error; bool reference_check; cutlass::gemm::GemmCoord problem_each; int batch_count; int iterations; int cuda_streams; bool verbose; float alpha; float beta; // // Methods // Options(): help(false), error(false), reference_check(true), batch_count(-1), iterations(20), cuda_streams(0), verbose(false), alpha(1), beta() { } // Parses the command line void parse(int argc, char const **args) { cutlass::CommandLine cmd(argc, args); if (cmd.check_cmd_line_flag("help")) { help = true; return; } cmd.get_cmd_line_argument("alpha", alpha, 1.0f); cmd.get_cmd_line_argument("beta", beta, 0.0f); cmd.get_cmd_line_argument("iterations", iterations, 20); cmd.get_cmd_line_argument("streams", cuda_streams, 0); cmd.get_cmd_line_argument("verbose", verbose, false); cmd.get_cmd_line_argument("reference-check", reference_check, true); int m, n, k; cmd.get_cmd_line_argument("m", m, 384); cmd.get_cmd_line_argument("n", n, 192); cmd.get_cmd_line_argument("k", k, 384); cmd.get_cmd_line_argument("batch-count", batch_count, 96); problem_each = cutlass::gemm::GemmCoord(m, n, k); } /// Prints the usage statement. std::ostream & print_usage(std::ostream &out) const { out << "39_gemm_permute\n" "\n" " This example tests and profiles the performance of normal GEMM and batched GEMM with different" " combinations of fused permutations of input and output tensors." "\n" " Permutations considered in this example:\n" "\n" " Normal GEMM:\n" " 1) Tensor4DPermute0213: matrix of shape [X, Y] is reshaped as [X/S1, S1, S2, Y/S2] and has its dimensions" " permuted as [0, 2, 1, 3], resulting in shape [X/S1, S2, S1, Y/S2] viewed as matrix of shape [X*S2/S1, Y*S1/S2].\n" " 2) Tensor5DPermute20314: matrix of shape [X, Y] is reshaped as [X/T1, T1, T2, T3, Y/T2/T3] and has its dimensions" " permuted as [2, 0, 3, 1, 4], resulting in shape [T2, X/T1, T3, T1, Y/T2/T3] viewed as matrix of shape [X*T2/T1, Y*T1/T2].\n" "\n" " Batched GEMM:\n" " 3) Tensor4DPermuteBMM0213: batched tensor of 3D shape [B, X, Y] is reshaped as 4D shape [B/D1, D1, X, Y]" " and has its dimensions permuted as [0, 2, 1, 3], resulting in shape [B/D1, X, D1, Y] viewed as" " a matrix of shape [B/D1, X, Y*D1] for batched GEMM purposes.\n" "\n" " Note: S1, S2, D1, D2, T1, T2, T3 are compile-time constants defined in gemm_permute.cu." " Runtime specification of these values is not supported." " These values along with alignment requirements place constraints on supported matrix sizes.\n" "\n" " Note: X, Y above may refer to M, N or K dimensions of GEMM problem, depending on the tensor considered (A, B or D)." " For the output tensor D the values correspond directly to dimensions of D, whereas for A and B the original dimensions" " X', Y' are inferred from the ones supplied to the GEMM, taking into account the permute operation.\n" "\n" "Options:\n" "\n" " --help If specified, displays this usage statement.\n\n" " --batch-count= Sets the number of batches in batched GEMM (batch number for BMM). (default: --batch-count=768)\n" " --m= Sets the M dimension for both batched GEMM and normal GEMM problems. (default: --m=128)\n" " --n= Sets the N dimension for both batched GEMM and normal GEMM problems. (default: --n=192)\n" " --k= Sets the K dimension for both batched GEMM and normal GEMM problems. (default: --k=384)\n" " --alpha= Epilogue scalar alpha (real part)\n" " --beta= Epilogue scalar beta (real part)\n\n" " --iterations= Number of profiling iterations to perform.\n" " --reference-check= If true, performs reference check.\n" " --verbose= If true, prints problem sizes and batching structure.\n" "\n" "Examples:\n" "\n" "# Runs a batched GEMM with 96 batches\n" "$ ./examples/39_gemm_permute/39_gemm_permute --batch-count=96\n" "\n" "# Runs a batched GEMM with 96 batches (with GEMM-K dimension equal to 1024)\n" "$ ./examples/39_gemm_permute/39_gemm_permute --batch-count=96 --k=1024 --verbose=true\n" "\n" "# Execute batched GEMM and profile with NSight\n" "$ nv-nsight-cu-cli ./examples/39_gemm_permute/39_gemm_permute --m=256 --n=192 --k=256 --verbose=true --iterations=1 --reference-check=false\n" "\n"; return out; } /// Compute performance in GFLOP/s double gflops(double runtime_s, bool batched) const { // Number of real-valued multiply-adds int64_t fmas = int64_t(); fmas += problem_each.product() * (batched ? batch_count : 1); // Two flops per multiply-add return 2.0 * double(fmas) / double(1.0e9) / runtime_s; } }; /////////////////////////////////////////////////////////////////////////////////////////////////// namespace { // (anonymous) /// Dimension-generic permutation loop template void permute_host_impl( cutlass::TensorView const & input, cutlass::TensorView const & output, PermuteOp && permute, Coord & coord ) { static_assert(Layout::kRank == Coord::kRank, "Incompatible Layout and Coord types"); if constexpr (I == Coord::kRank) { output.at(permute(coord)) = input.at(coord); } else { for (coord[I] = 0; coord[I] < input.extent(I); ++coord[I]) { permute_host_impl(input, output, std::forward(permute), coord); } } } } // namespace (anonymous) /// Perform a reference (host-based) permutation of an input tensor template void permute_host( cutlass::TensorView const &input, cutlass::TensorView const &output, int batch_count) { Layout layout = input.layout(); cutlass::MatrixCoord extent = input.extent(); std::size_t num_elems = layout.capacity(extent) * batch_count; std::vector h_input(num_elems); cutlass::device_memory::copy_to_host(h_input.data(), input.data(), num_elems); std::vector h_output(num_elems); using Info = PermuteInfo; using TensorLayout = typename Info::Layout; auto shape_orig = Info::original_shape(extent, batch_count); auto shape_perm = Info::permute(shape_orig); cutlass::TensorView view_input(h_input.data(), TensorLayout::packed(shape_orig), shape_orig); cutlass::TensorView view_output(h_output.data(), TensorLayout::packed(shape_perm), shape_perm); decltype(shape_orig) coord; permute_host_impl<0>(view_input, view_output, Info::permute, coord); cutlass::device_memory::copy_to_device(output.data(), h_output.data(), num_elems); } template struct LayoutInfo; template<> struct LayoutInfo { static std::string name() { return "RowMajor"; } }; template<> struct LayoutInfo { static std::string name() { return "ColumnMajor"; } }; /////////////////////////////////////////////////////////////////////////////////////////////////// template class Testbed { private: // // Data members // Options & options; /// Initialization cutlass::Distribution::Kind init_A; cutlass::Distribution::Kind init_B; cutlass::Distribution::Kind init_C; uint32_t seed; cutlass::DeviceAllocation block_A; cutlass::DeviceAllocation block_B; cutlass::DeviceAllocation block_C; cutlass::DeviceAllocation block_D; public: // // Methods // Testbed( Options &options_, cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, uint32_t seed_ = 3090 ): options(options_), init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } private: /// Print permutation info for one tensor template void print_tensor_info( std::ostream & os, std::string const &tensor_name, int row_dim, int col_dim) { cutlass::MatrixCoord extent(options.problem_each.at(row_dim), options.problem_each.at(col_dim)); using Info = PermuteInfo; os << "tensor " << tensor_name << ": " << Info::desc() << "\n"; os << " extent: [" << extent.row() << ", " << extent.column() << "]"; if (Info::kBatched) { os << ", batch count: " << options.batch_count; } os << "\n"; if (!cutlass::layout::is_trivial_permute) { auto shape_orig = Info::original_shape(extent, options.batch_count); auto shape_perm = Info::permute(shape_orig); os << " original: [" << shape_orig << "]\n"; os << " permuted: [" << shape_perm << "]\n"; } } /// Check shape compatibility for one tensor template bool check_tensor_shape( std::string const &tensor_name, int row_dim, int col_dim) { cutlass::MatrixCoord extent(options.problem_each.at(row_dim), options.problem_each.at(col_dim)); using Info = PermuteInfo; auto rowAlign = cutlass::platform::is_same::value ? Alignment : 1; auto colAlign = cutlass::platform::is_same::value ? Alignment : 1; auto rowFactor = Info::kRowFactor * rowAlign; auto colFactor = Info::kColumnFactor * colAlign; // Assumes row-major layout bool const valid_row = extent.row() % rowFactor == 0; if (!valid_row) { std::cerr << "Invalid tensor " << tensor_name << " row size = " << extent.row() << ", " "must be divisible by " << rowFactor << ", " "required by " << Info::name() << (rowAlign > 1 ? (" and alignment of " + std::to_string(rowAlign)) : "") << std::endl; } bool const valid_col = extent.column() % colFactor == 0; if (!valid_col) { std::cerr << "Invalid tensor " << tensor_name << " column size = " << extent.column() << ", " "must be divisible by " << colFactor << ", " "required by " << Info::name() << (colAlign > 1 ? (" and alignment of " + std::to_string(colAlign)) : "") << std::endl; } bool const valid_bsz = options.batch_count % Info::kBatchFactor == 0; if (!valid_bsz) { std::cerr << "Invalid batch count = " << options.batch_count << ", " "must be divisible by " << Info::kBatchFactor << ", " "required by " << Info::name() << std::endl; } return valid_row && valid_col && valid_bsz; } /// Helper to initialize a tensor view template void initialize_tensor_( Element *ptr, size_t capacity, cutlass::Distribution::Kind dist_kind, uint32_t seed) { if (dist_kind == cutlass::Distribution::Uniform) { Element scope_max, scope_min; int bits_input = cutlass::sizeof_bits::value; int bits_output = cutlass::sizeof_bits::value; if (bits_input == 1) { scope_max = 2; scope_min = 0; } else if (bits_input <= 8) { scope_max = 2; scope_min = -2; } else if (bits_output == 16) { if (cutlass::sizeof_bits::value <= 16) { scope_max = 5; scope_min = -5; } else { scope_max = 8; scope_min = -8; } } else { scope_max = 8; scope_min = -8; } cutlass::reference::device::BlockFillRandomUniform( ptr, capacity, seed, scope_max, scope_min, 0); } else if (dist_kind == cutlass::Distribution::Gaussian) { cutlass::reference::device::BlockFillRandomGaussian( ptr, capacity, seed, Element(), Element(0.5f)); } else if (dist_kind == cutlass::Distribution::Sequential) { // Fill with increasing elements cutlass::reference::device::BlockFillSequential( ptr, capacity, Element(1), Element()); } else { // Fill with all 1s cutlass::reference::device::BlockFillSequential( ptr, capacity, Element(), Element(1)); } } /// Initializes data structures void initialize(int batch_count) { srand(seed); int64_t total_elements_A = options.problem_each.m() * options.problem_each.k() * batch_count; int64_t total_elements_B = options.problem_each.n() * options.problem_each.k() * batch_count; int64_t total_elements_C = options.problem_each.m() * options.problem_each.n() * batch_count; int64_t total_elements_D = options.problem_each.m() * options.problem_each.n() * batch_count; // Allocate space block_A.reset(total_elements_A); block_B.reset(total_elements_B); block_C.reset(total_elements_C); block_D.reset(total_elements_D); // Initialize input tensors initialize_tensor_(block_A.get(), total_elements_A, init_A, seed * 2021); initialize_tensor_(block_B.get(), total_elements_B, init_B, seed * 2022); initialize_tensor_(block_C.get(), total_elements_C, init_C, seed * 2023); cutlass::reference::device::BlockFillSequential( block_D.get(), total_elements_D, ElementC(), ElementC()); } /// Check device GEMM results against a reference implementation with separate host-based permutation template bool validate(Gemm const &gemm) { bool constexpr kBatched = PermuteInfo::kBatched || PermuteInfo::kBatched || PermuteInfo::kBatched; int const batch_count = kBatched ? options.batch_count : 1; cutlass::gemm::GemmCoord problem = options.problem_each; cutlass::MatrixCoord extent_A{problem.m(), problem.k()}; cutlass::MatrixCoord extent_B{problem.k(), problem.n()}; cutlass::MatrixCoord extent_C{problem.m(), problem.n()}; using LayoutA = typename Gemm::LayoutA; using LayoutB = typename Gemm::LayoutB; using LayoutC = typename Gemm::LayoutC; LayoutA layout_A(LayoutA::packed(extent_A)); LayoutB layout_B(LayoutB::packed(extent_B)); LayoutC layout_C(LayoutC::packed(extent_C)); auto size_A = layout_A.capacity(extent_A) * batch_count; auto size_B = layout_B.capacity(extent_B) * batch_count; auto size_C = layout_C.capacity(extent_C) * batch_count; cutlass::TensorView view_A(block_A.get(), layout_A, extent_A); cutlass::TensorView view_B(block_B.get(), layout_B, extent_B); cutlass::TensorView view_C(block_C.get(), layout_C, extent_C); cutlass::TensorView view_D(block_D.get(), layout_C, extent_C); cutlass::DeviceAllocation block_A_perm(size_A); cutlass::DeviceAllocation block_B_perm(size_B); cutlass::TensorView view_A_perm(block_A_perm.get(), layout_A, extent_A); cutlass::TensorView view_B_perm(block_B_perm.get(), layout_B, extent_B); permute_host(view_A.const_view(), view_A_perm, batch_count); permute_host(view_B.const_view(), view_B_perm, batch_count); cutlass::DeviceAllocation block_D_ref(size_C); cutlass::TensorView view_D_ref(block_D_ref.get(), layout_C, extent_C); using EpilogueOutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; // Reference GEMM cutlass::reference::device::GemmComplex< ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, typename EpilogueOutputOp::ElementCompute, typename Gemm::ElementAccumulator >( problem, options.alpha, view_A_perm, Gemm::kTransformA, view_B_perm, Gemm::kTransformB, options.beta, view_C, view_D_ref, ElementAccumulator(0), batch_count, options.problem_each.m() * options.problem_each.k(), options.problem_each.n() * options.problem_each.k(), options.problem_each.m() * options.problem_each.n(), options.problem_each.m() * options.problem_each.n() ); cutlass::DeviceAllocation block_D_perm(size_C); cutlass::TensorView view_D_perm(block_D_perm.get(), layout_C, extent_C); permute_host(view_D_ref.const_view(), view_D_perm, batch_count); // Reference check return cutlass::reference::device::BlockCompareEqual(view_D_perm.data(), view_D.data(), size_C); } public: template bool profile_GEMM_permute() { using LayoutA = typename Gemm::LayoutA; using LayoutB = typename Gemm::LayoutB; using LayoutC = typename Gemm::LayoutC; using PermuteALayout = typename Gemm::PermuteALayout; using PermuteBLayout = typename Gemm::PermuteBLayout; using PermuteDLayout = typename Gemm::PermuteDLayout; bool constexpr kBatched = PermuteInfo::kBatched || PermuteInfo::kBatched || PermuteInfo::kBatched; std::cout << "\n" "====================================================\n" << (kBatched ? "Batched" : "Normal") << " GEMM:" << "\n A=" << LayoutInfo::name() << "," << PermuteInfo::name() << "\n B=" << LayoutInfo::name() << "," << PermuteInfo::name() << "\n D=" << LayoutInfo::name() << "," << PermuteInfo::name() << "\n" "====================================================\n"; if (options.verbose) { print_tensor_info(std::cout, "A", 0, 2); print_tensor_info(std::cout, "B", 2, 1); print_tensor_info(std::cout, "D", 0, 1); } std::cout << std::endl; bool valid = true; valid &= check_tensor_shape("A", 0, 2); valid &= check_tensor_shape("B", 2, 1); valid &= check_tensor_shape("D", 0, 1); if (!valid) { std::cout << "Skipped test" << std::endl; return true; } int const batch_count = kBatched ? options.batch_count : 1; // Initialize the problem initialize(batch_count); // Configure the GEMM arguments using EpilogueOutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; typename EpilogueOutputOp::Params epilogue_op(options.alpha, options.beta); // Please make sure all problem_sizes are the same for kBatched mode auto problem = options.problem_each; cutlass::MatrixCoord extent_A{problem.m(), problem.k()}; cutlass::MatrixCoord extent_B{problem.k(), problem.n()}; cutlass::MatrixCoord extent_C{problem.m(), problem.n()}; LayoutA layout_A(LayoutA::packed(extent_A)); LayoutB layout_B(LayoutB::packed(extent_B)); LayoutC layout_C(LayoutC::packed(extent_C)); // Configure GEMM arguments typename Gemm::Arguments arguments{ kBatched ? cutlass::gemm::GemmUniversalMode::kBatched : cutlass::gemm::GemmUniversalMode::kGemm, problem, batch_count, epilogue_op, (void*)block_A.get(), (void*)block_B.get(), (void*)block_C.get(), (void*)block_D.get(), // For any non-trivial permute the batch stride must be set to 0 cutlass::layout::is_trivial_permute ? layout_A.capacity(extent_A) : 0, cutlass::layout::is_trivial_permute ? layout_B.capacity(extent_B) : 0, layout_C.capacity(extent_C), cutlass::layout::is_trivial_permute ? layout_C.capacity(extent_C) : 0, layout_A.stride(0), layout_B.stride(0), layout_C.stride(0), layout_C.stride(0), }; // Initialize the GEMM object Gemm gemm_normal; CHECK_CUTLASS_CALL(gemm_normal.initialize(arguments, nullptr), return false); // Run the normal GEMM object CHECK_CUTLASS_CALL(gemm_normal.run(), return false); // Wait for completion CHECK_CUDA_CALL(cudaDeviceSynchronize(), return false); // // Verify correctness // if (options.reference_check) { if (validate(gemm_normal)) { std::cout << "\nPassed verification\n" << std::endl; } else { std::cerr << "\n*** Error - problem failed the QA check ***\n" << std::endl; return false; } } // Warm-up run of the normal GEMM object CHECK_CUTLASS_CALL(gemm_normal.run(), return false); // Construct events cudaEvent_t events[2]; for (auto & event : events) { CHECK_CUDA_CALL(cudaEventCreate(&event), return false); } // Record an event at the start of a series of GEMM operations CHECK_CUDA_CALL(cudaEventRecord(events[0]), return false); // Run profiling loop for (int iter = 0; iter < options.iterations; ++iter) { gemm_normal(); } // Record an event when the GEMM operations have been launched. CHECK_CUDA_CALL(cudaEventRecord(events[1]), return false); // Wait for work on the device to complete. CHECK_CUDA_CALL(cudaEventSynchronize(events[1]), return false); // Measure elapsed runtime float runtime_total_ms = 0; CHECK_CUDA_CALL(cudaEventElapsedTime(&runtime_total_ms, events[0], events[1]), return false); // Compute average runtime and GFLOPs. double runtime_avg_ms = double(runtime_total_ms) / double(options.iterations); double gflops = options.gflops(runtime_avg_ms / 1000.0, kBatched); // Cleanup for (auto event : events) { CHECK_CUDA_CALL(cudaEventDestroy(event), return false); } std::cout << " Runtime: " << runtime_avg_ms << " ms\n" " GFLOPs: " << gflops << std::endl; return true; } }; /// Shorthand alist for GEMM instantiations template using GemmPermute = cutlass::gemm::device::GemmUniversal< ElementInput, LayoutA, ElementInput, LayoutB, ElementOutput, LayoutC, ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, cutlass::gemm::GemmShape<128, 128, 32>, cutlass::gemm::GemmShape<64, 64, 32>, cutlass::gemm::GemmShape<16, 8, 16>, cutlass::epilogue::thread::LinearCombination< ElementOutput, AlignmentC, //128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementAccumulator >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, 4, /*kStages*/ AlignmentA, /*AlignmentA*/ AlignmentB, /*AlignmentB*/ cutlass::arch::OpMultiplyAdd, cutlass::ComplexTransform::kNone, cutlass::ComplexTransform::kNone, false, /*GatherA*/ false, /*GatherB*/ false, /*ScatterD*/ PermuteDLayout, /*PermuteDLayout*/ typename cutlass::layout::InversePermute::type, /*PermuteALayout*/ typename cutlass::layout::InversePermute::type /*PermuteBLayout*/ >; /////////////////////////////////////////////////////////////////////////////////////////////////// int main(int argc, char const **args) { // // This example uses mma.sync to directly access Tensor Cores to achieve peak performance. // cudaDeviceProp props; CHECK_CUDA_CALL(cudaGetDeviceProperties(&props, 0), return EXIT_FAILURE); if (__CUDACC_VER_MAJOR__ < 11 || props.major < 8) { // // This example requires an NVIDIA Ampere-architecture GPU. // std::cout << "CUTLASS's GEMM+Permute example requires a GPU of NVIDIA's Ampere Architecture " "or later (compute capability 80 or greater).\n"; return EXIT_SUCCESS; } // // Parse options // Options options; options.parse(argc, args); if (options.help) { options.print_usage(std::cout) << std::endl; return EXIT_SUCCESS; } if (options.error) { std::cerr << "Aborting execution." << std::endl; return EXIT_FAILURE; } // // Define GEMM types to test // // // TTT (Row-major) GEMMs // using TTTGemmNormalPermuteNone = GemmPermute< cutlass::layout::RowMajor, cutlass::layout::NoPermute, cutlass::layout::RowMajor, cutlass::layout::NoPermute, cutlass::layout::RowMajor, cutlass::layout::NoPermute >; using TTTGemmNormalPermuteA = GemmPermute< cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor, cutlass::layout::RowMajor, cutlass::layout::NoPermute, cutlass::layout::RowMajor, cutlass::layout::NoPermute >; using TTTGemmNormalPermuteAD = GemmPermute< cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor, cutlass::layout::RowMajor, cutlass::layout::NoPermute, cutlass::layout::RowMajor, cutlass::layout::Tensor5DPermute20314RowMajor >; using TTTGemmNormalPermuteB = GemmPermute< cutlass::layout::RowMajor, cutlass::layout::NoPermute, cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor, cutlass::layout::RowMajor, cutlass::layout::NoPermute >; using TTTGemmNormalPermuteBD = GemmPermute< cutlass::layout::RowMajor, cutlass::layout::NoPermute, cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor, cutlass::layout::RowMajor, cutlass::layout::Tensor5DPermute20314RowMajor >; using TTTGemmNormalPermuteD = GemmPermute< cutlass::layout::RowMajor, cutlass::layout::NoPermute, cutlass::layout::RowMajor, cutlass::layout::NoPermute, cutlass::layout::RowMajor, cutlass::layout::Tensor5DPermute20314RowMajor >; using TTTGemmNormalPermuteAB = GemmPermute< cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor, cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor, cutlass::layout::RowMajor, cutlass::layout::NoPermute >; using TTTGemmNormalPermuteABD = GemmPermute< cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor, cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor, cutlass::layout::RowMajor, cutlass::layout::Tensor5DPermute20314RowMajor >; // // NNN (Col-major) GEMMs // using NNNGemmNormalPermuteNone = GemmPermute< cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, cutlass::layout::ColumnMajor, cutlass::layout::NoPermute >; using NNNGemmNormalPermuteA = GemmPermute< cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor, cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, cutlass::layout::ColumnMajor, cutlass::layout::NoPermute >; using NNNGemmNormalPermuteAD = GemmPermute< cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor, cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor >; using NNNGemmNormalPermuteB = GemmPermute< cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213ColumnMajor, cutlass::layout::ColumnMajor, cutlass::layout::NoPermute >; using NNNGemmNormalPermuteBD = GemmPermute< cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213ColumnMajor, cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor >; using NNNGemmNormalPermuteD = GemmPermute< cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor >; using NNNGemmNormalPermuteAB = GemmPermute< cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor, cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213ColumnMajor, cutlass::layout::ColumnMajor, cutlass::layout::NoPermute >; using NNNGemmNormalPermuteABD = GemmPermute< cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor, cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213ColumnMajor, cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor >; // // NNT (Col-major inputs, row-major output) GEMMs // using NNTGemmNormalPermuteNone = GemmPermute< cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, cutlass::layout::RowMajor, cutlass::layout::NoPermute >; using NNTGemmNormalPermuteA = GemmPermute< cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213RowMajor, cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, cutlass::layout::RowMajor, cutlass::layout::NoPermute >; using NNTGemmNormalPermuteAD = GemmPermute< cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213RowMajor, cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, cutlass::layout::RowMajor, cutlass::layout::Tensor5DPermute20314RowMajor >; using NNTGemmNormalPermuteB = GemmPermute< cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213ColumnMajor, cutlass::layout::RowMajor, cutlass::layout::NoPermute >; using NNTGemmNormalPermuteBD = GemmPermute< cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213ColumnMajor, cutlass::layout::RowMajor, cutlass::layout::Tensor5DPermute20314RowMajor >; using NNTGemmNormalPermuteD = GemmPermute< cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, cutlass::layout::RowMajor, cutlass::layout::Tensor5DPermute20314RowMajor >; using NNTGemmNormalPermuteAB = GemmPermute< cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213RowMajor, cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213ColumnMajor, cutlass::layout::RowMajor, cutlass::layout::NoPermute >; using NNTGemmNormalPermuteABD = GemmPermute< cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213RowMajor, cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213ColumnMajor, cutlass::layout::RowMajor, cutlass::layout::Tensor5DPermute20314RowMajor >; // // TTN (Row-major inputs, col-major output) GEMMs // using TTNGemmNormalPermuteNone = GemmPermute< cutlass::layout::RowMajor, cutlass::layout::NoPermute, cutlass::layout::RowMajor, cutlass::layout::NoPermute, cutlass::layout::ColumnMajor, cutlass::layout::NoPermute >; using TTNGemmNormalPermuteA = GemmPermute< cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor, cutlass::layout::RowMajor, cutlass::layout::NoPermute, cutlass::layout::ColumnMajor, cutlass::layout::NoPermute >; using TTNGemmNormalPermuteAD = GemmPermute< cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor, cutlass::layout::RowMajor, cutlass::layout::NoPermute, cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor >; using TTNGemmNormalPermuteB = GemmPermute< cutlass::layout::RowMajor, cutlass::layout::NoPermute, cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor, cutlass::layout::ColumnMajor, cutlass::layout::NoPermute >; using TTNGemmNormalPermuteBD = GemmPermute< cutlass::layout::RowMajor, cutlass::layout::NoPermute, cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor, cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor >; using TTNGemmNormalPermuteD = GemmPermute< cutlass::layout::RowMajor, cutlass::layout::NoPermute, cutlass::layout::RowMajor, cutlass::layout::NoPermute, cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor >; using TTNGemmNormalPermuteAB = GemmPermute< cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor, cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor, cutlass::layout::ColumnMajor, cutlass::layout::NoPermute >; using TTNGemmNormalPermuteABD = GemmPermute< cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor, cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor, cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor >; // // TTT (Row-major) BMMs // using TTTGemmBatchedPermuteA = GemmPermute< cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor, cutlass::layout::RowMajor, cutlass::layout::NoPermute, cutlass::layout::RowMajor, cutlass::layout::NoPermute >; using TTTGemmBatchedPermuteAD = GemmPermute< cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor, cutlass::layout::RowMajor, cutlass::layout::NoPermute, cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor >; using TTTGemmBatchedPermuteB = GemmPermute< cutlass::layout::RowMajor, cutlass::layout::NoPermute, cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor, cutlass::layout::RowMajor, cutlass::layout::NoPermute >; using TTTGemmBatchedPermuteBD = GemmPermute< cutlass::layout::RowMajor, cutlass::layout::NoPermute, cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor, cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor >; using TTTGemmBatchedPermuteD = GemmPermute< cutlass::layout::RowMajor, cutlass::layout::NoPermute, cutlass::layout::RowMajor, cutlass::layout::NoPermute, cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor >; using TTTGemmBatchedPermuteAB = GemmPermute< cutlass::layout::RowMajor, cutlass::layout::NoPermute, cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor, cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor >; using TTTGemmBatchedPermuteABD = GemmPermute< cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor, cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor, cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor >; // // NNN (Col-major) BMMs // using NNNGemmBatchedPermuteA = GemmPermute< cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor, cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, cutlass::layout::ColumnMajor, cutlass::layout::NoPermute >; using NNNGemmBatchedPermuteAD = GemmPermute< cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor, cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor >; using NNNGemmBatchedPermuteB = GemmPermute< cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor, cutlass::layout::ColumnMajor, cutlass::layout::NoPermute >; using NNNGemmBatchedPermuteBD = GemmPermute< cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor, cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor >; using NNNGemmBatchedPermuteD = GemmPermute< cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor >; using NNNGemmBatchedPermuteAB = GemmPermute< cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor, cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor, cutlass::layout::ColumnMajor, cutlass::layout::NoPermute >; using NNNGemmBatchedPermuteABD = GemmPermute< cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor, cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor, cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor >; // // Profile it // Testbed testbed(options); bool result = true; result &= testbed.profile_GEMM_permute(); result &= testbed.profile_GEMM_permute(); result &= testbed.profile_GEMM_permute(); result &= testbed.profile_GEMM_permute(); result &= testbed.profile_GEMM_permute(); result &= testbed.profile_GEMM_permute(); result &= testbed.profile_GEMM_permute(); result &= testbed.profile_GEMM_permute(); result &= testbed.profile_GEMM_permute(); result &= testbed.profile_GEMM_permute(); result &= testbed.profile_GEMM_permute(); result &= testbed.profile_GEMM_permute(); result &= testbed.profile_GEMM_permute(); result &= testbed.profile_GEMM_permute(); result &= testbed.profile_GEMM_permute(); result &= testbed.profile_GEMM_permute(); result &= testbed.profile_GEMM_permute(); result &= testbed.profile_GEMM_permute(); result &= testbed.profile_GEMM_permute(); result &= testbed.profile_GEMM_permute(); result &= testbed.profile_GEMM_permute(); result &= testbed.profile_GEMM_permute(); result &= testbed.profile_GEMM_permute(); result &= testbed.profile_GEMM_permute(); result &= testbed.profile_GEMM_permute(); result &= testbed.profile_GEMM_permute(); result &= testbed.profile_GEMM_permute(); result &= testbed.profile_GEMM_permute(); result &= testbed.profile_GEMM_permute(); result &= testbed.profile_GEMM_permute(); result &= testbed.profile_GEMM_permute(); result &= testbed.profile_GEMM_permute(); result &= testbed.profile_GEMM_permute(); result &= testbed.profile_GEMM_permute(); result &= testbed.profile_GEMM_permute(); result &= testbed.profile_GEMM_permute(); result &= testbed.profile_GEMM_permute(); result &= testbed.profile_GEMM_permute(); result &= testbed.profile_GEMM_permute(); result &= testbed.profile_GEMM_permute(); result &= testbed.profile_GEMM_permute(); result &= testbed.profile_GEMM_permute(); result &= testbed.profile_GEMM_permute(); result &= testbed.profile_GEMM_permute(); result &= testbed.profile_GEMM_permute(); result &= testbed.profile_GEMM_permute(); std::cout << "\n" "====================================================\n" "Finished (" << (result ? "PASS" : "FAIL") << ")\n" "====================================================" << std::endl; return result ? EXIT_SUCCESS : EXIT_FAILURE; } /////////////////////////////////////////////////////////////////////////////////////////////////