/*************************************************************************************************** * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once #include #include namespace cute { // // Common Operations // template ::value && is_arithmetic::value)> CUTE_HOST_DEVICE constexpr auto max(T const& t, U const& u) { return t < u ? u : t; } template ::value && is_arithmetic::value)> CUTE_HOST_DEVICE constexpr auto min(T const& t, U const& u) { return t < u ? t : u; } template ::value)> CUTE_HOST_DEVICE constexpr auto abs(T const& t) { if constexpr (is_signed::value) { return t < T(0) ? -t : t; } else { return t; } CUTE_GCC_UNREACHABLE; } // Returns 1 if x > 0, -1 if x < 0, and 0 if x is zero. template ::value)> CUTE_HOST_DEVICE constexpr int signum(T const& x) { if constexpr (is_signed::value) { return (T(0) < x) - (x < T(0)); } else { return T(0) < x; } CUTE_GCC_UNREACHABLE; } // // C++17 operations // // Greatest common divisor of two positive integers template ::value && is_std_integral::value)> CUTE_HOST_DEVICE constexpr cute::common_type_t gcd(T t, U u) { while (true) { if (t == 0) { return u; } u %= t; if (u == 0) { return t; } t %= u; } } // Least common multiple of two positive integers template ::value && is_std_integral::value)> CUTE_HOST_DEVICE constexpr cute::common_type_t lcm(T const& t, U const& u) { return (t / gcd(t,u)) * u; } // // C++20 operations // // Checks if a number is an integral power of two template CUTE_HOST_DEVICE constexpr bool has_single_bit(T x) { return x != 0 && (x & (x - 1)) == 0; } // Smallest number of bits needed to represent the given value // For x == 0, this is 0 // For x != 0, this is 1 + floor(log2(x)) // bit_width( 0b0000 ) = 0 // bit_width( 0b0001 ) = 1 // bit_width( 0b0010 ) = 2 // bit_width( 0b0011 ) = 2 // bit_width( 0b0100 ) = 3 // bit_width( 0b0101 ) = 3 // bit_width( 0b0110 ) = 3 // bit_width( 0b0111 ) = 3 template CUTE_HOST_DEVICE constexpr T bit_width(T x) { static_assert(is_unsigned::value, "Only to be used for unsigned types."); constexpr int N = (numeric_limits::digits == 64 ? 6 : (numeric_limits::digits == 32 ? 5 : (numeric_limits::digits == 16 ? 4 : (numeric_limits::digits == 8 ? 3 : (assert(false),0))))); T r = 0; for (int i = N - 1; i >= 0; --i) { T shift = (x > ((T(1) << (T(1) << i))-1)) << i; x >>= shift; r |= shift; } return r + (x != 0); } // Smallest integral power of two not less than the given value // bit_ceil( 0b00000000 ) = 0b00000001 // bit_ceil( 0b00000001 ) = 0b00000001 // bit_ceil( 0b00000010 ) = 0b00000010 // bit_ceil( 0b00000011 ) = 0b00000100 // bit_ceil( 0b00000100 ) = 0b00000100 // bit_ceil( 0b00000101 ) = 0b00001000 // bit_ceil( 0b00000110 ) = 0b00001000 // bit_ceil( 0b00000111 ) = 0b00001000 // bit_ceil( 0b00001000 ) = 0b00001000 // bit_ceil( 0b00001001 ) = 0b00010000 template CUTE_HOST_DEVICE constexpr T bit_ceil(T x) { return x == 0 ? T(1) : (T(1) << bit_width(x - 1)); } // Largest integral power of two not greater than the given value // bit_floor( 0b00000000 ) = 0b00000000 // bit_floor( 0b00000001 ) = 0b00000001 // bit_floor( 0b00000010 ) = 0b00000010 // bit_floor( 0b00000011 ) = 0b00000010 // bit_floor( 0b00000100 ) = 0b00000100 // bit_floor( 0b00000101 ) = 0b00000100 // bit_floor( 0b00000110 ) = 0b00000100 // bit_floor( 0b00000111 ) = 0b00000100 // bit_floor( 0b00001000 ) = 0b00001000 // bit_floor( 0b00001001 ) = 0b00001000 template CUTE_HOST_DEVICE constexpr T bit_floor(T x) { return x == 0 ? 0 : (T(1) << (bit_width(x) - 1)); } template CUTE_HOST_DEVICE constexpr T rotl(T x, int s); template CUTE_HOST_DEVICE constexpr T rotr(T x, int s); // Computes the result of circular bitwise left-rotation template CUTE_HOST_DEVICE constexpr T rotl(T x, int s) { constexpr int N = numeric_limits::digits; return static_cast(s == 0 ? x : s > 0 ? (x << s) | (x >> (N - s)) : rotr(x, -s)); } // Computes the result of circular bitwise right-rotation template CUTE_HOST_DEVICE constexpr T rotr(T x, int s) { constexpr int N = numeric_limits::digits; return static_cast(s == 0 ? x : s > 0 ? (x >> s) | (x << (N - s)) : rotl(x, -s)); } // Counts the number of consecutive 0 bits, starting from the most significant bit // countl_zero( 0b00000000 ) = 8 // countl_zero( 0b11111111 ) = 0 // countl_zero( 0b00011100 ) = 3 template CUTE_HOST_DEVICE constexpr T countl_zero(T x) { return numeric_limits::digits - bit_width(x); } // Counts the number of consecutive 1 bits, starting from the most significant bit // countl_one( 0b00000000 ) = 0 // countl_one( 0b11111111 ) = 8 // countl_one( 0b11100011 ) = 3 template CUTE_HOST_DEVICE constexpr T countl_one(T x) { return countl_zero(~x); } // Counts the number of consecutive 0 bits, starting from the least significant bit // countr_zero( 0b00000000 ) = 8 // countr_zero( 0b11111111 ) = 0 // countr_zero( 0b00011100 ) = 2 template CUTE_HOST_DEVICE constexpr T countr_zero(T x) { return x == 0 ? numeric_limits::digits : bit_width(T(x & T(-x))) - 1; // bit_width of the LSB } // Counts the number of consecutive 1 bits, starting from the least significant bit // countr_one( 0b00000000 ) = 0 // countr_one( 0b11111111 ) = 8 // countr_one( 0b11100011 ) = 2 template CUTE_HOST_DEVICE constexpr T countr_one(T x) { return countr_zero(~x); } // Counts the number of 1 bits in an unsigned integer // popcount( 0b00000000 ) = 0 // popcount( 0b11111111 ) = 8 // popcount( 0b00011101 ) = 4 template CUTE_HOST_DEVICE constexpr int popcount(T x) { int c = 0; while (x) { ++c; x &= x - 1; // clear the least significant bit set } return c; } // // Custom operations // // Computes the result of bitwise left-shift template CUTE_HOST_DEVICE constexpr T shiftl(T x, int s) { return s >= 0 ? (x << s) : (x >> -s); } // Computes the result of bitwise right-shift template CUTE_HOST_DEVICE constexpr T shiftr(T x, int s) { return s >= 0 ? (x >> s) : (x << -s); } // Safe divide // @pre t % u == 0 // @result t / u template ::value && is_std_integral::value)> CUTE_HOST_DEVICE constexpr auto safe_div(T const& t, U const& u) { //assert(t % u == 0); return t / u; } /** * log2 computation */ template CUTE_HOST_DEVICE constexpr auto log_2(T x) { assert(x > 0); static_assert(is_unsigned::value, "Only to be used for unsigned integral types."); return bit_width(x) - 1; } } // namespace cute