// Copyright 2019 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Single-element vectors and operations. // External include guard in highway.h - see comment there. #include #include #include "hwy/base.h" #include "hwy/ops/shared-inl.h" HWY_BEFORE_NAMESPACE(); namespace hwy { namespace HWY_NAMESPACE { // Single instruction, single data. template using Sisd = Simd; // (Wrapper class required for overloading comparison operators.) template struct Vec1 { HWY_INLINE Vec1() = default; Vec1(const Vec1&) = default; Vec1& operator=(const Vec1&) = default; HWY_INLINE explicit Vec1(const T t) : raw(t) {} HWY_INLINE Vec1& operator*=(const Vec1 other) { return *this = (*this * other); } HWY_INLINE Vec1& operator/=(const Vec1 other) { return *this = (*this / other); } HWY_INLINE Vec1& operator+=(const Vec1 other) { return *this = (*this + other); } HWY_INLINE Vec1& operator-=(const Vec1 other) { return *this = (*this - other); } HWY_INLINE Vec1& operator&=(const Vec1 other) { return *this = (*this & other); } HWY_INLINE Vec1& operator|=(const Vec1 other) { return *this = (*this | other); } HWY_INLINE Vec1& operator^=(const Vec1 other) { return *this = (*this ^ other); } T raw; }; // 0 or FF..FF, same size as Vec1. template class Mask1 { using Raw = hwy::MakeUnsigned; public: static HWY_INLINE Mask1 FromBool(bool b) { Mask1 mask; mask.bits = b ? ~Raw(0) : 0; return mask; } Raw bits; }; namespace detail { // Deduce Sisd from Vec1 struct Deduce1 { template Sisd operator()(Vec1) const { return Sisd(); } }; } // namespace detail template using DFromV = decltype(detail::Deduce1()(V())); template using TFromV = TFromD>; // ------------------------------ BitCast template HWY_API Vec1 BitCast(Sisd /* tag */, Vec1 v) { static_assert(sizeof(T) <= sizeof(FromT), "Promoting is undefined"); T to; CopyBytes(&v.raw, &to); return Vec1(to); } // ------------------------------ Set template HWY_API Vec1 Zero(Sisd /* tag */) { return Vec1(T(0)); } template HWY_API Vec1 Set(Sisd /* tag */, const T2 t) { return Vec1(static_cast(t)); } template HWY_API Vec1 Undefined(Sisd d) { return Zero(d); } template HWY_API Vec1 Iota(const Sisd /* tag */, const T2 first) { return Vec1(static_cast(first)); } // ================================================== LOGICAL // ------------------------------ Not template HWY_API Vec1 Not(const Vec1 v) { using TU = MakeUnsigned; const Sisd du; return BitCast(Sisd(), Vec1(static_cast(~BitCast(du, v).raw))); } // ------------------------------ And template HWY_API Vec1 And(const Vec1 a, const Vec1 b) { using TU = MakeUnsigned; const Sisd du; return BitCast(Sisd(), Vec1(BitCast(du, a).raw & BitCast(du, b).raw)); } template HWY_API Vec1 operator&(const Vec1 a, const Vec1 b) { return And(a, b); } // ------------------------------ AndNot template HWY_API Vec1 AndNot(const Vec1 a, const Vec1 b) { using TU = MakeUnsigned; const Sisd du; return BitCast(Sisd(), Vec1(static_cast(~BitCast(du, a).raw & BitCast(du, b).raw))); } // ------------------------------ Or template HWY_API Vec1 Or(const Vec1 a, const Vec1 b) { using TU = MakeUnsigned; const Sisd du; return BitCast(Sisd(), Vec1(BitCast(du, a).raw | BitCast(du, b).raw)); } template HWY_API Vec1 operator|(const Vec1 a, const Vec1 b) { return Or(a, b); } // ------------------------------ Xor template HWY_API Vec1 Xor(const Vec1 a, const Vec1 b) { using TU = MakeUnsigned; const Sisd du; return BitCast(Sisd(), Vec1(BitCast(du, a).raw ^ BitCast(du, b).raw)); } template HWY_API Vec1 operator^(const Vec1 a, const Vec1 b) { return Xor(a, b); } // ------------------------------ OrAnd template HWY_API Vec1 OrAnd(const Vec1 o, const Vec1 a1, const Vec1 a2) { return Or(o, And(a1, a2)); } // ------------------------------ IfVecThenElse template HWY_API Vec1 IfVecThenElse(Vec1 mask, Vec1 yes, Vec1 no) { return IfThenElse(MaskFromVec(mask), yes, no); } // ------------------------------ CopySign template HWY_API Vec1 CopySign(const Vec1 magn, const Vec1 sign) { static_assert(IsFloat(), "Only makes sense for floating-point"); const auto msb = SignBit(Sisd()); return Or(AndNot(msb, magn), And(msb, sign)); } template HWY_API Vec1 CopySignToAbs(const Vec1 abs, const Vec1 sign) { static_assert(IsFloat(), "Only makes sense for floating-point"); return Or(abs, And(SignBit(Sisd()), sign)); } // ------------------------------ BroadcastSignBit template HWY_API Vec1 BroadcastSignBit(const Vec1 v) { // This is used inside ShiftRight, so we cannot implement in terms of it. return v.raw < 0 ? Vec1(T(-1)) : Vec1(0); } // ------------------------------ PopulationCount #ifdef HWY_NATIVE_POPCNT #undef HWY_NATIVE_POPCNT #else #define HWY_NATIVE_POPCNT #endif template HWY_API Vec1 PopulationCount(Vec1 v) { return Vec1(static_cast(PopCount(v.raw))); } // ------------------------------ Mask template HWY_API Mask1 RebindMask(Sisd /*tag*/, Mask1 m) { static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); return Mask1{m.bits}; } // v must be 0 or FF..FF. template HWY_API Mask1 MaskFromVec(const Vec1 v) { Mask1 mask; CopyBytes(&v.raw, &mask.bits); return mask; } template Vec1 VecFromMask(const Mask1 mask) { Vec1 v; CopyBytes(&mask.bits, &v.raw); return v; } template Vec1 VecFromMask(Sisd /* tag */, const Mask1 mask) { Vec1 v; CopyBytes(&mask.bits, &v.raw); return v; } template HWY_API Mask1 FirstN(Sisd /*tag*/, size_t n) { return Mask1::FromBool(n != 0); } // Returns mask ? yes : no. template HWY_API Vec1 IfThenElse(const Mask1 mask, const Vec1 yes, const Vec1 no) { return mask.bits ? yes : no; } template HWY_API Vec1 IfThenElseZero(const Mask1 mask, const Vec1 yes) { return mask.bits ? yes : Vec1(0); } template HWY_API Vec1 IfThenZeroElse(const Mask1 mask, const Vec1 no) { return mask.bits ? Vec1(0) : no; } template HWY_API Vec1 IfNegativeThenElse(Vec1 v, Vec1 yes, Vec1 no) { return v.raw < 0 ? yes : no; } template HWY_API Vec1 ZeroIfNegative(const Vec1 v) { return v.raw < 0 ? Vec1(0) : v; } // ------------------------------ Mask logical template HWY_API Mask1 Not(const Mask1 m) { return MaskFromVec(Not(VecFromMask(Sisd(), m))); } template HWY_API Mask1 And(const Mask1 a, Mask1 b) { const Sisd d; return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); } template HWY_API Mask1 AndNot(const Mask1 a, Mask1 b) { const Sisd d; return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); } template HWY_API Mask1 Or(const Mask1 a, Mask1 b) { const Sisd d; return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); } template HWY_API Mask1 Xor(const Mask1 a, Mask1 b) { const Sisd d; return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); } // ================================================== SHIFTS // ------------------------------ ShiftLeft/ShiftRight (BroadcastSignBit) template HWY_API Vec1 ShiftLeft(const Vec1 v) { static_assert(0 <= kBits && kBits < sizeof(T) * 8, "Invalid shift"); return Vec1(static_cast>(v.raw) << kBits); } template HWY_API Vec1 ShiftRight(const Vec1 v) { static_assert(0 <= kBits && kBits < sizeof(T) * 8, "Invalid shift"); #if __cplusplus >= 202002L // Signed right shift is now guaranteed to be arithmetic (rounding toward // negative infinity, i.e. shifting in the sign bit). return Vec1(v.raw >> kBits); #else if (IsSigned()) { // Emulate arithmetic shift using only logical (unsigned) shifts, because // signed shifts are still implementation-defined. using TU = hwy::MakeUnsigned; const Sisd du; const TU shifted = BitCast(du, v).raw >> kBits; const TU sign = BitCast(du, BroadcastSignBit(v)).raw; const TU upper = sign << (sizeof(TU) * 8 - 1 - kBits); return BitCast(Sisd(), Vec1(shifted | upper)); } else { return Vec1(v.raw >> kBits); // unsigned, logical shift } #endif } // ------------------------------ RotateRight (ShiftRight) template HWY_API Vec1 RotateRight(const Vec1 v) { static_assert(0 <= kBits && kBits < sizeof(T) * 8, "Invalid shift"); if (kBits == 0) return v; return Or(ShiftRight(v), ShiftLeft(v)); } // ------------------------------ ShiftLeftSame (BroadcastSignBit) template HWY_API Vec1 ShiftLeftSame(const Vec1 v, int bits) { return Vec1(static_cast>(v.raw) << bits); } template HWY_API Vec1 ShiftRightSame(const Vec1 v, int bits) { #if __cplusplus >= 202002L // Signed right shift is now guaranteed to be arithmetic (rounding toward // negative infinity, i.e. shifting in the sign bit). return Vec1(v.raw >> bits); #else if (IsSigned()) { // Emulate arithmetic shift using only logical (unsigned) shifts, because // signed shifts are still implementation-defined. using TU = hwy::MakeUnsigned; const Sisd du; const TU shifted = BitCast(du, v).raw >> bits; const TU sign = BitCast(du, BroadcastSignBit(v)).raw; const TU upper = sign << (sizeof(TU) * 8 - 1 - bits); return BitCast(Sisd(), Vec1(shifted | upper)); } else { return Vec1(v.raw >> bits); // unsigned, logical shift } #endif } // ------------------------------ Shl // Single-lane => same as ShiftLeftSame except for the argument type. template HWY_API Vec1 operator<<(const Vec1 v, const Vec1 bits) { return ShiftLeftSame(v, static_cast(bits.raw)); } template HWY_API Vec1 operator>>(const Vec1 v, const Vec1 bits) { return ShiftRightSame(v, static_cast(bits.raw)); } // ================================================== ARITHMETIC template HWY_API Vec1 operator+(Vec1 a, Vec1 b) { const uint64_t a64 = static_cast(a.raw); const uint64_t b64 = static_cast(b.raw); return Vec1(static_cast((a64 + b64) & static_cast(~T(0)))); } HWY_API Vec1 operator+(const Vec1 a, const Vec1 b) { return Vec1(a.raw + b.raw); } HWY_API Vec1 operator+(const Vec1 a, const Vec1 b) { return Vec1(a.raw + b.raw); } template HWY_API Vec1 operator-(Vec1 a, Vec1 b) { const uint64_t a64 = static_cast(a.raw); const uint64_t b64 = static_cast(b.raw); return Vec1(static_cast((a64 - b64) & static_cast(~T(0)))); } HWY_API Vec1 operator-(const Vec1 a, const Vec1 b) { return Vec1(a.raw - b.raw); } HWY_API Vec1 operator-(const Vec1 a, const Vec1 b) { return Vec1(a.raw - b.raw); } // ------------------------------ SumsOf8 HWY_API Vec1 SumsOf8(const Vec1 v) { return Vec1(v.raw); } // ------------------------------ SaturatedAdd // Returns a + b clamped to the destination range. // Unsigned HWY_API Vec1 SaturatedAdd(const Vec1 a, const Vec1 b) { return Vec1( static_cast(HWY_MIN(HWY_MAX(0, a.raw + b.raw), 255))); } HWY_API Vec1 SaturatedAdd(const Vec1 a, const Vec1 b) { return Vec1( static_cast(HWY_MIN(HWY_MAX(0, a.raw + b.raw), 65535))); } // Signed HWY_API Vec1 SaturatedAdd(const Vec1 a, const Vec1 b) { return Vec1( static_cast(HWY_MIN(HWY_MAX(-128, a.raw + b.raw), 127))); } HWY_API Vec1 SaturatedAdd(const Vec1 a, const Vec1 b) { return Vec1( static_cast(HWY_MIN(HWY_MAX(-32768, a.raw + b.raw), 32767))); } // ------------------------------ Saturating subtraction // Returns a - b clamped to the destination range. // Unsigned HWY_API Vec1 SaturatedSub(const Vec1 a, const Vec1 b) { return Vec1( static_cast(HWY_MIN(HWY_MAX(0, a.raw - b.raw), 255))); } HWY_API Vec1 SaturatedSub(const Vec1 a, const Vec1 b) { return Vec1( static_cast(HWY_MIN(HWY_MAX(0, a.raw - b.raw), 65535))); } // Signed HWY_API Vec1 SaturatedSub(const Vec1 a, const Vec1 b) { return Vec1( static_cast(HWY_MIN(HWY_MAX(-128, a.raw - b.raw), 127))); } HWY_API Vec1 SaturatedSub(const Vec1 a, const Vec1 b) { return Vec1( static_cast(HWY_MIN(HWY_MAX(-32768, a.raw - b.raw), 32767))); } // ------------------------------ Average // Returns (a + b + 1) / 2 HWY_API Vec1 AverageRound(const Vec1 a, const Vec1 b) { return Vec1(static_cast((a.raw + b.raw + 1) / 2)); } HWY_API Vec1 AverageRound(const Vec1 a, const Vec1 b) { return Vec1(static_cast((a.raw + b.raw + 1) / 2)); } // ------------------------------ Absolute value template HWY_API Vec1 Abs(const Vec1 a) { const T i = a.raw; return (i >= 0 || i == hwy::LimitsMin()) ? a : Vec1(-i); } HWY_API Vec1 Abs(const Vec1 a) { return Vec1(std::abs(a.raw)); } HWY_API Vec1 Abs(const Vec1 a) { return Vec1(std::abs(a.raw)); } // ------------------------------ min/max template HWY_API Vec1 Min(const Vec1 a, const Vec1 b) { return Vec1(HWY_MIN(a.raw, b.raw)); } template HWY_API Vec1 Min(const Vec1 a, const Vec1 b) { if (std::isnan(a.raw)) return b; if (std::isnan(b.raw)) return a; return Vec1(HWY_MIN(a.raw, b.raw)); } template HWY_API Vec1 Max(const Vec1 a, const Vec1 b) { return Vec1(HWY_MAX(a.raw, b.raw)); } template HWY_API Vec1 Max(const Vec1 a, const Vec1 b) { if (std::isnan(a.raw)) return b; if (std::isnan(b.raw)) return a; return Vec1(HWY_MAX(a.raw, b.raw)); } // ------------------------------ Floating-point negate template HWY_API Vec1 Neg(const Vec1 v) { return Xor(v, SignBit(Sisd())); } template HWY_API Vec1 Neg(const Vec1 v) { return Zero(Sisd()) - v; } // ------------------------------ mul/div template HWY_API Vec1 operator*(const Vec1 a, const Vec1 b) { return Vec1(static_cast(double(a.raw) * b.raw)); } template HWY_API Vec1 operator*(const Vec1 a, const Vec1 b) { return Vec1(static_cast(int64_t(a.raw) * b.raw)); } template HWY_API Vec1 operator*(const Vec1 a, const Vec1 b) { return Vec1(static_cast(uint64_t(a.raw) * b.raw)); } template HWY_API Vec1 operator/(const Vec1 a, const Vec1 b) { return Vec1(a.raw / b.raw); } // Returns the upper 16 bits of a * b in each lane. HWY_API Vec1 MulHigh(const Vec1 a, const Vec1 b) { return Vec1(static_cast((a.raw * b.raw) >> 16)); } HWY_API Vec1 MulHigh(const Vec1 a, const Vec1 b) { // Cast to uint32_t first to prevent overflow. Otherwise the result of // uint16_t * uint16_t is in "int" which may overflow. In practice the result // is the same but this way it is also defined. return Vec1(static_cast( (static_cast(a.raw) * static_cast(b.raw)) >> 16)); } // Multiplies even lanes (0, 2 ..) and returns the double-wide result. HWY_API Vec1 MulEven(const Vec1 a, const Vec1 b) { const int64_t a64 = a.raw; return Vec1(a64 * b.raw); } HWY_API Vec1 MulEven(const Vec1 a, const Vec1 b) { const uint64_t a64 = a.raw; return Vec1(a64 * b.raw); } // Approximate reciprocal HWY_API Vec1 ApproximateReciprocal(const Vec1 v) { // Zero inputs are allowed, but callers are responsible for replacing the // return value with something else (typically using IfThenElse). This check // avoids a ubsan error. The return value is arbitrary. if (v.raw == 0.0f) return Vec1(0.0f); return Vec1(1.0f / v.raw); } // Absolute value of difference. HWY_API Vec1 AbsDiff(const Vec1 a, const Vec1 b) { return Abs(a - b); } // ------------------------------ Floating-point multiply-add variants template HWY_API Vec1 MulAdd(const Vec1 mul, const Vec1 x, const Vec1 add) { return mul * x + add; } template HWY_API Vec1 NegMulAdd(const Vec1 mul, const Vec1 x, const Vec1 add) { return add - mul * x; } template HWY_API Vec1 MulSub(const Vec1 mul, const Vec1 x, const Vec1 sub) { return mul * x - sub; } template HWY_API Vec1 NegMulSub(const Vec1 mul, const Vec1 x, const Vec1 sub) { return Neg(mul) * x - sub; } // ------------------------------ Floating-point square root // Approximate reciprocal square root HWY_API Vec1 ApproximateReciprocalSqrt(const Vec1 v) { float f = v.raw; const float half = f * 0.5f; uint32_t bits; CopyBytes<4>(&f, &bits); // Initial guess based on log2(f) bits = 0x5F3759DF - (bits >> 1); CopyBytes<4>(&bits, &f); // One Newton-Raphson iteration return Vec1(f * (1.5f - (half * f * f))); } // Square root HWY_API Vec1 Sqrt(const Vec1 v) { return Vec1(std::sqrt(v.raw)); } HWY_API Vec1 Sqrt(const Vec1 v) { return Vec1(std::sqrt(v.raw)); } // ------------------------------ Floating-point rounding template HWY_API Vec1 Round(const Vec1 v) { using TI = MakeSigned; if (!(Abs(v).raw < MantissaEnd())) { // Huge or NaN return v; } const T bias = v.raw < T(0.0) ? T(-0.5) : T(0.5); const TI rounded = static_cast(v.raw + bias); if (rounded == 0) return CopySignToAbs(Vec1(0), v); // Round to even if ((rounded & 1) && std::abs(rounded - v.raw) == T(0.5)) { return Vec1(static_cast(rounded - (v.raw < T(0) ? -1 : 1))); } return Vec1(static_cast(rounded)); } // Round-to-nearest even. HWY_API Vec1 NearestInt(const Vec1 v) { using T = float; using TI = int32_t; const T abs = Abs(v).raw; const bool signbit = std::signbit(v.raw); if (!(abs < MantissaEnd())) { // Huge or NaN // Check if too large to cast or NaN if (!(abs <= static_cast(LimitsMax()))) { return Vec1(signbit ? LimitsMin() : LimitsMax()); } return Vec1(static_cast(v.raw)); } const T bias = v.raw < T(0.0) ? T(-0.5) : T(0.5); const TI rounded = static_cast(v.raw + bias); if (rounded == 0) return Vec1(0); // Round to even if ((rounded & 1) && std::abs(static_cast(rounded) - v.raw) == T(0.5)) { return Vec1(rounded - (signbit ? -1 : 1)); } return Vec1(rounded); } template HWY_API Vec1 Trunc(const Vec1 v) { using TI = MakeSigned; if (!(Abs(v).raw <= MantissaEnd())) { // Huge or NaN return v; } const TI truncated = static_cast(v.raw); if (truncated == 0) return CopySignToAbs(Vec1(0), v); return Vec1(static_cast(truncated)); } template V Ceiling(const V v) { const Bits kExponentMask = (1ull << kExponentBits) - 1; const Bits kMantissaMask = (1ull << kMantissaBits) - 1; const Bits kBias = kExponentMask / 2; Float f = v.raw; const bool positive = f > Float(0.0); Bits bits; CopyBytes(&v, &bits); const int exponent = static_cast(((bits >> kMantissaBits) & kExponentMask) - kBias); // Already an integer. if (exponent >= kMantissaBits) return v; // |v| <= 1 => 0 or 1. if (exponent < 0) return positive ? V(1) : V(-0.0); const Bits mantissa_mask = kMantissaMask >> exponent; // Already an integer if ((bits & mantissa_mask) == 0) return v; // Clear fractional bits and round up if (positive) bits += (kMantissaMask + 1) >> exponent; bits &= ~mantissa_mask; CopyBytes(&bits, &f); return V(f); } template V Floor(const V v) { const Bits kExponentMask = (1ull << kExponentBits) - 1; const Bits kMantissaMask = (1ull << kMantissaBits) - 1; const Bits kBias = kExponentMask / 2; Float f = v.raw; const bool negative = f < Float(0.0); Bits bits; CopyBytes(&v, &bits); const int exponent = static_cast(((bits >> kMantissaBits) & kExponentMask) - kBias); // Already an integer. if (exponent >= kMantissaBits) return v; // |v| <= 1 => -1 or 0. if (exponent < 0) return V(negative ? Float(-1.0) : Float(0.0)); const Bits mantissa_mask = kMantissaMask >> exponent; // Already an integer if ((bits & mantissa_mask) == 0) return v; // Clear fractional bits and round down if (negative) bits += (kMantissaMask + 1) >> exponent; bits &= ~mantissa_mask; CopyBytes(&bits, &f); return V(f); } // Toward +infinity, aka ceiling HWY_API Vec1 Ceil(const Vec1 v) { return Ceiling(v); } HWY_API Vec1 Ceil(const Vec1 v) { return Ceiling(v); } // Toward -infinity, aka floor HWY_API Vec1 Floor(const Vec1 v) { return Floor(v); } HWY_API Vec1 Floor(const Vec1 v) { return Floor(v); } // ================================================== COMPARE template HWY_API Mask1 operator==(const Vec1 a, const Vec1 b) { return Mask1::FromBool(a.raw == b.raw); } template HWY_API Mask1 operator!=(const Vec1 a, const Vec1 b) { return Mask1::FromBool(a.raw != b.raw); } template HWY_API Mask1 TestBit(const Vec1 v, const Vec1 bit) { static_assert(!hwy::IsFloat(), "Only integer vectors supported"); return (v & bit) == bit; } template HWY_API Mask1 operator<(const Vec1 a, const Vec1 b) { return Mask1::FromBool(a.raw < b.raw); } template HWY_API Mask1 operator>(const Vec1 a, const Vec1 b) { return Mask1::FromBool(a.raw > b.raw); } template HWY_API Mask1 operator<=(const Vec1 a, const Vec1 b) { return Mask1::FromBool(a.raw <= b.raw); } template HWY_API Mask1 operator>=(const Vec1 a, const Vec1 b) { return Mask1::FromBool(a.raw >= b.raw); } // ================================================== MEMORY // ------------------------------ Load template HWY_API Vec1 Load(Sisd /* tag */, const T* HWY_RESTRICT aligned) { T t; CopyBytes(aligned, &t); return Vec1(t); } template HWY_API Vec1 MaskedLoad(Mask1 m, Sisd d, const T* HWY_RESTRICT aligned) { return IfThenElseZero(m, Load(d, aligned)); } template HWY_API Vec1 LoadU(Sisd d, const T* HWY_RESTRICT p) { return Load(d, p); } // In some use cases, "load single lane" is sufficient; otherwise avoid this. template HWY_API Vec1 LoadDup128(Sisd d, const T* HWY_RESTRICT aligned) { return Load(d, aligned); } // ------------------------------ Store template HWY_API void Store(const Vec1 v, Sisd /* tag */, T* HWY_RESTRICT aligned) { CopyBytes(&v.raw, aligned); } template HWY_API void StoreU(const Vec1 v, Sisd d, T* HWY_RESTRICT p) { return Store(v, d, p); } // ------------------------------ StoreInterleaved3 HWY_API void StoreInterleaved3(const Vec1 v0, const Vec1 v1, const Vec1 v2, Sisd d, uint8_t* HWY_RESTRICT unaligned) { StoreU(v0, d, unaligned + 0); StoreU(v1, d, unaligned + 1); StoreU(v2, d, unaligned + 2); } HWY_API void StoreInterleaved4(const Vec1 v0, const Vec1 v1, const Vec1 v2, const Vec1 v3, Sisd d, uint8_t* HWY_RESTRICT unaligned) { StoreU(v0, d, unaligned + 0); StoreU(v1, d, unaligned + 1); StoreU(v2, d, unaligned + 2); StoreU(v3, d, unaligned + 3); } // ------------------------------ Stream template HWY_API void Stream(const Vec1 v, Sisd d, T* HWY_RESTRICT aligned) { return Store(v, d, aligned); } // ------------------------------ Scatter template HWY_API void ScatterOffset(Vec1 v, Sisd d, T* base, const Vec1 offset) { static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); uint8_t* const base8 = reinterpret_cast(base) + offset.raw; return Store(v, d, reinterpret_cast(base8)); } template HWY_API void ScatterIndex(Vec1 v, Sisd d, T* HWY_RESTRICT base, const Vec1 index) { static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); return Store(v, d, base + index.raw); } // ------------------------------ Gather template HWY_API Vec1 GatherOffset(Sisd d, const T* base, const Vec1 offset) { static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); const uintptr_t addr = reinterpret_cast(base) + offset.raw; return Load(d, reinterpret_cast(addr)); } template HWY_API Vec1 GatherIndex(Sisd d, const T* HWY_RESTRICT base, const Vec1 index) { static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); return Load(d, base + index.raw); } // ================================================== CONVERT // ConvertTo and DemoteTo with floating-point input and integer output truncate // (rounding toward zero). template HWY_API Vec1 PromoteTo(Sisd /* tag */, Vec1 from) { static_assert(sizeof(ToT) > sizeof(FromT), "Not promoting"); // For bits Y > X, floatX->floatY and intX->intY are always representable. return Vec1(static_cast(from.raw)); } // MSVC 19.10 cannot deduce the argument type if HWY_IF_FLOAT(FromT) is here, // so we overload for FromT=double and ToT={float,int32_t}. HWY_API Vec1 DemoteTo(Sisd /* tag */, Vec1 from) { // Prevent ubsan errors when converting float to narrower integer/float if (std::isinf(from.raw) || std::fabs(from.raw) > static_cast(HighestValue())) { return Vec1(std::signbit(from.raw) ? LowestValue() : HighestValue()); } return Vec1(static_cast(from.raw)); } HWY_API Vec1 DemoteTo(Sisd /* tag */, Vec1 from) { // Prevent ubsan errors when converting int32_t to narrower integer/int32_t if (std::isinf(from.raw) || std::fabs(from.raw) > static_cast(HighestValue())) { return Vec1(std::signbit(from.raw) ? LowestValue() : HighestValue()); } return Vec1(static_cast(from.raw)); } template HWY_API Vec1 DemoteTo(Sisd /* tag */, Vec1 from) { static_assert(!IsFloat(), "FromT=double are handled above"); static_assert(sizeof(ToT) < sizeof(FromT), "Not demoting"); // Int to int: choose closest value in ToT to `from` (avoids UB) from.raw = HWY_MIN(HWY_MAX(LimitsMin(), from.raw), LimitsMax()); return Vec1(static_cast(from.raw)); } HWY_API Vec1 PromoteTo(Sisd /* tag */, const Vec1 v) { #if HWY_NATIVE_FLOAT16 uint16_t bits16; CopyBytes<2>(&v.raw, &bits16); #else const uint16_t bits16 = v.raw.bits; #endif const uint32_t sign = static_cast(bits16 >> 15); const uint32_t biased_exp = (bits16 >> 10) & 0x1F; const uint32_t mantissa = bits16 & 0x3FF; // Subnormal or zero if (biased_exp == 0) { const float subnormal = (1.0f / 16384) * (static_cast(mantissa) * (1.0f / 1024)); return Vec1(sign ? -subnormal : subnormal); } // Normalized: convert the representation directly (faster than ldexp/tables). const uint32_t biased_exp32 = biased_exp + (127 - 15); const uint32_t mantissa32 = mantissa << (23 - 10); const uint32_t bits32 = (sign << 31) | (biased_exp32 << 23) | mantissa32; float out; CopyBytes<4>(&bits32, &out); return Vec1(out); } HWY_API Vec1 PromoteTo(Sisd d, const Vec1 v) { return Set(d, F32FromBF16(v.raw)); } HWY_API Vec1 DemoteTo(Sisd /* tag */, const Vec1 v) { uint32_t bits32; CopyBytes<4>(&v.raw, &bits32); const uint32_t sign = bits32 >> 31; const uint32_t biased_exp32 = (bits32 >> 23) & 0xFF; const uint32_t mantissa32 = bits32 & 0x7FFFFF; const int32_t exp = HWY_MIN(static_cast(biased_exp32) - 127, 15); // Tiny or zero => zero. Vec1 out; if (exp < -24) { #if HWY_NATIVE_FLOAT16 const uint16_t zero = 0; CopyBytes<2>(&zero, &out.raw); #else out.raw.bits = 0; #endif return out; } uint32_t biased_exp16, mantissa16; // exp = [-24, -15] => subnormal if (exp < -14) { biased_exp16 = 0; const uint32_t sub_exp = static_cast(-14 - exp); HWY_DASSERT(1 <= sub_exp && sub_exp < 11); mantissa16 = static_cast((1u << (10 - sub_exp)) + (mantissa32 >> (13 + sub_exp))); } else { // exp = [-14, 15] biased_exp16 = static_cast(exp + 15); HWY_DASSERT(1 <= biased_exp16 && biased_exp16 < 31); mantissa16 = mantissa32 >> 13; } HWY_DASSERT(mantissa16 < 1024); const uint32_t bits16 = (sign << 15) | (biased_exp16 << 10) | mantissa16; HWY_DASSERT(bits16 < 0x10000); #if HWY_NATIVE_FLOAT16 const uint16_t narrowed = static_cast(bits16); // big-endian safe CopyBytes<2>(&narrowed, &out.raw); #else out.raw.bits = static_cast(bits16); #endif return out; } HWY_API Vec1 DemoteTo(Sisd d, const Vec1 v) { return Set(d, BF16FromF32(v.raw)); } template HWY_API Vec1 ConvertTo(Sisd /* tag */, Vec1 from) { static_assert(sizeof(ToT) == sizeof(FromT), "Should have same size"); // float## -> int##: return closest representable value. We cannot exactly // represent LimitsMax in FromT, so use double. const double f = static_cast(from.raw); if (std::isinf(from.raw) || std::fabs(f) > static_cast(LimitsMax())) { return Vec1(std::signbit(from.raw) ? LimitsMin() : LimitsMax()); } return Vec1(static_cast(from.raw)); } template HWY_API Vec1 ConvertTo(Sisd /* tag */, Vec1 from) { static_assert(sizeof(ToT) == sizeof(FromT), "Should have same size"); // int## -> float##: no check needed return Vec1(static_cast(from.raw)); } HWY_API Vec1 U8FromU32(const Vec1 v) { return DemoteTo(Sisd(), v); } // ================================================== COMBINE // UpperHalf, ZeroExtendVector, Combine, Concat* are unsupported. template HWY_API Vec1 LowerHalf(Vec1 v) { return v; } template HWY_API Vec1 LowerHalf(Sisd /* tag */, Vec1 v) { return v; } // ================================================== SWIZZLE template HWY_API T GetLane(const Vec1 v) { return v.raw; } template HWY_API Vec1 DupEven(Vec1 v) { return v; } // DupOdd is unsupported. template HWY_API Vec1 OddEven(Vec1 /* odd */, Vec1 even) { return even; } template HWY_API Vec1 OddEvenBlocks(Vec1 /* odd */, Vec1 even) { return even; } // ------------------------------ SwapAdjacentBlocks template HWY_API Vec1 SwapAdjacentBlocks(Vec1 v) { return v; } // ------------------------------ TableLookupLanes // Returned by SetTableIndices for use by TableLookupLanes. template struct Indices1 { MakeSigned raw; }; template HWY_API Indices1 IndicesFromVec(Sisd, Vec1 vec) { static_assert(sizeof(T) == sizeof(TI), "Index size must match lane size"); HWY_DASSERT(vec.raw == 0); return Indices1{vec.raw}; } template HWY_API Indices1 SetTableIndices(Sisd d, const TI* idx) { return IndicesFromVec(d, LoadU(idx)); } template HWY_API Vec1 TableLookupLanes(const Vec1 v, const Indices1 /* idx */) { return v; } // ------------------------------ ReverseBlocks // Single block: no change template HWY_API Vec1 ReverseBlocks(Sisd /* tag */, const Vec1 v) { return v; } // ------------------------------ Reverse template HWY_API Vec1 Reverse(Sisd /* tag */, const Vec1 v) { return v; } template HWY_API Vec1 Reverse2(Sisd /* tag */, const Vec1 v) { return v; } template HWY_API Vec1 Reverse4(Sisd /* tag */, const Vec1 v) { return v; } template HWY_API Vec1 Reverse8(Sisd /* tag */, const Vec1 v) { return v; } // ================================================== BLOCKWISE // Shift*Bytes, CombineShiftRightBytes, Interleave*, Shuffle* are unsupported. // ------------------------------ Broadcast/splat any lane template HWY_API Vec1 Broadcast(const Vec1 v) { static_assert(kLane == 0, "Scalar only has one lane"); return v; } // ------------------------------ TableLookupBytes, TableLookupBytesOr0 template HWY_API Vec1 TableLookupBytes(const Vec1 in, const Vec1 indices) { uint8_t in_bytes[sizeof(T)]; uint8_t idx_bytes[sizeof(T)]; uint8_t out_bytes[sizeof(T)]; CopyBytes(&in, &in_bytes); CopyBytes(&indices, &idx_bytes); for (size_t i = 0; i < sizeof(T); ++i) { out_bytes[i] = in_bytes[idx_bytes[i]]; } TI out; CopyBytes(&out_bytes, &out); return Vec1{out}; } template HWY_API Vec1 TableLookupBytesOr0(const Vec1 in, const Vec1 indices) { uint8_t in_bytes[sizeof(T)]; uint8_t idx_bytes[sizeof(T)]; uint8_t out_bytes[sizeof(T)]; CopyBytes(&in, &in_bytes); CopyBytes(&indices, &idx_bytes); for (size_t i = 0; i < sizeof(T); ++i) { out_bytes[i] = idx_bytes[i] & 0x80 ? 0 : in_bytes[idx_bytes[i]]; } TI out; CopyBytes(&out_bytes, &out); return Vec1{out}; } // ------------------------------ ZipLower HWY_API Vec1 ZipLower(const Vec1 a, const Vec1 b) { return Vec1(static_cast((uint32_t(b.raw) << 8) + a.raw)); } HWY_API Vec1 ZipLower(const Vec1 a, const Vec1 b) { return Vec1((uint32_t(b.raw) << 16) + a.raw); } HWY_API Vec1 ZipLower(const Vec1 a, const Vec1 b) { return Vec1((uint64_t(b.raw) << 32) + a.raw); } HWY_API Vec1 ZipLower(const Vec1 a, const Vec1 b) { return Vec1(static_cast((int32_t(b.raw) << 8) + a.raw)); } HWY_API Vec1 ZipLower(const Vec1 a, const Vec1 b) { return Vec1((int32_t(b.raw) << 16) + a.raw); } HWY_API Vec1 ZipLower(const Vec1 a, const Vec1 b) { return Vec1((int64_t(b.raw) << 32) + a.raw); } template , class VW = Vec1> HWY_API VW ZipLower(Sisd /* tag */, Vec1 a, Vec1 b) { return VW(static_cast((TW{b.raw} << (sizeof(T) * 8)) + a.raw)); } // ================================================== MASK template HWY_API bool AllFalse(Sisd /* tag */, const Mask1 mask) { return mask.bits == 0; } template HWY_API bool AllTrue(Sisd /* tag */, const Mask1 mask) { return mask.bits != 0; } // `p` points to at least 8 readable bytes, not all of which need be valid. template HWY_API Mask1 LoadMaskBits(Sisd /* tag */, const uint8_t* HWY_RESTRICT bits) { return Mask1::FromBool((bits[0] & 1) != 0); } // `p` points to at least 8 writable bytes. template HWY_API size_t StoreMaskBits(Sisd d, const Mask1 mask, uint8_t* bits) { *bits = AllTrue(d, mask); return 1; } template HWY_API size_t CountTrue(Sisd /* tag */, const Mask1 mask) { return mask.bits == 0 ? 0 : 1; } template HWY_API intptr_t FindFirstTrue(Sisd /* tag */, const Mask1 mask) { return mask.bits == 0 ? -1 : 0; } // ------------------------------ Compress, CompressBits template HWY_API Vec1 Compress(Vec1 v, const Mask1 /* mask */) { // Upper lanes are undefined, so result is the same independent of mask. return v; } template HWY_API Vec1 Compress(Vec1 v, const uint8_t* HWY_RESTRICT /* bits */) { return v; } // ------------------------------ CompressStore template HWY_API size_t CompressStore(Vec1 v, const Mask1 mask, Sisd d, T* HWY_RESTRICT unaligned) { StoreU(Compress(v, mask), d, unaligned); return CountTrue(d, mask); } // ------------------------------ CompressBlendedStore template HWY_API size_t CompressBlendedStore(Vec1 v, const Mask1 mask, Sisd d, T* HWY_RESTRICT unaligned) { if (!mask.bits) return 0; StoreU(v, d, unaligned); return 1; } // ------------------------------ CompressBitsStore template HWY_API size_t CompressBitsStore(Vec1 v, const uint8_t* HWY_RESTRICT bits, Sisd d, T* HWY_RESTRICT unaligned) { const Mask1 mask = LoadMaskBits(d, bits); StoreU(Compress(v, mask), d, unaligned); return CountTrue(d, mask); } // ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) HWY_API Vec1 ReorderWidenMulAccumulate(Sisd /* tag */, Vec1 a, Vec1 b, const Vec1 sum0, Vec1& /* sum1 */) { return MulAdd(Vec1(F32FromBF16(a.raw)), Vec1(F32FromBF16(b.raw)), sum0); } // ================================================== REDUCTIONS // Sum of all lanes, i.e. the only one. template HWY_API Vec1 SumOfLanes(Sisd /* tag */, const Vec1 v) { return v; } template HWY_API Vec1 MinOfLanes(Sisd /* tag */, const Vec1 v) { return v; } template HWY_API Vec1 MaxOfLanes(Sisd /* tag */, const Vec1 v) { return v; } // ================================================== Operator wrapper template HWY_API V Add(V a, V b) { return a + b; } template HWY_API V Sub(V a, V b) { return a - b; } template HWY_API V Mul(V a, V b) { return a * b; } template HWY_API V Div(V a, V b) { return a / b; } template V Shl(V a, V b) { return a << b; } template V Shr(V a, V b) { return a >> b; } template HWY_API auto Eq(V a, V b) -> decltype(a == b) { return a == b; } template HWY_API auto Ne(V a, V b) -> decltype(a == b) { return a != b; } template HWY_API auto Lt(V a, V b) -> decltype(a == b) { return a < b; } template HWY_API auto Gt(V a, V b) -> decltype(a == b) { return a > b; } template HWY_API auto Ge(V a, V b) -> decltype(a == b) { return a >= b; } template HWY_API auto Le(V a, V b) -> decltype(a == b) { return a <= b; } // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace hwy HWY_AFTER_NAMESPACE();