/*************************************************************************************************** * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once #include #include namespace cute { /** crd2idx(c,s,d) maps a coordinate within to an index * * This is computed as follows: * [coord, shape, and stride are all integers => step forward by stride] * op(c, s, d) => c * d * [coord is integer, shape and stride are tuple => divmod coord for each mode] * op(c, (s,S), (d,D)) => op(c % prod(s), s, d) + op(c / prod(s), (S), (D)) * [coord, shape, and stride are all tuples => consider each mode independently] * op((c,C), (s,S), (d,D)) => op(c, s, d) + op((C), (S), (D)) */ template CUTE_HOST_DEVICE constexpr auto crd2idx(Coord const& coord, Shape const& shape, Stride const& stride); namespace detail { template CUTE_HOST_DEVICE constexpr auto crd2idx_ttt(Coord const& coord, Shape const& shape, Stride const& stride, seq) { return (... + crd2idx(get(coord), get(shape), get(stride))); } template CUTE_HOST_DEVICE constexpr auto crd2idx_itt(CInt const& coord, STuple const& shape, DTuple const& stride, seq) { if constexpr (sizeof...(Is) == 0) { // Avoid recursion and mod on single/last iter return crd2idx(coord, get(shape), get(stride)); } else if constexpr (is_constant<0, CInt>::value) { return crd2idx(_0{}, get(shape), get(stride)) + (_0{} + ... + crd2idx(_0{}, get(shape), get(stride))); } else { // General case return crd2idx(coord % product(get(shape)), get(shape), get(stride)) + crd2idx_itt(coord / product(get(shape)), shape, stride, seq{}); } CUTE_GCC_UNREACHABLE; } } // end namespace detail template CUTE_HOST_DEVICE constexpr auto crd2idx(Coord const& coord, Shape const& shape, Stride const& stride) { if constexpr (is_tuple::value) { if constexpr (is_tuple::value) { // tuple tuple tuple static_assert(tuple_size::value == tuple_size< Shape>::value, "Mismatched Ranks"); static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); return detail::crd2idx_ttt(coord, shape, stride, tuple_seq{}); } else { // tuple "int" "int" static_assert(sizeof(Coord) == 0, "Invalid parameters"); } } else { if constexpr (is_tuple::value) { // "int" tuple tuple static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); return detail::crd2idx_itt(coord, shape, stride, tuple_seq{}); } else { // "int" "int" "int" return coord * stride; } } CUTE_GCC_UNREACHABLE; } namespace detail { template CUTE_HOST_DEVICE constexpr auto crd2idx_horner(CTuple const& coord, STuple const& shape, seq) { if constexpr (sizeof...(Is) == 0) { // No recursion on single/last iter return get(coord); } else { // General case return get(coord) + get(shape) * crd2idx_horner(coord, shape, seq{}); } CUTE_GCC_UNREACHABLE; } } // end namespace detail /** crd2idx(c,s) maps a coordinate within Shape to an index * via a colexicographical enumeration of coordinates in Shape. * i = c0 + s0 * (c1 + s1 * (c2 + s2 * ...)) */ template CUTE_HOST_DEVICE constexpr auto crd2idx(Coord const& coord, Shape const& shape) { if constexpr (is_integral::value) { // Coord is already an index return coord; } else if constexpr (is_integral::value) { static_assert(dependent_false, "Invalid parameters"); } else { // Make congruent, flatten, and apply Horner's method static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); auto flat_coord = flatten(coord); auto flat_shape = flatten(product_like(shape, coord)); return detail::crd2idx_horner(flat_coord, flat_shape, tuple_seq{}); } CUTE_GCC_UNREACHABLE; } /** idx2crd(i,s,d) splits an index into a coordinate within . * * This is computed as follows: * [index, shape, and stride are all integers => determine 1D coord] * op(i, s, d) => (i / d) % s * [index is integer, shape and stride are tuple => determine component for each mode] * op(i, (s,S), (d,D)) => (op(i, s, d), op(i, S, D)...) * [index, shape, and stride are all tuples => consider each mode independently] * op((i,I), (s,S), (d,D)) => (op(i, s, d), op((I), (S), (D))) * * NOTE: This only works for compact shape+stride layouts. A more general version would * apply to all surjective layouts */ template CUTE_HOST_DEVICE constexpr auto idx2crd(Index const& idx, Shape const& shape, Stride const& stride) { if constexpr (is_tuple::value) { if constexpr (is_tuple::value) { // tuple tuple tuple static_assert(tuple_size::value == tuple_size< Shape>::value, "Mismatched Ranks"); static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); return transform(idx, shape, stride, [](auto const& i, auto const& s, auto const& d){ return idx2crd(i,s,d); }); } else { // tuple "int" "int" static_assert(sizeof(Index) == 0, "Invalid parameters"); } } else { if constexpr (is_tuple::value) { if constexpr (is_tuple::value) { // "int" tuple tuple static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); return transform(shape, stride, [&](auto const& s, auto const& d){ return idx2crd(idx,s,d); }); } else { // "int" tuple "int" return transform(shape, compact_col_major(shape, stride), [&](auto const& s, auto const& d){ return idx2crd(idx,s,d); }); } } else { // "int" "int" "int" if constexpr (is_constant<1, Shape>::value) { // Skip potential stride-0 division return Int<0>{}; } else { return (idx / stride) % shape; } } } CUTE_GCC_UNREACHABLE; } /** idx2crd(i,s) splits an index into a coordinate within Shape * via a colexicographical enumeration of coordinates in Shape. * c0 = (idx / 1) % s0 * c1 = (idx / s0) % s1 * c2 = (idx / (s0 * s1)) % s2 * ... */ template CUTE_HOST_DEVICE constexpr auto idx2crd(Index const& idx, Shape const& shape) { if constexpr (is_tuple::value) { if constexpr (is_tuple::value) { // tuple tuple static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); return transform(idx, shape, [](auto const& i, auto const& s) { return idx2crd(i,s); }); } else { // tuple "int" static_assert(sizeof(Index) == 0, "Invalid parameters"); } } else { if constexpr (is_tuple::value) { // "int" tuple return idx2crd(idx, shape, compact_col_major(shape)); } else { // "int" "int" return idx; } } CUTE_GCC_UNREACHABLE; } // // crd2crd // template CUTE_HOST_DEVICE constexpr auto crd2crd(Coord const& coord, SShape const& src_shape, DShape const& dst_shape) { if constexpr (is_tuple::value && is_tuple::value && is_tuple::value) { static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); return transform(coord, src_shape, dst_shape, [](auto const& c, auto const& s, auto const& d) { return crd2crd(c,s,d); }); } else { // assert(size(src_shape) == size(dst_shape)) return idx2crd(crd2idx(coord, src_shape), dst_shape); } CUTE_GCC_UNREACHABLE; } // // Compact Major // // Tags for common layouts and dispatching struct LayoutLeft; // Col-major layout mapping; leftmost extent has stride 1 using GenColMajor = LayoutLeft; // Alias struct LayoutRight; // Row-major layout mapping; rightmost extent has stride 1 using GenRowMajor = LayoutRight; // Alias namespace detail { // For GCC8.5 -- Use of lambdas in unevaluated contexts. Instead use function objects. template struct CompactLambda; // @pre is_integral // Return (result, current * product(shape)) to enable recurrence template CUTE_HOST_DEVICE constexpr auto compact(Shape const& shape, Current const& current) { if constexpr (is_tuple::value) { // Shape::tuple Current::int using Lambda = CompactLambda; // Append or Prepend using Seq = typename Lambda::template seq; // Seq or RSeq return cute::detail::fold(shape, cute::make_tuple(cute::make_tuple(), current), Lambda{}, Seq{}); } else { // Shape::int Current::int if constexpr (is_constant<1, Shape>::value) { return cute::make_tuple(Int<0>{}, current); // If current is dynamic, this could save a reg } else { return cute::make_tuple(current, current * shape); } } CUTE_GCC_UNREACHABLE; } // For GCC8.5 -- Specialization LayoutLeft template <> struct CompactLambda { template CUTE_HOST_DEVICE constexpr auto operator()(Init const& init, Shape const& si) { auto result = detail::compact(si, get<1>(init)); return cute::make_tuple(append(get<0>(init), get<0>(result)), get<1>(result)); // Append } template using seq = tuple_seq; // Seq }; // For GCC8.5 -- Specialization LayoutRight template <> struct CompactLambda { template CUTE_HOST_DEVICE constexpr auto operator()(Init const& init, Shape const& si) { auto result = detail::compact(si, get<1>(init)); return cute::make_tuple(prepend(get<0>(init), get<0>(result)), get<1>(result)); // Prepend } template using seq = tuple_rseq; // RSeq }; } // end namespace detail template , __CUTE_REQUIRES(is_tuple::value || is_integral::value)> CUTE_HOST_DEVICE constexpr auto compact_major(Shape const& shape, Current const& current = {}) { if constexpr (is_tuple::value) { // Shape::tuple Current::tuple static_assert(is_tuple::value, "Invalid parameters"); static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); // Recurse to apply to the terminals of current return transform(shape, current, [&](auto const& s, auto const& c){ return compact_major(s,c); }); } else { return get<0>(detail::compact(shape, current)); } CUTE_GCC_UNREACHABLE; } // // Compact Col Major // struct LayoutLeft { template using Apply = decltype(compact_major(declval())); }; template > CUTE_HOST_DEVICE constexpr auto compact_col_major(Shape const& shape, Current const& current = {}) { return compact_major(shape, current); } // // Compact Row Major // struct LayoutRight { template using Apply = decltype(compact_major(declval())); }; template > CUTE_HOST_DEVICE constexpr auto compact_row_major(Shape const& shape, Current const& current = {}) { return compact_major(shape, current); } // // Compact Order -- compute a compact stride based on an ordering of the modes // namespace detail { template CUTE_HOST_DEVICE constexpr auto compact_order(Shape const& shape, Order const& order, OrigShape const& orig_shape, OrigOrder const& orig_order) { if constexpr (is_tuple::value) { return transform(shape, order, [&](auto const& x, auto const& y) { return compact_order(x, y, orig_shape, orig_order); }); } else { auto d = product(transform(orig_shape, orig_order, [&](auto const& s, auto const& o) { return conditional_return(o < order, product(s), Int<1>{}); })); return compact_col_major(shape, d); } CUTE_GCC_UNREACHABLE; } } // end namespace detail template CUTE_HOST_DEVICE constexpr auto compact_order(Shape const& shape, Order const& order) { if constexpr(is_congruent::value) { return detail::compact_order(shape, order, flatten_to_tuple(shape), flatten_to_tuple(order)); } else { // Here we only want to apply order to top-level subshapes and default (col-major) order on other levels static_assert(rank(Shape{}) == rank(Order{}), "Need equal rank of shape and order"); return detail::compact_order(shape, order, shape, order); } } template CUTE_HOST_DEVICE constexpr auto compact_order(Shape const& shape, GenColMajor const& major) { return compact_major(shape); } template CUTE_HOST_DEVICE constexpr auto compact_order(Shape const& shape, GenRowMajor const& major) { return compact_major(shape); } } // end namespace cute