/*************************************************************************************************** * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once #include #include /* This implements a ComposedLayout of the form * LayoutA o Offset o LayoutB * and is useful in cases where composition() does not or cannot apply to LayoutA and LayoutB. * For example, when the "divisibility condition" in shape_div is violated in composition(LayoutA, LayoutB). * * This ComposedLayout provides similar functionality to Layout including tiling, partitioning, * coordinate-to-index mapping and layout manipulations, but is not considered a "normal" layout. * For example, this layout provides shape() and size() functions, but does not provide stride() functions. * Mostly, the similar functionality is accomplished by applying each operation to LayoutB only * as LayoutB defines the domain. */ namespace cute { // A Layout of non-trivially composable functions: F o I o L template struct ComposedLayout : private cute::tuple // EBO for static layouts { CUTE_HOST_DEVICE constexpr ComposedLayout(LayoutA const& layoutA = {}, Offset const& offset = {}, LayoutB const& layoutB = {}) : cute::tuple(layoutA, offset, layoutB) {} // // Accessors // static constexpr int rank = LayoutB::rank; CUTE_HOST_DEVICE constexpr decltype(auto) layout_a() const { return get<0>(static_cast const&>(*this)); } CUTE_HOST_DEVICE constexpr decltype(auto) offset() const { return get<1>(static_cast const&>(*this)); } CUTE_HOST_DEVICE constexpr decltype(auto) layout_b() const { return get<2>(static_cast const&>(*this)); } CUTE_HOST_DEVICE constexpr decltype(auto) layout() const { return *this; } CUTE_HOST_DEVICE constexpr decltype(auto) shape() const { return layout_b().shape(); } // Doesn't really make sense to ask for the strides of this "layout" CUTE_HOST_DEVICE constexpr decltype(auto) stride() const = delete; // // Mappings // // Map a logical coordinate to a linear index (Coord has no Underscore slice operators) // OR // Slice the layout and return the sublayout (Coord has an Underscore slice op) template CUTE_HOST_DEVICE constexpr auto operator()(Coord const& coord) const { if constexpr (has_underscore::value) { return slice(coord, *this); } else { return layout_a()(offset() + layout_b()(coord)); // (A o O o B)(c) } CUTE_GCC_UNREACHABLE; } // Convenience function for multi-dimensional coordinates template CUTE_HOST_DEVICE constexpr auto operator()(Coord0 const& c0, Coord1 const& c1, Coords const&... cs) const { return operator()(make_coord(c0,c1,cs...)); } // // Compose // template CUTE_HOST_DEVICE constexpr auto compose(OtherLayout const& other) const { return composition(*this, other); } template CUTE_HOST_DEVICE constexpr auto compose(Layouts const&... layouts) const { return composition(*this, make_tile(layouts...)); } template CUTE_HOST_DEVICE constexpr auto with_shape(OtherShape const& shape) const { return composition(*this, make_layout(shape)); } template CUTE_HOST_DEVICE constexpr auto with_shape(Shapes const&... shapes) const { return composition(*this, make_layout(make_shape(shapes...))); } // // Tile // template CUTE_HOST_DEVICE constexpr auto tile(OtherLayout const& other) const { return tiled_divide(*this, other); } template CUTE_HOST_DEVICE constexpr auto tile(Layouts const&... layouts) const { return tiled_divide(*this, make_tile(layouts...)); } }; template struct is_layout> : true_type {}; template struct is_composed_layout : false_type {}; template struct is_composed_layout> : true_type {}; // // Constructors // template CUTE_HOST_DEVICE constexpr auto make_composed_layout(LayoutA const& layoutA, Offset const& offset, LayoutB const& layoutB) { return ComposedLayout{layoutA, offset, layoutB}; } // // Utilities // // Return the layout of a mode template CUTE_HOST_DEVICE constexpr decltype(auto) layout(ComposedLayout const& clayout) { return composition(clayout.layout_a(), clayout.offset(), layout(clayout.layout_b())); } // Return the shape of a mode template CUTE_HOST_DEVICE constexpr decltype(auto) shape(ComposedLayout const& layout) { return shape(layout.layout_b()); } // Doesn't make sense to directly ask for the strides of this "layout" template CUTE_HOST_DEVICE constexpr decltype(auto) stride(ComposedLayout const& layout) = delete; // Return the number of elements in a mode template CUTE_HOST_DEVICE constexpr decltype(auto) size(ComposedLayout const& layout) { return size(layout.layout_b()); } // Return the number of modes template CUTE_HOST_DEVICE constexpr auto rank(ComposedLayout const& layout) { return rank(layout.layout_b()); } // Return the depth of the layout template CUTE_HOST_DEVICE constexpr auto depth(ComposedLayout const& layout) { return depth(layout.layout_b()); } // Return the codomain size of a mode template CUTE_HOST_DEVICE constexpr auto cosize(ComposedLayout const& layout) { return cosize(layout.layout_b()); } // // Operations to manipulate Layouts like a tuple of pairs // template CUTE_HOST_DEVICE constexpr auto get(ComposedLayout const& a) { return composition(a.layout_a(), a.offset(), get(a.layout_b())); } template CUTE_HOST_DEVICE constexpr auto take(ComposedLayout const& a) { return composition(a.layout_a(), a.offset(), take(a.layout_b())); } template CUTE_HOST_DEVICE constexpr auto flatten(ComposedLayout const& a) { return composition(a.layout_a(), a.offset(), flatten(a.layout_b())); } template CUTE_HOST_DEVICE constexpr auto append(ComposedLayout const& a, X const& x) { return composition(a.layout_a(), a.offset(), append(a.layout_b(), x)); } template CUTE_HOST_DEVICE constexpr auto group(ComposedLayout const& a) { return composition(a.layout_a(), a.offset(), group(a.layout_b())); } // // Slice a ComposedLayout // template CUTE_HOST_DEVICE constexpr auto slice_and_offset(Coord const& coord, ComposedLayout const& layout) { auto [slice, offset] = slice_and_offset(coord, layout.layout_b()); return cute::make_tuple(ComposedLayout{layout.layout_a(), layout.offset() + offset, slice}, Int<0>{}); } template CUTE_HOST_DEVICE constexpr auto slice(Coord const& coord, ComposedLayout const& layout) { return get<0>(slice_and_offset(coord, layout)); } // Compute a pointer offset and (potentially modified) layout from a coordinate // For composed layout tensors the offset is accumulated in the layout itself while pointer is not updated template CUTE_HOST_DEVICE constexpr auto domain_offset(Coord const& coord, ComposedLayout const& layout) { return cute::make_tuple(ComposedLayout{layout.layout_a(), layout.offset() + layout.layout_b()(coord), layout.layout_b()}, Int<0>{}); } // // composition // template CUTE_HOST_DEVICE constexpr auto composition(LayoutA const& layoutA, Offset const& offset, LayoutB const& layoutB) { return ComposedLayout{layoutA, offset, layoutB}; } template CUTE_HOST_DEVICE constexpr auto composition(ComposedLayout const& a, Tiler const& b) { return composition(a.layout_a(), a.offset(), composition(a.layout_b(), b)); } template CUTE_HOST_DEVICE constexpr auto composition(Layout const& a, ComposedLayout const& b) { CUTE_STATIC_ASSERT_V(b.offset() == Int<0>{}, "Require offset == 0."); return composition(composition(a, b.layout_a()), b.layout_b()); } // // complement // template CUTE_HOST_DEVICE constexpr auto complement(ComposedLayout const& layout, CoSizeHi const& cosize_hi) { return complement(layout.layout_b(), cosize_hi); } template CUTE_HOST_DEVICE constexpr auto complement(ComposedLayout const& layout) { return complement(layout, cosize(layout)); } // // inverse // template CUTE_HOST_DEVICE constexpr auto right_inverse(ComposedLayout const& layout) { return composition(right_inverse(layout.layout_b()), right_inverse(layout.offset()), right_inverse(layout.layout_a())); } template CUTE_HOST_DEVICE constexpr auto left_inverse(ComposedLayout const& layout) { return composition(left_inverse(layout.layout_b()), left_inverse(layout.offset()), left_inverse(layout.layout_a())); } // // Other operations // template CUTE_HOST_DEVICE constexpr auto zip(ComposedLayout const& a) { return composition(a.layout_a(), a.offset(), zip(a.layout_b())); } // Partitions template CUTE_HOST_DEVICE constexpr auto logical_divide(ComposedLayout const& a, Tiler const& b) { return composition(a.layout_a(), a.offset(), logical_divide(a.layout_b(), b)); } template CUTE_HOST_DEVICE constexpr auto tile_unzip(ComposedLayout const& a, Tiler const& b) { return composition(a.layout_a(), a.offset(), tile_unzip(a.layout_b(), b)); } template CUTE_HOST_DEVICE constexpr auto tiled_divide(ComposedLayout const& a, Tiler const& b) { return composition(a.layout_a(), a.offset(), tiled_divide(a.layout_b(), b)); } template CUTE_HOST_DEVICE constexpr auto zipped_divide(ComposedLayout const& a, Tiler const& b) { return composition(a.layout_a(), a.offset(), zipped_divide(a.layout_b(), b)); } template CUTE_HOST_DEVICE constexpr auto flat_divide(ComposedLayout const& a, Tiler const& b) { return composition(a.layout_a(), a.offset(), flat_divide(a.layout_b(), b)); } template CUTE_HOST_DEVICE constexpr auto logical_product(ComposedLayout const& a, Tiler const& b) { return composition(a.layout_a(), a.offset(), logical_product(a.layout_b(), b)); } template CUTE_HOST_DEVICE constexpr auto zipped_product(ComposedLayout const& a, Tiler const& b) { return composition(a.layout_a(), a.offset(), zipped_product(a.layout_b(), b)); } template CUTE_HOST_DEVICE constexpr auto tiled_product(ComposedLayout const& a, Tiler const& b) { return composition(a.layout_a(), a.offset(), tiled_product(a.layout_b(), b)); } template CUTE_HOST_DEVICE constexpr auto flat_product(ComposedLayout const& a, Tiler const& b) { return composition(a.layout_a(), a.offset(), flat_product(a.layout_b(), b)); } template CUTE_HOST_DEVICE constexpr auto blocked_product(ComposedLayout const& a, Tiler const& b) { return composition(a.layout_a(), a.offset(), blocked_product(a.layout_b(), b)); } template CUTE_HOST_DEVICE constexpr auto raked_product(ComposedLayout const& a, Tiler const& b) { return composition(a.layout_a(), a.offset(), raked_product(a.layout_b(), b)); } template CUTE_HOST_DEVICE constexpr auto tile_to_shape(ComposedLayout const& layout, Shape const& trg_shape, ModeOrder const& ord_shape = {}) { return composition(layout.layout_a(), layout.offset(), tile_to_shape(layout.layout_b(), trg_shape, ord_shape)); } template CUTE_HOST_DEVICE constexpr auto filter(ComposedLayout const& layout, Shape const& trg_profile) { return composition(layout.layout_a(), layout.offset(), filter(layout.layout_b(), trg_profile)); } template CUTE_HOST_DEVICE constexpr auto coalesce(ComposedLayout const& layout) { return composition(layout.layout_a(), layout.offset(), coalesce(layout.layout_b())); } template CUTE_HOST_DEVICE constexpr auto coalesce(ComposedLayout const& layout, Shape const& trg_profile) { return composition(layout.layout_a(), layout.offset(), coalesce(layout.layout_b(), trg_profile)); } // // Upcast and Downcast // template CUTE_HOST_DEVICE constexpr auto upcast(ComposedLayout const& layout) { return composition(upcast(layout.layout_a()), upcast(layout.offset()), upcast(layout.layout_b())); } template CUTE_HOST_DEVICE constexpr auto downcast(ComposedLayout const& layout) { return composition(downcast(layout.layout_a()), downcast(layout.offset()), downcast(layout.layout_b())); } template CUTE_HOST_DEVICE constexpr auto recast_layout(ComposedLayout const& layout) { using scale = decltype(trait_ratio(sizeof_bits{}, sizeof_bits{})); if constexpr (scale::num == 1 && scale::den == 1) { return layout; } else if constexpr (scale::num == 1) { return downcast(layout); } else if constexpr (scale::den == 1) { return upcast(layout); } else { static_assert(dependent_false, "Recast not supported."); } CUTE_GCC_UNREACHABLE; } // // Display utilities // template CUTE_HOST_DEVICE void print(ComposedLayout const& layout) { print(layout.layout_a()); print(" o "); print(layout.offset()); print(" o "); print(layout.layout_b()); } #if !defined(__CUDACC_RTC__) template CUTE_HOST std::ostream& operator<<(std::ostream& os, ComposedLayout const& layout) { return os << layout.layout_a() << " o " << layout.offset() << " o " << layout.layout_b(); } #endif } // end namespace cute