/*************************************************************************************************** * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once #include #include // iterator_traits #include #include #include /* This implements a swizzle pointer of the form * InvolutionFn o PtrAdd * where the InvolutionFn need not be linear. * * This differs subtly from swizzle_layout because the smem pointer is used * as the offset. That means that swizzle_layout will implement position-independent * swizzle layouts, while swizzle_ptr implements position-dependent swizzle tensors. * Arch chose to design hardware with position-dependent swizzles. * * For clarity: * NormalLayout : DeRef <- PtrAdd <- [Layout] * ComposedLayout: DeRef <- PtrAdd <- [Swizzle <- OffsetAdd <- Layout] * SwizzlePtr : [DeRef <- Swizzle <- PtrAdd] <- Layout * * Furthermore, for known swizzles, this pointer attempts to decay itself * to a normal-pointer with a new layout containing dynamic or static strides. * This is possible by determining the subdomain of the InvolutionFn * that is identity and testing if the Layout's codomain is contained * within it. */ namespace cute { // concept SwizzleFn { // CUTE_HOST_DEVICE constexpr static uint apply(uint); // } // See Swizzle in swizzle.hpp for common swizzle-functions. template struct swizzle_ptr : iter_adaptor> { using iterator = Iterator; using reference = typename iterator_traits::reference; using element_type = typename iterator_traits::element_type; using value_type = typename iterator_traits::value_type; using iter_adaptor>::iter_adaptor; template CUTE_HOST_DEVICE constexpr static Iter apply_swizzle(Iter ptr) { return {apply_swizzle(ptr.get())}; } template CUTE_HOST_DEVICE constexpr static T* apply_swizzle(T* ptr) { return reinterpret_cast(SwizzleFn::apply(reinterpret_cast(ptr))); } template CUTE_HOST_DEVICE constexpr static subbyte_iterator apply_swizzle(subbyte_iterator ptr) { return {apply_swizzle(ptr.ptr_), ptr.idx_}; } CUTE_HOST_DEVICE constexpr reference operator*() const { return *apply_swizzle(this->get()); } template CUTE_HOST_DEVICE constexpr reference operator[](Int const& i) const { return *apply_swizzle(this->get() + i); } }; template // Default No-Swizzle struct get_swizzle { using type = Swizzle<0,4,3>; }; template // Found the SwizzleFn struct get_swizzle> { using type = SwizzleFn; }; template // Recurse into anything with a ::iterator struct get_swizzle> : get_swizzle {}; template using get_swizzle_t = typename get_swizzle::type; template CUTE_HOST_DEVICE constexpr swizzle_ptr make_swizzle_ptr(Iterator ptr, SwizzleFn) { return {ptr}; } // Swizzle-0 specialization for immediate decay template CUTE_HOST_DEVICE constexpr Iterator make_swizzle_ptr(Iterator ptr, Swizzle<0,M,S>) { return ptr; } // // Recast // template CUTE_HOST_DEVICE constexpr auto raw_pointer_cast(swizzle_ptr const& ptr) { return raw_pointer_cast(ptr.get()); } // SwizzleFn operates on the pointer address, so it doesn't care about the type template CUTE_HOST_DEVICE constexpr auto recast_ptr(swizzle_ptr const& ptr) { return make_swizzle_ptr(recast_ptr(ptr.get()), SwizzleFn{}); } // // Display utilities // template CUTE_HOST_DEVICE void print(swizzle_ptr ptr) { print(SwizzleFn{}); printf("_"); print(ptr.get()); } #if !defined(__CUDACC_RTC__) template CUTE_HOST std::ostream& operator<<(std::ostream& os, swizzle_ptr ptr) { return os << SwizzleFn{} << "_" << ptr.get(); } #endif } // end namespace cute