/*************************************************************************************************** * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file \brief Functor performing elementwise operations used by epilogues. */ #pragma once #include "cutlass/cutlass.h" #include "cutlass/gemm/dispatch_policy.hpp" #include "cutlass/epilogue/collective/detail.hpp" #include "cute/tensor.hpp" #include "cute/numeric/int.hpp" #include "gather_tensor.hpp" namespace cutlass::epilogue::collective { /// Applies an element wise operation to all elements within the fragment /// and scatter-writes them out to destination storage. /// GatherC and ScatterD are types of user-defined functions that apply the /// transoformation of the strided coordinate (e.g. through an index array). template < class StrideC_, class StrideD_, class ThreadEpilogueOp_, class EpilogueSchedule_, class GatherC_, class ScatterD_ > class EpilogueGatherScatter { public: // // Type Aliases // using EpilogueSchedule = EpilogueSchedule_; // derived types of output thread level operator using ThreadEpilogueOp = ThreadEpilogueOp_; using ElementOutput = typename ThreadEpilogueOp::ElementOutput; using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; using ElementCompute = typename ThreadEpilogueOp::ElementCompute; using ElementScalar = ElementCompute; using ElementC = typename ThreadEpilogueOp::ElementC; using StrideC = StrideC_; using ElementD = typename ThreadEpilogueOp::ElementD; using StrideD = StrideD_; // Every epilogue needs these two GmemTiledCopy{C,D} aliases. // If you don't know what they should be, just use void. using GmemTiledCopyC = void; using GmemTiledCopyD = void; using GatherC = GatherC_; using ScatterD = ScatterD_; static const int kOutputAlignment = ThreadEpilogueOp::kCount; using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); struct SharedStorage { }; // Host side epilogue arguments struct Arguments { typename ThreadEpilogueOp::Params thread_params{}; ElementC const* ptr_C = nullptr; StrideC dC{}; ElementD* ptr_D = nullptr; StrideD dD{}; GatherC gather_C{}; ScatterD scatter_D{}; }; // Device side epilogue params using Params = Arguments; // // Methods // template static constexpr Params to_underlying_arguments( [[maybe_unused]] ProblemShape const& _, Arguments const& args, [[maybe_unused]] void* workspace) { return args; } template CUTLASS_HOST_DEVICE static bool can_implement( [[maybe_unused]] ProblemShape const& problem_shape, [[maybe_unused]] Arguments const& args) { return true; } CUTLASS_HOST_DEVICE EpilogueGatherScatter(Params const& params_) : params(params_) { } template< class ProblemShapeMNKL, class BlockShapeMNK, class BlockCoordMNKL, class FrgEngine, class FrgLayout, class TiledMma, class ResidueMNK > CUTLASS_DEVICE void operator()( ProblemShapeMNKL problem_shape_mnkl, BlockShapeMNK blk_shape_MNK, BlockCoordMNKL blk_coord_mnkl, cute::Tensor const& accumulators, TiledMma tiled_mma, ResidueMNK residue_mnk, int thread_idx, char* smem_buf) { using namespace cute; using X = Underscore; static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); static_assert(is_static::value, "ThreadBlock tile shape must be static"); static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); (void) smem_buf; ThreadEpilogueOp epilogue_op{params.thread_params}; // Separate out problem shape for convenience auto M = get<0>(problem_shape_mnkl); auto N = get<1>(problem_shape_mnkl); auto L = get<3>(problem_shape_mnkl); auto stride_c = detail::get_epilogue_stride(params.dC); auto stride_d = detail::get_epilogue_stride(params.dD); // Represent the full output tensor Tensor mC_mnl = make_gather_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), stride_c, params.gather_C); // (m,n,l) Tensor mD_mnl = make_gather_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), stride_d, params.scatter_D); // (m,n,l) Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) // Slice to get the tile this CTA is responsible for auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) // Partition source and destination tiles to match the accumulator partitioning auto thr_mma = tiled_mma.get_thread_slice(thread_idx); Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N) Tensor tCgC = thr_mma.partition_C(gC); // (VEC,THR_M,THR_N) static_assert(is_static::value, "Accumulator layout must be static"); CUTE_STATIC_ASSERT_V(size(tCgC) == size(tCgD), "Source and destination must have the same number of elements."); CUTE_STATIC_ASSERT_V(size(tCgD) == size(accumulators), "Accumulator count must have the same destination element count."); // Make an identity coordinate tensor for predicating our output MN tile auto cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD)))); Tensor tCcD = thr_mma.partition_C(cD); // source is needed if (epilogue_op.is_source_needed()) { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(accumulators); ++i) { if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { tCgD(i) = epilogue_op(accumulators(i), tCgC(i)); } } } // source is not needed, avoid load else { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size(accumulators); ++i) { if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { tCgD(i) = epilogue_op(accumulators(i)); } } } } private: Params params; }; } // namespace cutlass::epilogue::collective