/*************************************************************************************************** * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once #include #include #include #include #include #include /// @file tuple_algorithms.hpp /// @brief Common algorithms on (hierarchical) tuples /// /// Code guidelines and style preferences: /// /// For perfect forwarding, don't use std::forward, because it may not /// be defined in device code when compiling with NVRTC. Instead, use /// `static_cast(parameter_name)`. /// /// CuTe generally does not bother forwarding functions, as /// reference-qualified member functions are rare in this code base. /// /// Throughout CUTLASS, cute::make_tuple always needs to be called /// namespace-qualified, EVEN If inside the cute namespace and/or in /// scope of a "using namespace cute" declaration. Otherwise, the /// compiler may select std::make_tuple instead of cute::make_tuple, /// due to argument-dependent lookup. Two problems may result from /// that. /// /// 1. Functions have an unexpected return type (std::tuple instead of /// cute::tuple), so functions that take cute::tuple parameters /// fail to compile (generally inside functions that have template /// parameters expected to be cute::tuple). /// /// 2. std::tuple does not have the required __host__ __device__ /// markings, so the CUDA compiler complains if you use it in /// device code. /// /// cute::make_tuple will occur more often than std::make_tuple would /// in modern C++ code, because cute::tuple's design deprioritizes /// correct operation of CTAD (constructor template argument /// deduction) in favor of implementation simplicity. namespace cute { // // Apply (Unpack) // (t, f) => f(t_0,t_1,...,t_n) // namespace detail { template CUTE_HOST_DEVICE constexpr auto apply(T&& t, F&& f, seq) { return f(get(static_cast(t))...); } } // end namespace detail template CUTE_HOST_DEVICE constexpr auto apply(T&& t, F&& f) { return detail::apply(static_cast(t), f, tuple_seq{}); } // // Transform Apply // (t, f, g) => g(f(t_0),f(t_1),...) // namespace detail { template CUTE_HOST_DEVICE constexpr auto tapply(T&& t, F&& f, G&& g, seq) { return g(f(get(static_cast(t)))...); } template CUTE_HOST_DEVICE constexpr auto tapply(T0&& t0, T1&& t1, F&& f, G&& g, seq) { return g(f(get(static_cast(t0)), get(static_cast(t1)))...); } template CUTE_HOST_DEVICE constexpr auto tapply(T0&& t0, T1&& t1, T2&& t2, F&& f, G&& g, seq) { return g(f(get(static_cast(t0)), get(static_cast(t1)), get(static_cast(t2)))...); } } // end namespace detail template CUTE_HOST_DEVICE constexpr auto transform_apply(T&& t, F&& f, G&& g) { if constexpr (is_tuple>::value) { return detail::tapply(static_cast(t), f, g, tuple_seq{}); } else { return g(f(static_cast(t))); } } template CUTE_HOST_DEVICE constexpr auto transform_apply(T0&& t0, T1&& t1, F&& f, G&& g) { if constexpr (is_tuple>::value) { return detail::tapply(static_cast(t0), static_cast(t1), f, g, tuple_seq{}); } else { return g(f(static_cast(t0), static_cast(t1))); } } template CUTE_HOST_DEVICE constexpr auto transform_apply(T0&& t0, T1&& t1, T2&& t2, F&& f, G&& g) { if constexpr (is_tuple>::value) { return detail::tapply(static_cast(t0), static_cast(t1), static_cast(t2), f, g, tuple_seq{}); } else { return g(f(static_cast(t0), static_cast(t1), static_cast(t2))); } } // // For Each // (t, f) => f(t_0),f(t_1),...,f(t_n) // template CUTE_HOST_DEVICE constexpr void for_each(T&& t, F&& f) { if constexpr (is_tuple>::value) { return detail::apply(t, [&](auto&&... a) { (f(static_cast(a)), ...); }, tuple_seq{}); } else { return f(static_cast(t)); } CUTE_GCC_UNREACHABLE; } template CUTE_HOST_DEVICE constexpr auto for_each_leaf(T&& t, F&& f) { if constexpr (is_tuple>::value) { return detail::apply(static_cast(t), [&](auto&&... a){ return (for_each_leaf(static_cast(a), f), ...); }, tuple_seq{}); } else { return f(static_cast(t)); } CUTE_GCC_UNREACHABLE; } // // For Sequence // (s, t, f) => (f(t[s_0]),f(t[s_1]),...,f(t[s_n])) // namespace detail { template CUTE_HOST_DEVICE constexpr void for_sequence(seq const&, F&& f) { (f(Int{}), ...); } }; // end namespace detail template CUTE_HOST_DEVICE constexpr void for_sequence(seq const& s, T&& t, F&& f) { detail::for_sequence(s, [&](auto&& i){ f(get::value>(static_cast(t))); }); } template CUTE_HOST_DEVICE constexpr void for_sequence(T&& t, F&& f) { for_sequence(make_seq{}, static_cast(t), static_cast(f)); } // // Transform // (t, f) => (f(t_0),f(t_1),...,f(t_n)) // template CUTE_HOST_DEVICE constexpr auto transform(T const& t, F&& f) { if constexpr (is_tuple::value) { return detail::tapply(t, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq{}); } else { return f(t); } CUTE_GCC_UNREACHABLE; } template CUTE_HOST_DEVICE constexpr auto transform(T0 const& t0, T1 const& t1, F&& f) { if constexpr (is_tuple::value) { static_assert(tuple_size::value == tuple_size::value, "Mismatched tuple_size"); return detail::tapply(t0, t1, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq{}); } else { return f(t0, t1); } CUTE_GCC_UNREACHABLE; } template CUTE_HOST_DEVICE constexpr auto transform(T0 const& t0, T1 const& t1, T2 const& t2, F&& f) { if constexpr (is_tuple::value) { static_assert(tuple_size::value == tuple_size::value, "Mismatched tuple_size"); static_assert(tuple_size::value == tuple_size::value, "Mismatched tuple_size"); return detail::tapply(t0, t1, t2, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq{}); } else { return f(t0, t1, t2); } CUTE_GCC_UNREACHABLE; } template CUTE_HOST_DEVICE constexpr auto transform_leaf(T const& t, F&& f) { if constexpr (is_tuple::value) { return transform(t, [&](auto const& a) { return transform_leaf(a, f); }); } else { return f(t); } CUTE_GCC_UNREACHABLE; } template CUTE_HOST_DEVICE constexpr auto transform_leaf(T0 const& t0, T1 const& t1, F&& f) { if constexpr (is_tuple::value) { return transform(t0, t1, [&](auto const& a, auto const& b) { return transform_leaf(a, b, f); }); } else { return f(t0, t1); } CUTE_GCC_UNREACHABLE; } // // find and find_if // namespace detail { template CUTE_HOST_DEVICE constexpr auto find_if(T const& t, F&& f, seq) { if constexpr (decltype(f(get(t)))::value) { return cute::C{}; } else if constexpr (sizeof...(Is) == 0) { return cute::C{}; } else { return find_if(t, f, seq{}); } CUTE_GCC_UNREACHABLE; } } // end namespace detail template CUTE_HOST_DEVICE constexpr auto find_if(T const& t, F&& f) { if constexpr (is_tuple::value) { return detail::find_if(t, f, tuple_seq{}); } else { return cute::C{}; } CUTE_GCC_UNREACHABLE; } template CUTE_HOST_DEVICE constexpr auto find(T const& t, X const& x) { return find_if(t, [&](auto const& v) { return v == x; }); // This should always return a static true/false } template CUTE_HOST_DEVICE constexpr auto any_of(T const& t, F&& f) { if constexpr (is_tuple::value) { return detail::apply(cute::transform(t, f), [&] (auto const&... a) { return (false_type{} || ... || a); }, tuple_seq{}); } else { return f(t); } CUTE_GCC_UNREACHABLE; } template CUTE_HOST_DEVICE constexpr auto all_of(T const& t, F&& f) { if constexpr (is_tuple::value) { return detail::apply(t, [&] (auto const&... a) { return (true_type{} && ... && f(a)); }, tuple_seq{}); } else { return f(t); } CUTE_GCC_UNREACHABLE; } template CUTE_HOST_DEVICE constexpr auto none_of(T const& t, F&& f) { return not any_of(t, f); } // // Filter // (t, f) => // template CUTE_HOST_DEVICE constexpr auto filter_tuple(T const& t, F&& f) { return transform_apply(t, f, [](auto const&... a) { return cute::tuple_cat(a...); }); } template CUTE_HOST_DEVICE constexpr auto filter_tuple(T0 const& t0, T1 const& t1, F&& f) { return transform_apply(t0, t1, f, [](auto const&... a) { return cute::tuple_cat(a...); }); } template CUTE_HOST_DEVICE constexpr auto filter_tuple(T0 const& t0, T1 const& t1, T2 const& t2, F&& f) { return transform_apply(t0, t1, t2, f, [](auto const&... a) { return cute::tuple_cat(a...); }); } // // Fold (Reduce, Accumulate) // (t, v, f) => f(...f(f(v,t_0),t_1),...,t_n) // namespace detail { // This impl compiles much faster than cute::apply and variadic args template CUTE_HOST_DEVICE constexpr decltype(auto) fold(T&& t, V&& v, F&& f, seq<>) { return static_cast(v); } template CUTE_HOST_DEVICE constexpr decltype(auto) fold(T&& t, V&& v, F&& f, seq) { if constexpr (sizeof...(Is) == 0) { return f(static_cast(v), get(static_cast(t))); } else { return fold(static_cast(t), f(static_cast(v), get(static_cast(t))), f, seq{}); } CUTE_GCC_UNREACHABLE; } } // end namespace detail template CUTE_HOST_DEVICE constexpr auto fold(T&& t, V&& v, F&& f) { if constexpr (is_tuple>::value) { return detail::fold(static_cast(t), static_cast(v), f, tuple_seq{}); } else { return f(static_cast(v), static_cast(t)); } CUTE_GCC_UNREACHABLE; } template CUTE_HOST_DEVICE constexpr decltype(auto) fold_first(T&& t, F&& f) { if constexpr (is_tuple>::value) { return detail::fold(static_cast(t), get<0>(static_cast(t)), f, make_range<1,tuple_size>::value>{}); } else { return static_cast(t); } CUTE_GCC_UNREACHABLE; } // // front, back, take, select, unwrap // // Get the first non-tuple element in a hierarchical tuple template CUTE_HOST_DEVICE constexpr decltype(auto) front(T&& t) { if constexpr (is_tuple>::value) { return front(get<0>(static_cast(t))); } else { return static_cast(t); } CUTE_GCC_UNREACHABLE; } // Get the last non-tuple element in a hierarchical tuple template CUTE_HOST_DEVICE constexpr decltype(auto) back(T&& t) { if constexpr (is_tuple>::value) { constexpr int N = tuple_size>::value; // MSVC needs a bit of extra help here deducing return types. // We help it by peeling off the nonrecursive case a level "early." if constexpr (! is_tuple(static_cast(t)))>>::value) { return get(static_cast(t)); } else { return back(get(static_cast(t))); } } else { return static_cast(t); } CUTE_GCC_UNREACHABLE; } // Takes the elements in the range [B,E) template CUTE_HOST_DEVICE constexpr auto take(T const& t) { return detail::apply(t, [](auto const&... a) { return cute::make_tuple(a...); }, make_range{}); } // // Select tuple elements with given indices. // template CUTE_HOST_DEVICE constexpr auto select(T const & t) { return cute::make_tuple(get(t)...); } template CUTE_HOST_DEVICE constexpr auto select(T const & t, Indices const & indices) { if constexpr (is_tuple::value) { return cute::transform(indices, [&t](auto i) { return select(t, i); }); } else { static_assert(is_static::value, "Order must be static"); return get(t); } } // Wrap non-tuples into rank-1 tuples or forward template CUTE_HOST_DEVICE constexpr auto wrap(T const& t) { if constexpr (is_tuple::value) { return t; } else { return cute::make_tuple(t); } CUTE_GCC_UNREACHABLE; } // Unwrap rank-1 tuples until we're left with a rank>1 tuple or a non-tuple template CUTE_HOST_DEVICE constexpr auto unwrap(T const& t) { if constexpr (is_tuple::value) { if constexpr (tuple_size::value == 1) { return unwrap(get<0>(t)); } else { return t; } } else { return t; } CUTE_GCC_UNREACHABLE; } // // Flatten and Unflatten // template struct is_flat : true_type {}; template struct is_flat> : bool_constant<(true && ... && (not is_tuple::value))> {}; // Flatten a hierarchical tuple to a tuple of depth one // and wrap non-tuples into a rank-1 tuple. template CUTE_HOST_DEVICE constexpr auto flatten_to_tuple(T const& t) { if constexpr (is_tuple::value) { if constexpr (is_flat::value) { // Shortcut for perf return t; } else { return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); }); } } else { return cute::make_tuple(t); } CUTE_GCC_UNREACHABLE; } // Flatten a hierarchical tuple to a tuple of depth one // and leave non-tuple untouched. template CUTE_HOST_DEVICE constexpr auto flatten(T const& t) { if constexpr (is_tuple::value) { if constexpr (is_flat::value) { // Shortcut for perf return t; } else { return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); }); } } else { return t; } CUTE_GCC_UNREACHABLE; } namespace detail { template CUTE_HOST_DEVICE constexpr auto unflatten_impl(FlatTuple const& flat_tuple, TargetProfile const& target_profile) { if constexpr (is_tuple::value) { return fold(target_profile, cute::make_tuple(cute::make_tuple(), flat_tuple), [](auto const& v, auto const& t) { auto [result, remaining_tuple] = v; auto [sub_result, sub_tuple] = unflatten_impl(remaining_tuple, t); return cute::make_tuple(append(result, sub_result), sub_tuple); }); } else { return cute::make_tuple(get<0>(flat_tuple), take<1, decltype(rank(flat_tuple))::value>(flat_tuple)); } CUTE_GCC_UNREACHABLE; } } // end namespace detail // Unflatten a flat tuple into a hierarchical tuple // @pre flatten(@a flat_tuple) == @a flat_tuple // @pre rank(flatten(@a target_profile)) == rank(@a flat_tuple) // @post congruent(@a result, @a target_profile) // @post flatten(@a result) == @a flat_tuple template CUTE_HOST_DEVICE constexpr auto unflatten(FlatTuple const& flat_tuple, TargetProfile const& target_profile) { auto [unflatten_tuple, flat_remainder] = detail::unflatten_impl(flat_tuple, target_profile); CUTE_STATIC_ASSERT_V(rank(flat_remainder) == Int<0>{}); return unflatten_tuple; } // // insert and remove and replace // namespace detail { // Shortcut around cute::tuple_cat for common insert/remove/repeat cases template CUTE_HOST_DEVICE constexpr auto construct(T const& t, X const& x, seq, seq, seq) { return cute::make_tuple(get(t)..., (void(J),x)..., get(t)...); } } // end namespace detail // Insert x into the Nth position of the tuple template CUTE_HOST_DEVICE constexpr auto insert(T const& t, X const& x) { return detail::construct(t, x, make_seq{}, seq<0>{}, make_range::value>{}); } // Remove the Nth element of the tuple template CUTE_HOST_DEVICE constexpr auto remove(T const& t) { return detail::construct(t, 0, make_seq{}, seq<>{}, make_range::value>{}); } // Replace the Nth element of the tuple with x template CUTE_HOST_DEVICE constexpr auto replace(T const& t, X const& x) { return detail::construct(t, x, make_seq{}, seq<0>{}, make_range::value>{}); } // Replace the first element of the tuple with x template CUTE_HOST_DEVICE constexpr auto replace_front(T const& t, X const& x) { if constexpr (is_tuple::value) { return detail::construct(t, x, seq<>{}, seq<0>{}, make_range<1,tuple_size::value>{}); } else { return x; } CUTE_GCC_UNREACHABLE; } // Replace the last element of the tuple with x template CUTE_HOST_DEVICE constexpr auto replace_back(T const& t, X const& x) { if constexpr (is_tuple::value) { return detail::construct(t, x, make_seq::value-1>{}, seq<0>{}, seq<>{}); } else { return x; } CUTE_GCC_UNREACHABLE; } // // Make a tuple of Xs of tuple_size N // template CUTE_HOST_DEVICE constexpr auto tuple_repeat(X const& x) { return detail::construct(0, x, seq<>{}, make_seq{}, seq<>{}); } // // Make repeated Xs of rank N // template CUTE_HOST_DEVICE constexpr auto repeat(X const& x) { if constexpr (N == 1) { return x; } else { return detail::construct(0, x, seq<>{}, make_seq{}, seq<>{}); } CUTE_GCC_UNREACHABLE; } // // Make a tuple of Xs the same profile as tuple T // template CUTE_HOST_DEVICE constexpr auto repeat_like(T const& t, X const& x) { if constexpr (is_tuple::value) { return transform(t, [&](auto const& a) { return repeat_like(a,x); }); } else { return x; } CUTE_GCC_UNREACHABLE; } // Group the elements [B,E) of a T into a single element // e.g. group<2,4>(T<_1,_2,_3,_4,_5,_6>{}) // => T<_1,_2,T<_3,_4>,_5,_6>{} template CUTE_HOST_DEVICE constexpr auto group(T const& t) { if constexpr (not is_tuple::value) { if constexpr (E == -1) { return group(t); } else { return detail::construct(t, take(t), make_seq{}, make_seq<(B < E)>{}, make_range{}); } } else if constexpr (E == -1) { return group::value>(t); } else if constexpr (B <= E) { return detail::construct(t, take(t), make_seq{}, make_seq<(B < E)>{}, make_range::value>{}); } else { static_assert(B <= E); } CUTE_GCC_UNREACHABLE; } // // Extend a T to rank N by appending/prepending an element // template CUTE_HOST_DEVICE constexpr auto append(T const& a, X const& x) { if constexpr (is_tuple::value) { if constexpr (N == tuple_size::value) { return a; } else { static_assert(N > tuple_size::value); return detail::construct(a, x, make_seq::value>{}, make_seq::value>{}, seq<>{}); } } else { if constexpr (N == 1) { return a; } else { return detail::construct(cute::make_tuple(a), x, seq<0>{}, make_seq{}, seq<>{}); } } CUTE_GCC_UNREACHABLE; } template CUTE_HOST_DEVICE constexpr auto append(T const& a, X const& x) { if constexpr (is_tuple::value) { return detail::construct(a, x, make_seq::value>{}, seq<0>{}, seq<>{}); } else { return cute::make_tuple(a, x); } CUTE_GCC_UNREACHABLE; } template CUTE_HOST_DEVICE constexpr auto prepend(T const& a, X const& x) { if constexpr (is_tuple::value) { if constexpr (N == tuple_size::value) { return a; } else { static_assert(N > tuple_size::value); return detail::construct(a, x, seq<>{}, make_seq::value>{}, make_seq::value>{}); } } else { if constexpr (N == 1) { return a; } else { static_assert(N > 1); return detail::construct(cute::make_tuple(a), x, seq<>{}, make_seq{}, seq<0>{}); } } CUTE_GCC_UNREACHABLE; } template CUTE_HOST_DEVICE constexpr auto prepend(T const& a, X const& x) { if constexpr (is_tuple::value) { return detail::construct(a, x, seq<>{}, seq<0>{}, make_seq::value>{}); } else { return cute::make_tuple(x, a); } CUTE_GCC_UNREACHABLE; } // // Inclusive scan (prefix sum) // namespace detail { template CUTE_HOST_DEVICE constexpr auto iscan(T const& t, V const& v, F&& f, seq) { // Apply the function to v and the element at I auto v_next = f(v, get(t)); // Replace I with v_next auto t_next = replace(t, v_next); #if 0 std::cout << "ISCAN i" << I << std::endl; std::cout << " t " << t << std::endl; std::cout << " i " << v << std::endl; std::cout << " f(i,t) " << v_next << std::endl; std::cout << " t_n " << t_next << std::endl; #endif if constexpr (sizeof...(Is) == 0) { return t_next; } else { return iscan(t_next, v_next, f, seq{}); } CUTE_GCC_UNREACHABLE; } } // end namespace detail template CUTE_HOST_DEVICE constexpr auto iscan(T const& t, V const& v, F&& f) { return detail::iscan(t, v, f, tuple_seq{}); } // // Exclusive scan (prefix sum) // namespace detail { template CUTE_HOST_DEVICE constexpr auto escan(T const& t, V const& v, F&& f, seq) { if constexpr (sizeof...(Is) == 0) { // Replace I with v return replace(t, v); } else { // Apply the function to v and the element at I auto v_next = f(v, get(t)); // Replace I with v auto t_next = replace(t, v); #if 0 std::cout << "ESCAN i" << I << std::endl; std::cout << " t " << t << std::endl; std::cout << " i " << v << std::endl; std::cout << " f(i,t) " << v_next << std::endl; std::cout << " t_n " << t_next << std::endl; #endif // Recurse return escan(t_next, v_next, f, seq{}); } CUTE_GCC_UNREACHABLE; } } // end namespace detail template CUTE_HOST_DEVICE constexpr auto escan(T const& t, V const& v, F&& f) { return detail::escan(t, v, f, tuple_seq{}); } // // Zip (Transpose) // // Take ((a,b,c,...),(x,y,z,...),...) rank-R0 x rank-R1 input // to produce ((a,x,...),(b,y,...),(c,z,...),...) rank-R1 x rank-R0 output namespace detail { template CUTE_HOST_DEVICE constexpr auto zip_(Ts const&... ts) { return cute::make_tuple(get(ts)...); } template CUTE_HOST_DEVICE constexpr auto zip(T const& t, seq, seq) { static_assert(conjunction>::value == tuple_size>::value>...>::value, "Mismatched Ranks"); return cute::make_tuple(zip_(get(t)...)...); } } // end namespace detail template CUTE_HOST_DEVICE constexpr auto zip(T const& t) { if constexpr (is_tuple::value) { if constexpr (is_tuple>::value) { return detail::zip(t, tuple_seq{}, tuple_seq>{}); } else { return cute::make_tuple(t); } } else { return t; } CUTE_GCC_UNREACHABLE; } // Convenient to pass them in separately template CUTE_HOST_DEVICE constexpr auto zip(T0 const& t0, T1 const& t1, Ts const&... ts) { return zip(cute::make_tuple(t0, t1, ts...)); } // // zip2_by -- A guided zip for rank-2 tuples // Take a tuple like ((A,a),((B,b),(C,c)),d) // and produce a tuple ((A,(B,C)),(a,(b,c),d)) // where the rank-2 modes are selected by the terminals of the guide (X,(X,X)) // namespace detail { template CUTE_HOST_DEVICE constexpr auto zip2_by(T const& t, TG const& guide, seq, seq) { // zip2_by produces the modes like ((A,a),(B,b),...) auto split = cute::make_tuple(zip2_by(get(t), get(guide))...); // Rearrange and append missing modes from t to make ((A,B,...),(a,b,...,x,y)) return cute::make_tuple(cute::make_tuple(get<0>(get(split))...), cute::make_tuple(get<1>(get(split))..., get(t)...)); } } // end namespace detail template CUTE_HOST_DEVICE constexpr auto zip2_by(T const& t, TG const& guide) { if constexpr (is_tuple::value) { constexpr int TR = tuple_size::value; constexpr int GR = tuple_size::value; static_assert(TR >= GR, "Mismatched ranks"); return detail::zip2_by(t, guide, make_range< 0, GR>{}, make_range{}); } else { static_assert(tuple_size::value == 2, "Mismatched ranks"); return t; } CUTE_GCC_UNREACHABLE; } /// @return A tuple of the elements of @c t in reverse order. template CUTE_HOST_DEVICE constexpr auto reverse(T const& t) { if constexpr (is_tuple::value) { return detail::apply(t, [] (auto const&... a) { return cute::make_tuple(a...); }, tuple_rseq{}); } else { return t; } } } // end namespace cute