/*************************************************************************************************** * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once #include #include #include #include #include #include namespace cute { template struct ArithmeticTuple : tuple { template CUTE_HOST_DEVICE constexpr ArithmeticTuple(ArithmeticTuple const& u) : tuple(static_cast const&>(u)) {} template CUTE_HOST_DEVICE constexpr ArithmeticTuple(tuple const& u) : tuple(u) {} template CUTE_HOST_DEVICE constexpr ArithmeticTuple(U const&... u) : tuple(u...) {} }; template struct is_tuple> : true_type {}; template CUTE_HOST_DEVICE constexpr auto make_arithmetic_tuple(T const&... t) { return ArithmeticTuple(t...); } template CUTE_HOST_DEVICE constexpr auto as_arithmetic_tuple(tuple const& t) { return ArithmeticTuple(t); } template ::value)> CUTE_HOST_DEVICE constexpr T const& as_arithmetic_tuple(T const& t) { return t; } template CUTE_HOST_DEVICE constexpr auto as_arithmetic_tuple(ArithmeticTuple const& t) { return t; } // // Numeric operators // // Addition template CUTE_HOST_DEVICE constexpr auto operator+(ArithmeticTuple const& t, ArithmeticTuple const& u) { constexpr int R = cute::max(int(sizeof...(T)), int(sizeof...(U))); return transform_apply(append(t,Int<0>{}), append(u,Int<0>{}), plus{}, [](auto const&... a){ return make_arithmetic_tuple(a...); }); } template CUTE_HOST_DEVICE constexpr auto operator+(ArithmeticTuple const& t, tuple const& u) { constexpr int R = cute::max(int(sizeof...(T)), int(sizeof...(U))); return transform_apply(append(t,Int<0>{}), append(u,Int<0>{}), plus{}, [](auto const&... a){ return make_arithmetic_tuple(a...); }); } template CUTE_HOST_DEVICE constexpr auto operator+(tuple const& t, ArithmeticTuple const& u) { constexpr int R = cute::max(int(sizeof...(T)), int(sizeof...(U))); return transform_apply(append(t,Int<0>{}), append(u,Int<0>{}), plus{}, [](auto const&... a){ return make_arithmetic_tuple(a...); }); } // // Special cases // template CUTE_HOST_DEVICE constexpr ArithmeticTuple const& operator+(C, ArithmeticTuple const& u) { static_assert(t == 0, "Artihmetic tuple op+ error!"); return u; } template CUTE_HOST_DEVICE constexpr ArithmeticTuple const& operator+(ArithmeticTuple const& t, C) { static_assert(u == 0, "Artihmetic tuple op+ error!"); return t; } // // ArithmeticTupleIterator // template struct ArithmeticTupleIterator { using value_type = ArithTuple; using element_type = ArithTuple; using reference = ArithTuple; ArithTuple coord_; CUTE_HOST_DEVICE constexpr ArithmeticTupleIterator(ArithTuple const& coord = {}) : coord_(coord) {} CUTE_HOST_DEVICE constexpr ArithTuple const& operator*() const { return coord_; } template CUTE_HOST_DEVICE constexpr auto operator[](Coord const& c) const { return *(*this + c); } template CUTE_HOST_DEVICE constexpr auto operator+(Coord const& c) const { return ArithmeticTupleIterator(coord_ + c); } }; template CUTE_HOST_DEVICE constexpr auto make_inttuple_iter(Tuple const& t) { return ArithmeticTupleIterator(as_arithmetic_tuple(t)); } template CUTE_HOST_DEVICE constexpr auto make_inttuple_iter(T0 const& t0, T1 const& t1, Ts const&... ts) { return make_inttuple_iter(cute::make_tuple(t0, t1, ts...)); } // // ArithmeticTuple "basis" elements // A ScaledBasis is a (at least) rank-N+1 ArithmeticTuple: // (_0,_0,...,T,_0,...) // with value T in the Nth mode template struct ScaledBasis : private tuple { CUTE_HOST_DEVICE constexpr ScaledBasis(T const& t = {}) : tuple(t) {} CUTE_HOST_DEVICE constexpr decltype(auto) value() { return get<0>(static_cast &>(*this)); } CUTE_HOST_DEVICE constexpr decltype(auto) value() const { return get<0>(static_cast const&>(*this)); } CUTE_HOST_DEVICE static constexpr auto mode() { return Int{}; } }; template struct is_scaled_basis : false_type {}; template struct is_scaled_basis> : true_type {}; template struct is_integral> : true_type {}; // Get the scalar T out of a ScaledBasis template CUTE_HOST_DEVICE constexpr auto basis_value(SB const& e) { if constexpr (is_scaled_basis::value) { return basis_value(e.value()); } else { return e; } CUTE_GCC_UNREACHABLE; } // Apply the N... pack to another Tuple template CUTE_HOST_DEVICE constexpr auto basis_get(SB const& e, Tuple const& t) { if constexpr (is_scaled_basis::value) { return basis_get(e.value(), get(t)); } else { return t; } CUTE_GCC_UNREACHABLE; } namespace detail { template struct Basis; template <> struct Basis<> { using type = Int<1>; }; template struct Basis { using type = ScaledBasis::type, N>; }; } // end namespace detail // Shortcut for writing ScaledBasis, N0>, N1>, ...> // E<> := _1 // E<0> := (_1,_0,_0,...) // E<1> := (_0,_1,_0,...) // E<0,0> := ((_1,_0,_0,...),_0,_0,...) // E<0,1> := ((_0,_1,_0,...),_0,_0,...) // E<1,0> := (_0,(_1,_0,_0,...),_0,...) // E<1,1> := (_0,(_0,_1,_0,...),_0,...) template using E = typename detail::Basis::type; namespace detail { template CUTE_HOST_DEVICE constexpr auto as_arithmetic_tuple(T const& t, seq, seq) { return make_arithmetic_tuple((void(I),Int<0>{})..., t, (void(J),Int<0>{})...); } template CUTE_HOST_DEVICE constexpr auto as_arithmetic_tuple(ArithmeticTuple const& t, seq, seq) { return make_arithmetic_tuple(get(t)..., (void(J),Int<0>{})...); } } // end namespace detail // Turn a ScaledBases into a rank-M ArithmeticTuple // with N prefix 0s: (_0,_0,...N...,_0,T,_0,...,_0,_0) template CUTE_HOST_DEVICE constexpr auto as_arithmetic_tuple(ScaledBasis const& t) { static_assert(M > N, "Mismatched ranks"); return detail::as_arithmetic_tuple(t.value(), make_seq{}, make_seq{}); } // Turn a ScaledBases into a rank-N ArithmeticTuple // with N prefix 0s: (_0,_0,...N...,_0,T) template CUTE_HOST_DEVICE constexpr auto as_arithmetic_tuple(ScaledBasis const& t) { return as_arithmetic_tuple(t); } // Turn an ArithmeticTuple into a rank-M ArithmeticTuple // with postfix 0s: (t0,t1,t2,...,_0,...,_0,_0) template CUTE_HOST_DEVICE constexpr auto as_arithmetic_tuple(ArithmeticTuple const& t) { static_assert(M >= sizeof...(T), "Mismatched ranks"); return detail::as_arithmetic_tuple(t, make_seq{}, make_seq{}); } template CUTE_HOST_DEVICE constexpr auto safe_div(ScaledBasis const& b, U const& u) { auto t = safe_div(b.value(), u); return ScaledBasis{t}; } template CUTE_HOST_DEVICE constexpr auto shape_div(ScaledBasis const& b, U const& u) { auto t = shape_div(b.value(), u); return ScaledBasis{t}; } template CUTE_HOST_DEVICE constexpr auto make_basis_like(Shape const& shape) { if constexpr (is_integral::value) { return Int<1>{}; } else { // Generate bases for each rank of shape return transform(tuple_seq{}, shape, [](auto I, auto si) { // Generate bases for each rank of si and add an i on front using I_type = decltype(I); return transform_leaf(make_basis_like(si), [](auto e) { // MSVC has trouble capturing variables as constexpr, // so that they can be used as template arguments. // This is exactly what the code needs to do with i, unfortunately. // The work-around is to define i inside the inner lambda, // by using just the type from the enclosing scope. constexpr int i = I_type::value; return ScaledBasis{}; }); }); } CUTE_GCC_UNREACHABLE; } // Equality template CUTE_HOST_DEVICE constexpr auto operator==(ScaledBasis const& t, ScaledBasis const& u) { return bool_constant{} && t.value() == u.value(); } // Not equal to anything else template CUTE_HOST_DEVICE constexpr false_type operator==(ScaledBasis const&, U const&) { return {}; } template CUTE_HOST_DEVICE constexpr false_type operator==(T const&, ScaledBasis const&) { return {}; } // Abs template CUTE_HOST_DEVICE constexpr auto abs(ScaledBasis const& e) { return ScaledBasis{abs(e.value())}; } // Multiplication template CUTE_HOST_DEVICE constexpr auto operator*(A const& a, ScaledBasis const& e) { auto r = a * e.value(); return ScaledBasis{r}; } template CUTE_HOST_DEVICE constexpr auto operator*(ScaledBasis const& e, B const& b) { auto r = e.value() * b; return ScaledBasis{r}; } // Addition template CUTE_HOST_DEVICE constexpr auto operator+(ScaledBasis const& t, ArithmeticTuple const& u) { constexpr int R = cute::max(N+1, int(sizeof...(U))); return as_arithmetic_tuple(t) + as_arithmetic_tuple(u); } template CUTE_HOST_DEVICE constexpr auto operator+(ArithmeticTuple const& t, ScaledBasis const& u) { constexpr int R = cute::max(int(sizeof...(T)), M+1); return as_arithmetic_tuple(t) + as_arithmetic_tuple(u); } template CUTE_HOST_DEVICE constexpr auto operator+(ScaledBasis const& t, tuple const& u) { constexpr int R = cute::max(N+1, int(sizeof...(U))); return as_arithmetic_tuple(t) + as_arithmetic_tuple(u); } template CUTE_HOST_DEVICE constexpr auto operator+(tuple const& t, ScaledBasis const& u) { constexpr int R = cute::max(int(sizeof...(T)), M+1); return as_arithmetic_tuple(t) + as_arithmetic_tuple(u); } template CUTE_HOST_DEVICE constexpr auto operator+(ScaledBasis const& t, ScaledBasis const& u) { constexpr int R = cute::max(N+1,M+1); return as_arithmetic_tuple(t) + as_arithmetic_tuple(u); } template CUTE_HOST_DEVICE constexpr auto operator+(C, ScaledBasis const& u) { static_assert(t == 0, "ScaledBasis op+ error!"); return u; } template CUTE_HOST_DEVICE constexpr auto operator+(ScaledBasis const& t, C) { static_assert(u == 0, "ScaledBasis op+ error!"); return t; } // // Display utilities // template CUTE_HOST_DEVICE void print(ArithmeticTupleIterator const& iter) { printf("ArithTuple"); print(iter.coord_); } template CUTE_HOST_DEVICE void print(ScaledBasis const& e) { print(e.value()); printf("@%d", N); } #if !defined(__CUDACC_RTC__) template CUTE_HOST std::ostream& operator<<(std::ostream& os, ArithmeticTupleIterator const& iter) { return os << "ArithTuple" << iter.coord_; } template CUTE_HOST std::ostream& operator<<(std::ostream& os, ScaledBasis const& e) { return os << e.value() << "@" << N; } #endif } // end namespace cute namespace CUTE_STL_NAMESPACE { template struct tuple_size> : CUTE_STL_NAMESPACE::integral_constant {}; template struct tuple_element> : CUTE_STL_NAMESPACE::tuple_element> {}; template struct tuple_size> : CUTE_STL_NAMESPACE::integral_constant {}; template struct tuple_element> : CUTE_STL_NAMESPACE::tuple_element> {}; } // end namespace CUTE_STL_NAMESPACE #ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD namespace std { #if defined(__CUDACC_RTC__) template struct tuple_size; template struct tuple_element; #endif template struct tuple_size> : CUTE_STL_NAMESPACE::integral_constant {}; template struct tuple_element> : CUTE_STL_NAMESPACE::tuple_element> {}; template struct tuple_size> : CUTE_STL_NAMESPACE::integral_constant {}; template struct tuple_element> : CUTE_STL_NAMESPACE::tuple_element> {}; } // end namespace std #endif // CUTE_STL_NAMESPACE_IS_CUDA_STD