// Copyright 2022 Google LLC // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Single-element vectors and operations. // External include guard in highway.h - see comment there. #include "hwy/base.h" #ifndef HWY_NO_LIBCXX #include // sqrtf #endif #include "hwy/ops/shared-inl.h" HWY_BEFORE_NAMESPACE(); namespace hwy { namespace HWY_NAMESPACE { template using Full128 = Simd; // (Wrapper class required for overloading comparison operators.) template struct Vec128 { using PrivateT = T; // only for DFromV static constexpr size_t kPrivateN = N; // only for DFromV HWY_INLINE Vec128() = default; Vec128(const Vec128&) = default; Vec128& operator=(const Vec128&) = default; HWY_INLINE Vec128& operator*=(const Vec128 other) { return *this = (*this * other); } HWY_INLINE Vec128& operator/=(const Vec128 other) { return *this = (*this / other); } HWY_INLINE Vec128& operator+=(const Vec128 other) { return *this = (*this + other); } HWY_INLINE Vec128& operator-=(const Vec128 other) { return *this = (*this - other); } HWY_INLINE Vec128& operator%=(const Vec128 other) { return *this = (*this % other); } HWY_INLINE Vec128& operator&=(const Vec128 other) { return *this = (*this & other); } HWY_INLINE Vec128& operator|=(const Vec128 other) { return *this = (*this | other); } HWY_INLINE Vec128& operator^=(const Vec128 other) { return *this = (*this ^ other); } // Behave like wasm128 (vectors can always hold 128 bits). generic_ops-inl.h // relies on this for LoadInterleaved*. CAVEAT: this method of padding // prevents using range for, especially in SumOfLanes, where it would be // incorrect. Moving padding to another field would require handling the case // where N = 16 / sizeof(T) (i.e. there is no padding), which is also awkward. T raw[16 / sizeof(T)] = {}; }; // 0 or FF..FF, same size as Vec128. template struct Mask128 { using Raw = hwy::MakeUnsigned; static HWY_INLINE Raw FromBool(bool b) { return b ? static_cast(~Raw{0}) : 0; } // Must match the size of Vec128. Raw bits[16 / sizeof(T)] = {}; }; template using DFromV = Simd; template using TFromV = typename V::PrivateT; // ------------------------------ Zero // Use HWY_MAX_LANES_D here because VFromD is defined in terms of Zero. template HWY_API Vec128, HWY_MAX_LANES_D(D)> Zero(D /* tag */) { Vec128, HWY_MAX_LANES_D(D)> v; // zero-initialized return v; } template using VFromD = decltype(Zero(D())); // ------------------------------ Tuple (VFromD) #include "hwy/ops/tuple-inl.h" // ------------------------------ BitCast template HWY_API VFromD BitCast(D /* tag */, VFrom v) { VFromD to; CopySameSize(&v.raw, &to.raw); return to; } // ------------------------------ ResizeBitCast template HWY_API VFromD ResizeBitCast(D d, VFrom v) { using DFrom = DFromV; using TFrom = TFromD; using TTo = TFromD; constexpr size_t kFromByteLen = sizeof(TFrom) * HWY_MAX_LANES_D(DFrom); constexpr size_t kToByteLen = sizeof(TTo) * HWY_MAX_LANES_D(D); constexpr size_t kCopyByteLen = HWY_MIN(kFromByteLen, kToByteLen); VFromD to = Zero(d); CopyBytes(&v.raw, &to.raw); return to; } namespace detail { // ResizeBitCast on the HWY_EMU128 target has zero-extending semantics if // VFromD is a larger vector than FromV template HWY_INLINE VFromD ZeroExtendResizeBitCast(FromSizeTag /* from_size_tag */, ToSizeTag /* to_size_tag */, DTo d_to, DFrom /* d_from */, VFromD v) { return ResizeBitCast(d_to, v); } } // namespace detail // ------------------------------ Set template HWY_API VFromD Set(D d, const T2 t) { VFromD v; for (size_t i = 0; i < MaxLanes(d); ++i) { v.raw[i] = ConvertScalarTo>(t); } return v; } // ------------------------------ Undefined template HWY_API VFromD Undefined(D d) { return Zero(d); } // ------------------------------ Dup128VecFromValues template HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, TFromD t2, TFromD t3, TFromD t4, TFromD t5, TFromD t6, TFromD t7, TFromD t8, TFromD t9, TFromD t10, TFromD t11, TFromD t12, TFromD t13, TFromD t14, TFromD t15) { VFromD result; result.raw[0] = t0; result.raw[1] = t1; result.raw[2] = t2; result.raw[3] = t3; result.raw[4] = t4; result.raw[5] = t5; result.raw[6] = t6; result.raw[7] = t7; result.raw[8] = t8; result.raw[9] = t9; result.raw[10] = t10; result.raw[11] = t11; result.raw[12] = t12; result.raw[13] = t13; result.raw[14] = t14; result.raw[15] = t15; return result; } template HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, TFromD t2, TFromD t3, TFromD t4, TFromD t5, TFromD t6, TFromD t7) { VFromD result; result.raw[0] = t0; result.raw[1] = t1; result.raw[2] = t2; result.raw[3] = t3; result.raw[4] = t4; result.raw[5] = t5; result.raw[6] = t6; result.raw[7] = t7; return result; } template HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1, TFromD t2, TFromD t3) { VFromD result; result.raw[0] = t0; result.raw[1] = t1; result.raw[2] = t2; result.raw[3] = t3; return result; } template HWY_API VFromD Dup128VecFromValues(D /*d*/, TFromD t0, TFromD t1) { VFromD result; result.raw[0] = t0; result.raw[1] = t1; return result; } // ------------------------------ Iota template , typename T2> HWY_API VFromD Iota(D d, T2 first) { VFromD v; for (size_t i = 0; i < MaxLanes(d); ++i) { v.raw[i] = AddWithWraparound(static_cast(first), i); } return v; } // ================================================== LOGICAL // ------------------------------ Not template HWY_API Vec128 Not(Vec128 v) { const DFromV d; const RebindToUnsigned du; using TU = TFromD; VFromD vu = BitCast(du, v); for (size_t i = 0; i < N; ++i) { vu.raw[i] = static_cast(~vu.raw[i]); } return BitCast(d, vu); } // ------------------------------ And template HWY_API Vec128 And(Vec128 a, Vec128 b) { const DFromV d; const RebindToUnsigned du; auto au = BitCast(du, a); auto bu = BitCast(du, b); for (size_t i = 0; i < N; ++i) { au.raw[i] &= bu.raw[i]; } return BitCast(d, au); } template HWY_API Vec128 operator&(Vec128 a, Vec128 b) { return And(a, b); } // ------------------------------ AndNot template HWY_API Vec128 AndNot(Vec128 a, Vec128 b) { return And(Not(a), b); } // ------------------------------ Or template HWY_API Vec128 Or(Vec128 a, Vec128 b) { const DFromV d; const RebindToUnsigned du; auto au = BitCast(du, a); auto bu = BitCast(du, b); for (size_t i = 0; i < N; ++i) { au.raw[i] |= bu.raw[i]; } return BitCast(d, au); } template HWY_API Vec128 operator|(Vec128 a, Vec128 b) { return Or(a, b); } // ------------------------------ Xor template HWY_API Vec128 Xor(Vec128 a, Vec128 b) { const DFromV d; const RebindToUnsigned du; auto au = BitCast(du, a); auto bu = BitCast(du, b); for (size_t i = 0; i < N; ++i) { au.raw[i] ^= bu.raw[i]; } return BitCast(d, au); } template HWY_API Vec128 operator^(Vec128 a, Vec128 b) { return Xor(a, b); } // ------------------------------ Xor3 template HWY_API Vec128 Xor3(Vec128 x1, Vec128 x2, Vec128 x3) { return Xor(x1, Xor(x2, x3)); } // ------------------------------ Or3 template HWY_API Vec128 Or3(Vec128 o1, Vec128 o2, Vec128 o3) { return Or(o1, Or(o2, o3)); } // ------------------------------ OrAnd template HWY_API Vec128 OrAnd(Vec128 o, Vec128 a1, Vec128 a2) { return Or(o, And(a1, a2)); } // ------------------------------ IfVecThenElse template HWY_API Vec128 IfVecThenElse(Vec128 mask, Vec128 yes, Vec128 no) { return Or(And(mask, yes), AndNot(mask, no)); } // ------------------------------ CopySign template HWY_API Vec128 CopySign(Vec128 magn, Vec128 sign) { static_assert(IsFloat(), "Only makes sense for floating-point"); const DFromV d; return BitwiseIfThenElse(SignBit(d), sign, magn); } // ------------------------------ CopySignToAbs template HWY_API Vec128 CopySignToAbs(Vec128 abs, Vec128 sign) { static_assert(IsFloat(), "Only makes sense for floating-point"); const DFromV d; return OrAnd(abs, SignBit(d), sign); } // ------------------------------ BroadcastSignBit template HWY_API Vec128 BroadcastSignBit(Vec128 v) { // This is used inside ShiftRight, so we cannot implement in terms of it. for (size_t i = 0; i < N; ++i) { v.raw[i] = static_cast(v.raw[i] < 0 ? -1 : 0); } return v; } // ------------------------------ Mask // v must be 0 or FF..FF. template HWY_API Mask128 MaskFromVec(Vec128 v) { Mask128 mask; CopySameSize(&v.raw, &mask.bits); return mask; } template using MFromD = decltype(MaskFromVec(VFromD())); template HWY_API MFromD RebindMask(DTo /* tag */, MFrom mask) { MFromD to; CopySameSize(&mask.bits, &to.bits); return to; } template VFromD VecFromMask(D /* tag */, MFromD mask) { VFromD v; CopySameSize(&mask.bits, &v.raw); return v; } template HWY_API MFromD FirstN(D d, size_t n) { MFromD m; for (size_t i = 0; i < MaxLanes(d); ++i) { m.bits[i] = MFromD::FromBool(i < n); } return m; } // Returns mask ? yes : no. template HWY_API Vec128 IfThenElse(Mask128 mask, Vec128 yes, Vec128 no) { const DFromV d; return IfVecThenElse(VecFromMask(d, mask), yes, no); } template HWY_API Vec128 IfThenElseZero(Mask128 mask, Vec128 yes) { const DFromV d; return IfVecThenElse(VecFromMask(d, mask), yes, Zero(d)); } template HWY_API Vec128 IfThenZeroElse(Mask128 mask, Vec128 no) { const DFromV d; return IfVecThenElse(VecFromMask(d, mask), Zero(d), no); } template HWY_API Vec128 IfNegativeThenElse(Vec128 v, Vec128 yes, Vec128 no) { const DFromV d; const RebindToSigned di; const auto vi = BitCast(di, v); for (size_t i = 0; i < N; ++i) { v.raw[i] = vi.raw[i] < 0 ? yes.raw[i] : no.raw[i]; } return v; } template HWY_API Vec128 ZeroIfNegative(Vec128 v) { const DFromV d; return IfNegativeThenElse(v, Zero(d), v); } // ------------------------------ Mask logical template HWY_API Mask128 Not(Mask128 m) { const Simd d; return MaskFromVec(Not(VecFromMask(d, m))); } template HWY_API Mask128 And(Mask128 a, Mask128 b) { const Simd d; return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); } template HWY_API Mask128 AndNot(Mask128 a, Mask128 b) { const Simd d; return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); } template HWY_API Mask128 Or(Mask128 a, Mask128 b) { const Simd d; return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); } template HWY_API Mask128 Xor(Mask128 a, Mask128 b) { const Simd d; return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); } template HWY_API Mask128 ExclusiveNeither(Mask128 a, Mask128 b) { const Simd d; return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b)))); } // ================================================== SHIFTS // ------------------------------ ShiftLeft/ShiftRight (BroadcastSignBit) template HWY_API Vec128 ShiftLeft(Vec128 v) { static_assert(0 <= kBits && kBits < sizeof(T) * 8, "Invalid shift"); using TU = hwy::MakeUnsigned; for (size_t i = 0; i < N; ++i) { const TU raw_u = static_cast(v.raw[i]); const auto shifted = raw_u << kBits; // separate line to avoid MSVC warning v.raw[i] = static_cast(shifted); } return v; } template HWY_API Vec128 ShiftRight(Vec128 v) { static_assert(0 <= kBits && kBits < sizeof(T) * 8, "Invalid shift"); #if __cplusplus >= 202002L // Signed right shift is now guaranteed to be arithmetic (rounding toward // negative infinity, i.e. shifting in the sign bit). for (size_t i = 0; i < N; ++i) { v.raw[i] = static_cast(v.raw[i] >> kBits); } #else if (IsSigned()) { // Emulate arithmetic shift using only logical (unsigned) shifts, because // signed shifts are still implementation-defined. using TU = hwy::MakeUnsigned; for (size_t i = 0; i < N; ++i) { const TU shifted = static_cast(static_cast(v.raw[i]) >> kBits); const TU sign = v.raw[i] < 0 ? static_cast(~TU{0}) : 0; const size_t sign_shift = static_cast(static_cast(sizeof(TU)) * 8 - 1 - kBits); const TU upper = static_cast(sign << sign_shift); v.raw[i] = static_cast(shifted | upper); } } else { // T is unsigned for (size_t i = 0; i < N; ++i) { v.raw[i] = static_cast(v.raw[i] >> kBits); } } #endif return v; } // ------------------------------ RotateRight (ShiftRight) template HWY_API Vec128 RotateRight(const Vec128 v) { constexpr size_t kSizeInBits = sizeof(T) * 8; static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count"); if (kBits == 0) return v; return Or(ShiftRight(v), ShiftLeft(v)); } // ------------------------------ ShiftLeftSame template HWY_API Vec128 ShiftLeftSame(Vec128 v, int bits) { for (size_t i = 0; i < N; ++i) { const auto shifted = static_cast>(v.raw[i]) << bits; v.raw[i] = static_cast(shifted); } return v; } template HWY_API Vec128 ShiftRightSame(Vec128 v, int bits) { #if __cplusplus >= 202002L // Signed right shift is now guaranteed to be arithmetic (rounding toward // negative infinity, i.e. shifting in the sign bit). for (size_t i = 0; i < N; ++i) { v.raw[i] = static_cast(v.raw[i] >> bits); } #else if (IsSigned()) { // Emulate arithmetic shift using only logical (unsigned) shifts, because // signed shifts are still implementation-defined. using TU = hwy::MakeUnsigned; for (size_t i = 0; i < N; ++i) { const TU shifted = static_cast(static_cast(v.raw[i]) >> bits); const TU sign = v.raw[i] < 0 ? static_cast(~TU{0}) : 0; const size_t sign_shift = static_cast(static_cast(sizeof(TU)) * 8 - 1 - bits); const TU upper = static_cast(sign << sign_shift); v.raw[i] = static_cast(shifted | upper); } } else { for (size_t i = 0; i < N; ++i) { v.raw[i] = static_cast(v.raw[i] >> bits); // unsigned, logical shift } } #endif return v; } // ------------------------------ Shl template HWY_API Vec128 operator<<(Vec128 v, Vec128 bits) { for (size_t i = 0; i < N; ++i) { const auto shifted = static_cast>(v.raw[i]) << bits.raw[i]; v.raw[i] = static_cast(shifted); } return v; } template HWY_API Vec128 operator>>(Vec128 v, Vec128 bits) { #if __cplusplus >= 202002L // Signed right shift is now guaranteed to be arithmetic (rounding toward // negative infinity, i.e. shifting in the sign bit). for (size_t i = 0; i < N; ++i) { v.raw[i] = static_cast(v.raw[i] >> bits.raw[i]); } #else if (IsSigned()) { // Emulate arithmetic shift using only logical (unsigned) shifts, because // signed shifts are still implementation-defined. using TU = hwy::MakeUnsigned; for (size_t i = 0; i < N; ++i) { const TU shifted = static_cast(static_cast(v.raw[i]) >> bits.raw[i]); const TU sign = v.raw[i] < 0 ? static_cast(~TU{0}) : 0; const size_t sign_shift = static_cast( static_cast(sizeof(TU)) * 8 - 1 - bits.raw[i]); const TU upper = static_cast(sign << sign_shift); v.raw[i] = static_cast(shifted | upper); } } else { // T is unsigned for (size_t i = 0; i < N; ++i) { v.raw[i] = static_cast(v.raw[i] >> bits.raw[i]); } } #endif return v; } // ================================================== ARITHMETIC // Tag dispatch instead of SFINAE for MSVC 2017 compatibility namespace detail { template HWY_INLINE Vec128 Add(hwy::NonFloatTag /*tag*/, Vec128 a, Vec128 b) { for (size_t i = 0; i < N; ++i) { const uint64_t a64 = static_cast(a.raw[i]); const uint64_t b64 = static_cast(b.raw[i]); a.raw[i] = static_cast((a64 + b64) & static_cast(~T(0))); } return a; } template HWY_INLINE Vec128 Sub(hwy::NonFloatTag /*tag*/, Vec128 a, Vec128 b) { for (size_t i = 0; i < N; ++i) { const uint64_t a64 = static_cast(a.raw[i]); const uint64_t b64 = static_cast(b.raw[i]); a.raw[i] = static_cast((a64 - b64) & static_cast(~T(0))); } return a; } template HWY_INLINE Vec128 Add(hwy::FloatTag /*tag*/, Vec128 a, Vec128 b) { for (size_t i = 0; i < N; ++i) { a.raw[i] += b.raw[i]; } return a; } template HWY_INLINE Vec128 Sub(hwy::FloatTag /*tag*/, Vec128 a, Vec128 b) { for (size_t i = 0; i < N; ++i) { a.raw[i] -= b.raw[i]; } return a; } } // namespace detail template HWY_API Vec128 operator-(Vec128 a, Vec128 b) { return detail::Sub(hwy::IsFloatTag(), a, b); } template HWY_API Vec128 operator+(Vec128 a, Vec128 b) { return detail::Add(hwy::IsFloatTag(), a, b); } // ------------------------------ SumsOf8 template HWY_API Vec128 SumsOf8(Vec128 v) { Vec128 sums; for (size_t i = 0; i < N; ++i) { sums.raw[i / 8] += v.raw[i]; } return sums; } template HWY_API Vec128 SumsOf8(Vec128 v) { Vec128 sums; for (size_t i = 0; i < N; ++i) { sums.raw[i / 8] += v.raw[i]; } return sums; } // ------------------------------ SaturatedAdd template HWY_API Vec128 SaturatedAdd(Vec128 a, Vec128 b) { using TW = MakeSigned>; for (size_t i = 0; i < N; ++i) { a.raw[i] = static_cast(HWY_MIN( HWY_MAX(hwy::LowestValue(), static_cast(a.raw[i]) + b.raw[i]), hwy::HighestValue())); } return a; } // ------------------------------ SaturatedSub template HWY_API Vec128 SaturatedSub(Vec128 a, Vec128 b) { using TW = MakeSigned>; for (size_t i = 0; i < N; ++i) { a.raw[i] = static_cast(HWY_MIN( HWY_MAX(hwy::LowestValue(), static_cast(a.raw[i]) - b.raw[i]), hwy::HighestValue())); } return a; } // ------------------------------ AverageRound template HWY_API Vec128 AverageRound(Vec128 a, Vec128 b) { static_assert(!IsSigned(), "Only for unsigned"); for (size_t i = 0; i < N; ++i) { a.raw[i] = static_cast((a.raw[i] + b.raw[i] + 1) / 2); } return a; } // ------------------------------ Abs template HWY_API Vec128 Abs(Vec128 a) { for (size_t i = 0; i < N; ++i) { a.raw[i] = ScalarAbs(a.raw[i]); } return a; } // ------------------------------ Min/Max // Tag dispatch instead of SFINAE for MSVC 2017 compatibility namespace detail { template HWY_INLINE Vec128 Min(hwy::NonFloatTag /*tag*/, Vec128 a, Vec128 b) { for (size_t i = 0; i < N; ++i) { a.raw[i] = HWY_MIN(a.raw[i], b.raw[i]); } return a; } template HWY_INLINE Vec128 Max(hwy::NonFloatTag /*tag*/, Vec128 a, Vec128 b) { for (size_t i = 0; i < N; ++i) { a.raw[i] = HWY_MAX(a.raw[i], b.raw[i]); } return a; } template HWY_INLINE Vec128 Min(hwy::FloatTag /*tag*/, Vec128 a, Vec128 b) { for (size_t i = 0; i < N; ++i) { if (ScalarIsNaN(a.raw[i])) { a.raw[i] = b.raw[i]; } else if (ScalarIsNaN(b.raw[i])) { // no change } else { a.raw[i] = HWY_MIN(a.raw[i], b.raw[i]); } } return a; } template HWY_INLINE Vec128 Max(hwy::FloatTag /*tag*/, Vec128 a, Vec128 b) { for (size_t i = 0; i < N; ++i) { if (ScalarIsNaN(a.raw[i])) { a.raw[i] = b.raw[i]; } else if (ScalarIsNaN(b.raw[i])) { // no change } else { a.raw[i] = HWY_MAX(a.raw[i], b.raw[i]); } } return a; } } // namespace detail template HWY_API Vec128 Min(Vec128 a, Vec128 b) { return detail::Min(hwy::IsFloatTag(), a, b); } template HWY_API Vec128 Max(Vec128 a, Vec128 b) { return detail::Max(hwy::IsFloatTag(), a, b); } // ------------------------------ Neg // Tag dispatch instead of SFINAE for MSVC 2017 compatibility namespace detail { template HWY_API Vec128 Neg(hwy::NonFloatTag /*tag*/, Vec128 v) { const DFromV d; return Zero(d) - v; } template HWY_API Vec128 Neg(hwy::FloatTag /*tag*/, Vec128 v) { const DFromV d; return Xor(v, SignBit(d)); } template HWY_API Vec128 Neg(hwy::SpecialTag /*tag*/, Vec128 v) { const DFromV d; return Xor(v, SignBit(d)); } } // namespace detail template HWY_API Vec128 Neg(Vec128 v) { return detail::Neg(hwy::IsFloatTag(), v); } // ------------------------------ Mul/Div // Tag dispatch instead of SFINAE for MSVC 2017 compatibility namespace detail { template HWY_INLINE Vec128 Mul(hwy::FloatTag /*tag*/, Vec128 a, Vec128 b) { for (size_t i = 0; i < N; ++i) { a.raw[i] *= b.raw[i]; } return a; } template HWY_INLINE Vec128 Mul(SignedTag /*tag*/, Vec128 a, Vec128 b) { for (size_t i = 0; i < N; ++i) { a.raw[i] = static_cast(static_cast(a.raw[i]) * static_cast(b.raw[i])); } return a; } template HWY_INLINE Vec128 Mul(UnsignedTag /*tag*/, Vec128 a, Vec128 b) { for (size_t i = 0; i < N; ++i) { a.raw[i] = static_cast(static_cast(a.raw[i]) * static_cast(b.raw[i])); } return a; } } // namespace detail // Per-target flags to prevent generic_ops-inl.h defining 8/64-bit operator*. #ifdef HWY_NATIVE_MUL_8 #undef HWY_NATIVE_MUL_8 #else #define HWY_NATIVE_MUL_8 #endif #ifdef HWY_NATIVE_MUL_64 #undef HWY_NATIVE_MUL_64 #else #define HWY_NATIVE_MUL_64 #endif template HWY_API Vec128 operator*(Vec128 a, Vec128 b) { return detail::Mul(hwy::TypeTag(), a, b); } template HWY_API Vec128 operator/(Vec128 a, Vec128 b) { for (size_t i = 0; i < N; ++i) { a.raw[i] = (b.raw[i] == T{0}) ? 0 : a.raw[i] / b.raw[i]; } return a; } // Returns the upper 16 bits of a * b in each lane. template HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { for (size_t i = 0; i < N; ++i) { a.raw[i] = static_cast((int32_t{a.raw[i]} * b.raw[i]) >> 16); } return a; } template HWY_API Vec128 MulHigh(Vec128 a, Vec128 b) { for (size_t i = 0; i < N; ++i) { // Cast to uint32_t first to prevent overflow. Otherwise the result of // uint16_t * uint16_t is in "int" which may overflow. In practice the // result is the same but this way it is also defined. a.raw[i] = static_cast( (static_cast(a.raw[i]) * static_cast(b.raw[i])) >> 16); } return a; } template HWY_API Vec128 MulFixedPoint15(Vec128 a, Vec128 b) { for (size_t i = 0; i < N; ++i) { a.raw[i] = static_cast((a.raw[i] * b.raw[i] + 16384) >> 15); } return a; } // Multiplies even lanes (0, 2, ..) and returns the double-wide result. template HWY_API Vec128, (N + 1) / 2> MulEven(Vec128 a, Vec128 b) { using TW = MakeWide; Vec128 mul; for (size_t i = 0; i < N; i += 2) { const TW a_wide = a.raw[i]; mul.raw[i / 2] = static_cast(a_wide * b.raw[i]); } return mul; } // Multiplies odd lanes (1, 3, ..) and returns the double-wide result. template HWY_API Vec128, (N + 1) / 2> MulOdd(Vec128 a, Vec128 b) { using TW = MakeWide; Vec128 mul; for (size_t i = 0; i < N; i += 2) { const TW a_wide = a.raw[i + 1]; mul.raw[i / 2] = static_cast(a_wide * b.raw[i + 1]); } return mul; } template HWY_API Vec128 ApproximateReciprocal(Vec128 v) { for (size_t i = 0; i < N; ++i) { // Zero inputs are allowed, but callers are responsible for replacing the // return value with something else (typically using IfThenElse). This check // avoids a ubsan error. The result is arbitrary. v.raw[i] = (ScalarAbs(v.raw[i]) == 0.0f) ? 0.0f : 1.0f / v.raw[i]; } return v; } // generic_ops takes care of integer T. template HWY_API Vec128 AbsDiff(Vec128 a, Vec128 b) { return Abs(a - b); } // ------------------------------ Floating-point multiply-add variants template HWY_API Vec128 MulAdd(Vec128 mul, Vec128 x, Vec128 add) { return mul * x + add; } template HWY_API Vec128 NegMulAdd(Vec128 mul, Vec128 x, Vec128 add) { return add - mul * x; } template HWY_API Vec128 MulSub(Vec128 mul, Vec128 x, Vec128 sub) { return mul * x - sub; } template HWY_API Vec128 NegMulSub(Vec128 mul, Vec128 x, Vec128 sub) { return Neg(mul) * x - sub; } // ------------------------------ Floating-point square root template HWY_API Vec128 ApproximateReciprocalSqrt(Vec128 v) { for (size_t i = 0; i < N; ++i) { const float half = v.raw[i] * 0.5f; // Initial guess based on log2(f) v.raw[i] = BitCastScalar(static_cast( 0x5F3759DF - (BitCastScalar(v.raw[i]) >> 1))); // One Newton-Raphson iteration v.raw[i] = v.raw[i] * (1.5f - (half * v.raw[i] * v.raw[i])); } return v; } namespace detail { static HWY_INLINE float ScalarSqrt(float v) { #if defined(HWY_NO_LIBCXX) #if HWY_COMPILER_GCC_ACTUAL return __builtin_sqrt(v); #else uint32_t bits = BitCastScalar(v); // Coarse approximation, letting the exponent LSB leak into the mantissa bits = (1 << 29) + (bits >> 1) - (1 << 22); return BitCastScalar(bits); #endif // !HWY_COMPILER_GCC_ACTUAL #else return sqrtf(v); #endif // !HWY_NO_LIBCXX } static HWY_INLINE double ScalarSqrt(double v) { #if defined(HWY_NO_LIBCXX) #if HWY_COMPILER_GCC_ACTUAL return __builtin_sqrt(v); #else uint64_t bits = BitCastScalar(v); // Coarse approximation, letting the exponent LSB leak into the mantissa bits = (1ULL << 61) + (bits >> 1) - (1ULL << 51); return BitCastScalar(bits); #endif // !HWY_COMPILER_GCC_ACTUAL #else return sqrt(v); #endif // HWY_NO_LIBCXX } } // namespace detail template HWY_API Vec128 Sqrt(Vec128 v) { for (size_t i = 0; i < N; ++i) { v.raw[i] = detail::ScalarSqrt(v.raw[i]); } return v; } // ------------------------------ Floating-point rounding template HWY_API Vec128 Round(Vec128 v) { using TI = MakeSigned; const T k0 = ConvertScalarTo(0); const Vec128 a = Abs(v); for (size_t i = 0; i < N; ++i) { if (!(a.raw[i] < MantissaEnd())) { // Huge or NaN continue; } const T bias = ConvertScalarTo(v.raw[i] < k0 ? -0.5 : 0.5); const TI rounded = ConvertScalarTo(v.raw[i] + bias); if (rounded == 0) { v.raw[i] = v.raw[i] < 0 ? ConvertScalarTo(-0) : k0; continue; } const T rounded_f = ConvertScalarTo(rounded); // Round to even if ((rounded & 1) && ScalarAbs(rounded_f - v.raw[i]) == ConvertScalarTo(0.5)) { v.raw[i] = ConvertScalarTo(rounded - (v.raw[i] < k0 ? -1 : 1)); continue; } v.raw[i] = rounded_f; } return v; } // Round-to-nearest even. template HWY_API Vec128 NearestInt(Vec128 v) { using T = float; using TI = int32_t; const T k0 = ConvertScalarTo(0); const Vec128 abs = Abs(v); Vec128 ret; for (size_t i = 0; i < N; ++i) { const bool signbit = ScalarSignBit(v.raw[i]); if (!(abs.raw[i] < MantissaEnd())) { // Huge or NaN // Check if too large to cast or NaN if (!(abs.raw[i] <= ConvertScalarTo(LimitsMax()))) { ret.raw[i] = signbit ? LimitsMin() : LimitsMax(); continue; } ret.raw[i] = static_cast(v.raw[i]); continue; } const T bias = ConvertScalarTo(v.raw[i] < k0 ? -0.5 : 0.5); const TI rounded = ConvertScalarTo(v.raw[i] + bias); if (rounded == 0) { ret.raw[i] = 0; continue; } const T rounded_f = ConvertScalarTo(rounded); // Round to even if ((rounded & 1) && ScalarAbs(rounded_f - v.raw[i]) == ConvertScalarTo(0.5)) { ret.raw[i] = rounded - (signbit ? -1 : 1); continue; } ret.raw[i] = rounded; } return ret; } template HWY_API Vec128 Trunc(Vec128 v) { using TI = MakeSigned; const Vec128 abs = Abs(v); for (size_t i = 0; i < N; ++i) { if (!(abs.raw[i] <= MantissaEnd())) { // Huge or NaN continue; } const TI truncated = static_cast(v.raw[i]); if (truncated == 0) { v.raw[i] = v.raw[i] < 0 ? -T{0} : T{0}; continue; } v.raw[i] = static_cast(truncated); } return v; } // Toward +infinity, aka ceiling template Vec128 Ceil(Vec128 v) { constexpr int kMantissaBits = MantissaBits(); using Bits = MakeUnsigned; const Bits kExponentMask = MaxExponentField(); const Bits kMantissaMask = MantissaMask(); const Bits kBias = kExponentMask / 2; for (size_t i = 0; i < N; ++i) { const bool positive = v.raw[i] > Float(0.0); Bits bits = BitCastScalar(v.raw[i]); const int exponent = static_cast(((bits >> kMantissaBits) & kExponentMask) - kBias); // Already an integer. if (exponent >= kMantissaBits) continue; // |v| <= 1 => 0 or 1. if (exponent < 0) { v.raw[i] = positive ? Float{1} : Float{-0.0}; continue; } const Bits mantissa_mask = kMantissaMask >> exponent; // Already an integer if ((bits & mantissa_mask) == 0) continue; // Clear fractional bits and round up if (positive) bits += (kMantissaMask + 1) >> exponent; bits &= ~mantissa_mask; v.raw[i] = BitCastScalar(bits); } return v; } // Toward -infinity, aka floor template Vec128 Floor(Vec128 v) { constexpr int kMantissaBits = MantissaBits(); using Bits = MakeUnsigned; const Bits kExponentMask = MaxExponentField(); const Bits kMantissaMask = MantissaMask(); const Bits kBias = kExponentMask / 2; for (size_t i = 0; i < N; ++i) { const bool negative = v.raw[i] < Float(0.0); Bits bits = BitCastScalar(v.raw[i]); const int exponent = static_cast(((bits >> kMantissaBits) & kExponentMask) - kBias); // Already an integer. if (exponent >= kMantissaBits) continue; // |v| <= 1 => -1 or 0. if (exponent < 0) { v.raw[i] = negative ? Float(-1.0) : Float(0.0); continue; } const Bits mantissa_mask = kMantissaMask >> exponent; // Already an integer if ((bits & mantissa_mask) == 0) continue; // Clear fractional bits and round down if (negative) bits += (kMantissaMask + 1) >> exponent; bits &= ~mantissa_mask; v.raw[i] = BitCastScalar(bits); } return v; } // ------------------------------ Floating-point classification template HWY_API Mask128 IsNaN(Vec128 v) { Mask128 ret; for (size_t i = 0; i < N; ++i) { // std::isnan returns false for 0x7F..FF in clang AVX3 builds, so DIY. ret.bits[i] = Mask128::FromBool(ScalarIsNaN(v.raw[i])); } return ret; } // ================================================== COMPARE template HWY_API Mask128 operator==(Vec128 a, Vec128 b) { Mask128 m; for (size_t i = 0; i < N; ++i) { m.bits[i] = Mask128::FromBool(a.raw[i] == b.raw[i]); } return m; } template HWY_API Mask128 operator!=(Vec128 a, Vec128 b) { Mask128 m; for (size_t i = 0; i < N; ++i) { m.bits[i] = Mask128::FromBool(a.raw[i] != b.raw[i]); } return m; } template HWY_API Mask128 TestBit(Vec128 v, Vec128 bit) { static_assert(!hwy::IsFloat(), "Only integer vectors supported"); return (v & bit) == bit; } template HWY_API Mask128 operator<(Vec128 a, Vec128 b) { Mask128 m; for (size_t i = 0; i < N; ++i) { m.bits[i] = Mask128::FromBool(a.raw[i] < b.raw[i]); } return m; } template HWY_API Mask128 operator>(Vec128 a, Vec128 b) { Mask128 m; for (size_t i = 0; i < N; ++i) { m.bits[i] = Mask128::FromBool(a.raw[i] > b.raw[i]); } return m; } template HWY_API Mask128 operator<=(Vec128 a, Vec128 b) { Mask128 m; for (size_t i = 0; i < N; ++i) { m.bits[i] = Mask128::FromBool(a.raw[i] <= b.raw[i]); } return m; } template HWY_API Mask128 operator>=(Vec128 a, Vec128 b) { Mask128 m; for (size_t i = 0; i < N; ++i) { m.bits[i] = Mask128::FromBool(a.raw[i] >= b.raw[i]); } return m; } // ------------------------------ Lt128 // Only makes sense for full vectors of u64. template HWY_API MFromD Lt128(D /* tag */, Vec128 a, Vec128 b) { const bool lt = (a.raw[1] < b.raw[1]) || (a.raw[1] == b.raw[1] && a.raw[0] < b.raw[0]); Mask128 ret; ret.bits[0] = ret.bits[1] = Mask128::FromBool(lt); return ret; } template HWY_API MFromD Lt128Upper(D /* tag */, Vec128 a, Vec128 b) { const bool lt = a.raw[1] < b.raw[1]; Mask128 ret; ret.bits[0] = ret.bits[1] = Mask128::FromBool(lt); return ret; } // ------------------------------ Eq128 // Only makes sense for full vectors of u64. template HWY_API MFromD Eq128(D /* tag */, Vec128 a, Vec128 b) { const bool eq = a.raw[1] == b.raw[1] && a.raw[0] == b.raw[0]; Mask128 ret; ret.bits[0] = ret.bits[1] = Mask128::FromBool(eq); return ret; } template HWY_API Mask128 Ne128(D /* tag */, Vec128 a, Vec128 b) { const bool ne = a.raw[1] != b.raw[1] || a.raw[0] != b.raw[0]; Mask128 ret; ret.bits[0] = ret.bits[1] = Mask128::FromBool(ne); return ret; } template HWY_API MFromD Eq128Upper(D /* tag */, Vec128 a, Vec128 b) { const bool eq = a.raw[1] == b.raw[1]; Mask128 ret; ret.bits[0] = ret.bits[1] = Mask128::FromBool(eq); return ret; } template HWY_API MFromD Ne128Upper(D /* tag */, Vec128 a, Vec128 b) { const bool ne = a.raw[1] != b.raw[1]; Mask128 ret; ret.bits[0] = ret.bits[1] = Mask128::FromBool(ne); return ret; } // ------------------------------ Min128, Max128 (Lt128) template HWY_API VFromD Min128(D d, VFromD a, VFromD b) { return IfThenElse(Lt128(d, a, b), a, b); } template HWY_API VFromD Max128(D d, VFromD a, VFromD b) { return IfThenElse(Lt128(d, b, a), a, b); } template HWY_API VFromD Min128Upper(D d, VFromD a, VFromD b) { return IfThenElse(Lt128Upper(d, a, b), a, b); } template HWY_API VFromD Max128Upper(D d, VFromD a, VFromD b) { return IfThenElse(Lt128Upper(d, b, a), a, b); } // ================================================== MEMORY // ------------------------------ Load template HWY_API VFromD Load(D d, const TFromD* HWY_RESTRICT aligned) { VFromD v; CopyBytes(aligned, v.raw); // copy from array return v; } template HWY_API VFromD MaskedLoad(MFromD m, D d, const TFromD* HWY_RESTRICT p) { return IfThenElseZero(m, LoadU(d, p)); } template HWY_API VFromD MaskedLoadOr(VFromD v, MFromD m, D d, const TFromD* HWY_RESTRICT p) { return IfThenElse(m, LoadU(d, p), v); } template HWY_API VFromD LoadU(D d, const TFromD* HWY_RESTRICT p) { return Load(d, p); } // In some use cases, "load single lane" is sufficient; otherwise avoid this. template HWY_API VFromD LoadDup128(D d, const TFromD* HWY_RESTRICT aligned) { return Load(d, aligned); } #ifdef HWY_NATIVE_LOAD_N #undef HWY_NATIVE_LOAD_N #else #define HWY_NATIVE_LOAD_N #endif template HWY_API VFromD LoadN(D d, const TFromD* HWY_RESTRICT p, size_t max_lanes_to_load) { VFromD v = Zero(d); const size_t N = Lanes(d); const size_t num_of_lanes_to_load = HWY_MIN(max_lanes_to_load, N); CopyBytes(p, v.raw, num_of_lanes_to_load * sizeof(TFromD)); return v; } template HWY_API VFromD LoadNOr(VFromD no, D d, const TFromD* HWY_RESTRICT p, size_t max_lanes_to_load) { VFromD v = no; const size_t N = Lanes(d); const size_t num_of_lanes_to_load = HWY_MIN(max_lanes_to_load, N); CopyBytes(p, v.raw, num_of_lanes_to_load * sizeof(TFromD)); return v; } // ------------------------------ Store template HWY_API void Store(VFromD v, D d, TFromD* HWY_RESTRICT aligned) { CopyBytes(v.raw, aligned); // copy to array } template HWY_API void StoreU(VFromD v, D d, TFromD* HWY_RESTRICT p) { Store(v, d, p); } template HWY_API void BlendedStore(VFromD v, MFromD m, D d, TFromD* HWY_RESTRICT p) { for (size_t i = 0; i < MaxLanes(d); ++i) { if (m.bits[i]) p[i] = v.raw[i]; } } #ifdef HWY_NATIVE_STORE_N #undef HWY_NATIVE_STORE_N #else #define HWY_NATIVE_STORE_N #endif template HWY_API void StoreN(VFromD v, D d, TFromD* HWY_RESTRICT p, size_t max_lanes_to_store) { const size_t N = Lanes(d); const size_t num_of_lanes_to_store = HWY_MIN(max_lanes_to_store, N); CopyBytes(v.raw, p, num_of_lanes_to_store * sizeof(TFromD)); } // ------------------------------ LoadInterleaved2/3/4 // Per-target flag to prevent generic_ops-inl.h from defining LoadInterleaved2. // We implement those here because scalar code is likely faster than emulation // via shuffles. #ifdef HWY_NATIVE_LOAD_STORE_INTERLEAVED #undef HWY_NATIVE_LOAD_STORE_INTERLEAVED #else #define HWY_NATIVE_LOAD_STORE_INTERLEAVED #endif template > HWY_API void LoadInterleaved2(D d, const T* HWY_RESTRICT unaligned, VFromD& v0, VFromD& v1) { alignas(16) T buf0[MaxLanes(d)]; alignas(16) T buf1[MaxLanes(d)]; for (size_t i = 0; i < MaxLanes(d); ++i) { buf0[i] = *unaligned++; buf1[i] = *unaligned++; } v0 = Load(d, buf0); v1 = Load(d, buf1); } template > HWY_API void LoadInterleaved3(D d, const T* HWY_RESTRICT unaligned, VFromD& v0, VFromD& v1, VFromD& v2) { alignas(16) T buf0[MaxLanes(d)]; alignas(16) T buf1[MaxLanes(d)]; alignas(16) T buf2[MaxLanes(d)]; for (size_t i = 0; i < MaxLanes(d); ++i) { buf0[i] = *unaligned++; buf1[i] = *unaligned++; buf2[i] = *unaligned++; } v0 = Load(d, buf0); v1 = Load(d, buf1); v2 = Load(d, buf2); } template > HWY_API void LoadInterleaved4(D d, const T* HWY_RESTRICT unaligned, VFromD& v0, VFromD& v1, VFromD& v2, VFromD& v3) { alignas(16) T buf0[MaxLanes(d)]; alignas(16) T buf1[MaxLanes(d)]; alignas(16) T buf2[MaxLanes(d)]; alignas(16) T buf3[MaxLanes(d)]; for (size_t i = 0; i < MaxLanes(d); ++i) { buf0[i] = *unaligned++; buf1[i] = *unaligned++; buf2[i] = *unaligned++; buf3[i] = *unaligned++; } v0 = Load(d, buf0); v1 = Load(d, buf1); v2 = Load(d, buf2); v3 = Load(d, buf3); } // ------------------------------ StoreInterleaved2/3/4 template HWY_API void StoreInterleaved2(VFromD v0, VFromD v1, D d, TFromD* HWY_RESTRICT unaligned) { for (size_t i = 0; i < MaxLanes(d); ++i) { *unaligned++ = v0.raw[i]; *unaligned++ = v1.raw[i]; } } template HWY_API void StoreInterleaved3(VFromD v0, VFromD v1, VFromD v2, D d, TFromD* HWY_RESTRICT unaligned) { for (size_t i = 0; i < MaxLanes(d); ++i) { *unaligned++ = v0.raw[i]; *unaligned++ = v1.raw[i]; *unaligned++ = v2.raw[i]; } } template HWY_API void StoreInterleaved4(VFromD v0, VFromD v1, VFromD v2, VFromD v3, D d, TFromD* HWY_RESTRICT unaligned) { for (size_t i = 0; i < MaxLanes(d); ++i) { *unaligned++ = v0.raw[i]; *unaligned++ = v1.raw[i]; *unaligned++ = v2.raw[i]; *unaligned++ = v3.raw[i]; } } // ------------------------------ Stream template HWY_API void Stream(VFromD v, D d, TFromD* HWY_RESTRICT aligned) { Store(v, d, aligned); } // ------------------------------ Scatter in generic_ops-inl.h // ------------------------------ Gather in generic_ops-inl.h // ================================================== CONVERT // ConvertTo and DemoteTo with floating-point input and integer output truncate // (rounding toward zero). namespace detail { template HWY_INLINE ToT CastValueForF2IConv(FromT val) { // Prevent ubsan errors when converting float to narrower integer using FromTU = MakeUnsigned; using ToTU = MakeUnsigned; constexpr unsigned kMaxExpField = static_cast(MaxExponentField()); constexpr unsigned kExpBias = kMaxExpField >> 1; constexpr unsigned kMinOutOfRangeExpField = static_cast(HWY_MIN( kExpBias + sizeof(ToT) * 8 - static_cast(IsSigned()), kMaxExpField)); // If ToT is signed, compare only the exponent bits of val against // kMinOutOfRangeExpField. // // Otherwise, if ToT is unsigned, compare the sign bit plus exponent bits of // val against kMinOutOfRangeExpField as a negative value is outside of the // range of an unsigned integer type. const FromT val_to_compare = static_cast(IsSigned() ? ScalarAbs(val) : val); // val is within the range of ToT if // (BitCastScalar(val_to_compare) >> MantissaBits()) is less // than kMinOutOfRangeExpField // // Otherwise, val is either outside of the range of ToT or equal to // LimitsMin() if // (BitCastScalar(val_to_compare) >> MantissaBits()) is greater // than or equal to kMinOutOfRangeExpField. return (static_cast(BitCastScalar(val_to_compare) >> MantissaBits()) < kMinOutOfRangeExpField) ? static_cast(val) : static_cast(static_cast(LimitsMax()) + static_cast(ScalarSignBit(val))); } template HWY_INLINE ToT CastValueForPromoteTo(ToTypeTag /* to_type_tag */, FromT val) { return ConvertScalarTo(val); } template HWY_INLINE ToT CastValueForPromoteTo(hwy::SignedTag /*to_type_tag*/, float val) { return CastValueForF2IConv(val); } template HWY_INLINE ToT CastValueForPromoteTo(hwy::UnsignedTag /*to_type_tag*/, float val) { return CastValueForF2IConv(val); } } // namespace detail template HWY_API VFromD PromoteTo(DTo d, Vec128 from) { static_assert(sizeof(TFromD) > sizeof(TFrom), "Not promoting"); VFromD ret; for (size_t i = 0; i < MaxLanes(d); ++i) { // For bits Y > X, floatX->floatY and intX->intY are always representable. ret.raw[i] = detail::CastValueForPromoteTo>( hwy::TypeTag>(), from.raw[i]); } return ret; } // MSVC 19.10 cannot deduce the argument type if HWY_IF_FLOAT(TFrom) is here, // so we overload for TFrom=double and ToT={float,int32_t}. template HWY_API VFromD DemoteTo(D d, VFromD> from) { VFromD ret; for (size_t i = 0; i < MaxLanes(d); ++i) { // Prevent ubsan errors when converting float to narrower integer/float if (ScalarIsInf(from.raw[i]) || ScalarAbs(from.raw[i]) > static_cast(HighestValue())) { ret.raw[i] = ScalarSignBit(from.raw[i]) ? LowestValue() : HighestValue(); continue; } ret.raw[i] = static_cast(from.raw[i]); } return ret; } template HWY_API VFromD DemoteTo(D d, VFromD> from) { VFromD ret; for (size_t i = 0; i < MaxLanes(d); ++i) { // Prevent ubsan errors when converting double to narrower integer/int32_t ret.raw[i] = detail::CastValueForF2IConv>(from.raw[i]); } return ret; } template )> HWY_API VFromD DemoteTo(DTo /* tag */, Vec128 from) { using TTo = TFromD; static_assert(sizeof(TTo) < sizeof(TFrom), "Not demoting"); VFromD ret; for (size_t i = 0; i < N; ++i) { // Int to int: choose closest value in ToT to `from` (avoids UB) from.raw[i] = HWY_MIN(HWY_MAX(LimitsMin(), from.raw[i]), LimitsMax()); ret.raw[i] = static_cast(from.raw[i]); } return ret; } template HWY_API VFromD DemoteTo(DTo /* tag */, Vec128 from) { using TTo = TFromD; static_assert(sizeof(TTo) < sizeof(TFrom), "Not demoting"); VFromD ret; for (size_t i = 0; i < N; ++i) { // Int to int: choose closest value in ToT to `from` (avoids UB) from.raw[i] = HWY_MIN(from.raw[i], LimitsMax()); ret.raw[i] = static_cast(from.raw[i]); } return ret; } template HWY_API VFromD DemoteTo(DTo /* tag */, Vec128 from) { using TTo = TFromD; static_assert(sizeof(TTo) < sizeof(TFrom), "Not demoting"); VFromD ret; for (size_t i = 0; i < N; ++i) { // int64_t/uint64_t to float: okay to cast to float as an int64_t/uint64_t // value is always within the range of a float ret.raw[i] = static_cast(from.raw[i]); } return ret; } template HWY_API VFromD ReorderDemote2To(DBF16 dbf16, VF32 a, VF32 b) { const Repartition du32; const VFromD b_in_lower = ShiftRight<16>(BitCast(du32, b)); // Avoid OddEven - we want the upper half of `a` even on big-endian systems. const VFromD a_mask = Set(du32, 0xFFFF0000); return BitCast(dbf16, IfVecThenElse(a_mask, BitCast(du32, a), b_in_lower)); } template ), class V, HWY_IF_SIGNED_V(V), HWY_IF_T_SIZE_V(V, sizeof(TFromD) * 2), HWY_IF_LANES_D(DN, HWY_MAX_LANES_D(DFromV) * 2)> HWY_API VFromD ReorderDemote2To(DN dn, V a, V b) { const RepartitionToWide dw; const size_t NW = Lanes(dw); using TN = TFromD; const TN min = LimitsMin(); const TN max = LimitsMax(); VFromD ret; for (size_t i = 0; i < NW; ++i) { ret.raw[i] = static_cast(HWY_MIN(HWY_MAX(min, a.raw[i]), max)); } for (size_t i = 0; i < NW; ++i) { ret.raw[NW + i] = static_cast(HWY_MIN(HWY_MAX(min, b.raw[i]), max)); } return ret; } template ) * 2), HWY_IF_LANES_D(DN, HWY_MAX_LANES_D(DFromV) * 2)> HWY_API VFromD ReorderDemote2To(DN dn, V a, V b) { const RepartitionToWide dw; const size_t NW = Lanes(dw); using TN = TFromD; const TN max = LimitsMax(); VFromD ret; for (size_t i = 0; i < NW; ++i) { ret.raw[i] = static_cast(HWY_MIN(a.raw[i], max)); } for (size_t i = 0; i < NW; ++i) { ret.raw[NW + i] = static_cast(HWY_MIN(b.raw[i], max)); } return ret; } template ), class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V), HWY_IF_T_SIZE_V(V, sizeof(TFromD) * 2), HWY_IF_LANES_D(DN, HWY_MAX_LANES_D(DFromV) * 2)> HWY_API VFromD OrderedDemote2To(DN dn, V a, V b) { return ReorderDemote2To(dn, a, b); } template ), HWY_IF_LANES_D(DN, HWY_MAX_LANES_D(DFromV) * 2)> HWY_API VFromD OrderedDemote2To(DN dn, V a, V b) { const size_t NW = Lanes(dn) / 2; using TN = TFromD; VFromD ret; for (size_t i = 0; i < NW; ++i) { ret.raw[i] = ConvertScalarTo(a.raw[i]); } for (size_t i = 0; i < NW; ++i) { ret.raw[NW + i] = ConvertScalarTo(b.raw[i]); } return ret; } namespace detail { HWY_INLINE void StoreU16ToF16(const uint16_t val, hwy::float16_t* HWY_RESTRICT to) { CopySameSize(&val, to); } HWY_INLINE uint16_t U16FromF16(const hwy::float16_t* HWY_RESTRICT from) { uint16_t bits16; CopySameSize(from, &bits16); return bits16; } } // namespace detail template HWY_API VFromD PromoteTo(D /* tag */, Vec128 v) { VFromD ret; for (size_t i = 0; i < N; ++i) { ret.raw[i] = F32FromBF16(v.raw[i]); } return ret; } template HWY_API VFromD DemoteTo(D /* tag */, Vec128 v) { VFromD ret; for (size_t i = 0; i < N; ++i) { ret.raw[i] = BF16FromF32(v.raw[i]); } return ret; } // Tag dispatch instead of SFINAE for MSVC 2017 compatibility namespace detail { template HWY_API VFromD ConvertTo(hwy::FloatTag /*tag*/, DTo /*tag*/, Vec128 from) { using ToT = TFromD; static_assert(sizeof(ToT) == sizeof(TFrom), "Should have same size"); VFromD ret; constexpr size_t N = HWY_MAX_LANES_D(DTo); for (size_t i = 0; i < N; ++i) { // float## -> int##: return closest representable value ret.raw[i] = CastValueForF2IConv(from.raw[i]); } return ret; } template HWY_API VFromD ConvertTo(hwy::NonFloatTag /*tag*/, DTo /* tag */, Vec128 from) { using ToT = TFromD; static_assert(sizeof(ToT) == sizeof(TFrom), "Should have same size"); VFromD ret; constexpr size_t N = HWY_MAX_LANES_D(DTo); for (size_t i = 0; i < N; ++i) { // int## -> float##: no check needed ret.raw[i] = static_cast(from.raw[i]); } return ret; } } // namespace detail template HWY_API VFromD ConvertTo(DTo d, Vec128 from) { return detail::ConvertTo(hwy::IsFloatTag(), d, from); } template HWY_API Vec128 U8FromU32(Vec128 v) { return DemoteTo(Simd(), v); } // ------------------------------ Truncations template HWY_API VFromD TruncateTo(D /* tag */, Vec128 v) { VFromD ret; for (size_t i = 0; i < N; ++i) { ret.raw[i] = static_cast(v.raw[i] & 0xFF); } return ret; } template HWY_API VFromD TruncateTo(D /* tag */, Vec128 v) { VFromD ret; for (size_t i = 0; i < N; ++i) { ret.raw[i] = static_cast(v.raw[i] & 0xFFFF); } return ret; } template HWY_API VFromD TruncateTo(D /* tag */, Vec128 v) { VFromD ret; for (size_t i = 0; i < N; ++i) { ret.raw[i] = static_cast(v.raw[i] & 0xFFFFFFFFu); } return ret; } template HWY_API VFromD TruncateTo(D /* tag */, Vec128 v) { VFromD ret; for (size_t i = 0; i < N; ++i) { ret.raw[i] = static_cast(v.raw[i] & 0xFF); } return ret; } template HWY_API VFromD TruncateTo(D /* tag */, Vec128 v) { VFromD ret; for (size_t i = 0; i < N; ++i) { ret.raw[i] = static_cast(v.raw[i] & 0xFFFF); } return ret; } template HWY_API VFromD TruncateTo(D /* tag */, Vec128 v) { VFromD ret; for (size_t i = 0; i < N; ++i) { ret.raw[i] = static_cast(v.raw[i] & 0xFF); } return ret; } #ifdef HWY_NATIVE_ORDERED_TRUNCATE_2_TO #undef HWY_NATIVE_ORDERED_TRUNCATE_2_TO #else #define HWY_NATIVE_ORDERED_TRUNCATE_2_TO #endif template ) * 2), HWY_IF_LANES_D(DN, HWY_MAX_LANES_D(DFromV) * 2)> HWY_API VFromD OrderedTruncate2To(DN dn, V a, V b) { const RepartitionToWide dw; const size_t NW = Lanes(dw); using TW = TFromD; using TN = TFromD; VFromD ret; constexpr TW max_val{LimitsMax()}; for (size_t i = 0; i < NW; ++i) { ret.raw[i] = static_cast(a.raw[i] & max_val); } for (size_t i = 0; i < NW; ++i) { ret.raw[NW + i] = static_cast(b.raw[i] & max_val); } return ret; } // ================================================== COMBINE template HWY_API Vec128 LowerHalf(Vec128 v) { Vec128 ret; CopyBytes(v.raw, ret.raw); return ret; } template HWY_API VFromD LowerHalf(D /* tag */, VFromD> v) { return LowerHalf(v); } template HWY_API VFromD UpperHalf(D d, VFromD> v) { VFromD ret; CopyBytes(&v.raw[MaxLanes(d)], ret.raw); return ret; } template HWY_API VFromD ZeroExtendVector(D d, VFromD> v) { const Half dh; VFromD ret; // zero-initialized CopyBytes(v.raw, ret.raw); return ret; } template >> HWY_API VFromD Combine(D d, VH hi_half, VH lo_half) { const Half dh; VFromD ret; CopyBytes(lo_half.raw, &ret.raw[0]); CopyBytes(hi_half.raw, &ret.raw[MaxLanes(dh)]); return ret; } template HWY_API VFromD ConcatLowerLower(D d, VFromD hi, VFromD lo) { const Half dh; VFromD ret; CopyBytes(lo.raw, &ret.raw[0]); CopyBytes(hi.raw, &ret.raw[MaxLanes(dh)]); return ret; } template HWY_API VFromD ConcatUpperUpper(D d, VFromD hi, VFromD lo) { const Half dh; VFromD ret; CopyBytes(&lo.raw[MaxLanes(dh)], &ret.raw[0]); CopyBytes(&hi.raw[MaxLanes(dh)], &ret.raw[MaxLanes(dh)]); return ret; } template HWY_API VFromD ConcatLowerUpper(D d, VFromD hi, VFromD lo) { const Half dh; VFromD ret; CopyBytes(&lo.raw[MaxLanes(dh)], &ret.raw[0]); CopyBytes(hi.raw, &ret.raw[MaxLanes(dh)]); return ret; } template HWY_API VFromD ConcatUpperLower(D d, VFromD hi, VFromD lo) { const Half dh; VFromD ret; CopyBytes(lo.raw, &ret.raw[0]); CopyBytes(&hi.raw[MaxLanes(dh)], &ret.raw[MaxLanes(dh)]); return ret; } template HWY_API VFromD ConcatEven(D d, VFromD hi, VFromD lo) { const Half dh; VFromD ret; for (size_t i = 0; i < MaxLanes(dh); ++i) { ret.raw[i] = lo.raw[2 * i]; } for (size_t i = 0; i < MaxLanes(dh); ++i) { ret.raw[MaxLanes(dh) + i] = hi.raw[2 * i]; } return ret; } // 2023-11-23: workaround for incorrect codegen (reduction_test fails for // SumsOf2 because PromoteOddTo, which uses ConcatOdd, returns zero). #if HWY_ARCH_RVV && HWY_TARGET == HWY_EMU128 && HWY_COMPILER_CLANG #define HWY_EMU128_CONCAT_INLINE HWY_NOINLINE #else #define HWY_EMU128_CONCAT_INLINE HWY_API #endif template HWY_EMU128_CONCAT_INLINE VFromD ConcatOdd(D d, VFromD hi, VFromD lo) { const Half dh; VFromD ret; for (size_t i = 0; i < MaxLanes(dh); ++i) { ret.raw[i] = lo.raw[2 * i + 1]; } for (size_t i = 0; i < MaxLanes(dh); ++i) { ret.raw[MaxLanes(dh) + i] = hi.raw[2 * i + 1]; } return ret; } // ------------------------------ CombineShiftRightBytes template HWY_API VFromD CombineShiftRightBytes(D d, VFromD hi, VFromD lo) { VFromD ret; const uint8_t* HWY_RESTRICT lo8 = reinterpret_cast(lo.raw); uint8_t* HWY_RESTRICT ret8 = reinterpret_cast(ret.raw); CopyBytes(lo8 + kBytes, ret8); CopyBytes(hi.raw, ret8 + d.MaxBytes() - kBytes); return ret; } // ------------------------------ ShiftLeftBytes template HWY_API VFromD ShiftLeftBytes(D d, VFromD v) { static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); VFromD ret; uint8_t* HWY_RESTRICT ret8 = reinterpret_cast(ret.raw); ZeroBytes(ret8); CopyBytes(v.raw, ret8 + kBytes); return ret; } template HWY_API Vec128 ShiftLeftBytes(Vec128 v) { return ShiftLeftBytes(DFromV(), v); } // ------------------------------ ShiftLeftLanes template > HWY_API VFromD ShiftLeftLanes(D d, VFromD v) { const Repartition d8; return BitCast(d, ShiftLeftBytes(BitCast(d8, v))); } template HWY_API Vec128 ShiftLeftLanes(Vec128 v) { return ShiftLeftLanes(DFromV(), v); } // ------------------------------ ShiftRightBytes template HWY_API VFromD ShiftRightBytes(D d, VFromD v) { static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); VFromD ret; const uint8_t* HWY_RESTRICT v8 = reinterpret_cast(v.raw); uint8_t* HWY_RESTRICT ret8 = reinterpret_cast(ret.raw); CopyBytes(v8 + kBytes, ret8); ZeroBytes(ret8 + d.MaxBytes() - kBytes); return ret; } // ------------------------------ ShiftRightLanes template HWY_API VFromD ShiftRightLanes(D d, VFromD v) { const Repartition d8; constexpr size_t kBytes = kLanes * sizeof(TFromD); return BitCast(d, ShiftRightBytes(d8, BitCast(d8, v))); } // ================================================== SWIZZLE template HWY_API T GetLane(Vec128 v) { return v.raw[0]; } template HWY_API Vec128 InsertLane(Vec128 v, size_t i, T t) { v.raw[i] = t; return v; } template HWY_API T ExtractLane(Vec128 v, size_t i) { return v.raw[i]; } template HWY_API Vec128 DupEven(Vec128 v) { for (size_t i = 0; i < N; i += 2) { v.raw[i + 1] = v.raw[i]; } return v; } template HWY_API Vec128 DupOdd(Vec128 v) { for (size_t i = 0; i < N; i += 2) { v.raw[i] = v.raw[i + 1]; } return v; } template HWY_API Vec128 OddEven(Vec128 odd, Vec128 even) { for (size_t i = 0; i < N; i += 2) { odd.raw[i] = even.raw[i]; } return odd; } template HWY_API Vec128 OddEvenBlocks(Vec128 /* odd */, Vec128 even) { return even; } // ------------------------------ SwapAdjacentBlocks template HWY_API Vec128 SwapAdjacentBlocks(Vec128 v) { return v; } // ------------------------------ TableLookupLanes // Returned by SetTableIndices for use by TableLookupLanes. template struct Indices128 { MakeSigned raw[N]; }; template HWY_API Indices128, N> IndicesFromVec(D d, Vec128 vec) { static_assert(sizeof(TFromD) == sizeof(TI), "Index/lane size must match"); Indices128, N> ret; CopyBytes(vec.raw, ret.raw); return ret; } template HWY_API Indices128, HWY_MAX_LANES_D(D)> SetTableIndices( D d, const TI* idx) { return IndicesFromVec(d, LoadU(Rebind(), idx)); } template HWY_API Vec128 TableLookupLanes(Vec128 v, Indices128 idx) { Vec128 ret; for (size_t i = 0; i < N; ++i) { ret.raw[i] = v.raw[idx.raw[i]]; } return ret; } template HWY_API Vec128 TwoTablesLookupLanes(Vec128 a, Vec128 b, Indices128 idx) { using TI = MakeSigned; Vec128 ret; constexpr TI kVecLaneIdxMask = static_cast(N - 1); for (size_t i = 0; i < N; ++i) { const auto src_idx = idx.raw[i]; const auto masked_src_lane_idx = src_idx & kVecLaneIdxMask; ret.raw[i] = (src_idx < static_cast(N)) ? a.raw[masked_src_lane_idx] : b.raw[masked_src_lane_idx]; } return ret; } // ------------------------------ ReverseBlocks template HWY_API VFromD ReverseBlocks(D /* tag */, VFromD v) { return v; // Single block: no change } // ------------------------------ Reverse template HWY_API VFromD Reverse(D d, VFromD v) { VFromD ret; for (size_t i = 0; i < MaxLanes(d); ++i) { ret.raw[i] = v.raw[MaxLanes(d) - 1 - i]; } return ret; } // Per-target flag to prevent generic_ops-inl.h defining 8-bit Reverse2/4/8. #ifdef HWY_NATIVE_REVERSE2_8 #undef HWY_NATIVE_REVERSE2_8 #else #define HWY_NATIVE_REVERSE2_8 #endif template HWY_API VFromD Reverse2(D d, VFromD v) { VFromD ret; for (size_t i = 0; i < MaxLanes(d); i += 2) { ret.raw[i + 0] = v.raw[i + 1]; ret.raw[i + 1] = v.raw[i + 0]; } return ret; } template HWY_API VFromD Reverse4(D d, VFromD v) { VFromD ret; for (size_t i = 0; i < MaxLanes(d); i += 4) { ret.raw[i + 0] = v.raw[i + 3]; ret.raw[i + 1] = v.raw[i + 2]; ret.raw[i + 2] = v.raw[i + 1]; ret.raw[i + 3] = v.raw[i + 0]; } return ret; } template HWY_API VFromD Reverse8(D d, VFromD v) { VFromD ret; for (size_t i = 0; i < MaxLanes(d); i += 8) { ret.raw[i + 0] = v.raw[i + 7]; ret.raw[i + 1] = v.raw[i + 6]; ret.raw[i + 2] = v.raw[i + 5]; ret.raw[i + 3] = v.raw[i + 4]; ret.raw[i + 4] = v.raw[i + 3]; ret.raw[i + 5] = v.raw[i + 2]; ret.raw[i + 6] = v.raw[i + 1]; ret.raw[i + 7] = v.raw[i + 0]; } return ret; } // ------------------------------ SlideUpLanes template HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { VFromD ret = Zero(d); constexpr size_t N = HWY_MAX_LANES_D(D); const size_t clamped_amt = HWY_MIN(amt, N); CopyBytes(v.raw, ret.raw + clamped_amt, (N - clamped_amt) * sizeof(TFromD)); return ret; } // ------------------------------ SlideDownLanes template HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { VFromD ret = Zero(d); constexpr size_t N = HWY_MAX_LANES_D(D); const size_t clamped_amt = HWY_MIN(amt, N); CopyBytes(v.raw + clamped_amt, ret.raw, (N - clamped_amt) * sizeof(TFromD)); return ret; } // ================================================== BLOCKWISE // ------------------------------ Shuffle* // Swap 32-bit halves in 64-bit halves. template HWY_API Vec128 Shuffle2301(Vec128 v) { static_assert(sizeof(T) == 4, "Only for 32-bit"); static_assert(N == 2 || N == 4, "Does not make sense for N=1"); return Reverse2(DFromV(), v); } // Swap 64-bit halves template HWY_API Vec128 Shuffle1032(Vec128 v) { static_assert(sizeof(T) == 4, "Only for 32-bit"); Vec128 ret; ret.raw[3] = v.raw[1]; ret.raw[2] = v.raw[0]; ret.raw[1] = v.raw[3]; ret.raw[0] = v.raw[2]; return ret; } template HWY_API Vec128 Shuffle01(Vec128 v) { static_assert(sizeof(T) == 8, "Only for 64-bit"); return Reverse2(DFromV(), v); } // Rotate right 32 bits template HWY_API Vec128 Shuffle0321(Vec128 v) { Vec128 ret; ret.raw[3] = v.raw[0]; ret.raw[2] = v.raw[3]; ret.raw[1] = v.raw[2]; ret.raw[0] = v.raw[1]; return ret; } // Rotate left 32 bits template HWY_API Vec128 Shuffle2103(Vec128 v) { Vec128 ret; ret.raw[3] = v.raw[2]; ret.raw[2] = v.raw[1]; ret.raw[1] = v.raw[0]; ret.raw[0] = v.raw[3]; return ret; } template HWY_API Vec128 Shuffle0123(Vec128 v) { return Reverse4(DFromV(), v); } // ------------------------------ Broadcast template HWY_API Vec128 Broadcast(Vec128 v) { for (size_t i = 0; i < N; ++i) { v.raw[i] = v.raw[kLane]; } return v; } // ------------------------------ TableLookupBytes, TableLookupBytesOr0 template HWY_API Vec128 TableLookupBytes(Vec128 v, Vec128 indices) { const uint8_t* HWY_RESTRICT v_bytes = reinterpret_cast(v.raw); const uint8_t* HWY_RESTRICT idx_bytes = reinterpret_cast(indices.raw); Vec128 ret; uint8_t* HWY_RESTRICT ret_bytes = reinterpret_cast(ret.raw); for (size_t i = 0; i < NI * sizeof(TI); ++i) { const size_t idx = idx_bytes[i]; // Avoid out of bounds reads. ret_bytes[i] = idx < sizeof(T) * N ? v_bytes[idx] : 0; } return ret; } template HWY_API Vec128 TableLookupBytesOr0(Vec128 v, Vec128 indices) { // Same as TableLookupBytes, which already returns 0 if out of bounds. return TableLookupBytes(v, indices); } // ------------------------------ InterleaveLower/InterleaveUpper template HWY_API Vec128 InterleaveLower(Vec128 a, Vec128 b) { Vec128 ret; for (size_t i = 0; i < N / 2; ++i) { ret.raw[2 * i + 0] = a.raw[i]; ret.raw[2 * i + 1] = b.raw[i]; } return ret; } // Additional overload for the optional tag. template HWY_API VFromD InterleaveLower(D /* tag */, VFromD a, VFromD b) { return InterleaveLower(a, b); } template HWY_API VFromD InterleaveUpper(D d, VFromD a, VFromD b) { const Half dh; VFromD ret; for (size_t i = 0; i < MaxLanes(dh); ++i) { ret.raw[2 * i + 0] = a.raw[MaxLanes(dh) + i]; ret.raw[2 * i + 1] = b.raw[MaxLanes(dh) + i]; } return ret; } // ------------------------------ ZipLower/ZipUpper (InterleaveLower) // Same as Interleave*, except that the return lanes are double-width integers; // this is necessary because the single-lane scalar cannot return two values. template >> HWY_API VFromD ZipLower(V a, V b) { return BitCast(DW(), InterleaveLower(a, b)); } template , class DW = RepartitionToWide> HWY_API VFromD ZipLower(DW dw, V a, V b) { return BitCast(dw, InterleaveLower(D(), a, b)); } template , class DW = RepartitionToWide> HWY_API VFromD ZipUpper(DW dw, V a, V b) { return BitCast(dw, InterleaveUpper(D(), a, b)); } // ================================================== MASK template HWY_API bool AllFalse(D d, MFromD mask) { typename MFromD::Raw or_sum = 0; for (size_t i = 0; i < MaxLanes(d); ++i) { or_sum |= mask.bits[i]; } return or_sum == 0; } template HWY_API bool AllTrue(D d, MFromD mask) { constexpr uint64_t kAll = LimitsMax::Raw>(); uint64_t and_sum = kAll; for (size_t i = 0; i < MaxLanes(d); ++i) { and_sum &= mask.bits[i]; } return and_sum == kAll; } // `p` points to at least 8 readable bytes, not all of which need be valid. template HWY_API MFromD LoadMaskBits(D d, const uint8_t* HWY_RESTRICT bits) { MFromD m; for (size_t i = 0; i < MaxLanes(d); ++i) { const size_t bit = size_t{1} << (i & 7); const size_t idx_byte = i >> 3; m.bits[i] = MFromD::FromBool((bits[idx_byte] & bit) != 0); } return m; } template HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { MFromD m; for (size_t i = 0; i < MaxLanes(d); ++i) { m.bits[i] = MFromD::FromBool(((mask_bits >> i) & 1u) != 0); } return m; } // `p` points to at least 8 writable bytes. template HWY_API size_t StoreMaskBits(D d, MFromD mask, uint8_t* bits) { bits[0] = 0; if (MaxLanes(d) > 8) bits[1] = 0; // MaxLanes(d) <= 16, so max two bytes for (size_t i = 0; i < MaxLanes(d); ++i) { const size_t bit = size_t{1} << (i & 7); const size_t idx_byte = i >> 3; if (mask.bits[i]) { bits[idx_byte] = static_cast(bits[idx_byte] | bit); } } return MaxLanes(d) > 8 ? 2 : 1; } template HWY_API size_t CountTrue(D d, MFromD mask) { size_t count = 0; for (size_t i = 0; i < MaxLanes(d); ++i) { count += mask.bits[i] != 0; } return count; } template HWY_API size_t FindKnownFirstTrue(D d, MFromD mask) { for (size_t i = 0; i < MaxLanes(d); ++i) { if (mask.bits[i] != 0) return i; } HWY_DASSERT(false); return 0; } template HWY_API intptr_t FindFirstTrue(D d, MFromD mask) { for (size_t i = 0; i < MaxLanes(d); ++i) { if (mask.bits[i] != 0) return static_cast(i); } return intptr_t{-1}; } template HWY_API size_t FindKnownLastTrue(D d, MFromD mask) { for (intptr_t i = static_cast(MaxLanes(d) - 1); i >= 0; i--) { if (mask.bits[i] != 0) return static_cast(i); } HWY_DASSERT(false); return 0; } template HWY_API intptr_t FindLastTrue(D d, MFromD mask) { for (intptr_t i = static_cast(MaxLanes(d) - 1); i >= 0; i--) { if (mask.bits[i] != 0) return i; } return intptr_t{-1}; } // ------------------------------ Compress template struct CompressIsPartition { enum { value = (sizeof(T) != 1) }; }; template HWY_API Vec128 Compress(Vec128 v, Mask128 mask) { size_t count = 0; Vec128 ret; for (size_t i = 0; i < N; ++i) { if (mask.bits[i]) { ret.raw[count++] = v.raw[i]; } } for (size_t i = 0; i < N; ++i) { if (!mask.bits[i]) { ret.raw[count++] = v.raw[i]; } } HWY_DASSERT(count == N); return ret; } // ------------------------------ Expand // Could also just allow generic_ops-inl.h to implement these, but use our // simple implementation below to ensure the test is correct. #ifdef HWY_NATIVE_EXPAND #undef HWY_NATIVE_EXPAND #else #define HWY_NATIVE_EXPAND #endif template HWY_API Vec128 Expand(Vec128 v, const Mask128 mask) { size_t in_pos = 0; Vec128 ret; for (size_t i = 0; i < N; ++i) { if (mask.bits[i]) { ret.raw[i] = v.raw[in_pos++]; } else { ret.raw[i] = ConvertScalarTo(0); } } return ret; } // ------------------------------ LoadExpand template HWY_API VFromD LoadExpand(MFromD mask, D d, const TFromD* HWY_RESTRICT unaligned) { size_t in_pos = 0; VFromD ret; for (size_t i = 0; i < Lanes(d); ++i) { if (mask.bits[i]) { ret.raw[i] = unaligned[in_pos++]; } else { ret.raw[i] = TFromD(); // zero, also works for float16_t } } return ret; } // ------------------------------ CompressNot template HWY_API Vec128 CompressNot(Vec128 v, Mask128 mask) { size_t count = 0; Vec128 ret; for (size_t i = 0; i < N; ++i) { if (!mask.bits[i]) { ret.raw[count++] = v.raw[i]; } } for (size_t i = 0; i < N; ++i) { if (mask.bits[i]) { ret.raw[count++] = v.raw[i]; } } HWY_DASSERT(count == N); return ret; } // ------------------------------ CompressBlocksNot HWY_API Vec128 CompressBlocksNot(Vec128 v, Mask128 /* m */) { return v; } // ------------------------------ CompressBits template HWY_API Vec128 CompressBits(Vec128 v, const uint8_t* HWY_RESTRICT bits) { return Compress(v, LoadMaskBits(Simd(), bits)); } // ------------------------------ CompressStore // generic_ops-inl defines the 8-bit versions. template HWY_API size_t CompressStore(VFromD v, MFromD mask, D d, TFromD* HWY_RESTRICT unaligned) { size_t count = 0; for (size_t i = 0; i < MaxLanes(d); ++i) { if (mask.bits[i]) { unaligned[count++] = v.raw[i]; } } return count; } // ------------------------------ CompressBlendedStore template HWY_API size_t CompressBlendedStore(VFromD v, MFromD mask, D d, TFromD* HWY_RESTRICT unaligned) { return CompressStore(v, mask, d, unaligned); } // ------------------------------ CompressBitsStore template HWY_API size_t CompressBitsStore(VFromD v, const uint8_t* HWY_RESTRICT bits, D d, TFromD* HWY_RESTRICT unaligned) { const MFromD mask = LoadMaskBits(d, bits); StoreU(Compress(v, mask), d, unaligned); return CountTrue(d, mask); } // ------------------------------ Additional mask logical operations template HWY_API Mask128 SetAtOrAfterFirst(Mask128 mask) { return mask; } template HWY_API Mask128 SetAtOrAfterFirst(Mask128 mask) { using TU = hwy::MakeUnsigned; Mask128 result; TU result_lane_mask{0}; for (size_t i = 0; i < N; i++) { result_lane_mask = static_cast(result_lane_mask | mask.bits[i]); result.bits[i] = result_lane_mask; } return result; } template HWY_API Mask128 SetBeforeFirst(Mask128 mask) { return Not(SetAtOrAfterFirst(mask)); } template HWY_API Mask128 SetOnlyFirst(Mask128 mask) { using TU = hwy::MakeUnsigned; using TI = hwy::MakeSigned; Mask128 result; TU result_lane_mask = static_cast(~TU{0}); for (size_t i = 0; i < N; i++) { const auto curr_lane_mask_bits = mask.bits[i]; result.bits[i] = static_cast(curr_lane_mask_bits & result_lane_mask); result_lane_mask = static_cast(result_lane_mask & static_cast(-static_cast(mask.bits[i] == 0))); } return result; } template HWY_API Mask128 SetAtOrBeforeFirst(Mask128 mask) { using TU = hwy::MakeUnsigned; using TI = hwy::MakeSigned; Mask128 result; TU result_lane_mask = static_cast(~TU{0}); for (size_t i = 0; i < N; i++) { result.bits[i] = result_lane_mask; result_lane_mask = static_cast(result_lane_mask & static_cast(-static_cast(mask.bits[i] == 0))); } return result; } // ------------------------------ WidenMulPairwiseAdd template HWY_API VFromD WidenMulPairwiseAdd(D df32, VBF16 a, VBF16 b) { const Rebind du32; using VU32 = VFromD; const VU32 odd = Set(du32, 0xFFFF0000u); // bfloat16 is the upper half of f32 // Avoid ZipLower/Upper so this also works on big-endian systems. const VU32 ae = ShiftLeft<16>(BitCast(du32, a)); const VU32 ao = And(BitCast(du32, a), odd); const VU32 be = ShiftLeft<16>(BitCast(du32, b)); const VU32 bo = And(BitCast(du32, b), odd); return Mul(BitCast(df32, ae), BitCast(df32, be)) + Mul(BitCast(df32, ao), BitCast(df32, bo)); } template HWY_API VFromD WidenMulPairwiseAdd(D d32, VI16 a, VI16 b) { using VI32 = VFromD; // Manual sign extension requires two shifts for even lanes. const VI32 ae = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, a))); const VI32 be = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, b))); const VI32 ao = ShiftRight<16>(BitCast(d32, a)); const VI32 bo = ShiftRight<16>(BitCast(d32, b)); return Add(Mul(ae, be), Mul(ao, bo)); } template HWY_API VFromD WidenMulPairwiseAdd(D du32, VU16 a, VU16 b) { const auto lo16_mask = Set(du32, 0x0000FFFFu); const auto a0 = And(BitCast(du32, a), lo16_mask); const auto b0 = And(BitCast(du32, b), lo16_mask); const auto a1 = ShiftRight<16>(BitCast(du32, a)); const auto b1 = ShiftRight<16>(BitCast(du32, b)); return Add(Mul(a0, b0), Mul(a1, b1)); } // ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) template HWY_API VFromD ReorderWidenMulAccumulate(D df32, VBF16 a, VBF16 b, const Vec128 sum0, Vec128& sum1) { const Rebind du32; using VU32 = VFromD; const VU32 odd = Set(du32, 0xFFFF0000u); // bfloat16 is the upper half of f32 // Avoid ZipLower/Upper so this also works on big-endian systems. const VU32 ae = ShiftLeft<16>(BitCast(du32, a)); const VU32 ao = And(BitCast(du32, a), odd); const VU32 be = ShiftLeft<16>(BitCast(du32, b)); const VU32 bo = And(BitCast(du32, b), odd); sum1 = MulAdd(BitCast(df32, ao), BitCast(df32, bo), sum1); return MulAdd(BitCast(df32, ae), BitCast(df32, be), sum0); } template HWY_API VFromD ReorderWidenMulAccumulate(D d32, VI16 a, VI16 b, const Vec128 sum0, Vec128& sum1) { using VI32 = VFromD; // Manual sign extension requires two shifts for even lanes. const VI32 ae = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, a))); const VI32 be = ShiftRight<16>(ShiftLeft<16>(BitCast(d32, b))); const VI32 ao = ShiftRight<16>(BitCast(d32, a)); const VI32 bo = ShiftRight<16>(BitCast(d32, b)); sum1 = Add(Mul(ao, bo), sum1); return Add(Mul(ae, be), sum0); } template HWY_API VFromD ReorderWidenMulAccumulate(D du32, VU16 a, VU16 b, const Vec128 sum0, Vec128& sum1) { using VU32 = VFromD; const VU32 lo16_mask = Set(du32, uint32_t{0x0000FFFFu}); const VU32 ae = And(BitCast(du32, a), lo16_mask); const VU32 be = And(BitCast(du32, b), lo16_mask); const VU32 ao = ShiftRight<16>(BitCast(du32, a)); const VU32 bo = ShiftRight<16>(BitCast(du32, b)); sum1 = Add(Mul(ao, bo), sum1); return Add(Mul(ae, be), sum0); } // ------------------------------ RearrangeToOddPlusEven template HWY_API VW RearrangeToOddPlusEven(VW sum0, VW sum1) { return Add(sum0, sum1); } // ================================================== REDUCTIONS #ifdef HWY_NATIVE_REDUCE_SCALAR #undef HWY_NATIVE_REDUCE_SCALAR #else #define HWY_NATIVE_REDUCE_SCALAR #endif template , HWY_IF_REDUCE_D(D)> HWY_API T ReduceSum(D d, VFromD v) { T sum = T{0}; for (size_t i = 0; i < MaxLanes(d); ++i) { sum += v.raw[i]; } return sum; } template , HWY_IF_REDUCE_D(D)> HWY_API T ReduceMin(D d, VFromD v) { T min = HighestValue(); for (size_t i = 0; i < MaxLanes(d); ++i) { min = HWY_MIN(min, v.raw[i]); } return min; } template , HWY_IF_REDUCE_D(D)> HWY_API T ReduceMax(D d, VFromD v) { T max = LowestValue(); for (size_t i = 0; i < MaxLanes(d); ++i) { max = HWY_MAX(max, v.raw[i]); } return max; } // ------------------------------ SumOfLanes template HWY_API VFromD SumOfLanes(D d, VFromD v) { return Set(d, ReduceSum(d, v)); } template HWY_API VFromD MinOfLanes(D d, VFromD v) { return Set(d, ReduceMin(d, v)); } template HWY_API VFromD MaxOfLanes(D d, VFromD v) { return Set(d, ReduceMax(d, v)); } // ================================================== OPS WITH DEPENDENCIES // ------------------------------ MulEven/Odd 64x64 (UpperHalf) HWY_INLINE Vec128 MulEven(Vec128 a, Vec128 b) { alignas(16) uint64_t mul[2]; mul[0] = Mul128(GetLane(a), GetLane(b), &mul[1]); return Load(Full128(), mul); } HWY_INLINE Vec128 MulOdd(Vec128 a, Vec128 b) { alignas(16) uint64_t mul[2]; const Half> d2; mul[0] = Mul128(GetLane(UpperHalf(d2, a)), GetLane(UpperHalf(d2, b)), &mul[1]); return Load(Full128(), mul); } // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace hwy HWY_AFTER_NAMESPACE();