// Copyright 2020 Google LLC // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef HIGHWAY_HWY_BASE_H_ #define HIGHWAY_HWY_BASE_H_ // For SIMD module implementations and their callers, target-independent. // IWYU pragma: begin_exports #include #include #include "hwy/detect_compiler_arch.h" #include "hwy/highway_export.h" #if HWY_COMPILER_MSVC && defined(_MSVC_LANG) && _MSVC_LANG > __cplusplus #define HWY_CXX_LANG _MSVC_LANG #else #define HWY_CXX_LANG __cplusplus #endif // "IWYU pragma: keep" does not work for these includes, so hide from the IDE. #if !HWY_IDE #if !defined(HWY_NO_LIBCXX) #ifndef __STDC_FORMAT_MACROS #define __STDC_FORMAT_MACROS // before inttypes.h #endif #include #endif #if (HWY_ARCH_X86 && !defined(HWY_NO_LIBCXX)) || HWY_COMPILER_MSVC #include #endif #endif // !HWY_IDE #if !defined(HWY_NO_LIBCXX) && HWY_CXX_LANG > 201703L && \ __cpp_impl_three_way_comparison >= 201907L && defined(__has_include) && \ !defined(HWY_DISABLE_CXX20_THREE_WAY_COMPARE) #if __has_include() #include #define HWY_HAVE_CXX20_THREE_WAY_COMPARE 1 #endif #endif // IWYU pragma: end_exports #if HWY_COMPILER_MSVC #include // memcpy #endif //------------------------------------------------------------------------------ // Compiler-specific definitions #define HWY_STR_IMPL(macro) #macro #define HWY_STR(macro) HWY_STR_IMPL(macro) #if HWY_COMPILER_MSVC #include #define HWY_RESTRICT __restrict #define HWY_INLINE __forceinline #define HWY_NOINLINE __declspec(noinline) #define HWY_FLATTEN #define HWY_NORETURN __declspec(noreturn) #define HWY_LIKELY(expr) (expr) #define HWY_UNLIKELY(expr) (expr) #define HWY_PRAGMA(tokens) __pragma(tokens) #define HWY_DIAGNOSTICS(tokens) HWY_PRAGMA(warning(tokens)) #define HWY_DIAGNOSTICS_OFF(msc, gcc) HWY_DIAGNOSTICS(msc) #define HWY_MAYBE_UNUSED #define HWY_HAS_ASSUME_ALIGNED 0 #if (_MSC_VER >= 1700) #define HWY_MUST_USE_RESULT _Check_return_ #else #define HWY_MUST_USE_RESULT #endif #else #define HWY_RESTRICT __restrict__ // force inlining without optimization enabled creates very inefficient code // that can cause compiler timeout #ifdef __OPTIMIZE__ #define HWY_INLINE inline __attribute__((always_inline)) #else #define HWY_INLINE inline #endif #define HWY_NOINLINE __attribute__((noinline)) #define HWY_FLATTEN __attribute__((flatten)) #define HWY_NORETURN __attribute__((noreturn)) #define HWY_LIKELY(expr) __builtin_expect(!!(expr), 1) #define HWY_UNLIKELY(expr) __builtin_expect(!!(expr), 0) #define HWY_PRAGMA(tokens) _Pragma(#tokens) #define HWY_DIAGNOSTICS(tokens) HWY_PRAGMA(GCC diagnostic tokens) #define HWY_DIAGNOSTICS_OFF(msc, gcc) HWY_DIAGNOSTICS(gcc) // Encountered "attribute list cannot appear here" when using the C++17 // [[maybe_unused]], so only use the old style attribute for now. #define HWY_MAYBE_UNUSED __attribute__((unused)) #define HWY_MUST_USE_RESULT __attribute__((warn_unused_result)) #endif // !HWY_COMPILER_MSVC //------------------------------------------------------------------------------ // Builtin/attributes (no more #include after this point due to namespace!) namespace hwy { // Enables error-checking of format strings. #if HWY_HAS_ATTRIBUTE(__format__) #define HWY_FORMAT(idx_fmt, idx_arg) \ __attribute__((__format__(__printf__, idx_fmt, idx_arg))) #else #define HWY_FORMAT(idx_fmt, idx_arg) #endif // Returns a void* pointer which the compiler then assumes is N-byte aligned. // Example: float* HWY_RESTRICT aligned = (float*)HWY_ASSUME_ALIGNED(in, 32); // // The assignment semantics are required by GCC/Clang. ICC provides an in-place // __assume_aligned, whereas MSVC's __assume appears unsuitable. #if HWY_HAS_BUILTIN(__builtin_assume_aligned) #define HWY_ASSUME_ALIGNED(ptr, align) __builtin_assume_aligned((ptr), (align)) #else #define HWY_ASSUME_ALIGNED(ptr, align) (ptr) /* not supported */ #endif // Special case to increases required alignment #define HWY_RCAST_ALIGNED(type, ptr) \ reinterpret_cast(HWY_ASSUME_ALIGNED((ptr), alignof(type))) // Clang and GCC require attributes on each function into which SIMD intrinsics // are inlined. Support both per-function annotation (HWY_ATTR) for lambdas and // automatic annotation via pragmas. #if HWY_COMPILER_ICC // As of ICC 2021.{1-9} the pragma is neither implemented nor required. #define HWY_PUSH_ATTRIBUTES(targets_str) #define HWY_POP_ATTRIBUTES #elif HWY_COMPILER_CLANG #define HWY_PUSH_ATTRIBUTES(targets_str) \ HWY_PRAGMA(clang attribute push(__attribute__((target(targets_str))), \ apply_to = function)) #define HWY_POP_ATTRIBUTES HWY_PRAGMA(clang attribute pop) #elif HWY_COMPILER_GCC_ACTUAL #define HWY_PUSH_ATTRIBUTES(targets_str) \ HWY_PRAGMA(GCC push_options) HWY_PRAGMA(GCC target targets_str) #define HWY_POP_ATTRIBUTES HWY_PRAGMA(GCC pop_options) #else #define HWY_PUSH_ATTRIBUTES(targets_str) #define HWY_POP_ATTRIBUTES #endif //------------------------------------------------------------------------------ // Macros #define HWY_API static HWY_INLINE HWY_FLATTEN HWY_MAYBE_UNUSED #define HWY_CONCAT_IMPL(a, b) a##b #define HWY_CONCAT(a, b) HWY_CONCAT_IMPL(a, b) #define HWY_MIN(a, b) ((a) < (b) ? (a) : (b)) #define HWY_MAX(a, b) ((a) > (b) ? (a) : (b)) #if HWY_COMPILER_GCC_ACTUAL // nielskm: GCC does not support '#pragma GCC unroll' without the factor. #define HWY_UNROLL(factor) HWY_PRAGMA(GCC unroll factor) #define HWY_DEFAULT_UNROLL HWY_UNROLL(4) #elif HWY_COMPILER_CLANG || HWY_COMPILER_ICC || HWY_COMPILER_ICX #define HWY_UNROLL(factor) HWY_PRAGMA(unroll factor) #define HWY_DEFAULT_UNROLL HWY_UNROLL() #else #define HWY_UNROLL(factor) #define HWY_DEFAULT_UNROLL #endif // Tell a compiler that the expression always evaluates to true. // The expression should be free from any side effects. // Some older compilers may have trouble with complex expressions, therefore // it is advisable to split multiple conditions into separate assume statements, // and manually check the generated code. // OK but could fail: // HWY_ASSUME(x == 2 && y == 3); // Better: // HWY_ASSUME(x == 2); // HWY_ASSUME(y == 3); #if HWY_HAS_CPP_ATTRIBUTE(assume) #define HWY_ASSUME(expr) [[assume(expr)]] #elif HWY_COMPILER_MSVC || HWY_COMPILER_ICC #define HWY_ASSUME(expr) __assume(expr) // __builtin_assume() was added in clang 3.6. #elif HWY_COMPILER_CLANG && HWY_HAS_BUILTIN(__builtin_assume) #define HWY_ASSUME(expr) __builtin_assume(expr) // __builtin_unreachable() was added in GCC 4.5, but __has_builtin() was added // later, so check for the compiler version directly. #elif HWY_COMPILER_GCC_ACTUAL >= 405 #define HWY_ASSUME(expr) \ ((expr) ? static_cast(0) : __builtin_unreachable()) #else #define HWY_ASSUME(expr) static_cast(0) #endif // Compile-time fence to prevent undesirable code reordering. On Clang x86, the // typical asm volatile("" : : : "memory") has no effect, whereas atomic fence // does, without generating code. #if HWY_ARCH_X86 && !defined(HWY_NO_LIBCXX) #define HWY_FENCE std::atomic_thread_fence(std::memory_order_acq_rel) #else // TODO(janwas): investigate alternatives. On Arm, the above generates barriers. #define HWY_FENCE #endif // 4 instances of a given literal value, useful as input to LoadDup128. #define HWY_REP4(literal) literal, literal, literal, literal HWY_DLLEXPORT HWY_NORETURN void HWY_FORMAT(3, 4) Abort(const char* file, int line, const char* format, ...); #define HWY_ABORT(format, ...) \ ::hwy::Abort(__FILE__, __LINE__, format, ##__VA_ARGS__) // Always enabled. #define HWY_ASSERT(condition) \ do { \ if (!(condition)) { \ HWY_ABORT("Assert %s", #condition); \ } \ } while (0) #if HWY_HAS_FEATURE(memory_sanitizer) || defined(MEMORY_SANITIZER) #define HWY_IS_MSAN 1 #else #define HWY_IS_MSAN 0 #endif #if HWY_HAS_FEATURE(address_sanitizer) || defined(ADDRESS_SANITIZER) #define HWY_IS_ASAN 1 #else #define HWY_IS_ASAN 0 #endif #if HWY_HAS_FEATURE(thread_sanitizer) || defined(THREAD_SANITIZER) #define HWY_IS_TSAN 1 #else #define HWY_IS_TSAN 0 #endif // MSAN may cause lengthy build times or false positives e.g. in AVX3 DemoteTo. // You can disable MSAN by adding this attribute to the function that fails. #if HWY_IS_MSAN #define HWY_ATTR_NO_MSAN __attribute__((no_sanitize_memory)) #else #define HWY_ATTR_NO_MSAN #endif // For enabling HWY_DASSERT and shortening tests in slower debug builds #if !defined(HWY_IS_DEBUG_BUILD) // Clang does not define NDEBUG, but it and GCC define __OPTIMIZE__, and recent // MSVC defines NDEBUG (if not, could instead check _DEBUG). #if (!defined(__OPTIMIZE__) && !defined(NDEBUG)) || HWY_IS_ASAN || \ HWY_IS_MSAN || HWY_IS_TSAN || defined(__clang_analyzer__) #define HWY_IS_DEBUG_BUILD 1 #else #define HWY_IS_DEBUG_BUILD 0 #endif #endif // HWY_IS_DEBUG_BUILD #if HWY_IS_DEBUG_BUILD #define HWY_DASSERT(condition) HWY_ASSERT(condition) #else #define HWY_DASSERT(condition) \ do { \ } while (0) #endif #if __cpp_constexpr >= 201304L #define HWY_CXX14_CONSTEXPR constexpr #else #define HWY_CXX14_CONSTEXPR #endif #ifndef HWY_HAVE_CXX20_THREE_WAY_COMPARE #define HWY_HAVE_CXX20_THREE_WAY_COMPARE 0 #endif //------------------------------------------------------------------------------ // CopyBytes / ZeroBytes #if HWY_COMPILER_MSVC #pragma intrinsic(memcpy) #pragma intrinsic(memset) #endif // The source/destination must not overlap/alias. template HWY_API void CopyBytes(const From* from, To* to) { #if HWY_COMPILER_MSVC memcpy(to, from, kBytes); #else __builtin_memcpy(to, from, kBytes); #endif } HWY_API void CopyBytes(const void* HWY_RESTRICT from, void* HWY_RESTRICT to, size_t num_of_bytes_to_copy) { #if HWY_COMPILER_MSVC memcpy(to, from, num_of_bytes_to_copy); #else __builtin_memcpy(to, from, num_of_bytes_to_copy); #endif } // Same as CopyBytes, but for same-sized objects; avoids a size argument. template HWY_API void CopySameSize(const From* HWY_RESTRICT from, To* HWY_RESTRICT to) { static_assert(sizeof(From) == sizeof(To), ""); CopyBytes(from, to); } template HWY_API void ZeroBytes(To* to) { #if HWY_COMPILER_MSVC memset(to, 0, kBytes); #else __builtin_memset(to, 0, kBytes); #endif } HWY_API void ZeroBytes(void* to, size_t num_bytes) { #if HWY_COMPILER_MSVC memset(to, 0, num_bytes); #else __builtin_memset(to, 0, num_bytes); #endif } //------------------------------------------------------------------------------ // kMaxVectorSize (undocumented, pending removal) #if HWY_ARCH_X86 static constexpr HWY_MAYBE_UNUSED size_t kMaxVectorSize = 64; // AVX-512 #elif HWY_ARCH_RVV && defined(__riscv_v_intrinsic) && \ __riscv_v_intrinsic >= 11000 // Not actually an upper bound on the size. static constexpr HWY_MAYBE_UNUSED size_t kMaxVectorSize = 4096; #else static constexpr HWY_MAYBE_UNUSED size_t kMaxVectorSize = 16; #endif //------------------------------------------------------------------------------ // Alignment // Potentially useful for LoadDup128 and capped vectors. In other cases, arrays // should be allocated dynamically via aligned_allocator.h because Lanes() may // exceed the stack size. #if HWY_ARCH_X86 #define HWY_ALIGN_MAX alignas(64) #elif HWY_ARCH_RVV && defined(__riscv_v_intrinsic) && \ __riscv_v_intrinsic >= 11000 #define HWY_ALIGN_MAX alignas(8) // only elements need be aligned #else #define HWY_ALIGN_MAX alignas(16) #endif //------------------------------------------------------------------------------ // Lane types // hwy::float16_t and hwy::bfloat16_t are forward declared here to allow // BitCastScalar to be implemented before the implementations of the // hwy::float16_t and hwy::bfloat16_t types struct float16_t; struct bfloat16_t; using float32_t = float; using float64_t = double; #pragma pack(push, 1) // Aligned 128-bit type. Cannot use __int128 because clang doesn't yet align it: // https://reviews.llvm.org/D86310 struct alignas(16) uint128_t { uint64_t lo; // little-endian layout uint64_t hi; }; // 64 bit key plus 64 bit value. Faster than using uint128_t when only the key // field is to be compared (Lt128Upper instead of Lt128). struct alignas(16) K64V64 { uint64_t value; // little-endian layout uint64_t key; }; // 32 bit key plus 32 bit value. Allows vqsort recursions to terminate earlier // than when considering both to be a 64-bit key. struct alignas(8) K32V32 { uint32_t value; // little-endian layout uint32_t key; }; #pragma pack(pop) static inline HWY_MAYBE_UNUSED bool operator<(const uint128_t& a, const uint128_t& b) { return (a.hi == b.hi) ? a.lo < b.lo : a.hi < b.hi; } // Required for std::greater. static inline HWY_MAYBE_UNUSED bool operator>(const uint128_t& a, const uint128_t& b) { return b < a; } static inline HWY_MAYBE_UNUSED bool operator==(const uint128_t& a, const uint128_t& b) { return a.lo == b.lo && a.hi == b.hi; } static inline HWY_MAYBE_UNUSED bool operator<(const K64V64& a, const K64V64& b) { return a.key < b.key; } // Required for std::greater. static inline HWY_MAYBE_UNUSED bool operator>(const K64V64& a, const K64V64& b) { return b < a; } static inline HWY_MAYBE_UNUSED bool operator==(const K64V64& a, const K64V64& b) { return a.key == b.key; } static inline HWY_MAYBE_UNUSED bool operator<(const K32V32& a, const K32V32& b) { return a.key < b.key; } // Required for std::greater. static inline HWY_MAYBE_UNUSED bool operator>(const K32V32& a, const K32V32& b) { return b < a; } static inline HWY_MAYBE_UNUSED bool operator==(const K32V32& a, const K32V32& b) { return a.key == b.key; } //------------------------------------------------------------------------------ // Controlling overload resolution (SFINAE) template struct EnableIfT {}; template <> struct EnableIfT { using type = void; }; template using EnableIf = typename EnableIfT::type; template struct IsSameT { enum { value = 0 }; }; template struct IsSameT { enum { value = 1 }; }; template HWY_API constexpr bool IsSame() { return IsSameT::value; } // Returns whether T matches either of U1 or U2 template HWY_API constexpr bool IsSameEither() { return IsSameT::value || IsSameT::value; } template struct IfT { using type = Then; }; template struct IfT { using type = Else; }; template using If = typename IfT::type; template struct IsConstT { enum { value = 0 }; }; template struct IsConstT { enum { value = 1 }; }; template HWY_API constexpr bool IsConst() { return IsConstT::value; } template struct RemoveConstT { using type = T; }; template struct RemoveConstT { using type = T; }; template using RemoveConst = typename RemoveConstT::type; template struct RemoveVolatileT { using type = T; }; template struct RemoveVolatileT { using type = T; }; template using RemoveVolatile = typename RemoveVolatileT::type; template struct RemoveRefT { using type = T; }; template struct RemoveRefT { using type = T; }; template struct RemoveRefT { using type = T; }; template using RemoveRef = typename RemoveRefT::type; template using RemoveCvRef = RemoveConst>>; // Insert into template/function arguments to enable this overload only for // vectors of exactly, at most (LE), or more than (GT) this many bytes. // // As an example, checking for a total size of 16 bytes will match both // Simd and Simd. #define HWY_IF_V_SIZE(T, kN, bytes) \ hwy::EnableIf* = nullptr #define HWY_IF_V_SIZE_LE(T, kN, bytes) \ hwy::EnableIf* = nullptr #define HWY_IF_V_SIZE_GT(T, kN, bytes) \ hwy::EnableIf<(kN * sizeof(T) > bytes)>* = nullptr #define HWY_IF_LANES(kN, lanes) hwy::EnableIf<(kN == lanes)>* = nullptr #define HWY_IF_LANES_LE(kN, lanes) hwy::EnableIf<(kN <= lanes)>* = nullptr #define HWY_IF_LANES_GT(kN, lanes) hwy::EnableIf<(kN > lanes)>* = nullptr #define HWY_IF_UNSIGNED(T) hwy::EnableIf()>* = nullptr #define HWY_IF_SIGNED(T) \ hwy::EnableIf() && !hwy::IsFloat() && \ !hwy::IsSpecialFloat()>* = nullptr #define HWY_IF_FLOAT(T) hwy::EnableIf()>* = nullptr #define HWY_IF_NOT_FLOAT(T) hwy::EnableIf()>* = nullptr #define HWY_IF_FLOAT3264(T) hwy::EnableIf()>* = nullptr #define HWY_IF_NOT_FLOAT3264(T) hwy::EnableIf()>* = nullptr #define HWY_IF_SPECIAL_FLOAT(T) \ hwy::EnableIf()>* = nullptr #define HWY_IF_NOT_SPECIAL_FLOAT(T) \ hwy::EnableIf()>* = nullptr #define HWY_IF_FLOAT_OR_SPECIAL(T) \ hwy::EnableIf() || hwy::IsSpecialFloat()>* = nullptr #define HWY_IF_NOT_FLOAT_NOR_SPECIAL(T) \ hwy::EnableIf() && !hwy::IsSpecialFloat()>* = nullptr #define HWY_IF_INTEGER(T) hwy::EnableIf()>* = nullptr #define HWY_IF_T_SIZE(T, bytes) hwy::EnableIf* = nullptr #define HWY_IF_NOT_T_SIZE(T, bytes) \ hwy::EnableIf* = nullptr // bit_array = 0x102 means 1 or 8 bytes. There is no NONE_OF because it sounds // too similar. If you want the opposite of this (2 or 4 bytes), ask for those // bits explicitly (0x14) instead of attempting to 'negate' 0x102. #define HWY_IF_T_SIZE_ONE_OF(T, bit_array) \ hwy::EnableIf<((size_t{1} << sizeof(T)) & (bit_array)) != 0>* = nullptr #define HWY_IF_T_SIZE_LE(T, bytes) \ hwy::EnableIf<(sizeof(T) <= (bytes))>* = nullptr #define HWY_IF_T_SIZE_GT(T, bytes) \ hwy::EnableIf<(sizeof(T) > (bytes))>* = nullptr #define HWY_IF_SAME(T, expected) \ hwy::EnableIf, expected>()>* = nullptr #define HWY_IF_NOT_SAME(T, expected) \ hwy::EnableIf, expected>()>* = nullptr // One of two expected types #define HWY_IF_SAME2(T, expected1, expected2) \ hwy::EnableIf< \ hwy::IsSameEither, expected1, expected2>()>* = \ nullptr #define HWY_IF_U8(T) HWY_IF_SAME(T, uint8_t) #define HWY_IF_U16(T) HWY_IF_SAME(T, uint16_t) #define HWY_IF_U32(T) HWY_IF_SAME(T, uint32_t) #define HWY_IF_U64(T) HWY_IF_SAME(T, uint64_t) #define HWY_IF_I8(T) HWY_IF_SAME(T, int8_t) #define HWY_IF_I16(T) HWY_IF_SAME(T, int16_t) #define HWY_IF_I32(T) HWY_IF_SAME(T, int32_t) #define HWY_IF_I64(T) HWY_IF_SAME(T, int64_t) #define HWY_IF_BF16(T) HWY_IF_SAME(T, hwy::bfloat16_t) #define HWY_IF_NOT_BF16(T) HWY_IF_NOT_SAME(T, hwy::bfloat16_t) #define HWY_IF_F16(T) HWY_IF_SAME(T, hwy::float16_t) #define HWY_IF_NOT_F16(T) HWY_IF_NOT_SAME(T, hwy::float16_t) #define HWY_IF_F32(T) HWY_IF_SAME(T, float) #define HWY_IF_F64(T) HWY_IF_SAME(T, double) // Use instead of HWY_IF_T_SIZE to avoid ambiguity with float16_t/float/double // overloads. #define HWY_IF_UI8(T) HWY_IF_SAME2(T, uint8_t, int8_t) #define HWY_IF_UI16(T) HWY_IF_SAME2(T, uint16_t, int16_t) #define HWY_IF_UI32(T) HWY_IF_SAME2(T, uint32_t, int32_t) #define HWY_IF_UI64(T) HWY_IF_SAME2(T, uint64_t, int64_t) #define HWY_IF_LANES_PER_BLOCK(T, N, LANES) \ hwy::EnableIf* = nullptr // Empty struct used as a size tag type. template struct SizeTag {}; template class DeclValT { private: template static URef TryAddRValRef(int); template static U TryAddRValRef(Arg); public: using type = decltype(TryAddRValRef(0)); enum { kDisableDeclValEvaluation = 1 }; }; // hwy::DeclVal() can only be used in unevaluated contexts such as within an // expression of a decltype specifier. // hwy::DeclVal() does not require that T have a public default constructor template HWY_API typename DeclValT::type DeclVal() noexcept { static_assert(!DeclValT::kDisableDeclValEvaluation, "DeclVal() cannot be used in an evaluated context"); } template struct IsArrayT { enum { value = 0 }; }; template struct IsArrayT { enum { value = 1 }; }; template struct IsArrayT { enum { value = 1 }; }; template static constexpr bool IsArray() { return IsArrayT::value; } #if HWY_COMPILER_MSVC HWY_DIAGNOSTICS(push) HWY_DIAGNOSTICS_OFF(disable : 4180, ignored "-Wignored-qualifiers") #endif template class IsConvertibleT { private: template static hwy::SizeTag<1> TestFuncWithToArg(T); template static decltype(IsConvertibleT::template TestFuncWithToArg( DeclVal())) TryConvTest(int); template static hwy::SizeTag<0> TryConvTest(Arg); public: enum { value = (IsSame>, void>() && IsSame>, void>()) || (!IsArray() && (IsSame())>() || !IsSame, RemoveConst>()) && IsSame(0)), hwy::SizeTag<1>>()) }; }; #if HWY_COMPILER_MSVC HWY_DIAGNOSTICS(pop) #endif template HWY_API constexpr bool IsConvertible() { return IsConvertibleT::value; } template class IsStaticCastableT { private: template (DeclVal()))> static hwy::SizeTag<1> TryStaticCastTest(int); template static hwy::SizeTag<0> TryStaticCastTest(Arg); public: enum { value = IsSame(0)), hwy::SizeTag<1>>() }; }; template static constexpr bool IsStaticCastable() { return IsStaticCastableT::value; } #define HWY_IF_CASTABLE(From, To) \ hwy::EnableIf()>* = nullptr #define HWY_IF_OP_CASTABLE(op, T, Native) \ HWY_IF_CASTABLE(decltype(DeclVal() op DeclVal()), Native) template class IsAssignableT { private: template () = DeclVal())> static hwy::SizeTag<1> TryAssignTest(int); template static hwy::SizeTag<0> TryAssignTest(Arg); public: enum { value = IsSame(0)), hwy::SizeTag<1>>() }; }; template static constexpr bool IsAssignable() { return IsAssignableT::value; } #define HWY_IF_ASSIGNABLE(T, From) \ hwy::EnableIf()>* = nullptr // ---------------------------------------------------------------------------- // IsSpecialFloat // These types are often special-cased and not supported in all ops. template HWY_API constexpr bool IsSpecialFloat() { return IsSameEither, hwy::float16_t, hwy::bfloat16_t>(); } // ----------------------------------------------------------------------------- // IsIntegerLaneType and IsInteger template HWY_API constexpr bool IsIntegerLaneType() { return false; } template <> HWY_INLINE constexpr bool IsIntegerLaneType() { return true; } template <> HWY_INLINE constexpr bool IsIntegerLaneType() { return true; } template <> HWY_INLINE constexpr bool IsIntegerLaneType() { return true; } template <> HWY_INLINE constexpr bool IsIntegerLaneType() { return true; } template <> HWY_INLINE constexpr bool IsIntegerLaneType() { return true; } template <> HWY_INLINE constexpr bool IsIntegerLaneType() { return true; } template <> HWY_INLINE constexpr bool IsIntegerLaneType() { return true; } template <> HWY_INLINE constexpr bool IsIntegerLaneType() { return true; } template HWY_API constexpr bool IsInteger() { // NOTE: Do not add a IsInteger() specialization below as it is // possible for IsSame() to be true when compiled with MSVC // with the /Zc:wchar_t- option. return IsIntegerLaneType() || IsSame, wchar_t>() || IsSameEither, size_t, ptrdiff_t>() || IsSameEither, intptr_t, uintptr_t>(); } template <> HWY_INLINE constexpr bool IsInteger() { return true; } template <> HWY_INLINE constexpr bool IsInteger() { return true; } template <> HWY_INLINE constexpr bool IsInteger() { return true; } template <> HWY_INLINE constexpr bool IsInteger() { return true; } template <> HWY_INLINE constexpr bool IsInteger() { // NOLINT return true; } template <> HWY_INLINE constexpr bool IsInteger() { // NOLINT return true; } template <> HWY_INLINE constexpr bool IsInteger() { return true; } template <> HWY_INLINE constexpr bool IsInteger() { return true; } template <> HWY_INLINE constexpr bool IsInteger() { // NOLINT return true; } template <> HWY_INLINE constexpr bool IsInteger() { // NOLINT return true; } template <> HWY_INLINE constexpr bool IsInteger() { // NOLINT return true; } template <> HWY_INLINE constexpr bool IsInteger() { // NOLINT return true; } #if defined(__cpp_char8_t) && __cpp_char8_t >= 201811L template <> HWY_INLINE constexpr bool IsInteger() { return true; } #endif template <> HWY_INLINE constexpr bool IsInteger() { return true; } template <> HWY_INLINE constexpr bool IsInteger() { return true; } // ----------------------------------------------------------------------------- // BitCastScalar #if HWY_HAS_BUILTIN(__builtin_bit_cast) || HWY_COMPILER_MSVC >= 1926 #define HWY_BITCASTSCALAR_CONSTEXPR constexpr #else #define HWY_BITCASTSCALAR_CONSTEXPR #endif #if __cpp_constexpr >= 201304L #define HWY_BITCASTSCALAR_CXX14_CONSTEXPR HWY_BITCASTSCALAR_CONSTEXPR #else #define HWY_BITCASTSCALAR_CXX14_CONSTEXPR #endif #if HWY_HAS_BUILTIN(__builtin_bit_cast) || HWY_COMPILER_MSVC >= 1926 namespace detail { template struct BitCastScalarSrcCastHelper { static HWY_INLINE constexpr const From& CastSrcValRef(const From& val) { return val; } }; #if HWY_COMPILER_CLANG >= 900 && HWY_COMPILER_CLANG < 1000 // Workaround for Clang 9 constexpr __builtin_bit_cast bug template >() && hwy::IsInteger>()>* = nullptr> static HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR To BuiltinBitCastScalar(const From& val) { static_assert(sizeof(To) == sizeof(From), "sizeof(To) == sizeof(From) must be true"); return static_cast(val); } template >() && hwy::IsInteger>())>* = nullptr> static HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR To BuiltinBitCastScalar(const From& val) { return __builtin_bit_cast(To, val); } #endif // HWY_COMPILER_CLANG >= 900 && HWY_COMPILER_CLANG < 1000 } // namespace detail template HWY_API HWY_BITCASTSCALAR_CONSTEXPR To BitCastScalar(const From& val) { // If From is hwy::float16_t or hwy::bfloat16_t, first cast val to either // const typename From::Native& or const uint16_t& using // detail::BitCastScalarSrcCastHelper>::CastSrcValRef to // allow BitCastScalar from hwy::float16_t or hwy::bfloat16_t to be constexpr // if To is not a pointer type, union type, or a struct/class containing a // pointer, union, or reference subobject #if HWY_COMPILER_CLANG >= 900 && HWY_COMPILER_CLANG < 1000 return detail::BuiltinBitCastScalar( detail::BitCastScalarSrcCastHelper>::CastSrcValRef( val)); #else return __builtin_bit_cast( To, detail::BitCastScalarSrcCastHelper>::CastSrcValRef( val)); #endif } template HWY_API HWY_BITCASTSCALAR_CONSTEXPR To BitCastScalar(const From& val) { // If To is hwy::float16_t or hwy::bfloat16_t, first do a BitCastScalar of val // to uint16_t, and then bit cast the uint16_t value to To using To::FromBits // as hwy::float16_t::FromBits and hwy::bfloat16_t::FromBits are guaranteed to // be constexpr if the __builtin_bit_cast intrinsic is available. return To::FromBits(BitCastScalar(val)); } #else template HWY_API HWY_BITCASTSCALAR_CONSTEXPR To BitCastScalar(const From& val) { To result; CopySameSize(&val, &result); return result; } #endif //------------------------------------------------------------------------------ // F16 lane type #pragma pack(push, 1) // Compiler supports __fp16 and load/store/conversion NEON intrinsics, which are // included in Armv8 and VFPv4 (except with MSVC). On Armv7 Clang requires // __ARM_FP & 2 whereas Armv7 GCC requires -mfp16-format=ieee. #if (HWY_ARCH_ARM_A64 && !HWY_COMPILER_MSVC) || \ (HWY_COMPILER_CLANG && defined(__ARM_FP) && (__ARM_FP & 2)) || \ (HWY_COMPILER_GCC_ACTUAL && defined(__ARM_FP16_FORMAT_IEEE)) #define HWY_NEON_HAVE_F16C 1 #else #define HWY_NEON_HAVE_F16C 0 #endif // RVV with f16 extension supports _Float16 and f16 vector ops. If set, implies // HWY_HAVE_FLOAT16. #if HWY_ARCH_RVV && defined(__riscv_zvfh) && HWY_COMPILER_CLANG >= 1600 #define HWY_RVV_HAVE_F16_VEC 1 #else #define HWY_RVV_HAVE_F16_VEC 0 #endif // x86 compiler supports _Float16, not necessarily with operators. // Avoid clang-cl because it lacks __extendhfsf2. #if HWY_ARCH_X86 && defined(__SSE2__) && defined(__FLT16_MAX__) && \ ((HWY_COMPILER_CLANG >= 1500 && !HWY_COMPILER_CLANGCL) || \ HWY_COMPILER_GCC_ACTUAL >= 1200) #define HWY_SSE2_HAVE_F16_TYPE 1 #else #define HWY_SSE2_HAVE_F16_TYPE 0 #endif #ifndef HWY_HAVE_SCALAR_F16_TYPE // Compiler supports _Float16, not necessarily with operators. #if HWY_NEON_HAVE_F16C || HWY_RVV_HAVE_F16_VEC || HWY_SSE2_HAVE_F16_TYPE #define HWY_HAVE_SCALAR_F16_TYPE 1 #else #define HWY_HAVE_SCALAR_F16_TYPE 0 #endif #endif // HWY_HAVE_SCALAR_F16_TYPE #ifndef HWY_HAVE_SCALAR_F16_OPERATORS // Recent enough compiler also has operators. #if HWY_HAVE_SCALAR_F16_TYPE && \ (HWY_COMPILER_CLANG >= 1800 || HWY_COMPILER_GCC_ACTUAL >= 1200 || \ (HWY_COMPILER_CLANG >= 1500 && !HWY_COMPILER_CLANGCL && \ !defined(_WIN32)) || \ (HWY_ARCH_ARM && \ (HWY_COMPILER_CLANG >= 900 || HWY_COMPILER_GCC_ACTUAL >= 800))) #define HWY_HAVE_SCALAR_F16_OPERATORS 1 #else #define HWY_HAVE_SCALAR_F16_OPERATORS 0 #endif #endif // HWY_HAVE_SCALAR_F16_OPERATORS namespace detail { template , bool = IsSpecialFloat()> struct SpecialFloatUnwrapArithOpOperandT {}; template struct SpecialFloatUnwrapArithOpOperandT { using type = T; }; template using SpecialFloatUnwrapArithOpOperand = typename SpecialFloatUnwrapArithOpOperandT::type; template > struct NativeSpecialFloatToWrapperT { using type = T; }; template using NativeSpecialFloatToWrapper = typename NativeSpecialFloatToWrapperT::type; } // namespace detail // Match [u]int##_t naming scheme so rvv-inl.h macros can obtain the type name // by concatenating base type and bits. We use a wrapper class instead of a // typedef to the native type to ensure that the same symbols, e.g. for VQSort, // are generated regardless of F16 support; see #1684. struct alignas(2) float16_t { #if HWY_HAVE_SCALAR_F16_TYPE #if HWY_RVV_HAVE_F16_VEC || HWY_SSE2_HAVE_F16_TYPE using Native = _Float16; #elif HWY_NEON_HAVE_F16C using Native = __fp16; #else #error "Logic error: condition should be 'all but NEON_HAVE_F16C'" #endif #endif // HWY_HAVE_SCALAR_F16_TYPE union { #if HWY_HAVE_SCALAR_F16_TYPE // Accessed via NativeLaneType, and used directly if // HWY_HAVE_SCALAR_F16_OPERATORS. Native native; #endif // Only accessed via NativeLaneType or U16LaneType. uint16_t bits; }; // Default init and copying. float16_t() noexcept = default; constexpr float16_t(const float16_t&) noexcept = default; constexpr float16_t(float16_t&&) noexcept = default; float16_t& operator=(const float16_t&) noexcept = default; float16_t& operator=(float16_t&&) noexcept = default; #if HWY_HAVE_SCALAR_F16_TYPE // NEON vget/set_lane intrinsics and SVE `svaddv` could use explicit // float16_t(intrinsic()), but user code expects implicit conversions. constexpr float16_t(Native arg) noexcept : native(arg) {} constexpr operator Native() const noexcept { return native; } #endif #if HWY_HAVE_SCALAR_F16_TYPE static HWY_BITCASTSCALAR_CONSTEXPR float16_t FromBits(uint16_t bits) { return float16_t(BitCastScalar(bits)); } #else private: struct F16FromU16BitsTag {}; constexpr float16_t(F16FromU16BitsTag /*tag*/, uint16_t u16_bits) : bits(u16_bits) {} public: static constexpr float16_t FromBits(uint16_t bits) { return float16_t(F16FromU16BitsTag(), bits); } #endif // When backed by a native type, ensure the wrapper behaves like the native // type by forwarding all operators. Unfortunately it seems difficult to reuse // this code in a base class, so we repeat it in float16_t. #if HWY_HAVE_SCALAR_F16_OPERATORS || HWY_IDE template , float16_t>() && IsConvertible()>* = nullptr> constexpr float16_t(T&& arg) noexcept : native(static_cast(static_cast(arg))) {} template , float16_t>() && !IsConvertible() && IsStaticCastable()>* = nullptr> explicit constexpr float16_t(T&& arg) noexcept : native(static_cast(static_cast(arg))) {} // pre-decrement operator (--x) HWY_CXX14_CONSTEXPR float16_t& operator--() noexcept { native = static_cast(native - Native{1}); return *this; } // post-decrement operator (x--) HWY_CXX14_CONSTEXPR float16_t operator--(int) noexcept { float16_t result = *this; native = static_cast(native - Native{1}); return result; } // pre-increment operator (++x) HWY_CXX14_CONSTEXPR float16_t& operator++() noexcept { native = static_cast(native + Native{1}); return *this; } // post-increment operator (x++) HWY_CXX14_CONSTEXPR float16_t operator++(int) noexcept { float16_t result = *this; native = static_cast(native + Native{1}); return result; } constexpr float16_t operator-() const noexcept { return float16_t(static_cast(-native)); } constexpr float16_t operator+() const noexcept { return *this; } // Reduce clutter by generating `operator+` and `operator+=` etc. Note that // we cannot token-paste `operator` and `+`, so pass it in as `op_func`. #define HWY_FLOAT16_BINARY_OP(op, op_func, assign_func) \ constexpr float16_t op_func(const float16_t& rhs) const noexcept { \ return float16_t(static_cast(native op rhs.native)); \ } \ template , \ typename RawResultT = \ decltype(DeclVal() op DeclVal()), \ typename ResultT = \ detail::NativeSpecialFloatToWrapper, \ HWY_IF_CASTABLE(RawResultT, ResultT)> \ constexpr ResultT op_func(const T& rhs) const noexcept(noexcept( \ static_cast(DeclVal() op DeclVal()))) { \ return static_cast(native op static_cast(rhs)); \ } \ HWY_CXX14_CONSTEXPR hwy::float16_t& assign_func( \ const hwy::float16_t& rhs) noexcept { \ native = static_cast(native op rhs.native); \ return *this; \ } \ template () op DeclVal()))> \ HWY_CXX14_CONSTEXPR hwy::float16_t& assign_func(const T& rhs) noexcept( \ noexcept( \ static_cast(DeclVal() op DeclVal()))) { \ native = static_cast(native op rhs); \ return *this; \ } HWY_FLOAT16_BINARY_OP(+, operator+, operator+=) HWY_FLOAT16_BINARY_OP(-, operator-, operator-=) HWY_FLOAT16_BINARY_OP(*, operator*, operator*=) HWY_FLOAT16_BINARY_OP(/, operator/, operator/=) #undef HWY_FLOAT16_BINARY_OP #endif // HWY_HAVE_SCALAR_F16_OPERATORS }; static_assert(sizeof(hwy::float16_t) == 2, "Wrong size of float16_t"); #if HWY_HAVE_SCALAR_F16_TYPE namespace detail { #if HWY_HAVE_SCALAR_F16_OPERATORS template struct SpecialFloatUnwrapArithOpOperandT { using type = hwy::float16_t::Native; }; #endif template struct NativeSpecialFloatToWrapperT { using type = hwy::float16_t; }; } // namespace detail #endif // HWY_HAVE_SCALAR_F16_TYPE #if HWY_HAS_BUILTIN(__builtin_bit_cast) || HWY_COMPILER_MSVC >= 1926 namespace detail { template <> struct BitCastScalarSrcCastHelper { #if HWY_HAVE_SCALAR_F16_TYPE static HWY_INLINE constexpr const hwy::float16_t::Native& CastSrcValRef( const hwy::float16_t& val) { return val.native; } #else static HWY_INLINE constexpr const uint16_t& CastSrcValRef( const hwy::float16_t& val) { return val.bits; } #endif }; } // namespace detail #endif // HWY_HAS_BUILTIN(__builtin_bit_cast) || HWY_COMPILER_MSVC >= 1926 #if HWY_HAVE_SCALAR_F16_OPERATORS #define HWY_F16_CONSTEXPR constexpr #else #define HWY_F16_CONSTEXPR HWY_BITCASTSCALAR_CXX14_CONSTEXPR #endif // HWY_HAVE_SCALAR_F16_OPERATORS HWY_API HWY_F16_CONSTEXPR float F32FromF16(float16_t f16) { #if HWY_HAVE_SCALAR_F16_OPERATORS && !HWY_IDE return static_cast(f16); #endif #if !HWY_HAVE_SCALAR_F16_OPERATORS || HWY_IDE const uint16_t bits16 = BitCastScalar(f16); const uint32_t sign = static_cast(bits16 >> 15); const uint32_t biased_exp = (bits16 >> 10) & 0x1F; const uint32_t mantissa = bits16 & 0x3FF; // Subnormal or zero if (biased_exp == 0) { const float subnormal = (1.0f / 16384) * (static_cast(mantissa) * (1.0f / 1024)); return sign ? -subnormal : subnormal; } // Normalized, infinity or NaN: convert the representation directly // (faster than ldexp/tables). const uint32_t biased_exp32 = biased_exp == 31 ? 0xFF : biased_exp + (127 - 15); const uint32_t mantissa32 = mantissa << (23 - 10); const uint32_t bits32 = (sign << 31) | (biased_exp32 << 23) | mantissa32; return BitCastScalar(bits32); #endif // !HWY_HAVE_SCALAR_F16_OPERATORS } #if HWY_IS_DEBUG_BUILD && \ (HWY_HAS_BUILTIN(__builtin_bit_cast) || HWY_COMPILER_MSVC >= 1926) #if defined(__cpp_if_consteval) && __cpp_if_consteval >= 202106L // If C++23 if !consteval support is available, only execute // HWY_DASSERT(condition) if F16FromF32 is not called from a constant-evaluated // context to avoid compilation errors. #define HWY_F16_FROM_F32_DASSERT(condition) \ do { \ if !consteval { \ HWY_DASSERT(condition); \ } \ } while (0) #elif HWY_HAS_BUILTIN(__builtin_is_constant_evaluated) || \ HWY_COMPILER_MSVC >= 1926 // If the __builtin_is_constant_evaluated() intrinsic is available, // only do HWY_DASSERT(condition) if __builtin_is_constant_evaluated() returns // false to avoid compilation errors if F16FromF32 is called from a // constant-evaluated context. #define HWY_F16_FROM_F32_DASSERT(condition) \ do { \ if (!__builtin_is_constant_evaluated()) { \ HWY_DASSERT(condition); \ } \ } while (0) #else // If C++23 if !consteval support is not available, // the __builtin_is_constant_evaluated() intrinsic is not available, // HWY_IS_DEBUG_BUILD is 1, and the __builtin_bit_cast intrinsic is available, // do not do a HWY_DASSERT to avoid compilation errors if F16FromF32 is // called from a constant-evaluated context. #define HWY_F16_FROM_F32_DASSERT(condition) \ do { \ } while (0) #endif // defined(__cpp_if_consteval) && __cpp_if_consteval >= 202106L #else // If HWY_IS_DEBUG_BUILD is 0 or the __builtin_bit_cast intrinsic is not // available, define HWY_F16_FROM_F32_DASSERT(condition) as // HWY_DASSERT(condition) #define HWY_F16_FROM_F32_DASSERT(condition) HWY_DASSERT(condition) #endif // HWY_IS_DEBUG_BUILD && (HWY_HAS_BUILTIN(__builtin_bit_cast) || // HWY_COMPILER_MSVC >= 1926) HWY_API HWY_F16_CONSTEXPR float16_t F16FromF32(float f32) { #if HWY_HAVE_SCALAR_F16_OPERATORS && !HWY_IDE return float16_t(static_cast(f32)); #endif #if !HWY_HAVE_SCALAR_F16_OPERATORS || HWY_IDE const uint32_t bits32 = BitCastScalar(f32); const uint32_t sign = bits32 >> 31; const uint32_t biased_exp32 = (bits32 >> 23) & 0xFF; constexpr uint32_t kMantissaMask = 0x7FFFFF; const uint32_t mantissa32 = bits32 & kMantissaMask; // Before shifting (truncation), round to nearest even to reduce bias. If // the lowest remaining mantissa bit is odd, increase the offset. Example // with the lowest remaining bit (left) and next lower two bits; the // latter, plus two more, will be truncated. // 0[00] + 1 = 0[01] // 0[01] + 1 = 0[10] // 0[10] + 1 = 0[11] (round down toward even) // 0[11] + 1 = 1[00] (round up) // 1[00] + 10 = 1[10] // 1[01] + 10 = 1[11] // 1[10] + 10 = C0[00] (round up toward even with C=1 carry out) // 1[11] + 10 = C0[01] (round up toward even with C=1 carry out) const uint32_t odd_bit = (mantissa32 >> 13) & 1; const uint32_t rounded = mantissa32 + odd_bit + 0xFFF; const bool carry = rounded >= (1u << 23); const int32_t exp = static_cast(biased_exp32) - 127 + carry; // Tiny or zero => zero. if (exp < -24) { // restore original sign return float16_t::FromBits(static_cast(sign << 15)); } // If biased_exp16 would be >= 31, first check whether the input was NaN so we // can set the mantissa to nonzero. const bool is_nan = (biased_exp32 == 255) && mantissa32 != 0; const bool overflowed = exp >= 16; const uint32_t biased_exp16 = static_cast(HWY_MIN(HWY_MAX(0, exp + 15), 31)); // exp = [-24, -15] => subnormal, shift the mantissa. const uint32_t sub_exp = static_cast(HWY_MAX(-14 - exp, 0)); HWY_F16_FROM_F32_DASSERT(sub_exp < 11); const uint32_t shifted_mantissa = (rounded & kMantissaMask) >> (23 - 10 + sub_exp); const uint32_t leading = sub_exp == 0u ? 0u : (1024u >> sub_exp); const uint32_t mantissa16 = is_nan ? 0x3FF : overflowed ? 0u : (leading + shifted_mantissa); #if HWY_IS_DEBUG_BUILD if (exp < -14) { HWY_F16_FROM_F32_DASSERT(biased_exp16 == 0); HWY_F16_FROM_F32_DASSERT(sub_exp >= 1); } else if (exp <= 15) { HWY_F16_FROM_F32_DASSERT(1 <= biased_exp16 && biased_exp16 < 31); HWY_F16_FROM_F32_DASSERT(sub_exp == 0); } #endif HWY_F16_FROM_F32_DASSERT(mantissa16 < 1024); const uint32_t bits16 = (sign << 15) | (biased_exp16 << 10) | mantissa16; HWY_F16_FROM_F32_DASSERT(bits16 < 0x10000); const uint16_t narrowed = static_cast(bits16); // big-endian safe return float16_t::FromBits(narrowed); #endif // !HWY_HAVE_SCALAR_F16_OPERATORS } HWY_API HWY_F16_CONSTEXPR float16_t F16FromF64(double f64) { #if HWY_HAVE_SCALAR_F16_OPERATORS return float16_t(static_cast(f64)); #else // The mantissa bits of f64 are first rounded using round-to-odd rounding // to the nearest f64 value that has the lower 29 bits zeroed out to // ensure that the result is correctly rounded to a F16. // The F64 round-to-odd operation below will round a normal F64 value // (using round-to-odd rounding) to a F64 value that has 24 bits of precision. // It is okay if the magnitude of a denormal F64 value is rounded up in the // F64 round-to-odd step below as the magnitude of a denormal F64 value is // much smaller than 2^(-24) (the smallest positive denormal F16 value). // It is also okay if bit 29 of a NaN F64 value is changed by the F64 // round-to-odd step below as the lower 13 bits of a F32 NaN value are usually // discarded or ignored by the conversion of a F32 NaN value to a F16. // If f64 is a NaN value, the result of the F64 round-to-odd step will be a // NaN value as the result of the F64 round-to-odd step will have at least one // mantissa bit if f64 is a NaN value. // The F64 round-to-odd step will ensure that the F64 to F32 conversion is // exact if the magnitude of the rounded F64 value (using round-to-odd // rounding) is between 2^(-126) (the smallest normal F32 value) and // HighestValue() (the largest finite F32 value) // It is okay if the F64 to F32 conversion is inexact for F64 values that have // a magnitude that is less than 2^(-126) as the magnitude of a denormal F32 // value is much smaller than 2^(-24) (the smallest positive denormal F16 // value). return F16FromF32( static_cast(BitCastScalar(static_cast( (BitCastScalar(f64) & 0xFFFFFFFFE0000000ULL) | ((BitCastScalar(f64) + 0x000000001FFFFFFFULL) & 0x0000000020000000ULL))))); #endif } // More convenient to define outside float16_t because these may use // F32FromF16, which is defined after the struct. HWY_F16_CONSTEXPR inline bool operator==(float16_t lhs, float16_t rhs) noexcept { #if HWY_HAVE_SCALAR_F16_OPERATORS return lhs.native == rhs.native; #else return F32FromF16(lhs) == F32FromF16(rhs); #endif } HWY_F16_CONSTEXPR inline bool operator!=(float16_t lhs, float16_t rhs) noexcept { #if HWY_HAVE_SCALAR_F16_OPERATORS return lhs.native != rhs.native; #else return F32FromF16(lhs) != F32FromF16(rhs); #endif } HWY_F16_CONSTEXPR inline bool operator<(float16_t lhs, float16_t rhs) noexcept { #if HWY_HAVE_SCALAR_F16_OPERATORS return lhs.native < rhs.native; #else return F32FromF16(lhs) < F32FromF16(rhs); #endif } HWY_F16_CONSTEXPR inline bool operator<=(float16_t lhs, float16_t rhs) noexcept { #if HWY_HAVE_SCALAR_F16_OPERATORS return lhs.native <= rhs.native; #else return F32FromF16(lhs) <= F32FromF16(rhs); #endif } HWY_F16_CONSTEXPR inline bool operator>(float16_t lhs, float16_t rhs) noexcept { #if HWY_HAVE_SCALAR_F16_OPERATORS return lhs.native > rhs.native; #else return F32FromF16(lhs) > F32FromF16(rhs); #endif } HWY_F16_CONSTEXPR inline bool operator>=(float16_t lhs, float16_t rhs) noexcept { #if HWY_HAVE_SCALAR_F16_OPERATORS return lhs.native >= rhs.native; #else return F32FromF16(lhs) >= F32FromF16(rhs); #endif } #if HWY_HAVE_CXX20_THREE_WAY_COMPARE HWY_F16_CONSTEXPR inline std::partial_ordering operator<=>( float16_t lhs, float16_t rhs) noexcept { #if HWY_HAVE_SCALAR_F16_OPERATORS return lhs.native <=> rhs.native; #else return F32FromF16(lhs) <=> F32FromF16(rhs); #endif } #endif // HWY_HAVE_CXX20_THREE_WAY_COMPARE //------------------------------------------------------------------------------ // BF16 lane type // Compiler supports ACLE __bf16, not necessarily with operators. // Disable the __bf16 type on AArch64 with GCC 13 or earlier as there is a bug // in GCC 13 and earlier that sometimes causes BF16 constant values to be // incorrectly loaded on AArch64, and this GCC bug on AArch64 is // described at https://gcc.gnu.org/bugzilla/show_bug.cgi?id=111867. #if HWY_ARCH_ARM_A64 && \ (HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400) #define HWY_ARM_HAVE_SCALAR_BF16_TYPE 1 #else #define HWY_ARM_HAVE_SCALAR_BF16_TYPE 0 #endif // x86 compiler supports __bf16, not necessarily with operators. #ifndef HWY_SSE2_HAVE_SCALAR_BF16_TYPE #if HWY_ARCH_X86 && defined(__SSE2__) && \ ((HWY_COMPILER_CLANG >= 1700 && !HWY_COMPILER_CLANGCL) || \ HWY_COMPILER_GCC_ACTUAL >= 1300) #define HWY_SSE2_HAVE_SCALAR_BF16_TYPE 1 #else #define HWY_SSE2_HAVE_SCALAR_BF16_TYPE 0 #endif #endif // HWY_SSE2_HAVE_SCALAR_BF16_TYPE // Compiler supports __bf16, not necessarily with operators. #if HWY_ARM_HAVE_SCALAR_BF16_TYPE || HWY_SSE2_HAVE_SCALAR_BF16_TYPE #define HWY_HAVE_SCALAR_BF16_TYPE 1 #else #define HWY_HAVE_SCALAR_BF16_TYPE 0 #endif #ifndef HWY_HAVE_SCALAR_BF16_OPERATORS // Recent enough compiler also has operators. aarch64 clang 18 hits internal // compiler errors on bf16 ToString, hence only enable on GCC for now. #if HWY_HAVE_SCALAR_BF16_TYPE && (HWY_COMPILER_GCC_ACTUAL >= 1300) #define HWY_HAVE_SCALAR_BF16_OPERATORS 1 #else #define HWY_HAVE_SCALAR_BF16_OPERATORS 0 #endif #endif // HWY_HAVE_SCALAR_BF16_OPERATORS #if HWY_HAVE_SCALAR_BF16_OPERATORS #define HWY_BF16_CONSTEXPR constexpr #else #define HWY_BF16_CONSTEXPR HWY_BITCASTSCALAR_CONSTEXPR #endif struct alignas(2) bfloat16_t { #if HWY_HAVE_SCALAR_BF16_TYPE using Native = __bf16; #endif union { #if HWY_HAVE_SCALAR_BF16_TYPE // Accessed via NativeLaneType, and used directly if // HWY_HAVE_SCALAR_BF16_OPERATORS. Native native; #endif // Only accessed via NativeLaneType or U16LaneType. uint16_t bits; }; // Default init and copying bfloat16_t() noexcept = default; constexpr bfloat16_t(bfloat16_t&&) noexcept = default; constexpr bfloat16_t(const bfloat16_t&) noexcept = default; bfloat16_t& operator=(bfloat16_t&& arg) noexcept = default; bfloat16_t& operator=(const bfloat16_t& arg) noexcept = default; // Only enable implicit conversions if we have a native type. #if HWY_HAVE_SCALAR_BF16_TYPE constexpr bfloat16_t(Native arg) noexcept : native(arg) {} constexpr operator Native() const noexcept { return native; } #endif #if HWY_HAVE_SCALAR_BF16_TYPE static HWY_BITCASTSCALAR_CONSTEXPR bfloat16_t FromBits(uint16_t bits) { return bfloat16_t(BitCastScalar(bits)); } #else private: struct BF16FromU16BitsTag {}; constexpr bfloat16_t(BF16FromU16BitsTag /*tag*/, uint16_t u16_bits) : bits(u16_bits) {} public: static constexpr bfloat16_t FromBits(uint16_t bits) { return bfloat16_t(BF16FromU16BitsTag(), bits); } #endif // When backed by a native type, ensure the wrapper behaves like the native // type by forwarding all operators. Unfortunately it seems difficult to reuse // this code in a base class, so we repeat it in float16_t. #if HWY_HAVE_SCALAR_BF16_OPERATORS || HWY_IDE template , Native>() && !IsSame, bfloat16_t>() && IsConvertible()>* = nullptr> constexpr bfloat16_t(T&& arg) noexcept( noexcept(static_cast(DeclVal()))) : native(static_cast(static_cast(arg))) {} template , Native>() && !IsSame, bfloat16_t>() && !IsConvertible() && IsStaticCastable()>* = nullptr> explicit constexpr bfloat16_t(T&& arg) noexcept( noexcept(static_cast(DeclVal()))) : native(static_cast(static_cast(arg))) {} HWY_CXX14_CONSTEXPR bfloat16_t& operator=(Native arg) noexcept { native = arg; return *this; } // pre-decrement operator (--x) HWY_CXX14_CONSTEXPR bfloat16_t& operator--() noexcept { native = static_cast(native - Native{1}); return *this; } // post-decrement operator (x--) HWY_CXX14_CONSTEXPR bfloat16_t operator--(int) noexcept { bfloat16_t result = *this; native = static_cast(native - Native{1}); return result; } // pre-increment operator (++x) HWY_CXX14_CONSTEXPR bfloat16_t& operator++() noexcept { native = static_cast(native + Native{1}); return *this; } // post-increment operator (x++) HWY_CXX14_CONSTEXPR bfloat16_t operator++(int) noexcept { bfloat16_t result = *this; native = static_cast(native + Native{1}); return result; } constexpr bfloat16_t operator-() const noexcept { return bfloat16_t(static_cast(-native)); } constexpr bfloat16_t operator+() const noexcept { return *this; } // Reduce clutter by generating `operator+` and `operator+=` etc. Note that // we cannot token-paste `operator` and `+`, so pass it in as `op_func`. #define HWY_BFLOAT16_BINARY_OP(op, op_func, assign_func) \ constexpr bfloat16_t op_func(const bfloat16_t& rhs) const noexcept { \ return bfloat16_t(static_cast(native op rhs.native)); \ } \ template , \ typename RawResultT = \ decltype(DeclVal() op DeclVal()), \ typename ResultT = \ detail::NativeSpecialFloatToWrapper, \ HWY_IF_CASTABLE(RawResultT, ResultT)> \ constexpr ResultT op_func(const T& rhs) const noexcept(noexcept( \ static_cast(DeclVal() op DeclVal()))) { \ return static_cast(native op static_cast(rhs)); \ } \ HWY_CXX14_CONSTEXPR hwy::bfloat16_t& assign_func( \ const hwy::bfloat16_t& rhs) noexcept { \ native = static_cast(native op rhs.native); \ return *this; \ } \ template () op DeclVal()))> \ HWY_CXX14_CONSTEXPR hwy::bfloat16_t& assign_func(const T& rhs) noexcept( \ noexcept( \ static_cast(DeclVal() op DeclVal()))) { \ native = static_cast(native op rhs); \ return *this; \ } HWY_BFLOAT16_BINARY_OP(+, operator+, operator+=) HWY_BFLOAT16_BINARY_OP(-, operator-, operator-=) HWY_BFLOAT16_BINARY_OP(*, operator*, operator*=) HWY_BFLOAT16_BINARY_OP(/, operator/, operator/=) #undef HWY_BFLOAT16_BINARY_OP #endif // HWY_HAVE_SCALAR_BF16_OPERATORS }; static_assert(sizeof(hwy::bfloat16_t) == 2, "Wrong size of bfloat16_t"); #pragma pack(pop) #if HWY_HAVE_SCALAR_BF16_TYPE namespace detail { #if HWY_HAVE_SCALAR_BF16_OPERATORS template struct SpecialFloatUnwrapArithOpOperandT { using type = hwy::bfloat16_t::Native; }; #endif template struct NativeSpecialFloatToWrapperT { using type = hwy::bfloat16_t; }; } // namespace detail #endif // HWY_HAVE_SCALAR_BF16_TYPE #if HWY_HAS_BUILTIN(__builtin_bit_cast) || HWY_COMPILER_MSVC >= 1926 namespace detail { template <> struct BitCastScalarSrcCastHelper { #if HWY_HAVE_SCALAR_BF16_TYPE static HWY_INLINE constexpr const hwy::bfloat16_t::Native& CastSrcValRef( const hwy::bfloat16_t& val) { return val.native; } #else static HWY_INLINE constexpr const uint16_t& CastSrcValRef( const hwy::bfloat16_t& val) { return val.bits; } #endif }; } // namespace detail #endif // HWY_HAS_BUILTIN(__builtin_bit_cast) || HWY_COMPILER_MSVC >= 1926 HWY_API HWY_BF16_CONSTEXPR float F32FromBF16(bfloat16_t bf) { #if HWY_HAVE_SCALAR_BF16_OPERATORS return static_cast(bf); #else return BitCastScalar(static_cast( static_cast(BitCastScalar(bf)) << 16)); #endif } HWY_API HWY_BF16_CONSTEXPR bfloat16_t BF16FromF32(float f) { #if HWY_HAVE_SCALAR_BF16_OPERATORS return static_cast(f); #else return bfloat16_t::FromBits( static_cast(BitCastScalar(f) >> 16)); #endif } HWY_API HWY_BF16_CONSTEXPR bfloat16_t BF16FromF64(double f64) { #if HWY_HAVE_SCALAR_BF16_OPERATORS return static_cast(f64); #else // The mantissa bits of f64 are first rounded using round-to-odd rounding // to the nearest f64 value that has the lower 38 bits zeroed out to // ensure that the result is correctly rounded to a BF16. // The F64 round-to-odd operation below will round a normal F64 value // (using round-to-odd rounding) to a F64 value that has 15 bits of precision. // It is okay if the magnitude of a denormal F64 value is rounded up in the // F64 round-to-odd step below as the magnitude of a denormal F64 value is // much smaller than 2^(-133) (the smallest positive denormal BF16 value). // It is also okay if bit 38 of a NaN F64 value is changed by the F64 // round-to-odd step below as the lower 16 bits of a F32 NaN value are usually // discarded or ignored by the conversion of a F32 NaN value to a BF16. // If f64 is a NaN value, the result of the F64 round-to-odd step will be a // NaN value as the result of the F64 round-to-odd step will have at least one // mantissa bit if f64 is a NaN value. // The F64 round-to-odd step below will ensure that the F64 to F32 conversion // is exact if the magnitude of the rounded F64 value (using round-to-odd // rounding) is between 2^(-135) (one-fourth of the smallest positive denormal // BF16 value) and HighestValue() (the largest finite F32 value). // If |f64| is less than 2^(-135), the magnitude of the result of the F64 to // F32 conversion is guaranteed to be less than or equal to 2^(-135), which // ensures that the F32 to BF16 conversion is correctly rounded, even if the // conversion of a rounded F64 value whose magnitude is less than 2^(-135) // to a F32 is inexact. return BF16FromF32( static_cast(BitCastScalar(static_cast( (BitCastScalar(f64) & 0xFFFFFFC000000000ULL) | ((BitCastScalar(f64) + 0x0000003FFFFFFFFFULL) & 0x0000004000000000ULL))))); #endif } // More convenient to define outside bfloat16_t because these may use // F32FromBF16, which is defined after the struct. HWY_BF16_CONSTEXPR inline bool operator==(bfloat16_t lhs, bfloat16_t rhs) noexcept { #if HWY_HAVE_SCALAR_BF16_OPERATORS return lhs.native == rhs.native; #else return F32FromBF16(lhs) == F32FromBF16(rhs); #endif } HWY_BF16_CONSTEXPR inline bool operator!=(bfloat16_t lhs, bfloat16_t rhs) noexcept { #if HWY_HAVE_SCALAR_BF16_OPERATORS return lhs.native != rhs.native; #else return F32FromBF16(lhs) != F32FromBF16(rhs); #endif } HWY_BF16_CONSTEXPR inline bool operator<(bfloat16_t lhs, bfloat16_t rhs) noexcept { #if HWY_HAVE_SCALAR_BF16_OPERATORS return lhs.native < rhs.native; #else return F32FromBF16(lhs) < F32FromBF16(rhs); #endif } HWY_BF16_CONSTEXPR inline bool operator<=(bfloat16_t lhs, bfloat16_t rhs) noexcept { #if HWY_HAVE_SCALAR_BF16_OPERATORS return lhs.native <= rhs.native; #else return F32FromBF16(lhs) <= F32FromBF16(rhs); #endif } HWY_BF16_CONSTEXPR inline bool operator>(bfloat16_t lhs, bfloat16_t rhs) noexcept { #if HWY_HAVE_SCALAR_BF16_OPERATORS return lhs.native > rhs.native; #else return F32FromBF16(lhs) > F32FromBF16(rhs); #endif } HWY_BF16_CONSTEXPR inline bool operator>=(bfloat16_t lhs, bfloat16_t rhs) noexcept { #if HWY_HAVE_SCALAR_BF16_OPERATORS return lhs.native >= rhs.native; #else return F32FromBF16(lhs) >= F32FromBF16(rhs); #endif } #if HWY_HAVE_CXX20_THREE_WAY_COMPARE HWY_BF16_CONSTEXPR inline std::partial_ordering operator<=>( bfloat16_t lhs, bfloat16_t rhs) noexcept { #if HWY_HAVE_SCALAR_BF16_OPERATORS return lhs.native <=> rhs.native; #else return F32FromBF16(lhs) <=> F32FromBF16(rhs); #endif } #endif // HWY_HAVE_CXX20_THREE_WAY_COMPARE //------------------------------------------------------------------------------ // Type relations namespace detail { template struct Relations; template <> struct Relations { using Unsigned = uint8_t; using Signed = int8_t; using Wide = uint16_t; enum { is_signed = 0, is_float = 0, is_bf16 = 0 }; }; template <> struct Relations { using Unsigned = uint8_t; using Signed = int8_t; using Wide = int16_t; enum { is_signed = 1, is_float = 0, is_bf16 = 0 }; }; template <> struct Relations { using Unsigned = uint16_t; using Signed = int16_t; using Float = float16_t; using Wide = uint32_t; using Narrow = uint8_t; enum { is_signed = 0, is_float = 0, is_bf16 = 0 }; }; template <> struct Relations { using Unsigned = uint16_t; using Signed = int16_t; using Float = float16_t; using Wide = int32_t; using Narrow = int8_t; enum { is_signed = 1, is_float = 0, is_bf16 = 0 }; }; template <> struct Relations { using Unsigned = uint32_t; using Signed = int32_t; using Float = float; using Wide = uint64_t; using Narrow = uint16_t; enum { is_signed = 0, is_float = 0, is_bf16 = 0 }; }; template <> struct Relations { using Unsigned = uint32_t; using Signed = int32_t; using Float = float; using Wide = int64_t; using Narrow = int16_t; enum { is_signed = 1, is_float = 0, is_bf16 = 0 }; }; template <> struct Relations { using Unsigned = uint64_t; using Signed = int64_t; using Float = double; using Wide = uint128_t; using Narrow = uint32_t; enum { is_signed = 0, is_float = 0, is_bf16 = 0 }; }; template <> struct Relations { using Unsigned = uint64_t; using Signed = int64_t; using Float = double; using Narrow = int32_t; enum { is_signed = 1, is_float = 0, is_bf16 = 0 }; }; template <> struct Relations { using Unsigned = uint128_t; using Narrow = uint64_t; enum { is_signed = 0, is_float = 0, is_bf16 = 0 }; }; template <> struct Relations { using Unsigned = uint16_t; using Signed = int16_t; using Float = float16_t; using Wide = float; enum { is_signed = 1, is_float = 1, is_bf16 = 0 }; }; template <> struct Relations { using Unsigned = uint16_t; using Signed = int16_t; using Wide = float; enum { is_signed = 1, is_float = 1, is_bf16 = 1 }; }; template <> struct Relations { using Unsigned = uint32_t; using Signed = int32_t; using Float = float; using Wide = double; using Narrow = float16_t; enum { is_signed = 1, is_float = 1, is_bf16 = 0 }; }; template <> struct Relations { using Unsigned = uint64_t; using Signed = int64_t; using Float = double; using Narrow = float; enum { is_signed = 1, is_float = 1, is_bf16 = 0 }; }; template struct TypeFromSize; template <> struct TypeFromSize<1> { using Unsigned = uint8_t; using Signed = int8_t; }; template <> struct TypeFromSize<2> { using Unsigned = uint16_t; using Signed = int16_t; using Float = float16_t; }; template <> struct TypeFromSize<4> { using Unsigned = uint32_t; using Signed = int32_t; using Float = float; }; template <> struct TypeFromSize<8> { using Unsigned = uint64_t; using Signed = int64_t; using Float = double; }; template <> struct TypeFromSize<16> { using Unsigned = uint128_t; }; } // namespace detail // Aliases for types of a different category, but the same size. template using MakeUnsigned = typename detail::Relations::Unsigned; template using MakeSigned = typename detail::Relations::Signed; template using MakeFloat = typename detail::Relations::Float; // Aliases for types of the same category, but different size. template using MakeWide = typename detail::Relations::Wide; template using MakeNarrow = typename detail::Relations::Narrow; // Obtain type from its size [bytes]. template using UnsignedFromSize = typename detail::TypeFromSize::Unsigned; template using SignedFromSize = typename detail::TypeFromSize::Signed; template using FloatFromSize = typename detail::TypeFromSize::Float; // Avoid confusion with SizeTag where the parameter is a lane size. using UnsignedTag = SizeTag<0>; using SignedTag = SizeTag<0x100>; // integer using FloatTag = SizeTag<0x200>; using SpecialTag = SizeTag<0x300>; template > constexpr auto TypeTag() -> hwy::SizeTag<((R::is_signed + R::is_float + R::is_bf16) << 8)> { return hwy::SizeTag<((R::is_signed + R::is_float + R::is_bf16) << 8)>(); } // For when we only want to distinguish FloatTag from everything else. using NonFloatTag = SizeTag<0x400>; template > constexpr auto IsFloatTag() -> hwy::SizeTag<(R::is_float ? 0x200 : 0x400)> { return hwy::SizeTag<(R::is_float ? 0x200 : 0x400)>(); } //------------------------------------------------------------------------------ // Type traits template HWY_API constexpr bool IsFloat3264() { return IsSameEither, float, double>(); } template HWY_API constexpr bool IsFloat() { // Cannot use T(1.25) != T(1) for float16_t, which can only be converted to or // from a float, not compared. Include float16_t in case HWY_HAVE_FLOAT16=1. return IsSame, float16_t>() || IsFloat3264(); } template HWY_API constexpr bool IsSigned() { return static_cast(0) > static_cast(-1); } template <> constexpr bool IsSigned() { return true; } template <> constexpr bool IsSigned() { return true; } template <> constexpr bool IsSigned() { return false; } template <> constexpr bool IsSigned() { return false; } template <> constexpr bool IsSigned() { return false; } template () && !IsIntegerLaneType()> struct MakeLaneTypeIfIntegerT { using type = T; }; template struct MakeLaneTypeIfIntegerT { using type = hwy::If(), SignedFromSize, UnsignedFromSize>; }; template using MakeLaneTypeIfInteger = typename MakeLaneTypeIfIntegerT::type; // Largest/smallest representable integer values. template HWY_API constexpr T LimitsMax() { static_assert(IsInteger(), "Only for integer types"); using TU = UnsignedFromSize; return static_cast(IsSigned() ? (static_cast(~TU(0)) >> 1) : static_cast(~TU(0))); } template HWY_API constexpr T LimitsMin() { static_assert(IsInteger(), "Only for integer types"); return IsSigned() ? static_cast(-1) - LimitsMax() : static_cast(0); } // Largest/smallest representable value (integer or float). This naming avoids // confusion with numeric_limits::min() (the smallest positive value). // Cannot be constexpr because we use CopySameSize for [b]float16_t. template HWY_API HWY_BITCASTSCALAR_CONSTEXPR T LowestValue() { return LimitsMin(); } template <> HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR bfloat16_t LowestValue() { return bfloat16_t::FromBits(uint16_t{0xFF7Fu}); // -1.1111111 x 2^127 } template <> HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR float16_t LowestValue() { return float16_t::FromBits(uint16_t{0xFBFFu}); // -1.1111111111 x 2^15 } template <> HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR float LowestValue() { return -3.402823466e+38F; } template <> HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR double LowestValue() { return -1.7976931348623158e+308; } template HWY_API HWY_BITCASTSCALAR_CONSTEXPR T HighestValue() { return LimitsMax(); } template <> HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR bfloat16_t HighestValue() { return bfloat16_t::FromBits(uint16_t{0x7F7Fu}); // 1.1111111 x 2^127 } template <> HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR float16_t HighestValue() { return float16_t::FromBits(uint16_t{0x7BFFu}); // 1.1111111111 x 2^15 } template <> HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR float HighestValue() { return 3.402823466e+38F; } template <> HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR double HighestValue() { return 1.7976931348623158e+308; } // Difference between 1.0 and the next representable value. Equal to // 1 / (1ULL << MantissaBits()), but hard-coding ensures precision. template HWY_API HWY_BITCASTSCALAR_CONSTEXPR T Epsilon() { return 1; } template <> HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR bfloat16_t Epsilon() { return bfloat16_t::FromBits(uint16_t{0x3C00u}); // 0.0078125 } template <> HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR float16_t Epsilon() { return float16_t::FromBits(uint16_t{0x1400u}); // 0.0009765625 } template <> HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR float Epsilon() { return 1.192092896e-7f; } template <> HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR double Epsilon() { return 2.2204460492503131e-16; } // Returns width in bits of the mantissa field in IEEE binary16/32/64. template constexpr int MantissaBits() { static_assert(sizeof(T) == 0, "Only instantiate the specializations"); return 0; } template <> constexpr int MantissaBits() { return 7; } template <> constexpr int MantissaBits() { return 10; } template <> constexpr int MantissaBits() { return 23; } template <> constexpr int MantissaBits() { return 52; } // Returns the (left-shifted by one bit) IEEE binary16/32/64 representation with // the largest possible (biased) exponent field. Used by IsInf. template constexpr MakeSigned MaxExponentTimes2() { return -(MakeSigned{1} << (MantissaBits() + 1)); } // Returns bitmask of the sign bit in IEEE binary16/32/64. template constexpr MakeUnsigned SignMask() { return MakeUnsigned{1} << (sizeof(T) * 8 - 1); } // Returns bitmask of the exponent field in IEEE binary16/32/64. template constexpr MakeUnsigned ExponentMask() { return (~(MakeUnsigned{1} << MantissaBits()) + 1) & static_cast>(~SignMask()); } // Returns bitmask of the mantissa field in IEEE binary16/32/64. template constexpr MakeUnsigned MantissaMask() { return (MakeUnsigned{1} << MantissaBits()) - 1; } // Returns 1 << mantissa_bits as a floating-point number. All integers whose // absolute value are less than this can be represented exactly. template HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR T MantissaEnd() { static_assert(sizeof(T) == 0, "Only instantiate the specializations"); return 0; } template <> HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR bfloat16_t MantissaEnd() { return bfloat16_t::FromBits(uint16_t{0x4300u}); // 1.0 x 2^7 } template <> HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR float16_t MantissaEnd() { return float16_t::FromBits(uint16_t{0x6400u}); // 1.0 x 2^10 } template <> HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR float MantissaEnd() { return 8388608.0f; // 1 << 23 } template <> HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR double MantissaEnd() { // floating point literal with p52 requires C++17. return 4503599627370496.0; // 1 << 52 } // Returns width in bits of the exponent field in IEEE binary16/32/64. template constexpr int ExponentBits() { // Exponent := remaining bits after deducting sign and mantissa. return 8 * sizeof(T) - 1 - MantissaBits(); } // Returns largest value of the biased exponent field in IEEE binary16/32/64, // right-shifted so that the LSB is bit zero. Example: 0xFF for float. // This is expressed as a signed integer for more efficient comparison. template constexpr MakeSigned MaxExponentField() { return (MakeSigned{1} << ExponentBits()) - 1; } //------------------------------------------------------------------------------ // Additional F16/BF16 operators #if HWY_HAVE_SCALAR_F16_OPERATORS || HWY_HAVE_SCALAR_BF16_OPERATORS #define HWY_RHS_SPECIAL_FLOAT_ARITH_OP(op, op_func, T2) \ template < \ typename T1, \ hwy::EnableIf>() || \ hwy::IsFloat3264>()>* = nullptr, \ typename RawResultT = decltype(DeclVal() op DeclVal()), \ typename ResultT = detail::NativeSpecialFloatToWrapper, \ HWY_IF_CASTABLE(RawResultT, ResultT)> \ static HWY_INLINE constexpr ResultT op_func(T1 a, T2 b) noexcept { \ return static_cast(a op b.native); \ } #define HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(op, op_func, T1) \ HWY_RHS_SPECIAL_FLOAT_ARITH_OP(op, op_func, T1) \ template < \ typename T2, \ hwy::EnableIf>() || \ hwy::IsFloat3264>()>* = nullptr, \ typename RawResultT = decltype(DeclVal() op DeclVal()), \ typename ResultT = detail::NativeSpecialFloatToWrapper, \ HWY_IF_CASTABLE(RawResultT, ResultT)> \ static HWY_INLINE constexpr ResultT op_func(T1 a, T2 b) noexcept { \ return static_cast(a.native op b); \ } #if HWY_HAVE_SCALAR_F16_OPERATORS HWY_RHS_SPECIAL_FLOAT_ARITH_OP(+, operator+, float16_t) HWY_RHS_SPECIAL_FLOAT_ARITH_OP(-, operator-, float16_t) HWY_RHS_SPECIAL_FLOAT_ARITH_OP(*, operator*, float16_t) HWY_RHS_SPECIAL_FLOAT_ARITH_OP(/, operator/, float16_t) HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(==, operator==, float16_t) HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(!=, operator!=, float16_t) HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(<, operator<, float16_t) HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(<=, operator<=, float16_t) HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(>, operator>, float16_t) HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(>=, operator>=, float16_t) #if HWY_HAVE_CXX20_THREE_WAY_COMPARE HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(<=>, operator<=>, float16_t) #endif #endif // HWY_HAVE_SCALAR_F16_OPERATORS #if HWY_HAVE_SCALAR_BF16_OPERATORS HWY_RHS_SPECIAL_FLOAT_ARITH_OP(+, operator+, bfloat16_t) HWY_RHS_SPECIAL_FLOAT_ARITH_OP(-, operator-, bfloat16_t) HWY_RHS_SPECIAL_FLOAT_ARITH_OP(*, operator*, bfloat16_t) HWY_RHS_SPECIAL_FLOAT_ARITH_OP(/, operator/, bfloat16_t) HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(==, operator==, bfloat16_t) HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(!=, operator!=, bfloat16_t) HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(<, operator<, bfloat16_t) HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(<=, operator<=, bfloat16_t) HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(>, operator>, bfloat16_t) HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(>=, operator>=, bfloat16_t) #if HWY_HAVE_CXX20_THREE_WAY_COMPARE HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP(<=>, operator<=>, bfloat16_t) #endif #endif // HWY_HAVE_SCALAR_BF16_OPERATORS #undef HWY_RHS_SPECIAL_FLOAT_ARITH_OP #undef HWY_SPECIAL_FLOAT_CMP_AGAINST_NON_SPECIAL_OP #endif // HWY_HAVE_SCALAR_F16_OPERATORS || HWY_HAVE_SCALAR_BF16_OPERATORS //------------------------------------------------------------------------------ // Type conversions (after IsSpecialFloat) HWY_API float F32FromF16Mem(const void* ptr) { float16_t f16; CopyBytes<2>(HWY_ASSUME_ALIGNED(ptr, 2), &f16); return F32FromF16(f16); } HWY_API float F32FromBF16Mem(const void* ptr) { bfloat16_t bf; CopyBytes<2>(HWY_ASSUME_ALIGNED(ptr, 2), &bf); return F32FromBF16(bf); } #if HWY_HAVE_SCALAR_F16_OPERATORS #define HWY_BF16_TO_F16_CONSTEXPR HWY_BF16_CONSTEXPR #else #define HWY_BF16_TO_F16_CONSTEXPR HWY_F16_CONSTEXPR #endif // For casting from TFrom to TTo template HWY_API constexpr TTo ConvertScalarTo(const TFrom in) { return static_cast(in); } template HWY_API constexpr TTo ConvertScalarTo(const TFrom in) { return F16FromF32(static_cast(in)); } template HWY_API HWY_BF16_TO_F16_CONSTEXPR TTo ConvertScalarTo(const hwy::bfloat16_t in) { return F16FromF32(F32FromBF16(in)); } template HWY_API HWY_F16_CONSTEXPR TTo ConvertScalarTo(const double in) { return F16FromF64(in); } template HWY_API HWY_BF16_CONSTEXPR TTo ConvertScalarTo(const TFrom in) { return BF16FromF32(static_cast(in)); } template HWY_API HWY_BF16_TO_F16_CONSTEXPR TTo ConvertScalarTo(const hwy::float16_t in) { return BF16FromF32(F32FromF16(in)); } template HWY_API HWY_BF16_CONSTEXPR TTo ConvertScalarTo(const double in) { return BF16FromF64(in); } template HWY_API HWY_F16_CONSTEXPR TTo ConvertScalarTo(const TFrom in) { return static_cast(F32FromF16(in)); } template HWY_API HWY_BF16_CONSTEXPR TTo ConvertScalarTo(TFrom in) { return static_cast(F32FromBF16(in)); } // Same: return unchanged template HWY_API constexpr TTo ConvertScalarTo(TTo in) { return in; } //------------------------------------------------------------------------------ // Helper functions template constexpr inline T1 DivCeil(T1 a, T2 b) { return (a + b - 1) / b; } // Works for any `align`; if a power of two, compiler emits ADD+AND. constexpr inline size_t RoundUpTo(size_t what, size_t align) { return DivCeil(what, align) * align; } // Undefined results for x == 0. HWY_API size_t Num0BitsBelowLS1Bit_Nonzero32(const uint32_t x) { HWY_DASSERT(x != 0); #if HWY_COMPILER_MSVC unsigned long index; // NOLINT _BitScanForward(&index, x); return index; #else // HWY_COMPILER_MSVC return static_cast(__builtin_ctz(x)); #endif // HWY_COMPILER_MSVC } HWY_API size_t Num0BitsBelowLS1Bit_Nonzero64(const uint64_t x) { HWY_DASSERT(x != 0); #if HWY_COMPILER_MSVC #if HWY_ARCH_X86_64 unsigned long index; // NOLINT _BitScanForward64(&index, x); return index; #else // HWY_ARCH_X86_64 // _BitScanForward64 not available uint32_t lsb = static_cast(x & 0xFFFFFFFF); unsigned long index; // NOLINT if (lsb == 0) { uint32_t msb = static_cast(x >> 32u); _BitScanForward(&index, msb); return 32 + index; } else { _BitScanForward(&index, lsb); return index; } #endif // HWY_ARCH_X86_64 #else // HWY_COMPILER_MSVC return static_cast(__builtin_ctzll(x)); #endif // HWY_COMPILER_MSVC } // Undefined results for x == 0. HWY_API size_t Num0BitsAboveMS1Bit_Nonzero32(const uint32_t x) { HWY_DASSERT(x != 0); #if HWY_COMPILER_MSVC unsigned long index; // NOLINT _BitScanReverse(&index, x); return 31 - index; #else // HWY_COMPILER_MSVC return static_cast(__builtin_clz(x)); #endif // HWY_COMPILER_MSVC } HWY_API size_t Num0BitsAboveMS1Bit_Nonzero64(const uint64_t x) { HWY_DASSERT(x != 0); #if HWY_COMPILER_MSVC #if HWY_ARCH_X86_64 unsigned long index; // NOLINT _BitScanReverse64(&index, x); return 63 - index; #else // HWY_ARCH_X86_64 // _BitScanReverse64 not available const uint32_t msb = static_cast(x >> 32u); unsigned long index; // NOLINT if (msb == 0) { const uint32_t lsb = static_cast(x & 0xFFFFFFFF); _BitScanReverse(&index, lsb); return 63 - index; } else { _BitScanReverse(&index, msb); return 31 - index; } #endif // HWY_ARCH_X86_64 #else // HWY_COMPILER_MSVC return static_cast(__builtin_clzll(x)); #endif // HWY_COMPILER_MSVC } template ), HWY_IF_T_SIZE_ONE_OF(RemoveCvRef, (1 << 1) | (1 << 2) | (1 << 4))> HWY_API size_t PopCount(T x) { uint32_t u32_x = static_cast( static_cast)>>(x)); #if HWY_COMPILER_GCC || HWY_COMPILER_CLANG return static_cast(__builtin_popcountl(u32_x)); #elif HWY_COMPILER_MSVC && HWY_ARCH_X86_32 && defined(__AVX__) return static_cast(_mm_popcnt_u32(u32_x)); #else u32_x -= ((u32_x >> 1) & 0x55555555u); u32_x = (((u32_x >> 2) & 0x33333333u) + (u32_x & 0x33333333u)); u32_x = (((u32_x >> 4) + u32_x) & 0x0F0F0F0Fu); u32_x += (u32_x >> 8); u32_x += (u32_x >> 16); return static_cast(u32_x & 0x3Fu); #endif } template ), HWY_IF_T_SIZE(RemoveCvRef, 8)> HWY_API size_t PopCount(T x) { uint64_t u64_x = static_cast( static_cast)>>(x)); #if HWY_COMPILER_GCC || HWY_COMPILER_CLANG return static_cast(__builtin_popcountll(u64_x)); #elif HWY_COMPILER_MSVC && HWY_ARCH_X86_64 && defined(__AVX__) return _mm_popcnt_u64(u64_x); #elif HWY_COMPILER_MSVC && HWY_ARCH_X86_32 && defined(__AVX__) return _mm_popcnt_u32(static_cast(u64_x & 0xFFFFFFFFu)) + _mm_popcnt_u32(static_cast(u64_x >> 32)); #else u64_x -= ((u64_x >> 1) & 0x5555555555555555ULL); u64_x = (((u64_x >> 2) & 0x3333333333333333ULL) + (u64_x & 0x3333333333333333ULL)); u64_x = (((u64_x >> 4) + u64_x) & 0x0F0F0F0F0F0F0F0FULL); u64_x += (u64_x >> 8); u64_x += (u64_x >> 16); u64_x += (u64_x >> 32); return static_cast(u64_x & 0x7Fu); #endif } // Skip HWY_API due to GCC "function not considered for inlining". Previously // such errors were caused by underlying type mismatches, but it's not clear // what is still mismatched despite all the casts. template /*HWY_API*/ constexpr size_t FloorLog2(TI x) { return x == TI{1} ? 0 : static_cast(FloorLog2(static_cast(x >> 1)) + 1); } template /*HWY_API*/ constexpr size_t CeilLog2(TI x) { return x == TI{1} ? 0 : static_cast(FloorLog2(static_cast(x - 1)) + 1); } template HWY_INLINE constexpr T AddWithWraparound(T t, T2 increment) { return t + static_cast(increment); } template HWY_INLINE constexpr T AddWithWraparound(T t, T2 increment) { return ConvertScalarTo(ConvertScalarTo(t) + ConvertScalarTo(increment)); } template HWY_INLINE constexpr T AddWithWraparound(T t, T2 n) { using TU = MakeUnsigned; // Sub-int types would promote to int, not unsigned, which would trigger // warnings, so first promote to the largest unsigned type. Due to // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=87519, which affected GCC 8 // until fixed in 9.3, we use built-in types rather than uint64_t. return static_cast(static_cast( static_cast(static_cast(t) + static_cast(n)) & uint64_t{hwy::LimitsMax()})); } #if HWY_COMPILER_MSVC && HWY_ARCH_X86_64 #pragma intrinsic(_umul128) #endif // 64 x 64 = 128 bit multiplication HWY_API uint64_t Mul128(uint64_t a, uint64_t b, uint64_t* HWY_RESTRICT upper) { #if defined(__SIZEOF_INT128__) __uint128_t product = (__uint128_t)a * (__uint128_t)b; *upper = (uint64_t)(product >> 64); return (uint64_t)(product & 0xFFFFFFFFFFFFFFFFULL); #elif HWY_COMPILER_MSVC && HWY_ARCH_X86_64 return _umul128(a, b, upper); #else constexpr uint64_t kLo32 = 0xFFFFFFFFU; const uint64_t lo_lo = (a & kLo32) * (b & kLo32); const uint64_t hi_lo = (a >> 32) * (b & kLo32); const uint64_t lo_hi = (a & kLo32) * (b >> 32); const uint64_t hi_hi = (a >> 32) * (b >> 32); const uint64_t t = (lo_lo >> 32) + (hi_lo & kLo32) + lo_hi; *upper = (hi_lo >> 32) + (t >> 32) + hi_hi; return (t << 32) | (lo_lo & kLo32); #endif } namespace detail { template static HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR T ScalarAbs(hwy::FloatTag /*tag*/, T val) { using TU = MakeUnsigned; return BitCastScalar( static_cast(BitCastScalar(val) & (~SignMask()))); } template static HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR T ScalarAbs(hwy::SpecialTag /*tag*/, T val) { return ScalarAbs(hwy::FloatTag(), val); } template static HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR T ScalarAbs(hwy::SignedTag /*tag*/, T val) { using TU = MakeUnsigned; return (val < T{0}) ? static_cast(TU{0} - static_cast(val)) : val; } template static HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR T ScalarAbs(hwy::UnsignedTag /*tag*/, T val) { return val; } } // namespace detail template HWY_API HWY_BITCASTSCALAR_CONSTEXPR RemoveCvRef ScalarAbs(T val) { using TVal = MakeLaneTypeIfInteger< detail::NativeSpecialFloatToWrapper>>; return detail::ScalarAbs(hwy::TypeTag(), static_cast(val)); } template HWY_API HWY_BITCASTSCALAR_CONSTEXPR bool ScalarIsNaN(T val) { using TF = detail::NativeSpecialFloatToWrapper>; using TU = MakeUnsigned; return (BitCastScalar(ScalarAbs(val)) > ExponentMask()); } template HWY_API HWY_BITCASTSCALAR_CONSTEXPR bool ScalarIsInf(T val) { using TF = detail::NativeSpecialFloatToWrapper>; using TU = MakeUnsigned; return static_cast(BitCastScalar(static_cast(val)) << 1) == static_cast(MaxExponentTimes2()); } namespace detail { template static HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR bool ScalarIsFinite( hwy::FloatTag /*tag*/, T val) { using TU = MakeUnsigned; return (BitCastScalar(hwy::ScalarAbs(val)) < ExponentMask()); } template static HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR bool ScalarIsFinite( hwy::NonFloatTag /*tag*/, T /*val*/) { // Integer values are always finite return true; } } // namespace detail template HWY_API HWY_BITCASTSCALAR_CONSTEXPR bool ScalarIsFinite(T val) { using TVal = MakeLaneTypeIfInteger< detail::NativeSpecialFloatToWrapper>>; return detail::ScalarIsFinite(hwy::IsFloatTag(), static_cast(val)); } template HWY_API HWY_BITCASTSCALAR_CONSTEXPR RemoveCvRef ScalarCopySign(T magn, T sign) { using TF = RemoveCvRef>>; using TU = MakeUnsigned; return BitCastScalar(static_cast( (BitCastScalar(static_cast(magn)) & (~SignMask())) | (BitCastScalar(static_cast(sign)) & SignMask()))); } template HWY_API HWY_BITCASTSCALAR_CONSTEXPR bool ScalarSignBit(T val) { using TVal = MakeLaneTypeIfInteger< detail::NativeSpecialFloatToWrapper>>; using TU = MakeUnsigned; return ((BitCastScalar(static_cast(val)) & SignMask()) != 0); } // Prevents the compiler from eliding the computations that led to "output". #if HWY_ARCH_PPC && (HWY_COMPILER_GCC || HWY_COMPILER_CLANG) && \ !defined(_SOFT_FLOAT) // Workaround to avoid test failures on PPC if compiled with Clang template HWY_API void PreventElision(T&& output) { asm volatile("" : "+f"(output)::"memory"); } template HWY_API void PreventElision(T&& output) { asm volatile("" : "+d"(output)::"memory"); } template HWY_API void PreventElision(T&& output) { asm volatile("" : "+r"(output)::"memory"); } #else template HWY_API void PreventElision(T&& output) { #if HWY_COMPILER_MSVC // MSVC does not support inline assembly anymore (and never supported GCC's // RTL constraints). Self-assignment with #pragma optimize("off") might be // expected to prevent elision, but it does not with MSVC 2015. Type-punning // with volatile pointers generates inefficient code on MSVC 2017. static std::atomic> sink; sink.store(output, std::memory_order_relaxed); #else // Works by indicating to the compiler that "output" is being read and // modified. The +r constraint avoids unnecessary writes to memory, but only // works for built-in types (typically FuncOutput). asm volatile("" : "+r"(output) : : "memory"); #endif } #endif } // namespace hwy #endif // HIGHWAY_HWY_BASE_H_