/*************************************************************************************************** * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once #include #include #include // sizeof_bits namespace cute { // // C++20 iterator_traits // namespace detail { // Default reference type of an iterator template struct iter_ref { using type = decltype(*declval()); }; // Prefer to propagate ::reference template struct iter_ref> { using type = typename T::reference; }; } // end namespace detail template using iter_reference = detail::iter_ref; template using iter_reference_t = typename iter_reference::type; namespace detail { // Default element_type of an iterator template struct iter_e { using type = remove_reference_t::type>; }; // Prefer to propagate ::element_type template struct iter_e> { using type = typename T::element_type; }; } // end namespace detail template using iter_element = detail::iter_e; template using iter_element_t = typename iter_element::type; namespace detail { // Default value_type of an iterator template struct iter_v { using type = remove_cv_t::type>; }; // Prefer to propagate ::value_type template struct iter_v> { using type = typename T::value_type; }; } // end namespace detail template using iter_value = detail::iter_v; template using iter_value_t = typename iter_value::type; template struct iterator_traits { using reference = iter_reference_t; using element_type = iter_element_t; using value_type = iter_value_t; }; // // has_dereference to determine if a type is an iterator concept // namespace detail { template struct has_dereference : CUTE_STL_NAMESPACE::false_type {}; template struct has_dereference())>> : CUTE_STL_NAMESPACE::true_type {}; } // end namespace detail template using has_dereference = detail::has_dereference; // // raw_pointer_cast // template CUTE_HOST_DEVICE constexpr T* raw_pointer_cast(T* ptr) { return ptr; } // // A very simplified iterator adaptor. // Derived classed may override methods, but be careful to reproduce interfaces exactly. // Clients should never have an instance of this class. Do not write methods that take this as a param. // template struct iter_adaptor { using iterator = Iterator; using reference = typename iterator_traits::reference; using element_type = typename iterator_traits::element_type; using value_type = typename iterator_traits::value_type; iterator ptr_; CUTE_HOST_DEVICE constexpr iter_adaptor(iterator ptr = {}) : ptr_(ptr) {} CUTE_HOST_DEVICE constexpr reference operator*() const { return *ptr_; } template CUTE_HOST_DEVICE constexpr reference operator[](Index const& i) const { return ptr_[i]; } template CUTE_HOST_DEVICE constexpr DerivedType operator+(Index const& i) const { return {ptr_ + i}; } CUTE_HOST_DEVICE constexpr iterator get() const { return ptr_; } CUTE_HOST_DEVICE constexpr friend bool operator==(DerivedType const& x, DerivedType const& y) { return x.ptr_ == y.ptr_; } CUTE_HOST_DEVICE constexpr friend bool operator!=(DerivedType const& x, DerivedType const& y) { return x.ptr_ != y.ptr_; } CUTE_HOST_DEVICE constexpr friend bool operator< (DerivedType const& x, DerivedType const& y) { return x.ptr_ < y.ptr_; } CUTE_HOST_DEVICE constexpr friend bool operator<=(DerivedType const& x, DerivedType const& y) { return x.ptr_ <= y.ptr_; } CUTE_HOST_DEVICE constexpr friend bool operator> (DerivedType const& x, DerivedType const& y) { return x.ptr_ > y.ptr_; } CUTE_HOST_DEVICE constexpr friend bool operator>=(DerivedType const& x, DerivedType const& y) { return x.ptr_ >= y.ptr_; } }; template CUTE_HOST_DEVICE constexpr auto raw_pointer_cast(iter_adaptor const& x) { return raw_pointer_cast(x.ptr_); } // // counting iterator -- quick and dirty // template struct counting_iterator { using index_type = T; using value_type = T; using reference = T; index_type n_; CUTE_HOST_DEVICE constexpr counting_iterator(index_type n = 0) : n_(n) {} CUTE_HOST_DEVICE constexpr index_type operator*() const { return n_; } CUTE_HOST_DEVICE constexpr index_type operator[](index_type i) const { return n_ + i; } CUTE_HOST_DEVICE constexpr counting_iterator operator+(index_type i) const { return {n_ + i}; } CUTE_HOST_DEVICE constexpr counting_iterator& operator++() { ++n_; return *this; } CUTE_HOST_DEVICE constexpr counting_iterator operator++(int) { counting_iterator ret = *this; ++n_; return ret; } CUTE_HOST_DEVICE constexpr friend bool operator==(counting_iterator const& x, counting_iterator const& y) { return x.n_ == y.n_; } CUTE_HOST_DEVICE constexpr friend bool operator!=(counting_iterator const& x, counting_iterator const& y) { return x.n_ != y.n_; } CUTE_HOST_DEVICE constexpr friend bool operator< (counting_iterator const& x, counting_iterator const& y) { return x.n_ < y.n_; } CUTE_HOST_DEVICE constexpr friend bool operator<=(counting_iterator const& x, counting_iterator const& y) { return x.n_ <= y.n_; } CUTE_HOST_DEVICE constexpr friend bool operator> (counting_iterator const& x, counting_iterator const& y) { return x.n_ > y.n_; } CUTE_HOST_DEVICE constexpr friend bool operator>=(counting_iterator const& x, counting_iterator const& y) { return x.n_ >= y.n_; } }; template CUTE_HOST_DEVICE constexpr T raw_pointer_cast(counting_iterator const& x) { return x.n_; } // // Display utilities // template CUTE_HOST_DEVICE void print(T const* const ptr) { printf("ptr["); print(sizeof_bits::value); printf("b](%p)", ptr); } template CUTE_HOST_DEVICE void print(counting_iterator ptr) { printf("counting_iter("); print(ptr.n_); printf(")"); } #if !defined(__CUDACC_RTC__) template CUTE_HOST std::ostream& operator<<(std::ostream& os, counting_iterator ptr) { return os << "counting_iter(" << ptr.n_ << ")"; } #endif // !defined(__CUDACC_RTC__) } // end namespace cute