/*************************************************************************************************** * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once #include #include #include // cute::true_type, cute::false_type #include #include //#include // Advanced optimizations // // cute::tuple is like std::tuple, with two differences. // // 1. It works on both host and device. // 2. Its template arguments must be semiregular types. // // Semiregular types are default constructible and copyable. // They include "value types" like int or float, // but do _not_ include references like int& or float&. // (See std::tie for an example of a tuple of references.) // // This is simplified over the implementations in std::, cuda::std::, and thrust:: by ignoring much of // the conversion SFINAE, special overloading, and avoiding cvref template types. // Furthermore, the empty base optimization (EBO) is MORE aggressive by avoiding // construction calls, and ignoring any need for unique element addresses. // // Over standard-conforming tuple implementations, this appears to accelerate compilation times by over 3x. namespace cute { namespace detail { // EBO stands for "empty base optimization." // We use this technique to ensure that cute::tuple // doesn't need to waste space storing any template arguments // of cute::tuple that have no data (like integral_constant). // Otherwise, cute::tuple would need to spend at least 1 byte // for each of its template arguments. // // EBO always "holds" a single value of type T. // N is like an array index that TupleBase uses // to access the desired tuple element. template ::value> struct EBO; template CUTE_HOST_DEVICE constexpr C findt(EBO const&) { return {}; } // Specialization for types T that have no data; // the "static tuple leaf." Valid T here include // integral_constant, Int, // and any other semiregular type // for which std::is_empty_v is true. template struct EBO { CUTE_HOST_DEVICE constexpr EBO() {} CUTE_HOST_DEVICE constexpr EBO(T const&) {} }; template CUTE_HOST_DEVICE constexpr T getv(EBO const&) { return {}; } // Specialization for types T that are not empty; // the "dynamic tuple leaf." Valid T here include int, // any other integral or floating-point type, // or any semiregular type for which std::is_empty_v is false. template struct EBO { CUTE_HOST_DEVICE constexpr EBO() : t_{} {} template CUTE_HOST_DEVICE constexpr EBO(U const& u) : t_{u} {} T t_; }; template CUTE_HOST_DEVICE constexpr T const& getv(EBO const& x) { return x.t_; } template CUTE_HOST_DEVICE constexpr T& getv(EBO& x) { return x.t_; } template CUTE_HOST_DEVICE constexpr T&& getv(EBO&& x) { return static_cast(x.t_); } template struct TupleBase; // Base class of cute::tuple. // It inherits from EBO for each (i, t) in (I..., T...). // The actual storage (for nonempty t) lives in the base classes. // index_sequence is a way to wrap up a sequence of zero or more // compile-time integer values in a single type. // We only ever use index_sequence<0, 1, ..., sizeof...(T)> in practice, // as the type alias TupleBase below indicates. template struct TupleBase, T...> : EBO... { CUTE_HOST_DEVICE constexpr TupleBase() {} template CUTE_HOST_DEVICE constexpr explicit TupleBase(U const&... u) : EBO(u)... {} template CUTE_HOST_DEVICE constexpr TupleBase(TupleBase, U...> const& u) : EBO(getv(static_cast const&>(u)))... {} }; } // end namespace detail // Attempting to use the following commented-out alias // in the declaration of `struct tuple` causes MSVC 2022 build errors. // //template //using TupleBase = detail::TupleBase, T...>; // This is the actual cute::tuple class. // The storage (if any) lives in TupleBase's EBO base classes. // // Inheriting from the above alias TupleBase // causes MSVC 2022 build errors when assigning one tuple to another: // // illegal member initialization: // 'TupleBase< /* template arguments */ >' is not a base or member // // Not using the alias or any kind of alias fixed the errors. // In summary: this is verbose as a work-around for MSVC build errors. template struct tuple : detail::TupleBase, T...> { CUTE_HOST_DEVICE constexpr tuple() {} template CUTE_HOST_DEVICE constexpr tuple(U const&... u) : detail::TupleBase, T...>(u...) {} template CUTE_HOST_DEVICE constexpr tuple(tuple const& u) : detail::TupleBase, T...>(static_cast, U...> const&>(u)) {} }; // // get for cute::tuple (just like std::get for std::tuple) // template CUTE_HOST_DEVICE constexpr decltype(auto) get(tuple const& t) noexcept { static_assert(I < sizeof...(T), "Index out of range"); return detail::getv(t); } template CUTE_HOST_DEVICE constexpr decltype(auto) get(tuple& t) noexcept { static_assert(I < sizeof...(T), "Index out of range"); return detail::getv(t); } template CUTE_HOST_DEVICE constexpr decltype(auto) get(tuple&& t) noexcept { static_assert(I < sizeof...(T), "Index out of range"); return detail::getv(static_cast&&>(t)); } // // find a type X within a cute::tuple // Requires X to be unique in tuple // Returns a static integer // template CUTE_HOST_DEVICE constexpr auto find(tuple const& t) noexcept { return detail::findt(t); } // // Custom is_tuple trait simply checks the existence of tuple_size // and assumes std::get(.), std::tuple_element // namespace detail { template auto has_tuple_size( T*) -> bool_constant<(0 <= tuple_size::value)>; auto has_tuple_size(...) -> false_type; } // end namespace detail template struct is_tuple : decltype(detail::has_tuple_size((T*)0)) {}; // // make_tuple (value-based implementation) // template CUTE_HOST_DEVICE constexpr tuple make_tuple(T const&... t) { return {t...}; } // // tuple_cat concatenates multiple cute::tuple into a single cute::tuple, // just like std::tuple_cat for std::tuple. // #if 0 // Original implementation namespace detail { template CUTE_HOST_DEVICE constexpr auto tuple_cat(T0 const& t0, T1 const& t1, index_sequence, index_sequence) { return cute::make_tuple(get(t0)..., get(t1)...); } } // end namespace detail CUTE_HOST_DEVICE constexpr tuple<> tuple_cat() { return {}; } template ::value)> CUTE_HOST_DEVICE constexpr Tuple const& tuple_cat(Tuple const& t) { return t; } template CUTE_HOST_DEVICE constexpr auto tuple_cat(T0 const& t0, T1 const& t1) { return detail::tuple_cat(t0, t1, make_index_sequence::value>{}, make_index_sequence::value>{}); } template CUTE_HOST_DEVICE constexpr auto tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, Ts const&... ts) { return cute::tuple_cat(cute::tuple_cat(t0,t1),t2,ts...); } #endif #if 1 // Extended implementation namespace detail { template CUTE_HOST_DEVICE constexpr auto tuple_cat(T0 const& t0, T1 const& t1, index_sequence, index_sequence) { return cute::make_tuple(get(t0)..., get(t1)...); } template CUTE_HOST_DEVICE constexpr auto tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, index_sequence, index_sequence, index_sequence) { return cute::make_tuple(get(t0)..., get(t1)..., get(t2)...); } template CUTE_HOST_DEVICE constexpr auto tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, index_sequence, index_sequence, index_sequence, index_sequence) { return cute::make_tuple(get(t0)..., get(t1)..., get(t2)..., get(t3)...); } template CUTE_HOST_DEVICE constexpr auto tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, T4 const& t4, index_sequence, index_sequence, index_sequence, index_sequence, index_sequence) { return cute::make_tuple(get(t0)..., get(t1)..., get(t2)..., get(t3)..., get(t4)...); } template struct tuple_cat_static; template struct tuple_cat_static, tuple> { using type = tuple; }; } // end namespace detail CUTE_HOST_DEVICE constexpr tuple<> tuple_cat() { return {}; } template ::value)> CUTE_HOST_DEVICE constexpr Tuple const& tuple_cat(Tuple const& t) { return t; } template CUTE_HOST_DEVICE constexpr auto tuple_cat(T0 const& t0, T1 const& t1) { if constexpr (is_static::value && is_static::value && is_tuple::value && is_tuple::value) { return typename detail::tuple_cat_static::type{}; } else { return detail::tuple_cat(t0, t1, make_index_sequence::value>{}, make_index_sequence::value>{}); } CUTE_GCC_UNREACHABLE; } template CUTE_HOST_DEVICE constexpr auto tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2) { return detail::tuple_cat(t0, t1, t2, make_index_sequence::value>{}, make_index_sequence::value>{}, make_index_sequence::value>{}); } template CUTE_HOST_DEVICE constexpr auto tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3) { return detail::tuple_cat(t0, t1, t2, t3, make_index_sequence::value>{}, make_index_sequence::value>{}, make_index_sequence::value>{}, make_index_sequence::value>{}); } template CUTE_HOST_DEVICE constexpr auto tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, T4 const& t4) { return detail::tuple_cat(t0, t1, t2, t3, t4, make_index_sequence::value>{}, make_index_sequence::value>{}, make_index_sequence::value>{}, make_index_sequence::value>{}, make_index_sequence::value>{}); } template CUTE_HOST_DEVICE constexpr auto tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, T4 const& t4, T5 const& t5, Ts const&... ts) { return cute::tuple_cat(cute::tuple_cat(t0,t1,t2,t3,t4), cute::tuple_cat(t5, ts...)); } #endif #if 0 // Outer-Inner indexing trick to concat all tuples at once namespace detail { template struct tuple_cat_helper { static constexpr cute::array ns = {Ns...}; static constexpr size_t total_size() { size_t sum = 0; for (size_t n : ns) sum += n; return sum; } static constexpr size_t total_size_ = total_size(); static constexpr auto values() { cute::array outer_inner = {}; size_t idx = 0; for (size_t i = 0; i < ns.size(); ++i) { for (size_t j = 0; j < ns[i]; ++j, ++idx) { outer_inner[idx][0] = i; outer_inner[idx][1] = j; } } return outer_inner; } static constexpr auto outer_inner_ = values(); using total_sequence = make_index_sequence; }; template CUTE_HOST_DEVICE constexpr auto tuple_cat(Tuple const& t, index_sequence) { return cute::make_tuple(get(get(t))...); } template CUTE_HOST_DEVICE constexpr auto tuple_cat(T0 const& t0, T1 const& t1, index_sequence, index_sequence) { return cute::make_tuple(get(t0)..., get(t1)...); } } // end namespace detail CUTE_HOST_DEVICE constexpr tuple<> tuple_cat() { return {}; } template ::value)> CUTE_HOST_DEVICE constexpr Tuple const& tuple_cat(Tuple const& t) { return t; } template CUTE_HOST_DEVICE constexpr auto tuple_cat(T0 const& t0, T1 const& t1) { return detail::tuple_cat(t0, t1, make_index_sequence::value>{}, make_index_sequence::value>{}); } template CUTE_HOST_DEVICE constexpr auto tuple_cat(Tuples const&... ts) { using Helper = detail::tuple_cat_helper::value...>; return detail::tuple_cat(cute::make_tuple(ts...), typename Helper::total_sequence{}); } #endif // // Equality operators // namespace detail { template CUTE_HOST_DEVICE constexpr auto equal_impl(TupleA const& a, TupleB const& b) { if constexpr (I == tuple_size::value) { return cute::true_type{}; // Terminal: TupleA is exhausted } else if constexpr (I == tuple_size::value) { return cute::false_type{}; // Terminal: TupleA is not exhausted, TupleB is exhausted } else { return (get(a) == get(b)) && equal_impl(a,b); } CUTE_GCC_UNREACHABLE; } } // end namespace detail template ::value && is_tuple::value)> CUTE_HOST_DEVICE constexpr auto operator==(TupleT const& t, TupleU const& u) { return detail::equal_impl<0>(t, u); } template ::value ^ is_tuple::value)> CUTE_HOST_DEVICE constexpr auto operator==(TupleT const& t, TupleU const& u) { return cute::false_type{}; } template ::value && is_tuple::value)> CUTE_HOST_DEVICE constexpr auto operator!=(TupleT const& t, TupleU const& u) { return !(t == u); } template ::value ^ is_tuple::value)> CUTE_HOST_DEVICE constexpr auto operator!=(TupleT const& t, TupleU const& u) { return cute::true_type{}; } // // Comparison operators // // // There are many ways to compare tuple of elements and because CuTe is built // on parameterizing layouts of coordinates, some comparisons are appropriate // only in certain cases. // -- lexicographical comparison [reverse, reflected, revref] // -- colexicographical comparison [reverse, reflected, revref] // -- element-wise comparison [any,all] // This can be very confusing. To avoid errors in selecting the appropriate // comparison, op<|op<=|op>|op>= are *not* implemented for cute::tuple. // // That said, see int_tuple for more explicitly named common comparison ops. // // // Display utilities // namespace detail { template CUTE_HOST_DEVICE void print_tuple(Tuple const& t, index_sequence, char s = '(', char e = ')') { using eat = int[]; using cute::print; (void) eat {(print(s), 0), (print(Is == 0 ? "" : ","), print(get(t)), 0)..., (print(e), 0)}; } #if !defined(__CUDACC_RTC__) template CUTE_HOST std::ostream& print_tuple_os(std::ostream& os, Tuple const& t, index_sequence, char s = '(', char e = ')') { using eat = int[]; (void) eat {(void(os << s), 0), (void(os << (Is == 0 ? "" : ",") << get(t)), 0)..., (void(os << e), 0)}; return os; } #endif // !defined(__CUDACC_RTC__) } // end namespace detail template ::value)> CUTE_HOST_DEVICE void print(Tuple const& t) { return detail::print_tuple(t, make_index_sequence::value>{}); } #if !defined(__CUDACC_RTC__) template ::value)> CUTE_HOST std::ostream& operator<<(std::ostream& os, Tuple const& t) { return detail::print_tuple_os(os, t, make_index_sequence::value>{}); } #endif // !defined(__CUDACC_RTC__) } // end namespace cute namespace CUTE_STL_NAMESPACE { template struct tuple_size> : CUTE_STL_NAMESPACE::integral_constant {}; template struct tuple_element> : CUTE_STL_NAMESPACE::tuple_element> {}; template struct tuple_size> : CUTE_STL_NAMESPACE::integral_constant {}; template struct tuple_element> : CUTE_STL_NAMESPACE::tuple_element> {}; } // end namespace CUTE_STL_NAMESPACE // // std compatibility // #ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD namespace std { #if defined(__CUDACC_RTC__) template struct tuple_size; template struct tuple_element; #endif template struct tuple_size> : CUTE_STL_NAMESPACE::integral_constant {}; template struct tuple_element> : CUTE_STL_NAMESPACE::tuple_element> {}; template struct tuple_size> : CUTE_STL_NAMESPACE::integral_constant {}; template struct tuple_element> : CUTE_STL_NAMESPACE::tuple_element> {}; } // end namepsace std #endif // CUTE_STL_NAMESPACE_IS_CUDA_STD