/*************************************************************************************************** * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file \brief Grouped FMHA kernel */ #pragma once #include "cutlass/cutlass.h" #include "cutlass/fast_math.h" #include "cutlass/gemm/gemm.h" #include "cutlass/matrix_coord.h" #include "cutlass/complex.h" #include "cutlass/semaphore.h" #include "cutlass/layout/matrix.h" #include "cutlass/trace.h" #include "cutlass/gemm/kernel/gemm_transpose_operands.h" #include "fmha_grouped_problem_visitor.h" #include "gemm_kernel_utils.h" #include "gemm/mma_accum_lambda_iterator.h" #include "epilogue/epilogue_rescale_output.h" namespace { static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value) { // source: https://stackoverflow.com/a/51549250 return (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) : __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value))); } } ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass { namespace gemm { namespace kernel { ///////////////////////////////////////////////////////////////////////////////////////////////// template < typename MM0_, ///! Structure for computing P = Q @ K typename MM1_, ///! Structure for computing O = P @ V typename scalar_t_, typename accum_t_, typename output_t_, typename output_accum_t_, bool kKeepOutputInRF, ///! Whether the intermediate output from MM0_ should be kept in the register file GroupScheduleMode GroupScheduleMode_ ///! Type of scheduling to perform > struct FMHAGrouped { public: using MM0 = MM0_; using MM1 = MM1_; using scalar_t = scalar_t_; using accum_t = accum_t_; using output_t = output_t_; using output_accum_t = output_accum_t_; static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; static constexpr bool kNeedsOutputAccumulatorBuffer = !kKeepOutputInRF && !cutlass::platform::is_same::value; // Parameters to satisfy BaseGrouped using ElementA = scalar_t; using ElementB = scalar_t; using ElementC = accum_t; using LayoutA = typename MM0::LayoutA; using LayoutB = typename MM0::ElementB; using LayoutC = typename MM1::ElementC; static ComplexTransform const kTransformA = ComplexTransform::kNone; static ComplexTransform const kTransformB = ComplexTransform::kNone; static int const kAlignmentA = MM0::kAlignmentA; static int const kAlignmentB = MM0::kAlignmentB; static int const kAlignmentC = 1; using Mma = typename MM1::Mma; using EpilogueOutputOp = typename MM1::EpilogueOutputOp; using ThreadblockSwizzle = void; using Operator = typename MM1::Operator; using WarpShape = typename MM1::WarpShape; using InstructionShape = typename MM1::InstructionShape; using ElementQ = scalar_t; using ElementK = scalar_t; using ElementP = accum_t; using ElementV = scalar_t; using ElementO = output_t; using ElementOAccum = output_accum_t; using ElementAccumulator = accum_t; using LayoutQ = typename MM0::LayoutA; using LayoutK = typename MM0::LayoutB; using LayoutP = typename MM0::LayoutC; using LayoutV = typename MM1::LayoutB; using LayoutO = typename MM1::LayoutC; static bool const kPreloadV = (MM1::Mma::ArchTag::kMinComputeCapability >= 80 && cutlass::sizeof_bits::value == 16); static int const kAlignmentQ = MM0::kAlignmentA; static int const kAlignmentK = MM0::kAlignmentB; static int const kAlignmentV = 1; using ThreadblockShape = typename MM0::ThreadblockShape; static int const kQueriesPerBlock = ThreadblockShape::kM; static int const kKeysPerBlock = ThreadblockShape::kN; static constexpr bool kSupportsDropout = false; static constexpr bool kSupportsBias = false; /// Warp count (concept: GemmShape) using WarpCount = typename MM1::WarpCount; static int const kThreadsPerWarp = 32; static int const kThreadCount = kThreadsPerWarp * WarpCount::kCount; static constexpr int kNumWarpsPerBlock = kQueriesPerBlock * kKeysPerBlock / (kThreadsPerWarp * kThreadsPerWarp); using ProblemVisitor = FMHAGroupedProblemVisitor< ThreadblockShape, kGroupScheduleMode, kThreadCount, kThreadCount>; // // Structures // /// Argument structure struct Arguments { // // Data members // GemmCoord *problem_sizes0; GemmCoord *problem_sizes1; int problem_count; int threadblock_count; ElementQ ** ptr_Q; ElementK ** ptr_K; ElementP ** ptr_P; ElementV ** ptr_V; ElementO ** ptr_O; ElementOAccum ** ptr_O_accum; typename LayoutQ::Stride::LongIndex *ldq; typename LayoutK::Stride::LongIndex *ldk; typename LayoutP::Stride::LongIndex *ldv; typename LayoutO::Stride::LongIndex *ldo; // Scale ElementAccumulator scale; // Whether causal masking is to be performed bool causal; // Only used by device-level operator GemmCoord *host_problem_sizes; // // Methods // /// Default ctor CUTLASS_HOST_DEVICE Arguments(): problem_count(0), threadblock_count(0), ptr_Q(nullptr), ptr_K(nullptr), ptr_P(nullptr), ptr_V(nullptr), ptr_O(nullptr), ptr_O_accum(nullptr), ldq(nullptr), ldk(nullptr), ldv(nullptr), ldo(nullptr), scale(0), causal(false), host_problem_sizes(nullptr) { } /// Ctor CUTLASS_HOST_DEVICE Arguments( GemmCoord *problem_sizes0, GemmCoord *problem_sizes1, int problem_count, int threadblock_count, ElementQ ** ptr_Q, ElementK ** ptr_K, ElementP ** ptr_P, ElementV ** ptr_V, ElementO ** ptr_O, ElementOAccum ** ptr_O_accum, typename LayoutQ::Stride::LongIndex *ldq, typename LayoutK::Stride::LongIndex *ldk, typename LayoutP::Stride::LongIndex *ldp, typename LayoutV::Stride::LongIndex *ldv, typename LayoutO::Stride::LongIndex *ldo, bool causal, ElementAccumulator scale, GemmCoord *host_problem_sizes=nullptr ): problem_sizes0(problem_sizes0), problem_sizes1(problem_sizes1), problem_count(problem_count), threadblock_count(threadblock_count), ptr_Q(ptr_Q), ptr_K(ptr_K), ptr_P(ptr_P), ptr_V(ptr_V), ptr_O(ptr_O), ptr_O_accum(kNeedsOutputAccumulatorBuffer ? ptr_O_accum : (accum_t**)ptr_O), ldq(ldq), ldk(ldk), ldv(ldv), ldo(ldo), causal(causal), scale(scale), host_problem_sizes(host_problem_sizes) { } bool __host__ check_supported() { CHECK_ALIGNED_PTR(ptr_Q, kAlignmentQ); CHECK_ALIGNED_PTR(ptr_K, kAlignmentK); CHECK_ALIGNED_PTR(ptr_V, kAlignmentV); XFORMERS_CHECK(ldq % kAlignmentQ == 0, "query is not correctly aligned"); XFORMERS_CHECK(ldk % kAlignmentK == 0, "key is not correctly aligned"); XFORMERS_CHECK(ldv % kAlignmentV == 0, "value is not correctly aligned"); return true; } }; // // Structure for precomputing values in host memory and passing to kernels // /// Parameters structure struct Params { typename ProblemVisitor::Params problem_visitor; int threadblock_count; ElementQ ** ptr_Q; ElementK ** ptr_K; ElementP ** ptr_P; ElementV ** ptr_V; ElementO ** ptr_O; ElementOAccum ** ptr_O_accum; typename LayoutQ::Stride::LongIndex *ldq; typename LayoutK::Stride::LongIndex *ldk; typename LayoutP::Stride::LongIndex *ldv; typename LayoutO::Stride::LongIndex *ldo; ElementAccumulator scale; bool causal; // // Methods // CUTLASS_HOST_DEVICE Params(): ptr_Q(nullptr), ptr_K(nullptr), ptr_P(nullptr), ptr_V(nullptr), ptr_O(nullptr), ptr_O_accum(nullptr), ldq(nullptr), ldk(nullptr), ldv(nullptr), ldo(nullptr), causal(false), scale(0) { } CUTLASS_HOST_DEVICE Params(Arguments const &args, void *workspace = nullptr, int tile_count = 0): problem_visitor(args.problem_sizes0, args.problem_sizes1, args.problem_count, workspace, tile_count), threadblock_count(args.threadblock_count), ptr_Q(args.ptr_Q), ptr_K(args.ptr_K), ptr_P(args.ptr_P), ptr_V(args.ptr_V), ptr_O(args.ptr_O), ptr_O_accum(kNeedsOutputAccumulatorBuffer ? args.ptr_O_accum : (accum_t**)args.ptr_O), ldq(args.ldq), ldk(args.ldk), ldv(args.ldv), ldo(args.ldo), causal(args.causal), scale(args.scale) { } CUTLASS_HOST_DEVICE void update( Arguments const &args, void *workspace = nullptr, int tile_count = 0) { problem_visitor = typename ProblemVisitor::Params(args.problem_sizes0, args.problem_sizes1, args.problem_count, workspace, tile_count); threadblock_count = args.threadblock_count; ptr_Q = args.ptr_Q; ptr_K = args.ptr_K; ptr_P = args.ptr_P; ptr_V = args.ptr_V; ptr_O = args.ptr_O; ptr_O_accum = kNeedsOutputAccumulatorBuffer ? args.ptr_O_accum : (accum_t**)args.ptr_O; ldq = args.ldq; ldk = args.ldk; ldv = args.ldv; ldo = args.ldo; causal = args.causal; scale = args.scale; } }; // Shared storage - depends on kernel params struct ScalingCoefs { cutlass::Array m_prime; cutlass::Array s_prime; cutlass::Array mi; cutlass::Array out_rescale; cutlass::Array addition_storage; }; struct SharedStorageEpilogueAtEnd : ScalingCoefs { struct SharedStorageAfterMM0 { // Everything here might be overwritten during MM0 typename MM0::AccumulatorSharedStorage si; typename MM1::Mma::SharedStorage mm1; }; union { typename MM0::Mma::SharedStorage mm0; SharedStorageAfterMM0 after_mm0; typename MM1::DefaultEpilogue::SharedStorage epilogue; }; CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage& epilogue_shared_storage() { return epilogue; } // ProblemVisitor shared storage can't be overlapped with others typename ProblemVisitor::SharedStorage problem_visitor; }; struct SharedStorageEpilogueInLoop : ScalingCoefs { struct SharedStorageAfterMM0 { // Everything here might be overwritten during MM0 typename MM0::AccumulatorSharedStorage si; typename MM1::Mma::SharedStorage mm1; typename MM1::DefaultEpilogue::SharedStorage epilogue; }; union { typename MM0::Mma::SharedStorage mm0; SharedStorageAfterMM0 after_mm0; }; CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage& epilogue_shared_storage() { return after_mm0.epilogue; } // ProblemVisitor shared storage can't be overlapped with others typename ProblemVisitor::SharedStorage problem_visitor; }; using SharedStorage = typename cutlass::platform::conditional< kKeepOutputInRF, SharedStorageEpilogueAtEnd, SharedStorageEpilogueInLoop>::type; private: // Parameters to be used by an individual tile struct TileParams { CUTLASS_HOST_DEVICE static int query_start(int threadblock_idx) { return threadblock_idx * kQueriesPerBlock; } // Returns whether this threadblock computes within the number of queries, // which is determined by the M dimension of problem 0 CUTLASS_HOST_DEVICE static bool can_compute(int threadblock_idx, const GemmCoord& problem_size0) { return query_start(threadblock_idx) < problem_size0.m(); } CUTLASS_HOST_DEVICE static int num_queries(int threadblock_idx, const GemmCoord& problem_size0) { return problem_size0.m() - query_start(threadblock_idx); } CUTLASS_HOST_DEVICE static int num_keys(int threadblock_idx, const GemmCoord& problem_size0, bool causal) { int nk = problem_size0.n(); if (causal) { nk = cutlass::fast_min(int32_t(query_start(threadblock_idx) + kQueriesPerBlock), nk); } return nk; } }; public: // // Methods // CUTLASS_DEVICE FMHAGrouped() { } /// Determines whether kernel satisfies alignment static Status can_implement(cutlass::gemm::GemmCoord const & problem_size) { return Status::kSuccess; } static Status can_implement(Arguments const &args) { return Status::kSuccess; } static CUTLASS_DEVICE int16_t thread_id() { return threadIdx.x; } static CUTLASS_DEVICE int8_t warp_id() { return threadIdx.x / kThreadsPerWarp; } static CUTLASS_DEVICE int8_t lane_id() { return threadIdx.x % kThreadsPerWarp; } /// Executes one GEMM CUTLASS_DEVICE void operator()(Params const ¶ms, SharedStorage &shared_storage) { auto& m_prime = shared_storage.m_prime; auto& s_prime = shared_storage.s_prime; [[maybe_unused]] auto& si = shared_storage.after_mm0.si; auto& mi = shared_storage.mi; auto& out_rescale = shared_storage.out_rescale; ProblemVisitor problem_visitor( params.problem_visitor, shared_storage.problem_visitor, blockIdx.x); // Outer 'persistent' loop to iterate over tiles while (problem_visitor.next_tile()) { GemmCoord problem_size0 = problem_visitor.problem_size0(); GemmCoord problem_size1 = problem_visitor.problem_size1(); const int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); if (!TileParams::can_compute(threadblock_idx, problem_size0)) { problem_visitor.advance(gridDim.x); continue; } const int32_t problem_idx = problem_visitor.problem_index(); if (thread_id() < kQueriesPerBlock) { s_prime[thread_id()] = ElementAccumulator(0); out_rescale[thread_id()] = accum_t(1.0); m_prime[thread_id()] = -cutlass::platform::numeric_limits::infinity(); mi[thread_id()] = -cutlass::platform::numeric_limits::infinity(); } ElementO *ptr_O = params.ptr_O[problem_idx] + TileParams::query_start(threadblock_idx) * params.ldo[problem_idx]; ElementOAccum *ptr_O_accum = params.ptr_O_accum[problem_idx] + TileParams::query_start(threadblock_idx) * params.ldo[problem_idx]; const int num_queries = TileParams::num_queries(threadblock_idx, problem_size0); auto createOutputIter = [&](int col) -> typename MM1::OutputTileIterator { using OutputTileIterator = typename MM1::OutputTileIterator; return OutputTileIterator( typename OutputTileIterator::Params{(int32_t)params.ldo[problem_idx]}, ptr_O, typename OutputTileIterator::TensorCoord{ num_queries, problem_size1.n()}, thread_id(), {0, col}); }; auto createOutputAccumIter = [&](int col) -> typename MM1::OutputTileIteratorAccum { using OutputTileIteratorAccum = typename MM1::OutputTileIteratorAccum; return OutputTileIteratorAccum( typename OutputTileIteratorAccum::Params{(int32_t)params.ldo[problem_idx]}, ptr_O_accum, typename OutputTileIteratorAccum::TensorCoord{ num_queries, problem_size1.n()}, thread_id(), {0, col}); }; typename MM1::Mma::FragmentC accum_o; accum_o.clear(); const int num_keys = TileParams::num_keys(threadblock_idx, problem_size0, params.causal); for (int32_t iter_key_start = 0; iter_key_start < num_keys; iter_key_start += kKeysPerBlock) { int32_t problem_size_0_m = cutlass::fast_min((int32_t)kQueriesPerBlock, num_queries); int32_t problem_size_0_n = cutlass::fast_min( (int32_t)kKeysPerBlock, num_keys - iter_key_start); int32_t const& problem_size_0_k = problem_size0.k(); int32_t const& problem_size_1_n = problem_size1.n(); int32_t const& problem_size_1_k = problem_size_0_n; auto prologueV = [&](int blockN) { typename MM1::Mma::IteratorB iterator_V( typename MM1::IteratorB::Params{MM1::LayoutB(params.ldv[problem_idx])}, params.ptr_V[problem_idx] + iter_key_start * params.ldv[problem_idx], {problem_size_1_k, problem_size_1_n}, thread_id(), cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); MM1::Mma::prologue( shared_storage.after_mm0.mm1, iterator_V, thread_id(), problem_size_1_k); }; __syncthreads(); // Need to have shared memory initialized, and `m_prime` // updated from end of prev iter // // MATMUL: Q.K_t // // Computes the block-matrix product of: // (a) query[query_start:query_end, :] // with // (b) key[iter_key_start:iter_key_start + kKeysPerBlock] // and stores that into `shared_storage.si` // ElementQ *ptr_Q = params.ptr_Q[problem_idx] + TileParams::query_start(threadblock_idx) * params.ldq[problem_idx]; // Construct iterators to A and B operands typename MM0::IteratorA iterator_A( typename MM0::IteratorA::Params( typename MM0::MmaCore::LayoutA(params.ldq[problem_idx])), ptr_Q, {problem_size_0_m, problem_size_0_k}, thread_id(), {0, 0}); typename MM0::IteratorB iterator_B( typename MM0::IteratorB::Params( typename MM0::MmaCore::LayoutB(params.ldk[problem_idx])), params.ptr_K[problem_idx] + iter_key_start * params.ldk[problem_idx], {problem_size_0_k, problem_size_0_n}, thread_id(), {0, 0}); // Construct thread-scoped matrix multiply typename MM0::Mma mma( shared_storage.mm0, thread_id(), warp_id(), lane_id()); typename MM0::Mma::FragmentC accum; accum.clear(); auto gemm_k_iterations = (problem_size_0_k + MM0::Mma::Shape::kK - 1) / MM0::Mma::Shape::kK; // Compute threadblock-scoped matrix multiply-add mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); __syncthreads(); if (kPreloadV) { prologueV(0); } else { MM1::Mma::drain_cp_asyncs(); } typename MM0::Mma::Operator::IteratorC::TensorCoord iteratorC_tile_offset = { (warp_id() % MM0::Mma::WarpCount::kM), (warp_id() / MM0::Mma::WarpCount::kM) }; // Mask out last if causal if (params.causal && num_keys - iter_key_start <= kKeysPerBlock) { auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset( lane_id(), warp_id(), iteratorC_tile_offset); int32_t last_col; MM0::AccumLambdaIterator::iterateRows( lane_offset, [&](int accum_m) { last_col = TileParams::query_start(threadblock_idx) + accum_m - iter_key_start; }, [&](int accum_m, int accum_n, int idx) { if (accum_n > last_col) { accum[idx] = -cutlass::platform::numeric_limits::infinity(); } }, [&](int accum_m) {}); } // DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] { // DISPATCH_BOOL( // num_keys - iter_key_start >= kKeysPerBlock, // kFullColumns, // ([&] { // // Update `mi` from accum stored in registers // // Also does accum[i] <- exp(accum[i] - mi) // iterative_softmax< // typename MM0::Mma::Operator::IteratorC, // kFullColumns, // kIsFirst>( // accum_o, // accum, // mi, // m_prime, // s_prime, // lane_id(), // thread_id(), // warp_id(), // num_keys - iter_key_start, // iteratorC_tile_offset, // kSupportsBias ? 1.0f : params.scale); // })); // })); // Update `mi` from accum stored in registers // Also does accum[i] <- exp(accum[i] - mi) iterative_softmax( accum_o, accum, mi, m_prime, s_prime, out_rescale, shared_storage.addition_storage, lane_id(), thread_id(), warp_id(), num_keys - iter_key_start, iter_key_start == 0, iteratorC_tile_offset, kSupportsBias ? 1.0f : params.scale); // Output results to shared-memory int warp_idx_mn_0 = warp_id() % (MM0::Mma::Base::WarpCount::kM * MM0::Mma::Base::WarpCount::kN); auto output_tile_coords = cutlass::MatrixCoord{ warp_idx_mn_0 % MM0::Mma::Base::WarpCount::kM, warp_idx_mn_0 / MM0::Mma::Base::WarpCount::kM}; MM0::B2bGemm::accumToSmem( shared_storage.after_mm0.si, accum, lane_id(), output_tile_coords); __syncthreads(); // // MATMUL: Attn . V // Run the matmul `attn @ V` for a block of attn and V. // `attn` is read from shared memory (in `shared_storage_si`) // `V` is read from global memory (with iterator_B) // const int64_t nBlockN = kKeepOutputInRF ? 1 : ceil_div( (int64_t)problem_size_1_n, int64_t(MM1::ThreadblockShape::kN)); // Iterate over the N dimension of GEMM1 for (int blockN = 0; blockN < nBlockN; ++blockN) { int gemm_k_iterations = (problem_size_1_k + MM1::Mma::Shape::kK - 1) / MM1::Mma::Shape::kK; // Compute threadblock-scoped matrix multiply-add and store it in accum // (in registers) if (!kPreloadV) { __syncthreads(); // we share shmem between mma and epilogue } typename MM1::Mma::IteratorB iterator_V( typename MM1::IteratorB::Params{MM1::LayoutB(params.ldv[problem_idx])}, params.ptr_V[problem_idx] + iter_key_start * params.ldv[problem_idx], {problem_size_1_k, problem_size_1_n}, thread_id(), cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); typename MM1::Mma mma_pv( // operand A: Pij_dropped in shared memory shared_storage.after_mm0.si.accum_ref(), // operand B: shared memory staging area for Vj, which is loaded // from global memory shared_storage.after_mm0.mm1.operand_B_ref(), (int)thread_id(), (int)warp_id(), (int)lane_id()); mma_pv.set_prologue_done(kPreloadV); if (!kKeepOutputInRF) { accum_o.clear(); } mma_pv(gemm_k_iterations, accum_o, iterator_V, accum_o); __syncthreads(); if (kPreloadV && !kKeepOutputInRF && blockN + 1 < nBlockN) { prologueV(blockN + 1); } if (!kKeepOutputInRF) { MM1::Mma::drain_cp_asyncs(); DISPATCH_BOOL( iter_key_start == 0, kIsFirst, ([&] { DISPATCH_BOOL( (iter_key_start + kKeysPerBlock) >= num_keys, kIsLast, ([&] { using DefaultEpilogue = typename MM1::DefaultEpilogue; using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp; using ElementCompute = typename DefaultOp::ElementCompute; using EpilogueOutputOp = typename cutlass::epilogue:: thread::MemoryEfficientAttentionNormalize< typename cutlass::platform::conditional< kIsLast, output_t, output_accum_t>::type, output_accum_t, DefaultOp::kCount, typename DefaultOp::ElementAccumulator, output_accum_t, kIsFirst, kIsLast, cutlass::Array>; using Epilogue = typename cutlass::epilogue::threadblock:: EpiloguePipelined< typename DefaultEpilogue::Shape, typename MM1::Mma::Operator, DefaultEpilogue::kPartitionsK, typename cutlass::platform::conditional< kIsLast, typename MM1::OutputTileIterator, typename MM1::OutputTileIteratorAccum>::type, typename DefaultEpilogue:: AccumulatorFragmentIterator, typename DefaultEpilogue::WarpTileIterator, typename DefaultEpilogue::SharedLoadIterator, EpilogueOutputOp, typename DefaultEpilogue::Padding, DefaultEpilogue::kFragmentsPerIteration, true, // IterationsUnroll typename MM1::OutputTileIteratorAccum // Read // iterator >; int col = blockN * MM1::Mma::Shape::kN; auto source_iter = createOutputAccumIter(col); auto dest_iter = gemm_kernel_utils::call_conditional< kIsLast, decltype(createOutputIter), decltype(createOutputAccumIter)>:: apply(createOutputIter, createOutputAccumIter, col); EpilogueOutputOp rescale(s_prime, out_rescale); Epilogue epilogue( shared_storage.epilogue_shared_storage(), thread_id(), warp_id(), lane_id()); epilogue(rescale, dest_iter, accum_o, source_iter); })); })); if (!kKeepOutputInRF) { __syncthreads(); } } } __syncthreads(); // we modify `m_prime` after } if (kKeepOutputInRF) { const bool kIsFirst = true; const bool kIsLast = true; using DefaultEpilogue = typename MM1::DefaultEpilogue; using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp; using ElementCompute = typename DefaultOp::ElementCompute; using EpilogueOutputOp = typename cutlass::epilogue::thread::MemoryEfficientAttentionNormalize< output_t, // output output_accum_t, // source DefaultOp::kCount, typename DefaultOp::ElementAccumulator, // accum output_accum_t, // compute kIsFirst, kIsLast, cutlass::Array>; using Epilogue = typename cutlass::epilogue::threadblock::EpiloguePipelined< typename DefaultEpilogue::Shape, typename MM1::Mma::Operator, DefaultEpilogue::kPartitionsK, typename MM1::OutputTileIterator, // destination typename DefaultEpilogue::AccumulatorFragmentIterator, typename DefaultEpilogue::WarpTileIterator, typename DefaultEpilogue::SharedLoadIterator, EpilogueOutputOp, typename DefaultEpilogue::Padding, DefaultEpilogue::kFragmentsPerIteration, true, // IterationsUnroll typename MM1::OutputTileIteratorAccum // source tile >; auto dest_iter = createOutputIter(0); EpilogueOutputOp rescale(s_prime, out_rescale); Epilogue epilogue( shared_storage.epilogue_shared_storage(), thread_id(), warp_id(), lane_id()); MM1::Mma::drain_cp_asyncs(); epilogue(rescale, dest_iter, accum_o); } // Next tile problem_visitor.advance(gridDim.x); __syncthreads(); // Don't start the next iteration until all threads are done using shared memory. } } template CUTLASS_DEVICE static void iterative_softmax( typename WarpIteratorC::Fragment& frag_o, // output so far typename WarpIteratorC::Fragment& frag, cutlass::Array& mi, cutlass::Array& m_prime, cutlass::Array& s_prime, cutlass::Array& out_rescale, cutlass::Array& addition_storage, int8_t lane_id, int8_t thread_id, int8_t warp_id, int max_col, bool is_first, typename WarpIteratorC::TensorCoord const& tile_offset, float scaling) { /* Iterates on the accumulator and corresponding position on result matrix (1) Update `mi[r]` to the max value of the row `r` (2) In a second iteration do the following: (a) accum <- exp(accum - mi) (b) m_prime <- exp(m_prime - mi) (c) s_prime <- s_prime * m_prime + sum(accum) All of this is done on registers, before we store all of this on shared memory for the next matmul with Value. */ using Fragment = typename WarpIteratorC::Fragment; using LambdaIterator = typename DefaultMmaAccumLambdaIterator< WarpIteratorC, accum_t, kThreadsPerWarp>::Iterator; // Convert to `accum_t` (rather than double) constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E static_assert(kQueriesPerBlock % kNumWarpsPerBlock == 0, ""); static constexpr int kLinesPerWarp = kQueriesPerBlock / kNumWarpsPerBlock; frag = cutlass::multiplies()(scaling * kLog2e, frag); auto lane_offset = LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset); // First update `mi` to the max per-row { accum_t max; LambdaIterator::iterateRows( lane_offset, [&](int accum_m) { max = -cutlass::platform::numeric_limits::infinity(); }, [&](int accum_m, int accum_n, int idx) { if (accum_n < max_col) { max = cutlass::fast_max(max, frag[idx]); } }, [&](int accum_m) { // Having 4x atomicMax seems faster than reduce within warp // first... atomicMaxFloat(&mi[accum_m], max); }); } // Make sure we all share the update values for `mi` __syncthreads(); // Doing this `exp` is quite expensive. Let's // split it across the warps bool restore_mi_to_minus_inf = false; if (lane_id < kLinesPerWarp) { int id = warp_id * kLinesPerWarp + lane_id; auto m_prime_id = m_prime[id]; auto mi_id = mi[id]; bool changed = m_prime_id < mi_id; // `false` if both are -inf if (changed) { auto m_prime_exp = exp2f(m_prime_id - mi_id); out_rescale[id] = m_prime_exp; s_prime[id] *= m_prime_exp; } else { // Only when bias is enabled, it's possible that all the first values // of attention are masked to `-inf`. In that case we want to avoid // `nan = exp2f(-inf - (-inf))` so we temporarily set `mi` to 0 if (kSupportsBias && mi_id == -cutlass::platform::numeric_limits::infinity()) { restore_mi_to_minus_inf = true; mi[id] = 0.0f; } out_rescale[id] = 1.0f; } } __syncthreads(); // Update output fragments if (kKeepOutputInRF && !is_first) { accum_t line_rescale; LambdaIterator::iterateRows( lane_offset, [&](int accum_m) { line_rescale = out_rescale[accum_m]; }, [&](int accum_m, int accum_n, int idx) { frag_o[idx] = frag_o[idx] * line_rescale; }, [&](int accum_m) {}); } // Update accum_m, accum_n, ... { accum_t mi_row, total_row; LambdaIterator::iterateRows( lane_offset, [&](int accum_m) { mi_row = mi[accum_m]; }, [&](int accum_m, int accum_n, int idx) { frag[idx] = (accum_n < max_col) ? exp2f(frag[idx] - mi_row) : accum_t(0.0); }, [&](int accum_m) {}); LambdaIterator::iterateRows( lane_offset, [&](int accum_m) { total_row = 0.0; }, [&](int accum_m, int accum_n, int idx) { total_row += frag[idx]; }, [&](int accum_m) { if (LambdaIterator::reduceSameRow( lane_id, total_row, [](accum_t a, accum_t b) { return a + b; })) { // NOTE: we could atomically add `total_row` to `s_prime`, but // it's faster (and deterministic) to avoid atomics here addition_storage [accum_m + kQueriesPerBlock * tile_offset.column()] = total_row; } }); } __syncthreads(); if (lane_id < kLinesPerWarp) { int id = warp_id * kLinesPerWarp + lane_id; accum_t total_row = s_prime[id]; if (restore_mi_to_minus_inf) { // Restore `mi`, see above when we set `restore_mi_to_minus_inf=true` mi[id] = -cutlass::platform::numeric_limits::infinity(); } else { m_prime[id] = mi[id]; } CUTLASS_PRAGMA_UNROLL for (int i = 0; i < MM0::MmaCore::WarpCount::kN; ++i) { total_row += addition_storage[id + kQueriesPerBlock * i]; } s_prime[id] = total_row; } } }; ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace kernel } // namespace gemm } // namespace cutlass /////////////////////////////////////////////////////////////////////////////////////////////////