/*************************************************************************************************** * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once #include #include #include // sizeof_bits #include #include #include #include #include #include namespace cute { // // recast_ptr -- Create an iterator over values of type T. // For most types this will simply be T*, but certain types require more care. // Subbyte Types: uint2_t, uint4_t, etc // Requires construction of a subbyte_iterator in order to properly // resolve each element in byte-addressed memory. // template CUTE_HOST_DEVICE constexpr auto recast_ptr(void* ptr) { if constexpr (is_subbyte::value) { return subbyte_iterator(ptr); } else { return reinterpret_cast(ptr); } CUTE_GCC_UNREACHABLE; } template CUTE_HOST_DEVICE constexpr auto recast_ptr(void const* ptr) { if constexpr (is_subbyte::value) { return subbyte_iterator(ptr); } else { return reinterpret_cast(ptr); } CUTE_GCC_UNREACHABLE; } // Disambiguate nullptr template CUTE_HOST_DEVICE constexpr auto recast_ptr(decltype(nullptr)) { // nullptr_t return recast_ptr(static_cast(nullptr)); } // // gmem_ptr // template struct gmem_ptr : iter_adaptor> { using iter_adaptor>::iter_adaptor; }; template struct is_gmem : false_type {}; template // Found the gmem struct is_gmem> : true_type {}; template // Recurse on ::iterator, if possible struct is_gmem> : is_gmem {}; // Idempotent gmem tag on an iterator template CUTE_HOST_DEVICE constexpr auto make_gmem_ptr(Iterator iter) { if constexpr (is_gmem::value) { return iter; } else { return gmem_ptr{iter}; } CUTE_GCC_UNREACHABLE; } // Explicitly typed construction from a raw pointer template CUTE_HOST_DEVICE constexpr auto make_gmem_ptr(void* ptr) { return make_gmem_ptr(recast_ptr(ptr)); } // Explicitly typed construction from a raw pointer template CUTE_HOST_DEVICE constexpr auto make_gmem_ptr(void const* ptr) { return make_gmem_ptr(recast_ptr(ptr)); } // nullptr_t overload for make_gmem_ptr(nullptr) disambiguation template CUTE_HOST_DEVICE constexpr auto make_gmem_ptr(decltype(nullptr)) { // nullptr_t return make_gmem_ptr(recast_ptr(nullptr)); } // The gmem tag is invariant over type-recast template CUTE_HOST_DEVICE constexpr auto recast_ptr(gmem_ptr

const& ptr) { return make_gmem_ptr(recast_ptr(ptr.get())); } // // smem_ptr // template struct smem_ptr : iter_adaptor> { using iter_adaptor>::iter_adaptor; }; template struct is_smem : false_type {}; template // Found the smem struct is_smem> : true_type {}; template // Recurse on ::iterator, if possible struct is_smem> : is_smem {}; // Idempotent smem tag on an iterator template CUTE_HOST_DEVICE constexpr auto make_smem_ptr(Iterator iter) { if constexpr (is_smem::value) { return iter; } else { return smem_ptr{iter}; } CUTE_GCC_UNREACHABLE; } // Make a smem swizzle pointer, common operation template CUTE_HOST_DEVICE constexpr auto make_smem_ptr(Iterator ptr, Swizzle sw) { return make_swizzle_ptr(make_smem_ptr(ptr), sw); } // Explicitly typed construction from a raw pointer template CUTE_HOST_DEVICE constexpr auto make_smem_ptr(void* ptr) { return make_smem_ptr(recast_ptr(ptr)); } // Explicitly typed construction from a raw pointer template CUTE_HOST_DEVICE constexpr auto make_smem_ptr(void const* ptr) { return make_smem_ptr(recast_ptr(ptr)); } // The smem tag is invariant over type-recast template CUTE_HOST_DEVICE constexpr auto recast_ptr(smem_ptr

const& ptr) { return make_smem_ptr(recast_ptr(ptr.get())); } // // rmem_ptr // template struct rmem_ptr : iter_adaptor> { using iter_adaptor>::iter_adaptor; }; // Anything that is not gmem or smem is rmem template struct is_rmem : bool_constant::value || is_smem::value)> {}; template struct is_rmem> : true_type {}; // Idempotent rmem tag on an iterator template CUTE_HOST_DEVICE constexpr auto make_rmem_ptr(Iterator iter) { if constexpr (is_rmem::value) { return iter; } else { return rmem_ptr{iter}; } CUTE_GCC_UNREACHABLE; } // Explicitly typed construction from a raw pointer template CUTE_HOST_DEVICE constexpr auto make_rmem_ptr(void* ptr) { return make_rmem_ptr(recast_ptr(ptr)); } // Explicitly typed construction from a raw pointer template CUTE_HOST_DEVICE constexpr auto make_rmem_ptr(void const* ptr) { return make_rmem_ptr(recast_ptr(ptr)); } // The rmem tag is invariant over type-recast template CUTE_HOST_DEVICE constexpr auto recast_ptr(rmem_ptr

const& ptr) { return make_rmem_ptr(recast_ptr(ptr.get())); } // // Display utilities // template CUTE_HOST_DEVICE void print(gmem_ptr ptr) { printf("gmem_"); print(ptr.get()); } template CUTE_HOST_DEVICE void print(smem_ptr ptr) { printf("smem_"); print(ptr.get()); } template CUTE_HOST_DEVICE void print(rmem_ptr ptr) { printf("rmem_"); print(ptr.get()); } #if !defined(__CUDACC_RTC__) template CUTE_HOST std::ostream& operator<<(std::ostream& os, gmem_ptr ptr) { return os << "gmem_[" << int(sizeof_bits>::value) << "b]"; } template CUTE_HOST std::ostream& operator<<(std::ostream& os, smem_ptr ptr) { return os << "smem_[" << int(sizeof_bits>::value) << "b]"; } template CUTE_HOST std::ostream& operator<<(std::ostream& os, rmem_ptr ptr) { return os << "rmem_[" << int(sizeof_bits>::value) << "b]"; } #endif // !defined(__CUDACC_RTC__) } // end namespace cute