/*************************************************************************************************** * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once #include #include namespace cute { namespace detail { template struct supports_output_scaling { static constexpr bool value = false; }; template struct supports_output_scaling().accumulate_)>> { static constexpr bool value = true; }; } // end namespace detail /** * concept MMA_Traits * { * using ValTypeD = // Logical A-value type * using ValTypeA = // Logical B-value type * using ValTypeB = // Logical C-value type * using ValTypeC = // Logical D-value type (NOTE: Not used? Assumed == ValTypeD) * * using FrgTypeA = // A-type consumed by MMA (if ommitted, same as ValTypeA) * using FrgTypeB = // B_type consumed by MMA (if ommitted, same as ValTypeB) * using FrgTypeC = // C_type consumed by MMA (if ommitted, same as ValTypeC) * * using Shape_MNK = // Logical MxNxK shape of the MMA * * using ThrID = // Logical thread id (tid) -> tidx * * using ALayout = // (Logical thread id (tid), Logical value id (vid)) -> Flat MK-coord * using BLayout = // (Logical thread id (tid), Logical value id (vid)) -> Flat NK-coord * using CLayout = // (Logical thread id (tid), Logical value id (vid)) -> Flat MN-coord * }; */ template struct MMA_Traits { static_assert(sizeof(MMAOperation) == 0, "MMA_Traits not implemented for this MMA_Operation."); }; template struct MMA_Traits> { using ValTypeD = D; using ValTypeA = A; using ValTypeB = B; using ValTypeC = C; // Logical shape of the MMA using Shape_MNK = Shape<_1,_1,_1>; // Logical thread id (tid) -> tidx using ThrID = Layout<_1>; // (Logical thread id (tid), Logical value id (vid)) -> coord // (tid,vid) -> (m,k) using ALayout = Layout>; // (tid,vid) -> (n,k) using BLayout = Layout>; // (tid,vid) -> (m,n) using CLayout = Layout>; }; // // Generic mma_unpack for any MMA_Traits // template CUTE_HOST_DEVICE constexpr void mma_unpack(MMA_Traits const& traits, Tensor & D, Tensor const& A, Tensor const& B, Tensor const& C) { static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); // Register value types from the MMA_Operation register arrays using RegTypeD = typename remove_extent::type; using RegTypeA = typename remove_extent::type; using RegTypeB = typename remove_extent::type; using RegTypeC = typename remove_extent::type; using MMATraits = MMA_Traits; [[maybe_unused]] constexpr int RegNumD = extent::value; constexpr int RegNumA = extent::value; constexpr int RegNumB = extent::value; constexpr int RegNumC = extent::value; Tensor rA = recast(A); Tensor rB = recast(B); CUTE_STATIC_ASSERT_V(size(rA) == Int{}); CUTE_STATIC_ASSERT_V(size(rB) == Int{}); if constexpr (is_same::value) { static_assert(is_same::value, "GMMA C and D value_type must match."); static_assert(is_same::value, "GMMA C and D layouts must match."); // assert((void*)&C == (void*)&D); Tensor rC = recast(D); // NOTE: D and C are same, so use mutable D //CUTE_STATIC_ASSERT_V(size(rC) == Int{}); if constexpr (detail::supports_output_scaling::value) { detail::explode_with_d_scaling(MMA_Op::fma, rA, make_int_sequence{}, rB, make_int_sequence{}, rC, make_int_sequence{}, traits.accumulate_); } else { detail::explode(MMA_Op::fma, rA, make_int_sequence{}, rB, make_int_sequence{}, rC, make_int_sequence{}); } } else { Tensor rD = recast(D); Tensor rC = recast(C); CUTE_STATIC_ASSERT_V(size(rD) == Int{}); CUTE_STATIC_ASSERT_V(size(rC) == Int{}); if constexpr (detail::supports_output_scaling::value) { detail::explode_with_d_scaling(MMA_Op::fma, rD, make_int_sequence{}, rA, make_int_sequence{}, rB, make_int_sequence{}, rC, make_int_sequence{}, traits.accumulate_); } else { detail::explode(MMA_Op::fma, rD, make_int_sequence{}, rA, make_int_sequence{}, rB, make_int_sequence{}, rC, make_int_sequence{}); } } } // // Accept mutable temporaries // template CUTE_HOST_DEVICE constexpr void mma_unpack(MMA_Traits const& traits, Tensor && D, Tensor const& A, Tensor const& B, Tensor const& C) { mma_unpack(traits, D, A, B, C); } namespace detail { template struct FrgTypeA_or_Default { using type = typename X::ValTypeA; }; template struct FrgTypeA_or_Default> { using type = typename X::FrgTypeA; }; template struct FrgTypeB_or_Default { using type = typename X::ValTypeB; }; template struct FrgTypeB_or_Default> { using type = typename X::FrgTypeB; }; template struct FrgTypeC_or_Default { using type = typename X::ValTypeC; }; template struct FrgTypeC_or_Default> { using type = typename X::FrgTypeC; }; } // end namespace detail } // namespace cute