/*************************************************************************************************** * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once #include "cute/util/print.hpp" #include "cute/util/type_traits.hpp" #include "cute/numeric/math.hpp" namespace cute { // A constant value: short name and type-deduction for fast compilation template struct C { using type = C; static constexpr auto value = v; using value_type = decltype(v); CUTE_HOST_DEVICE constexpr operator value_type() const noexcept { return value; } CUTE_HOST_DEVICE constexpr value_type operator()() const noexcept { return value; } }; // Deprecate template using constant = C; template using bool_constant = C; using true_type = bool_constant; using false_type = bool_constant; // A more std:: conforming integral_constant that enforces type but interops with C template struct integral_constant : C { using type = integral_constant; static constexpr T value = v; using value_type = T; // Disambiguate C::operator value_type() //CUTE_HOST_DEVICE constexpr operator value_type() const noexcept { return value; } CUTE_HOST_DEVICE constexpr value_type operator()() const noexcept { return value; } }; // // Traits // // Use cute::is_std_integral to match built-in integral types (int, int64_t, unsigned, etc) // Use cute::is_integral to match both built-in integral types AND static integral types. template struct is_integral : bool_constant::value> {}; template struct is_integral > : true_type {}; template struct is_integral> : true_type {}; // is_static detects if an (abstract) value is defined completely by it's type (no members) template struct is_static : bool_constant>::value> {}; template constexpr bool is_static_v = is_static::value; // is_constant detects if a type is a static integral type and if v is equal to a value template struct is_constant : false_type {}; template struct is_constant : is_constant {}; template struct is_constant : is_constant {}; template struct is_constant : is_constant {}; template struct is_constant : is_constant {}; template struct is_constant > : bool_constant {}; template struct is_constant> : bool_constant {}; // // Specializations // template using Int = C; using _m32 = Int<-32>; using _m24 = Int<-24>; using _m16 = Int<-16>; using _m12 = Int<-12>; using _m10 = Int<-10>; using _m9 = Int<-9>; using _m8 = Int<-8>; using _m7 = Int<-7>; using _m6 = Int<-6>; using _m5 = Int<-5>; using _m4 = Int<-4>; using _m3 = Int<-3>; using _m2 = Int<-2>; using _m1 = Int<-1>; using _0 = Int<0>; using _1 = Int<1>; using _2 = Int<2>; using _3 = Int<3>; using _4 = Int<4>; using _5 = Int<5>; using _6 = Int<6>; using _7 = Int<7>; using _8 = Int<8>; using _9 = Int<9>; using _10 = Int<10>; using _12 = Int<12>; using _16 = Int<16>; using _24 = Int<24>; using _32 = Int<32>; using _64 = Int<64>; using _96 = Int<96>; using _128 = Int<128>; using _192 = Int<192>; using _256 = Int<256>; using _384 = Int<384>; using _512 = Int<512>; using _768 = Int<768>; using _1024 = Int<1024>; using _2048 = Int<2048>; using _4096 = Int<4096>; using _8192 = Int<8192>; using _16384 = Int<16384>; using _32768 = Int<32768>; using _65536 = Int<65536>; using _131072 = Int<131072>; using _262144 = Int<262144>; using _524288 = Int<524288>; /***************/ /** Operators **/ /***************/ #define CUTE_LEFT_UNARY_OP(OP) \ template \ CUTE_HOST_DEVICE constexpr \ C<(OP t)> operator OP (C) { \ return {}; \ } #define CUTE_RIGHT_UNARY_OP(OP) \ template \ CUTE_HOST_DEVICE constexpr \ C<(t OP)> operator OP (C) { \ return {}; \ } #define CUTE_BINARY_OP(OP) \ template \ CUTE_HOST_DEVICE constexpr \ C<(t OP u)> operator OP (C, C) { \ return {}; \ } CUTE_LEFT_UNARY_OP(+); CUTE_LEFT_UNARY_OP(-); CUTE_LEFT_UNARY_OP(~); CUTE_LEFT_UNARY_OP(!); CUTE_LEFT_UNARY_OP(*); CUTE_BINARY_OP( +); CUTE_BINARY_OP( -); CUTE_BINARY_OP( *); CUTE_BINARY_OP( /); CUTE_BINARY_OP( %); CUTE_BINARY_OP( &); CUTE_BINARY_OP( |); CUTE_BINARY_OP( ^); CUTE_BINARY_OP(<<); CUTE_BINARY_OP(>>); CUTE_BINARY_OP(&&); CUTE_BINARY_OP(||); CUTE_BINARY_OP(==); CUTE_BINARY_OP(!=); CUTE_BINARY_OP( >); CUTE_BINARY_OP( <); CUTE_BINARY_OP(>=); CUTE_BINARY_OP(<=); #undef CUTE_BINARY_OP #undef CUTE_LEFT_UNARY_OP #undef CUTE_RIGHT_UNARY_OP // // Mixed static-dynamic special cases // template ::value && t == 0)> CUTE_HOST_DEVICE constexpr C<0> operator*(C, U) { return {}; } template ::value && t == 0)> CUTE_HOST_DEVICE constexpr C<0> operator*(U, C) { return {}; } template ::value && t == 0)> CUTE_HOST_DEVICE constexpr C<0> operator/(C, U) { return {}; } template ::value && (t == 1 || t == -1))> CUTE_HOST_DEVICE constexpr C<0> operator%(U, C) { return {}; } template ::value && t == 0)> CUTE_HOST_DEVICE constexpr C<0> operator%(C, U) { return {}; } template ::value && t == 0)> CUTE_HOST_DEVICE constexpr C<0> operator&(C, U) { return {}; } template ::value && t == 0)> CUTE_HOST_DEVICE constexpr C<0> operator&(U, C) { return {}; } template ::value && !bool(t))> CUTE_HOST_DEVICE constexpr C operator&&(C, U) { return {}; } template ::value && !bool(t))> CUTE_HOST_DEVICE constexpr C operator&&(U, C) { return {}; } template ::value && bool(t))> CUTE_HOST_DEVICE constexpr C operator||(C, U) { return {}; } template ::value && bool(t))> CUTE_HOST_DEVICE constexpr C operator||(U, C) { return {}; } // // Named functions from math.hpp // #define CUTE_NAMED_UNARY_FN(OP) \ template \ CUTE_HOST_DEVICE constexpr \ C OP (C) { \ return {}; \ } #define CUTE_NAMED_BINARY_FN(OP) \ template \ CUTE_HOST_DEVICE constexpr \ C OP (C, C) { \ return {}; \ } \ template ::value)> \ CUTE_HOST_DEVICE constexpr \ auto OP (C, U u) { \ return OP(t,u); \ } \ template ::value)> \ CUTE_HOST_DEVICE constexpr \ auto OP (T t, C) { \ return OP(t,u); \ } CUTE_NAMED_UNARY_FN(abs); CUTE_NAMED_UNARY_FN(signum); CUTE_NAMED_UNARY_FN(has_single_bit); CUTE_NAMED_BINARY_FN(max); CUTE_NAMED_BINARY_FN(min); CUTE_NAMED_BINARY_FN(shiftl); CUTE_NAMED_BINARY_FN(shiftr); CUTE_NAMED_BINARY_FN(gcd); CUTE_NAMED_BINARY_FN(lcm); #undef CUTE_NAMED_UNARY_FN #undef CUTE_NAMED_BINARY_FN // // Other functions // template CUTE_HOST_DEVICE constexpr C safe_div(C, C) { static_assert(t % u == 0, "Static safe_div requires t % u == 0"); return {}; } template ::value)> CUTE_HOST_DEVICE constexpr auto safe_div(C, U u) { return t / u; } template ::value)> CUTE_HOST_DEVICE constexpr auto safe_div(T t, C) { return t / u; } template CUTE_HOST_DEVICE constexpr decltype(auto) conditional_return(true_type, TrueType&& t, FalseType&&) { return static_cast(t); } template CUTE_HOST_DEVICE constexpr decltype(auto) conditional_return(false_type, TrueType&&, FalseType&& f) { return static_cast(f); } // TrueType and FalseType must have a common type template CUTE_HOST_DEVICE constexpr auto conditional_return(bool b, TrueType const& t, FalseType const& f) { return b ? t : f; } // TrueType and FalseType don't require a common type template CUTE_HOST_DEVICE constexpr auto conditional_return(TrueType const& t, FalseType const& f) { if constexpr (b) { return t; } else { return f; } } template CUTE_HOST_DEVICE constexpr auto static_value() { if constexpr (is_std_integral::value) { return Int{}; } else { return Trait::value; } CUTE_GCC_UNREACHABLE; } // // Display utilities // template CUTE_HOST_DEVICE void print(C) { printf("_"); ::cute::print(Value); } #if !defined(__CUDACC_RTC__) template CUTE_HOST std::ostream& operator<<(std::ostream& os, C const&) { return os << "_" << t; } #endif } // end namespace cute