// Copyright 2021 Google LLC // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // 256-bit WASM vectors and operations. Experimental. // External include guard in highway.h - see comment there. // For half-width vectors. Already includes base.h and shared-inl.h. #include "hwy/ops/wasm_128-inl.h" HWY_BEFORE_NAMESPACE(); namespace hwy { namespace HWY_NAMESPACE { template class Vec256 { public: using PrivateT = T; // only for DFromV static constexpr size_t kPrivateN = 32 / sizeof(T); // only for DFromV // Compound assignment. Only usable if there is a corresponding non-member // binary operator overload. For example, only f32 and f64 support division. HWY_INLINE Vec256& operator*=(const Vec256 other) { return *this = (*this * other); } HWY_INLINE Vec256& operator/=(const Vec256 other) { return *this = (*this / other); } HWY_INLINE Vec256& operator+=(const Vec256 other) { return *this = (*this + other); } HWY_INLINE Vec256& operator-=(const Vec256 other) { return *this = (*this - other); } HWY_INLINE Vec256& operator%=(const Vec256 other) { return *this = (*this % other); } HWY_INLINE Vec256& operator&=(const Vec256 other) { return *this = (*this & other); } HWY_INLINE Vec256& operator|=(const Vec256 other) { return *this = (*this | other); } HWY_INLINE Vec256& operator^=(const Vec256 other) { return *this = (*this ^ other); } Vec128 v0; Vec128 v1; }; template struct Mask256 { Mask128 m0; Mask128 m1; }; // ------------------------------ Zero // Avoid VFromD here because it is defined in terms of Zero. template HWY_API Vec256> Zero(D d) { const Half dh; Vec256> ret; ret.v0 = ret.v1 = Zero(dh); return ret; } // ------------------------------ BitCast template HWY_API VFromD BitCast(D d, Vec256 v) { const Half dh; VFromD ret; ret.v0 = BitCast(dh, v.v0); ret.v1 = BitCast(dh, v.v1); return ret; } // ------------------------------ ResizeBitCast // 32-byte vector to 32-byte vector: Same as BitCast template HWY_API VFromD ResizeBitCast(D d, FromV v) { return BitCast(d, v); } // <= 16-byte vector to 32-byte vector template HWY_API VFromD ResizeBitCast(D d, FromV v) { const Half dh; VFromD ret; ret.v0 = ResizeBitCast(dh, v); ret.v1 = Zero(dh); return ret; } // 32-byte vector to <= 16-byte vector template HWY_API VFromD ResizeBitCast(D d, FromV v) { return ResizeBitCast(d, v.v0); } // ------------------------------ Set template HWY_API VFromD Set(D d, const T2 t) { const Half dh; VFromD ret; ret.v0 = ret.v1 = Set(dh, static_cast>(t)); return ret; } // Undefined, Iota defined in wasm_128. // ------------------------------ Dup128VecFromValues template HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, TFromD t2, TFromD t3, TFromD t4, TFromD t5, TFromD t6, TFromD t7, TFromD t8, TFromD t9, TFromD t10, TFromD t11, TFromD t12, TFromD t13, TFromD t14, TFromD t15) { const Half dh; VFromD ret; ret.v0 = ret.v1 = Dup128VecFromValues(dh, t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13, t14, t15); return ret; } template HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, TFromD t2, TFromD t3, TFromD t4, TFromD t5, TFromD t6, TFromD t7) { const Half dh; VFromD ret; ret.v0 = ret.v1 = Dup128VecFromValues(dh, t0, t1, t2, t3, t4, t5, t6, t7); return ret; } template HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1, TFromD t2, TFromD t3) { const Half dh; VFromD ret; ret.v0 = ret.v1 = Dup128VecFromValues(dh, t0, t1, t2, t3); return ret; } template HWY_API VFromD Dup128VecFromValues(D d, TFromD t0, TFromD t1) { const Half dh; VFromD ret; ret.v0 = ret.v1 = Dup128VecFromValues(dh, t0, t1); return ret; } // ================================================== ARITHMETIC template HWY_API Vec256 operator+(Vec256 a, const Vec256 b) { a.v0 += b.v0; a.v1 += b.v1; return a; } template HWY_API Vec256 operator-(Vec256 a, const Vec256 b) { a.v0 -= b.v0; a.v1 -= b.v1; return a; } // ------------------------------ SumsOf8 HWY_API Vec256 SumsOf8(const Vec256 v) { Vec256 ret; ret.v0 = SumsOf8(v.v0); ret.v1 = SumsOf8(v.v1); return ret; } HWY_API Vec256 SumsOf8(const Vec256 v) { Vec256 ret; ret.v0 = SumsOf8(v.v0); ret.v1 = SumsOf8(v.v1); return ret; } template HWY_API Vec256 SaturatedAdd(Vec256 a, const Vec256 b) { a.v0 = SaturatedAdd(a.v0, b.v0); a.v1 = SaturatedAdd(a.v1, b.v1); return a; } template HWY_API Vec256 SaturatedSub(Vec256 a, const Vec256 b) { a.v0 = SaturatedSub(a.v0, b.v0); a.v1 = SaturatedSub(a.v1, b.v1); return a; } template HWY_API Vec256 AverageRound(Vec256 a, const Vec256 b) { a.v0 = AverageRound(a.v0, b.v0); a.v1 = AverageRound(a.v1, b.v1); return a; } template HWY_API Vec256 Abs(Vec256 v) { v.v0 = Abs(v.v0); v.v1 = Abs(v.v1); return v; } // ------------------------------ Shift lanes by constant #bits template HWY_API Vec256 ShiftLeft(Vec256 v) { v.v0 = ShiftLeft(v.v0); v.v1 = ShiftLeft(v.v1); return v; } template HWY_API Vec256 ShiftRight(Vec256 v) { v.v0 = ShiftRight(v.v0); v.v1 = ShiftRight(v.v1); return v; } // ------------------------------ RotateRight (ShiftRight, Or) template HWY_API Vec256 RotateRight(const Vec256 v) { constexpr size_t kSizeInBits = sizeof(T) * 8; static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count"); if (kBits == 0) return v; return Or(ShiftRight(v), ShiftLeft(v)); } // ------------------------------ Shift lanes by same variable #bits template HWY_API Vec256 ShiftLeftSame(Vec256 v, const int bits) { v.v0 = ShiftLeftSame(v.v0, bits); v.v1 = ShiftLeftSame(v.v1, bits); return v; } template HWY_API Vec256 ShiftRightSame(Vec256 v, const int bits) { v.v0 = ShiftRightSame(v.v0, bits); v.v1 = ShiftRightSame(v.v1, bits); return v; } // ------------------------------ Min, Max template HWY_API Vec256 Min(Vec256 a, const Vec256 b) { a.v0 = Min(a.v0, b.v0); a.v1 = Min(a.v1, b.v1); return a; } template HWY_API Vec256 Max(Vec256 a, const Vec256 b) { a.v0 = Max(a.v0, b.v0); a.v1 = Max(a.v1, b.v1); return a; } // ------------------------------ Integer multiplication template HWY_API Vec256 operator*(Vec256 a, const Vec256 b) { a.v0 *= b.v0; a.v1 *= b.v1; return a; } template HWY_API Vec256 MulHigh(Vec256 a, const Vec256 b) { a.v0 = MulHigh(a.v0, b.v0); a.v1 = MulHigh(a.v1, b.v1); return a; } template HWY_API Vec256 MulFixedPoint15(Vec256 a, const Vec256 b) { a.v0 = MulFixedPoint15(a.v0, b.v0); a.v1 = MulFixedPoint15(a.v1, b.v1); return a; } // Cannot use MakeWide because that returns uint128_t for uint64_t, but we want // uint64_t. template HWY_API Vec256> MulEven(Vec256 a, const Vec256 b) { Vec256> ret; ret.v0 = MulEven(a.v0, b.v0); ret.v1 = MulEven(a.v1, b.v1); return ret; } HWY_API Vec256 MulEven(Vec256 a, const Vec256 b) { Vec256 ret; ret.v0 = MulEven(a.v0, b.v0); ret.v1 = MulEven(a.v1, b.v1); return ret; } template HWY_API Vec256> MulOdd(Vec256 a, const Vec256 b) { Vec256> ret; ret.v0 = MulOdd(a.v0, b.v0); ret.v1 = MulOdd(a.v1, b.v1); return ret; } HWY_API Vec256 MulOdd(Vec256 a, const Vec256 b) { Vec256 ret; ret.v0 = MulOdd(a.v0, b.v0); ret.v1 = MulOdd(a.v1, b.v1); return ret; } // ------------------------------ Negate template HWY_API Vec256 Neg(Vec256 v) { v.v0 = Neg(v.v0); v.v1 = Neg(v.v1); return v; } // ------------------------------ AbsDiff // generic_ops takes care of integer T. template HWY_API Vec256 AbsDiff(const Vec256 a, const Vec256 b) { return Abs(a - b); } // ------------------------------ Floating-point division // generic_ops takes care of integer T. template HWY_API Vec256 operator/(Vec256 a, const Vec256 b) { a.v0 /= b.v0; a.v1 /= b.v1; return a; } // Approximate reciprocal HWY_API Vec256 ApproximateReciprocal(const Vec256 v) { const Vec256 one = Set(Full256(), 1.0f); return one / v; } // ------------------------------ Floating-point multiply-add variants HWY_API Vec256 MulAdd(Vec256 mul, Vec256 x, Vec256 add) { mul.v0 = MulAdd(mul.v0, x.v0, add.v0); mul.v1 = MulAdd(mul.v1, x.v1, add.v1); return mul; } HWY_API Vec256 NegMulAdd(Vec256 mul, Vec256 x, Vec256 add) { mul.v0 = NegMulAdd(mul.v0, x.v0, add.v0); mul.v1 = NegMulAdd(mul.v1, x.v1, add.v1); return mul; } HWY_API Vec256 MulSub(Vec256 mul, Vec256 x, Vec256 sub) { mul.v0 = MulSub(mul.v0, x.v0, sub.v0); mul.v1 = MulSub(mul.v1, x.v1, sub.v1); return mul; } HWY_API Vec256 NegMulSub(Vec256 mul, Vec256 x, Vec256 sub) { mul.v0 = NegMulSub(mul.v0, x.v0, sub.v0); mul.v1 = NegMulSub(mul.v1, x.v1, sub.v1); return mul; } // ------------------------------ Floating-point square root template HWY_API Vec256 Sqrt(Vec256 v) { v.v0 = Sqrt(v.v0); v.v1 = Sqrt(v.v1); return v; } // Approximate reciprocal square root HWY_API Vec256 ApproximateReciprocalSqrt(const Vec256 v) { // TODO(eustas): find cheaper a way to calculate this. const Vec256 one = Set(Full256(), 1.0f); return one / Sqrt(v); } // ------------------------------ Floating-point rounding // Toward nearest integer, ties to even HWY_API Vec256 Round(Vec256 v) { v.v0 = Round(v.v0); v.v1 = Round(v.v1); return v; } // Toward zero, aka truncate HWY_API Vec256 Trunc(Vec256 v) { v.v0 = Trunc(v.v0); v.v1 = Trunc(v.v1); return v; } // Toward +infinity, aka ceiling HWY_API Vec256 Ceil(Vec256 v) { v.v0 = Ceil(v.v0); v.v1 = Ceil(v.v1); return v; } // Toward -infinity, aka floor HWY_API Vec256 Floor(Vec256 v) { v.v0 = Floor(v.v0); v.v1 = Floor(v.v1); return v; } // ------------------------------ Floating-point classification template HWY_API Mask256 IsNaN(const Vec256 v) { return v != v; } template HWY_API Mask256 IsInf(const Vec256 v) { const DFromV d; const RebindToUnsigned du; const VFromD vu = BitCast(du, v); // 'Shift left' to clear the sign bit, check for exponent=max and mantissa=0. return RebindMask(d, Eq(Add(vu, vu), Set(du, hwy::MaxExponentTimes2()))); } // Returns whether normal/subnormal/zero. template HWY_API Mask256 IsFinite(const Vec256 v) { const DFromV d; const RebindToUnsigned du; const RebindToSigned di; // cheaper than unsigned comparison const VFromD vu = BitCast(du, v); // 'Shift left' to clear the sign bit, then right so we can compare with the // max exponent (cannot compare with MaxExponentTimes2 directly because it is // negative and non-negative floats would be greater). const VFromD exp = BitCast(di, ShiftRight() + 1>(Add(vu, vu))); return RebindMask(d, Lt(exp, Set(di, hwy::MaxExponentField()))); } // ================================================== COMPARE // Comparisons fill a lane with 1-bits if the condition is true, else 0. template > HWY_API MFromD RebindMask(DTo /*tag*/, Mask256 m) { static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); return MFromD{Mask128{m.m0.raw}, Mask128{m.m1.raw}}; } template HWY_API Mask256 TestBit(Vec256 v, Vec256 bit) { static_assert(!hwy::IsFloat(), "Only integer vectors supported"); return (v & bit) == bit; } template HWY_API Mask256 operator==(Vec256 a, const Vec256 b) { Mask256 m; m.m0 = operator==(a.v0, b.v0); m.m1 = operator==(a.v1, b.v1); return m; } template HWY_API Mask256 operator!=(Vec256 a, const Vec256 b) { Mask256 m; m.m0 = operator!=(a.v0, b.v0); m.m1 = operator!=(a.v1, b.v1); return m; } template HWY_API Mask256 operator<(Vec256 a, const Vec256 b) { Mask256 m; m.m0 = operator<(a.v0, b.v0); m.m1 = operator<(a.v1, b.v1); return m; } template HWY_API Mask256 operator>(Vec256 a, const Vec256 b) { Mask256 m; m.m0 = operator>(a.v0, b.v0); m.m1 = operator>(a.v1, b.v1); return m; } template HWY_API Mask256 operator<=(Vec256 a, const Vec256 b) { Mask256 m; m.m0 = operator<=(a.v0, b.v0); m.m1 = operator<=(a.v1, b.v1); return m; } template HWY_API Mask256 operator>=(Vec256 a, const Vec256 b) { Mask256 m; m.m0 = operator>=(a.v0, b.v0); m.m1 = operator>=(a.v1, b.v1); return m; } // ------------------------------ FirstN (Iota, Lt) template HWY_API MFromD FirstN(const D d, size_t num) { const RebindToSigned di; // Signed comparisons may be cheaper. using TI = TFromD; return RebindMask(d, Iota(di, 0) < Set(di, static_cast(num))); } // ================================================== LOGICAL template HWY_API Vec256 Not(Vec256 v) { v.v0 = Not(v.v0); v.v1 = Not(v.v1); return v; } template HWY_API Vec256 And(Vec256 a, Vec256 b) { a.v0 = And(a.v0, b.v0); a.v1 = And(a.v1, b.v1); return a; } template HWY_API Vec256 AndNot(Vec256 not_mask, Vec256 mask) { not_mask.v0 = AndNot(not_mask.v0, mask.v0); not_mask.v1 = AndNot(not_mask.v1, mask.v1); return not_mask; } template HWY_API Vec256 Or(Vec256 a, Vec256 b) { a.v0 = Or(a.v0, b.v0); a.v1 = Or(a.v1, b.v1); return a; } template HWY_API Vec256 Xor(Vec256 a, Vec256 b) { a.v0 = Xor(a.v0, b.v0); a.v1 = Xor(a.v1, b.v1); return a; } template HWY_API Vec256 Xor3(Vec256 x1, Vec256 x2, Vec256 x3) { return Xor(x1, Xor(x2, x3)); } template HWY_API Vec256 Or3(Vec256 o1, Vec256 o2, Vec256 o3) { return Or(o1, Or(o2, o3)); } template HWY_API Vec256 OrAnd(Vec256 o, Vec256 a1, Vec256 a2) { return Or(o, And(a1, a2)); } template HWY_API Vec256 IfVecThenElse(Vec256 mask, Vec256 yes, Vec256 no) { return IfThenElse(MaskFromVec(mask), yes, no); } // ------------------------------ Operator overloads (internal-only if float) template HWY_API Vec256 operator&(const Vec256 a, const Vec256 b) { return And(a, b); } template HWY_API Vec256 operator|(const Vec256 a, const Vec256 b) { return Or(a, b); } template HWY_API Vec256 operator^(const Vec256 a, const Vec256 b) { return Xor(a, b); } // ------------------------------ CopySign template HWY_API Vec256 CopySign(const Vec256 magn, const Vec256 sign) { static_assert(IsFloat(), "Only makes sense for floating-point"); const DFromV d; return BitwiseIfThenElse(SignBit(d), sign, magn); } // ------------------------------ CopySignToAbs template HWY_API Vec256 CopySignToAbs(const Vec256 abs, const Vec256 sign) { static_assert(IsFloat(), "Only makes sense for floating-point"); const DFromV d; return OrAnd(abs, SignBit(d), sign); } // ------------------------------ Mask // Mask and Vec are the same (true = FF..FF). template HWY_API Mask256 MaskFromVec(const Vec256 v) { Mask256 m; m.m0 = MaskFromVec(v.v0); m.m1 = MaskFromVec(v.v1); return m; } template > HWY_API Vec256 VecFromMask(D d, Mask256 m) { const Half dh; Vec256 v; v.v0 = VecFromMask(dh, m.m0); v.v1 = VecFromMask(dh, m.m1); return v; } // mask ? yes : no template HWY_API Vec256 IfThenElse(Mask256 mask, Vec256 yes, Vec256 no) { yes.v0 = IfThenElse(mask.m0, yes.v0, no.v0); yes.v1 = IfThenElse(mask.m1, yes.v1, no.v1); return yes; } // mask ? yes : 0 template HWY_API Vec256 IfThenElseZero(Mask256 mask, Vec256 yes) { return yes & VecFromMask(DFromV(), mask); } // mask ? 0 : no template HWY_API Vec256 IfThenZeroElse(Mask256 mask, Vec256 no) { return AndNot(VecFromMask(DFromV(), mask), no); } template HWY_API Vec256 IfNegativeThenElse(Vec256 v, Vec256 yes, Vec256 no) { v.v0 = IfNegativeThenElse(v.v0, yes.v0, no.v0); v.v1 = IfNegativeThenElse(v.v1, yes.v1, no.v1); return v; } template HWY_API Vec256 ZeroIfNegative(Vec256 v) { return IfThenZeroElse(v < Zero(DFromV()), v); } // ------------------------------ Mask logical template HWY_API Mask256 Not(const Mask256 m) { return MaskFromVec(Not(VecFromMask(Full256(), m))); } template HWY_API Mask256 And(const Mask256 a, Mask256 b) { const Full256 d; return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); } template HWY_API Mask256 AndNot(const Mask256 a, Mask256 b) { const Full256 d; return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); } template HWY_API Mask256 Or(const Mask256 a, Mask256 b) { const Full256 d; return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); } template HWY_API Mask256 Xor(const Mask256 a, Mask256 b) { const Full256 d; return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); } template HWY_API Mask256 ExclusiveNeither(const Mask256 a, Mask256 b) { const Full256 d; return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b)))); } // ------------------------------ Shl (BroadcastSignBit, IfThenElse) template HWY_API Vec256 operator<<(Vec256 v, const Vec256 bits) { v.v0 = operator<<(v.v0, bits.v0); v.v1 = operator<<(v.v1, bits.v1); return v; } // ------------------------------ Shr (BroadcastSignBit, IfThenElse) template HWY_API Vec256 operator>>(Vec256 v, const Vec256 bits) { v.v0 = operator>>(v.v0, bits.v0); v.v1 = operator>>(v.v1, bits.v1); return v; } // ------------------------------ BroadcastSignBit (compare, VecFromMask) template HWY_API Vec256 BroadcastSignBit(const Vec256 v) { return ShiftRight(v); } HWY_API Vec256 BroadcastSignBit(const Vec256 v) { const DFromV d; return VecFromMask(d, v < Zero(d)); } // ================================================== MEMORY // ------------------------------ Load template HWY_API VFromD Load(D d, const TFromD* HWY_RESTRICT aligned) { const Half dh; VFromD ret; ret.v0 = Load(dh, aligned); ret.v1 = Load(dh, aligned + Lanes(dh)); return ret; } template > HWY_API Vec256 MaskedLoad(Mask256 m, D d, const T* HWY_RESTRICT aligned) { return IfThenElseZero(m, Load(d, aligned)); } template > HWY_API Vec256 MaskedLoadOr(Vec256 v, Mask256 m, D d, const T* HWY_RESTRICT aligned) { return IfThenElse(m, Load(d, aligned), v); } // LoadU == Load. template HWY_API VFromD LoadU(D d, const TFromD* HWY_RESTRICT p) { return Load(d, p); } template HWY_API VFromD LoadDup128(D d, const TFromD* HWY_RESTRICT p) { const Half dh; VFromD ret; ret.v0 = ret.v1 = Load(dh, p); return ret; } // ------------------------------ Store template > HWY_API void Store(Vec256 v, D d, T* HWY_RESTRICT aligned) { const Half dh; Store(v.v0, dh, aligned); Store(v.v1, dh, aligned + Lanes(dh)); } // StoreU == Store. template > HWY_API void StoreU(Vec256 v, D d, T* HWY_RESTRICT p) { Store(v, d, p); } template > HWY_API void BlendedStore(Vec256 v, Mask256 m, D d, T* HWY_RESTRICT p) { StoreU(IfThenElse(m, v, LoadU(d, p)), d, p); } // ------------------------------ Stream template > HWY_API void Stream(Vec256 v, D d, T* HWY_RESTRICT aligned) { // Same as aligned stores. Store(v, d, aligned); } // ------------------------------ Scatter, Gather defined in wasm_128 // ================================================== SWIZZLE // ------------------------------ ExtractLane template HWY_API T ExtractLane(const Vec256 v, size_t i) { alignas(32) T lanes[32 / sizeof(T)]; Store(v, DFromV(), lanes); return lanes[i]; } // ------------------------------ InsertLane template HWY_API Vec256 InsertLane(const Vec256 v, size_t i, T t) { DFromV d; alignas(32) T lanes[32 / sizeof(T)]; Store(v, d, lanes); lanes[i] = t; return Load(d, lanes); } // ------------------------------ ExtractBlock template HWY_API Vec128 ExtractBlock(Vec256 v) { static_assert(kBlockIdx == 0 || kBlockIdx == 1, "Invalid block index"); return (kBlockIdx == 0) ? v.v0 : v.v1; } // ------------------------------ InsertBlock template HWY_API Vec256 InsertBlock(Vec256 v, Vec128 blk_to_insert) { static_assert(kBlockIdx == 0 || kBlockIdx == 1, "Invalid block index"); Vec256 result; if (kBlockIdx == 0) { result.v0 = blk_to_insert; result.v1 = v.v1; } else { result.v0 = v.v0; result.v1 = blk_to_insert; } return result; } // ------------------------------ BroadcastBlock template HWY_API Vec256 BroadcastBlock(Vec256 v) { static_assert(kBlockIdx == 0 || kBlockIdx == 1, "Invalid block index"); Vec256 result; result.v0 = result.v1 = (kBlockIdx == 0 ? v.v0 : v.v1); return result; } // ------------------------------ LowerHalf template > HWY_API Vec128 LowerHalf(D /* tag */, Vec256 v) { return v.v0; } template HWY_API Vec128 LowerHalf(Vec256 v) { return v.v0; } // ------------------------------ GetLane (LowerHalf) template HWY_API T GetLane(const Vec256 v) { return GetLane(LowerHalf(v)); } // ------------------------------ ShiftLeftBytes template > HWY_API Vec256 ShiftLeftBytes(D d, Vec256 v) { const Half dh; v.v0 = ShiftLeftBytes(dh, v.v0); v.v1 = ShiftLeftBytes(dh, v.v1); return v; } template HWY_API Vec256 ShiftLeftBytes(Vec256 v) { return ShiftLeftBytes(DFromV(), v); } // ------------------------------ ShiftLeftLanes template > HWY_API Vec256 ShiftLeftLanes(D d, const Vec256 v) { const Repartition d8; return BitCast(d, ShiftLeftBytes(BitCast(d8, v))); } template HWY_API Vec256 ShiftLeftLanes(const Vec256 v) { return ShiftLeftLanes(DFromV(), v); } // ------------------------------ ShiftRightBytes template > HWY_API Vec256 ShiftRightBytes(D d, Vec256 v) { const Half dh; v.v0 = ShiftRightBytes(dh, v.v0); v.v1 = ShiftRightBytes(dh, v.v1); return v; } // ------------------------------ ShiftRightLanes template > HWY_API Vec256 ShiftRightLanes(D d, const Vec256 v) { const Repartition d8; return BitCast(d, ShiftRightBytes(d8, BitCast(d8, v))); } // ------------------------------ UpperHalf (ShiftRightBytes) template > HWY_API Vec128 UpperHalf(D /* tag */, const Vec256 v) { return v.v1; } // ------------------------------ CombineShiftRightBytes template > HWY_API Vec256 CombineShiftRightBytes(D d, Vec256 hi, Vec256 lo) { const Half dh; hi.v0 = CombineShiftRightBytes(dh, hi.v0, lo.v0); hi.v1 = CombineShiftRightBytes(dh, hi.v1, lo.v1); return hi; } // ------------------------------ Broadcast/splat any lane template HWY_API Vec256 Broadcast(const Vec256 v) { Vec256 ret; ret.v0 = Broadcast(v.v0); ret.v1 = Broadcast(v.v1); return ret; } template HWY_API Vec256 BroadcastLane(const Vec256 v) { constexpr int kLanesPerBlock = static_cast(16 / sizeof(T)); static_assert(0 <= kLane && kLane < kLanesPerBlock * 2, "Invalid lane"); constexpr int kLaneInBlkIdx = kLane & (kLanesPerBlock - 1); Vec256 ret; ret.v0 = ret.v1 = Broadcast(kLane >= kLanesPerBlock ? v.v1 : v.v0); return ret; } // ------------------------------ TableLookupBytes // Both full template HWY_API Vec256 TableLookupBytes(const Vec256 bytes, Vec256 from) { from.v0 = TableLookupBytes(bytes.v0, from.v0); from.v1 = TableLookupBytes(bytes.v1, from.v1); return from; } // Partial index vector template HWY_API Vec128 TableLookupBytes(Vec256 bytes, const Vec128 from) { // First expand to full 128, then 256. const auto from_256 = ZeroExtendVector(Full256(), Vec128{from.raw}); const auto tbl_full = TableLookupBytes(bytes, from_256); // Shrink to 128, then partial. return Vec128{LowerHalf(Full128(), tbl_full).raw}; } // Partial table vector template HWY_API Vec256 TableLookupBytes(Vec128 bytes, const Vec256 from) { // First expand to full 128, then 256. const auto bytes_256 = ZeroExtendVector(Full256(), Vec128{bytes.raw}); return TableLookupBytes(bytes_256, from); } // Partial both are handled by wasm_128. template HWY_API VI TableLookupBytesOr0(V bytes, VI from) { // wasm out-of-bounds policy already zeros, so TableLookupBytes is fine. return TableLookupBytes(bytes, from); } // ------------------------------ Hard-coded shuffles template HWY_API Vec256 Shuffle01(Vec256 v) { v.v0 = Shuffle01(v.v0); v.v1 = Shuffle01(v.v1); return v; } template HWY_API Vec256 Shuffle2301(Vec256 v) { v.v0 = Shuffle2301(v.v0); v.v1 = Shuffle2301(v.v1); return v; } template HWY_API Vec256 Shuffle1032(Vec256 v) { v.v0 = Shuffle1032(v.v0); v.v1 = Shuffle1032(v.v1); return v; } template HWY_API Vec256 Shuffle0321(Vec256 v) { v.v0 = Shuffle0321(v.v0); v.v1 = Shuffle0321(v.v1); return v; } template HWY_API Vec256 Shuffle2103(Vec256 v) { v.v0 = Shuffle2103(v.v0); v.v1 = Shuffle2103(v.v1); return v; } template HWY_API Vec256 Shuffle0123(Vec256 v) { v.v0 = Shuffle0123(v.v0); v.v1 = Shuffle0123(v.v1); return v; } // Used by generic_ops-inl.h namespace detail { template HWY_API Vec256 ShuffleTwo2301(Vec256 a, const Vec256 b) { a.v0 = ShuffleTwo2301(a.v0, b.v0); a.v1 = ShuffleTwo2301(a.v1, b.v1); return a; } template HWY_API Vec256 ShuffleTwo1230(Vec256 a, const Vec256 b) { a.v0 = ShuffleTwo1230(a.v0, b.v0); a.v1 = ShuffleTwo1230(a.v1, b.v1); return a; } template HWY_API Vec256 ShuffleTwo3012(Vec256 a, const Vec256 b) { a.v0 = ShuffleTwo3012(a.v0, b.v0); a.v1 = ShuffleTwo3012(a.v1, b.v1); return a; } } // namespace detail // ------------------------------ TableLookupLanes // Returned by SetTableIndices for use by TableLookupLanes. template struct Indices256 { __v128_u i0; __v128_u i1; }; template , typename TI> HWY_API Indices256 IndicesFromVec(D /* tag */, Vec256 vec) { static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); Indices256 ret; ret.i0 = vec.v0.raw; ret.i1 = vec.v1.raw; return ret; } template HWY_API Indices256> SetTableIndices(D d, const TI* idx) { const Rebind di; return IndicesFromVec(d, LoadU(di, idx)); } template HWY_API Vec256 TableLookupLanes(const Vec256 v, Indices256 idx) { const DFromV d; const Half dh; const auto idx_i0 = IndicesFromVec(dh, Vec128{idx.i0}); const auto idx_i1 = IndicesFromVec(dh, Vec128{idx.i1}); Vec256 result; result.v0 = TwoTablesLookupLanes(v.v0, v.v1, idx_i0); result.v1 = TwoTablesLookupLanes(v.v0, v.v1, idx_i1); return result; } template HWY_API Vec256 TableLookupLanesOr0(Vec256 v, Indices256 idx) { // The out of bounds behavior will already zero lanes. return TableLookupLanesOr0(v, idx); } template HWY_API Vec256 TwoTablesLookupLanes(const Vec256 a, const Vec256 b, Indices256 idx) { const DFromV d; const Half dh; const RebindToUnsigned du; using TU = MakeUnsigned; constexpr size_t kLanesPerVect = 32 / sizeof(TU); Vec256 vi; vi.v0 = Vec128{idx.i0}; vi.v1 = Vec128{idx.i1}; const auto vmod = vi & Set(du, TU{kLanesPerVect - 1}); const auto is_lo = RebindMask(d, vi == vmod); const auto idx_i0 = IndicesFromVec(dh, vmod.v0); const auto idx_i1 = IndicesFromVec(dh, vmod.v1); Vec256 result_lo; Vec256 result_hi; result_lo.v0 = TwoTablesLookupLanes(a.v0, a.v1, idx_i0); result_lo.v1 = TwoTablesLookupLanes(a.v0, a.v1, idx_i1); result_hi.v0 = TwoTablesLookupLanes(b.v0, b.v1, idx_i0); result_hi.v1 = TwoTablesLookupLanes(b.v0, b.v1, idx_i1); return IfThenElse(is_lo, result_lo, result_hi); } // ------------------------------ Reverse template > HWY_API Vec256 Reverse(D d, const Vec256 v) { const Half dh; Vec256 ret; ret.v1 = Reverse(dh, v.v0); // note reversed v1 member order ret.v0 = Reverse(dh, v.v1); return ret; } // ------------------------------ Reverse2 template > HWY_API Vec256 Reverse2(D d, Vec256 v) { const Half dh; v.v0 = Reverse2(dh, v.v0); v.v1 = Reverse2(dh, v.v1); return v; } // ------------------------------ Reverse4 // Each block has only 2 lanes, so swap blocks and their lanes. template , HWY_IF_T_SIZE(T, 8)> HWY_API Vec256 Reverse4(D d, const Vec256 v) { const Half dh; Vec256 ret; ret.v0 = Reverse2(dh, v.v1); // swapped ret.v1 = Reverse2(dh, v.v0); return ret; } template , HWY_IF_NOT_T_SIZE(T, 8)> HWY_API Vec256 Reverse4(D d, Vec256 v) { const Half dh; v.v0 = Reverse4(dh, v.v0); v.v1 = Reverse4(dh, v.v1); return v; } // ------------------------------ Reverse8 template , HWY_IF_T_SIZE(T, 8)> HWY_API Vec256 Reverse8(D /* tag */, Vec256 /* v */) { HWY_ASSERT(0); // don't have 8 u64 lanes } // Each block has only 4 lanes, so swap blocks and their lanes. template , HWY_IF_T_SIZE(T, 4)> HWY_API Vec256 Reverse8(D d, const Vec256 v) { const Half dh; Vec256 ret; ret.v0 = Reverse4(dh, v.v1); // swapped ret.v1 = Reverse4(dh, v.v0); return ret; } template , HWY_IF_T_SIZE_ONE_OF(T, (1 << 1) | (1 << 2))> HWY_API Vec256 Reverse8(D d, Vec256 v) { const Half dh; v.v0 = Reverse8(dh, v.v0); v.v1 = Reverse8(dh, v.v1); return v; } // ------------------------------ InterleaveLower template HWY_API Vec256 InterleaveLower(Vec256 a, Vec256 b) { a.v0 = InterleaveLower(a.v0, b.v0); a.v1 = InterleaveLower(a.v1, b.v1); return a; } // wasm_128 already defines a template with D, V, V args. // ------------------------------ InterleaveUpper (UpperHalf) template > HWY_API Vec256 InterleaveUpper(D d, Vec256 a, Vec256 b) { const Half dh; a.v0 = InterleaveUpper(dh, a.v0, b.v0); a.v1 = InterleaveUpper(dh, a.v1, b.v1); return a; } // ------------------------------ InterleaveWholeLower template HWY_API VFromD InterleaveWholeLower(D d, VFromD a, VFromD b) { const Half dh; VFromD ret; ret.v0 = InterleaveLower(a.v0, b.v0); ret.v1 = InterleaveUpper(dh, a.v0, b.v0); return ret; } // ------------------------------ InterleaveWholeUpper template HWY_API VFromD InterleaveWholeUpper(D d, VFromD a, VFromD b) { const Half dh; VFromD ret; ret.v0 = InterleaveLower(a.v1, b.v1); ret.v1 = InterleaveUpper(dh, a.v1, b.v1); return ret; } // ------------------------------ ZipLower/ZipUpper defined in wasm_128 // ================================================== COMBINE // ------------------------------ Combine (InterleaveLower) template > HWY_API Vec256 Combine(D /* d */, Vec128 hi, Vec128 lo) { Vec256 ret; ret.v1 = hi; ret.v0 = lo; return ret; } // ------------------------------ ZeroExtendVector (Combine) template > HWY_API Vec256 ZeroExtendVector(D d, Vec128 lo) { const Half dh; return Combine(d, Zero(dh), lo); } // ------------------------------ ZeroExtendResizeBitCast namespace detail { template HWY_INLINE VFromD ZeroExtendResizeBitCast( hwy::SizeTag /* from_size_tag */, hwy::SizeTag<32> /* to_size_tag */, DTo d_to, DFrom d_from, VFromD v) { const Half dh_to; return ZeroExtendVector(d_to, ZeroExtendResizeBitCast(dh_to, d_from, v)); } } // namespace detail // ------------------------------ ConcatLowerLower template > HWY_API Vec256 ConcatLowerLower(D /* tag */, Vec256 hi, Vec256 lo) { Vec256 ret; ret.v1 = hi.v0; ret.v0 = lo.v0; return ret; } // ------------------------------ ConcatUpperUpper template > HWY_API Vec256 ConcatUpperUpper(D /* tag */, Vec256 hi, Vec256 lo) { Vec256 ret; ret.v1 = hi.v1; ret.v0 = lo.v1; return ret; } // ------------------------------ ConcatLowerUpper template > HWY_API Vec256 ConcatLowerUpper(D /* tag */, Vec256 hi, Vec256 lo) { Vec256 ret; ret.v1 = hi.v0; ret.v0 = lo.v1; return ret; } // ------------------------------ ConcatUpperLower template > HWY_API Vec256 ConcatUpperLower(D /* tag */, Vec256 hi, Vec256 lo) { Vec256 ret; ret.v1 = hi.v1; ret.v0 = lo.v0; return ret; } // ------------------------------ ConcatOdd template > HWY_API Vec256 ConcatOdd(D d, Vec256 hi, Vec256 lo) { const Half dh; Vec256 ret; ret.v0 = ConcatOdd(dh, lo.v1, lo.v0); ret.v1 = ConcatOdd(dh, hi.v1, hi.v0); return ret; } // ------------------------------ ConcatEven template > HWY_API Vec256 ConcatEven(D d, Vec256 hi, Vec256 lo) { const Half dh; Vec256 ret; ret.v0 = ConcatEven(dh, lo.v1, lo.v0); ret.v1 = ConcatEven(dh, hi.v1, hi.v0); return ret; } // ------------------------------ DupEven template HWY_API Vec256 DupEven(Vec256 v) { v.v0 = DupEven(v.v0); v.v1 = DupEven(v.v1); return v; } // ------------------------------ DupOdd template HWY_API Vec256 DupOdd(Vec256 v) { v.v0 = DupOdd(v.v0); v.v1 = DupOdd(v.v1); return v; } // ------------------------------ OddEven template HWY_API Vec256 OddEven(Vec256 a, const Vec256 b) { a.v0 = OddEven(a.v0, b.v0); a.v1 = OddEven(a.v1, b.v1); return a; } // ------------------------------ OddEvenBlocks template HWY_API Vec256 OddEvenBlocks(Vec256 odd, Vec256 even) { odd.v0 = even.v0; return odd; } // ------------------------------ SwapAdjacentBlocks template HWY_API Vec256 SwapAdjacentBlocks(Vec256 v) { Vec256 ret; ret.v0 = v.v1; // swapped order ret.v1 = v.v0; return ret; } // ------------------------------ ReverseBlocks template > HWY_API Vec256 ReverseBlocks(D /* tag */, const Vec256 v) { return SwapAdjacentBlocks(v); // 2 blocks, so Swap = Reverse } // ------------------------------ Per4LaneBlockShuffle namespace detail { template HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, hwy::SizeTag<1> /*lane_size_tag*/, hwy::SizeTag<32> /*vect_size_tag*/, V v) { const DFromV d; const Half dh; using VH = VFromD; constexpr int kIdx3 = static_cast((kIdx3210 >> 6) & 3); constexpr int kIdx2 = static_cast((kIdx3210 >> 4) & 3); constexpr int kIdx1 = static_cast((kIdx3210 >> 2) & 3); constexpr int kIdx0 = static_cast(kIdx3210 & 3); V ret; ret.v0 = VH{wasm_i8x16_shuffle( v.v0.raw, v.v0.raw, kIdx0, kIdx1, kIdx2, kIdx3, kIdx0 + 4, kIdx1 + 4, kIdx2 + 4, kIdx3 + 4, kIdx0 + 8, kIdx1 + 8, kIdx2 + 8, kIdx3 + 8, kIdx0 + 12, kIdx1 + 12, kIdx2 + 12, kIdx3 + 12)}; ret.v1 = VH{wasm_i8x16_shuffle( v.v1.raw, v.v1.raw, kIdx0, kIdx1, kIdx2, kIdx3, kIdx0 + 4, kIdx1 + 4, kIdx2 + 4, kIdx3 + 4, kIdx0 + 8, kIdx1 + 8, kIdx2 + 8, kIdx3 + 8, kIdx0 + 12, kIdx1 + 12, kIdx2 + 12, kIdx3 + 12)}; return ret; } template HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, hwy::SizeTag<2> /*lane_size_tag*/, hwy::SizeTag<32> /*vect_size_tag*/, V v) { const DFromV d; const Half dh; using VH = VFromD; constexpr int kIdx3 = static_cast((kIdx3210 >> 6) & 3); constexpr int kIdx2 = static_cast((kIdx3210 >> 4) & 3); constexpr int kIdx1 = static_cast((kIdx3210 >> 2) & 3); constexpr int kIdx0 = static_cast(kIdx3210 & 3); V ret; ret.v0 = VH{wasm_i16x8_shuffle(v.v0.raw, v.v0.raw, kIdx0, kIdx1, kIdx2, kIdx3, kIdx0 + 4, kIdx1 + 4, kIdx2 + 4, kIdx3 + 4)}; ret.v1 = VH{wasm_i16x8_shuffle(v.v1.raw, v.v1.raw, kIdx0, kIdx1, kIdx2, kIdx3, kIdx0 + 4, kIdx1 + 4, kIdx2 + 4, kIdx3 + 4)}; return ret; } template HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, hwy::SizeTag<4> /*lane_size_tag*/, hwy::SizeTag<32> /*vect_size_tag*/, V v) { const DFromV d; const Half dh; using VH = VFromD; constexpr int kIdx3 = static_cast((kIdx3210 >> 6) & 3); constexpr int kIdx2 = static_cast((kIdx3210 >> 4) & 3); constexpr int kIdx1 = static_cast((kIdx3210 >> 2) & 3); constexpr int kIdx0 = static_cast(kIdx3210 & 3); V ret; ret.v0 = VH{wasm_i32x4_shuffle(v.v0.raw, v.v0.raw, kIdx0, kIdx1, kIdx2, kIdx3)}; ret.v1 = VH{wasm_i32x4_shuffle(v.v1.raw, v.v1.raw, kIdx0, kIdx1, kIdx2, kIdx3)}; return ret; } template HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag /*idx_3210_tag*/, hwy::SizeTag<8> /*lane_size_tag*/, hwy::SizeTag<32> /*vect_size_tag*/, V v) { const DFromV d; const Half dh; using VH = VFromD; constexpr int kIdx3 = static_cast((kIdx3210 >> 6) & 3); constexpr int kIdx2 = static_cast((kIdx3210 >> 4) & 3); constexpr int kIdx1 = static_cast((kIdx3210 >> 2) & 3); constexpr int kIdx0 = static_cast(kIdx3210 & 3); V ret; ret.v0 = VH{wasm_i64x2_shuffle(v.v0.raw, v.v1.raw, kIdx0, kIdx1)}; ret.v1 = VH{wasm_i64x2_shuffle(v.v0.raw, v.v1.raw, kIdx2, kIdx3)}; return ret; } } // namespace detail // ------------------------------ SlideUpBlocks template HWY_API VFromD SlideUpBlocks(D d, VFromD v) { static_assert(0 <= kBlocks && kBlocks <= 1, "kBlocks must be between 0 and 1"); return (kBlocks == 1) ? ConcatLowerLower(d, v, Zero(d)) : v; } // ------------------------------ SlideDownBlocks template HWY_API VFromD SlideDownBlocks(D d, VFromD v) { static_assert(0 <= kBlocks && kBlocks <= 1, "kBlocks must be between 0 and 1"); const Half dh; return (kBlocks == 1) ? ZeroExtendVector(d, UpperHalf(dh, v)) : v; } // ------------------------------ SlideUpLanes template HWY_API VFromD SlideUpLanes(D d, VFromD v, size_t amt) { const Half dh; const RebindToUnsigned du; const RebindToUnsigned dh_u; const auto vu = BitCast(du, v); VFromD ret; #if !HWY_IS_DEBUG_BUILD constexpr size_t kLanesPerBlock = 16 / sizeof(TFromD); if (__builtin_constant_p(amt) && amt < kLanesPerBlock) { switch (amt * sizeof(TFromD)) { case 0: return v; case 1: ret.v0 = BitCast(dh, ShiftLeftBytes<1>(dh_u, vu.v0)); ret.v1 = BitCast(dh, CombineShiftRightBytes<15>(dh_u, vu.v1, vu.v0)); return ret; case 2: ret.v0 = BitCast(dh, ShiftLeftBytes<2>(dh_u, vu.v0)); ret.v1 = BitCast(dh, CombineShiftRightBytes<14>(dh_u, vu.v1, vu.v0)); return ret; case 3: ret.v0 = BitCast(dh, ShiftLeftBytes<3>(dh_u, vu.v0)); ret.v1 = BitCast(dh, CombineShiftRightBytes<13>(dh_u, vu.v1, vu.v0)); return ret; case 4: ret.v0 = BitCast(dh, ShiftLeftBytes<4>(dh_u, vu.v0)); ret.v1 = BitCast(dh, CombineShiftRightBytes<12>(dh_u, vu.v1, vu.v0)); return ret; case 5: ret.v0 = BitCast(dh, ShiftLeftBytes<5>(dh_u, vu.v0)); ret.v1 = BitCast(dh, CombineShiftRightBytes<11>(dh_u, vu.v1, vu.v0)); return ret; case 6: ret.v0 = BitCast(dh, ShiftLeftBytes<6>(dh_u, vu.v0)); ret.v1 = BitCast(dh, CombineShiftRightBytes<10>(dh_u, vu.v1, vu.v0)); return ret; case 7: ret.v0 = BitCast(dh, ShiftLeftBytes<7>(dh_u, vu.v0)); ret.v1 = BitCast(dh, CombineShiftRightBytes<9>(dh_u, vu.v1, vu.v0)); return ret; case 8: ret.v0 = BitCast(dh, ShiftLeftBytes<8>(dh_u, vu.v0)); ret.v1 = BitCast(dh, CombineShiftRightBytes<8>(dh_u, vu.v1, vu.v0)); return ret; case 9: ret.v0 = BitCast(dh, ShiftLeftBytes<9>(dh_u, vu.v0)); ret.v1 = BitCast(dh, CombineShiftRightBytes<7>(dh_u, vu.v1, vu.v0)); return ret; case 10: ret.v0 = BitCast(dh, ShiftLeftBytes<10>(dh_u, vu.v0)); ret.v1 = BitCast(dh, CombineShiftRightBytes<6>(dh_u, vu.v1, vu.v0)); return ret; case 11: ret.v0 = BitCast(dh, ShiftLeftBytes<11>(dh_u, vu.v0)); ret.v1 = BitCast(dh, CombineShiftRightBytes<5>(dh_u, vu.v1, vu.v0)); return ret; case 12: ret.v0 = BitCast(dh, ShiftLeftBytes<12>(dh_u, vu.v0)); ret.v1 = BitCast(dh, CombineShiftRightBytes<4>(dh_u, vu.v1, vu.v0)); return ret; case 13: ret.v0 = BitCast(dh, ShiftLeftBytes<13>(dh_u, vu.v0)); ret.v1 = BitCast(dh, CombineShiftRightBytes<3>(dh_u, vu.v1, vu.v0)); return ret; case 14: ret.v0 = BitCast(dh, ShiftLeftBytes<14>(dh_u, vu.v0)); ret.v1 = BitCast(dh, CombineShiftRightBytes<2>(dh_u, vu.v1, vu.v0)); return ret; case 15: ret.v0 = BitCast(dh, ShiftLeftBytes<15>(dh_u, vu.v0)); ret.v1 = BitCast(dh, CombineShiftRightBytes<1>(dh_u, vu.v1, vu.v0)); return ret; } } if (__builtin_constant_p(amt >= kLanesPerBlock) && amt >= kLanesPerBlock) { ret.v0 = Zero(dh); ret.v1 = SlideUpLanes(dh, LowerHalf(dh, v), amt - kLanesPerBlock); return ret; } #endif const Repartition du8; const RebindToSigned di8; const Half dh_i8; const auto lo_byte_idx = BitCast( di8, Iota(du8, static_cast(size_t{0} - amt * sizeof(TFromD)))); const auto hi_byte_idx = UpperHalf(dh_i8, lo_byte_idx) - Set(dh_i8, int8_t{16}); const auto hi_sel_mask = UpperHalf(dh_i8, lo_byte_idx) > Set(dh_i8, int8_t{15}); ret = BitCast(d, TableLookupBytesOr0(ConcatLowerLower(du, vu, vu), lo_byte_idx)); ret.v1 = BitCast(dh, IfThenElse(hi_sel_mask, TableLookupBytes(UpperHalf(dh_u, vu), hi_byte_idx), BitCast(dh_i8, ret.v1))); return ret; } // ------------------------------ Slide1Up template HWY_API VFromD Slide1Up(D d, VFromD v) { VFromD ret; const Half dh; constexpr int kShrByteAmt = static_cast(16 - sizeof(TFromD)); ret.v0 = ShiftLeftLanes<1>(dh, v.v0); ret.v1 = CombineShiftRightBytes(dh, v.v1, v.v0); return ret; } // ------------------------------ SlideDownLanes template HWY_API VFromD SlideDownLanes(D d, VFromD v, size_t amt) { const Half dh; const RebindToUnsigned du; const RebindToUnsigned dh_u; VFromD ret; const auto vu = BitCast(du, v); #if !HWY_IS_DEBUG_BUILD constexpr size_t kLanesPerBlock = 16 / sizeof(TFromD); if (__builtin_constant_p(amt) && amt < kLanesPerBlock) { switch (amt * sizeof(TFromD)) { case 0: return v; case 1: ret.v0 = BitCast(dh, CombineShiftRightBytes<1>(dh_u, vu.v1, vu.v0)); ret.v1 = BitCast(dh, ShiftRightBytes<1>(dh_u, vu.v1)); return ret; case 2: ret.v0 = BitCast(dh, CombineShiftRightBytes<2>(dh_u, vu.v1, vu.v0)); ret.v1 = BitCast(dh, ShiftRightBytes<2>(dh_u, vu.v1)); return ret; case 3: ret.v0 = BitCast(dh, CombineShiftRightBytes<3>(dh_u, vu.v1, vu.v0)); ret.v1 = BitCast(dh, ShiftRightBytes<3>(dh_u, vu.v1)); return ret; case 4: ret.v0 = BitCast(dh, CombineShiftRightBytes<4>(dh_u, vu.v1, vu.v0)); ret.v1 = BitCast(dh, ShiftRightBytes<4>(dh_u, vu.v1)); return ret; case 5: ret.v0 = BitCast(dh, CombineShiftRightBytes<5>(dh_u, vu.v1, vu.v0)); ret.v1 = BitCast(dh, ShiftRightBytes<5>(dh_u, vu.v1)); return ret; case 6: ret.v0 = BitCast(dh, CombineShiftRightBytes<6>(dh_u, vu.v1, vu.v0)); ret.v1 = BitCast(dh, ShiftRightBytes<6>(dh_u, vu.v1)); return ret; case 7: ret.v0 = BitCast(dh, CombineShiftRightBytes<7>(dh_u, vu.v1, vu.v0)); ret.v1 = BitCast(dh, ShiftRightBytes<7>(dh_u, vu.v1)); return ret; case 8: ret.v0 = BitCast(dh, CombineShiftRightBytes<8>(dh_u, vu.v1, vu.v0)); ret.v1 = BitCast(dh, ShiftRightBytes<8>(dh_u, vu.v1)); return ret; case 9: ret.v0 = BitCast(dh, CombineShiftRightBytes<9>(dh_u, vu.v1, vu.v0)); ret.v1 = BitCast(dh, ShiftRightBytes<9>(dh_u, vu.v1)); return ret; case 10: ret.v0 = BitCast(dh, CombineShiftRightBytes<10>(dh_u, vu.v1, vu.v0)); ret.v1 = BitCast(dh, ShiftRightBytes<10>(dh_u, vu.v1)); return ret; case 11: ret.v0 = BitCast(dh, CombineShiftRightBytes<11>(dh_u, vu.v1, vu.v0)); ret.v1 = BitCast(dh, ShiftRightBytes<11>(dh_u, vu.v1)); return ret; case 12: ret.v0 = BitCast(dh, CombineShiftRightBytes<12>(dh_u, vu.v1, vu.v0)); ret.v1 = BitCast(dh, ShiftRightBytes<12>(dh_u, vu.v1)); return ret; case 13: ret.v0 = BitCast(dh, CombineShiftRightBytes<13>(dh_u, vu.v1, vu.v0)); ret.v1 = BitCast(dh, ShiftRightBytes<13>(dh_u, vu.v1)); return ret; case 14: ret.v0 = BitCast(dh, CombineShiftRightBytes<14>(dh_u, vu.v1, vu.v0)); ret.v1 = BitCast(dh, ShiftRightBytes<14>(dh_u, vu.v1)); return ret; case 15: ret.v0 = BitCast(dh, CombineShiftRightBytes<15>(dh_u, vu.v1, vu.v0)); ret.v1 = BitCast(dh, ShiftRightBytes<15>(dh_u, vu.v1)); return ret; } } if (__builtin_constant_p(amt >= kLanesPerBlock) && amt >= kLanesPerBlock) { ret.v0 = SlideDownLanes(dh, UpperHalf(dh, v), amt - kLanesPerBlock); ret.v1 = Zero(dh); return ret; } #endif const Repartition du8; const Half dh_u8; const auto lo_byte_idx = Iota(du8, static_cast(amt * sizeof(TFromD))); const auto u8_16 = Set(du8, uint8_t{16}); const auto hi_byte_idx = lo_byte_idx - u8_16; const auto lo_sel_mask = LowerHalf(dh_u8, lo_byte_idx) < LowerHalf(dh_u8, u8_16); ret = BitCast(d, IfThenElseZero(hi_byte_idx < u8_16, TableLookupBytes(ConcatUpperUpper(du, vu, vu), hi_byte_idx))); ret.v0 = BitCast(dh, IfThenElse(lo_sel_mask, TableLookupBytes(LowerHalf(dh_u, vu), LowerHalf(dh_u8, lo_byte_idx)), BitCast(dh_u8, LowerHalf(dh, ret)))); return ret; } // ------------------------------ Slide1Down template HWY_API VFromD Slide1Down(D d, VFromD v) { VFromD ret; const Half dh; constexpr int kShrByteAmt = static_cast(sizeof(TFromD)); ret.v0 = CombineShiftRightBytes(dh, v.v1, v.v0); ret.v1 = ShiftRightBytes(dh, v.v1); return ret; } // ================================================== CONVERT // ------------------------------ PromoteTo template HWY_API VFromD PromoteTo(D d, Vec128 v) { const Half dh; VFromD ret; // PromoteLowerTo is defined later in generic_ops-inl.h. ret.v0 = PromoteTo(dh, LowerHalf(v)); ret.v1 = PromoteUpperTo(dh, v); return ret; } // 4x promotion: 8-bit to 32-bit or 16-bit to 64-bit template HWY_API Vec256> PromoteTo(DW d, Vec64 v) { const Half dh; // 16-bit lanes for UI8->UI32, 32-bit lanes for UI16->UI64 const Rebind, decltype(d)> d2; const auto v_2x = PromoteTo(d2, v); Vec256> ret; // PromoteLowerTo is defined later in generic_ops-inl.h. ret.v0 = PromoteTo(dh, LowerHalf(v_2x)); ret.v1 = PromoteUpperTo(dh, v_2x); return ret; } // 8x promotion: 8-bit to 64-bit template HWY_API Vec256> PromoteTo(DW d, Vec32 v) { const Half dh; const Repartition>, decltype(dh)> d4; // 32-bit lanes const auto v32 = PromoteTo(d4, v); Vec256> ret; // PromoteLowerTo is defined later in generic_ops-inl.h. ret.v0 = PromoteTo(dh, LowerHalf(v32)); ret.v1 = PromoteUpperTo(dh, v32); return ret; } // ------------------------------ PromoteUpperTo // Not native, but still define this here because wasm_128 toggles // HWY_NATIVE_PROMOTE_UPPER_TO. template HWY_API VFromD PromoteUpperTo(D d, Vec256 v) { // Lanes(d) may differ from Lanes(DFromV()). Use the lane type // from v because it cannot be deduced from D (could be either bf16 or f16). const Rebind dh; return PromoteTo(d, UpperHalf(dh, v)); } // ------------------------------ DemoteTo template HWY_API Vec128 DemoteTo(D /* tag */, Vec256 v) { return Vec128{wasm_u16x8_narrow_i32x4(v.v0.raw, v.v1.raw)}; } template HWY_API Vec128 DemoteTo(D /* tag */, Vec256 v) { return Vec128{wasm_i16x8_narrow_i32x4(v.v0.raw, v.v1.raw)}; } template HWY_API Vec64 DemoteTo(D /* tag */, Vec256 v) { const auto intermediate = wasm_i16x8_narrow_i32x4(v.v0.raw, v.v1.raw); return Vec64{wasm_u8x16_narrow_i16x8(intermediate, intermediate)}; } template HWY_API Vec128 DemoteTo(D /* tag */, Vec256 v) { return Vec128{wasm_u8x16_narrow_i16x8(v.v0.raw, v.v1.raw)}; } template HWY_API Vec64 DemoteTo(D /* tag */, Vec256 v) { const auto intermediate = wasm_i16x8_narrow_i32x4(v.v0.raw, v.v1.raw); return Vec64{wasm_i8x16_narrow_i16x8(intermediate, intermediate)}; } template HWY_API Vec128 DemoteTo(D /* tag */, Vec256 v) { return Vec128{wasm_i8x16_narrow_i16x8(v.v0.raw, v.v1.raw)}; } template HWY_API Vec128 DemoteTo(D di, Vec256 v) { const Vec64 lo{wasm_i32x4_trunc_sat_f64x2_zero(v.v0.raw)}; const Vec64 hi{wasm_i32x4_trunc_sat_f64x2_zero(v.v1.raw)}; return Combine(di, hi, lo); } template HWY_API Vec128 DemoteTo(D di, Vec256 v) { const Vec64 lo{wasm_u32x4_trunc_sat_f64x2_zero(v.v0.raw)}; const Vec64 hi{wasm_u32x4_trunc_sat_f64x2_zero(v.v1.raw)}; return Combine(di, hi, lo); } template HWY_API Vec128 DemoteTo(D df, Vec256 v) { const Vec64 lo = DemoteTo(Full64(), v.v0); const Vec64 hi = DemoteTo(Full64(), v.v1); return Combine(df, hi, lo); } template HWY_API Vec128 DemoteTo(D df, Vec256 v) { const Vec64 lo = DemoteTo(Full64(), v.v0); const Vec64 hi = DemoteTo(Full64(), v.v1); return Combine(df, hi, lo); } template HWY_API Vec128 DemoteTo(D d16, Vec256 v) { const Half d16h; const Vec64 lo = DemoteTo(d16h, v.v0); const Vec64 hi = DemoteTo(d16h, v.v1); return Combine(d16, hi, lo); } template HWY_API Vec128 DemoteTo(D dbf16, Vec256 v) { const Half dbf16h; const Vec64 lo = DemoteTo(dbf16h, v.v0); const Vec64 hi = DemoteTo(dbf16h, v.v1); return Combine(dbf16, hi, lo); } // For already range-limited input [0, 255]. HWY_API Vec64 U8FromU32(Vec256 v) { const Full64 du8; const Full256 di32; // no unsigned DemoteTo return DemoteTo(du8, BitCast(di32, v)); } // ------------------------------ Truncations template HWY_API Vec32 TruncateTo(D /* tag */, Vec256 v) { return Vec32{wasm_i8x16_shuffle(v.v0.raw, v.v1.raw, 0, 8, 16, 24, 0, 8, 16, 24, 0, 8, 16, 24, 0, 8, 16, 24)}; } template HWY_API Vec64 TruncateTo(D /* tag */, Vec256 v) { return Vec64{wasm_i8x16_shuffle(v.v0.raw, v.v1.raw, 0, 1, 8, 9, 16, 17, 24, 25, 0, 1, 8, 9, 16, 17, 24, 25)}; } template HWY_API Vec128 TruncateTo(D /* tag */, Vec256 v) { return Vec128{wasm_i8x16_shuffle(v.v0.raw, v.v1.raw, 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27)}; } template HWY_API Vec64 TruncateTo(D /* tag */, Vec256 v) { return Vec64{wasm_i8x16_shuffle(v.v0.raw, v.v1.raw, 0, 4, 8, 12, 16, 20, 24, 28, 0, 4, 8, 12, 16, 20, 24, 28)}; } template HWY_API Vec128 TruncateTo(D /* tag */, Vec256 v) { return Vec128{wasm_i8x16_shuffle(v.v0.raw, v.v1.raw, 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29)}; } template HWY_API Vec128 TruncateTo(D /* tag */, Vec256 v) { return Vec128{wasm_i8x16_shuffle(v.v0.raw, v.v1.raw, 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30)}; } // ------------------------------ ReorderDemote2To template HWY_API Vec256 ReorderDemote2To(DBF16 dbf16, Vec256 a, Vec256 b) { const RebindToUnsigned du16; return BitCast(dbf16, ConcatOdd(du16, BitCast(du16, b), BitCast(du16, a))); } template ), HWY_IF_SIGNED_V(V), HWY_IF_T_SIZE_ONE_OF_D(DN, (1 << 1) | (1 << 2) | (1 << 4)), HWY_IF_T_SIZE_V(V, sizeof(TFromD) * 2)> HWY_API VFromD ReorderDemote2To(DN dn, V a, V b) { const Half dnh; VFromD demoted; demoted.v0 = DemoteTo(dnh, a); demoted.v1 = DemoteTo(dnh, b); return demoted; } template ) * 2)> HWY_API VFromD ReorderDemote2To(DN dn, V a, V b) { const Half dnh; VFromD demoted; demoted.v0 = DemoteTo(dnh, a); demoted.v1 = DemoteTo(dnh, b); return demoted; } // ------------------------------ Convert i32 <=> f32 (Round) template > HWY_API Vec256 ConvertTo(DTo d, const Vec256 v) { const Half dh; Vec256 ret; ret.v0 = ConvertTo(dh, v.v0); ret.v1 = ConvertTo(dh, v.v1); return ret; } HWY_API Vec256 NearestInt(const Vec256 v) { return ConvertTo(Full256(), Round(v)); } // ================================================== MISC // ------------------------------ LoadMaskBits (TestBit) // `p` points to at least 8 readable bytes, not all of which need be valid. template HWY_API MFromD LoadMaskBits(D d, const uint8_t* HWY_RESTRICT bits) { const Half dh; MFromD ret; ret.m0 = LoadMaskBits(dh, bits); // If size=4, one 128-bit vector has 4 mask bits; otherwise 2 for size=8. // Both halves fit in one byte's worth of mask bits. constexpr size_t kBitsPerHalf = 16 / sizeof(TFromD); const uint8_t bits_upper[8] = {static_cast(bits[0] >> kBitsPerHalf)}; ret.m1 = LoadMaskBits(dh, bits_upper); return ret; } template HWY_API MFromD LoadMaskBits(D d, const uint8_t* HWY_RESTRICT bits) { const Half dh; MFromD ret; ret.m0 = LoadMaskBits(dh, bits); constexpr size_t kLanesPerHalf = 16 / sizeof(TFromD); constexpr size_t kBytesPerHalf = kLanesPerHalf / 8; static_assert(kBytesPerHalf != 0, "Lane size <= 16 bits => at least 8 lanes"); ret.m1 = LoadMaskBits(dh, bits + kBytesPerHalf); return ret; } template HWY_API MFromD Dup128MaskFromMaskBits(D d, unsigned mask_bits) { const Half dh; MFromD ret; ret.m0 = ret.m1 = Dup128MaskFromMaskBits(dh, mask_bits); return ret; } // ------------------------------ Mask // `p` points to at least 8 writable bytes. template , HWY_IF_T_SIZE_ONE_OF(T, (1 << 4) | (1 << 8))> HWY_API size_t StoreMaskBits(D d, const Mask256 mask, uint8_t* bits) { const Half dh; StoreMaskBits(dh, mask.m0, bits); const uint8_t lo = bits[0]; StoreMaskBits(dh, mask.m1, bits); // If size=4, one 128-bit vector has 4 mask bits; otherwise 2 for size=8. // Both halves fit in one byte's worth of mask bits. constexpr size_t kBitsPerHalf = 16 / sizeof(T); bits[0] = static_cast(lo | (bits[0] << kBitsPerHalf)); return (kBitsPerHalf * 2 + 7) / 8; } template , HWY_IF_T_SIZE_ONE_OF(T, (1 << 1) | (1 << 2))> HWY_API size_t StoreMaskBits(D d, const Mask256 mask, uint8_t* bits) { const Half dh; constexpr size_t kLanesPerHalf = 16 / sizeof(T); constexpr size_t kBytesPerHalf = kLanesPerHalf / 8; static_assert(kBytesPerHalf != 0, "Lane size <= 16 bits => at least 8 lanes"); StoreMaskBits(dh, mask.m0, bits); StoreMaskBits(dh, mask.m1, bits + kBytesPerHalf); return kBytesPerHalf * 2; } template > HWY_API size_t CountTrue(D d, const Mask256 m) { const Half dh; return CountTrue(dh, m.m0) + CountTrue(dh, m.m1); } template > HWY_API bool AllFalse(D d, const Mask256 m) { const Half dh; return AllFalse(dh, m.m0) && AllFalse(dh, m.m1); } template > HWY_API bool AllTrue(D d, const Mask256 m) { const Half dh; return AllTrue(dh, m.m0) && AllTrue(dh, m.m1); } template > HWY_API size_t FindKnownFirstTrue(D d, const Mask256 mask) { const Half dh; const intptr_t lo = FindFirstTrue(dh, mask.m0); // not known constexpr size_t kLanesPerHalf = 16 / sizeof(T); return lo >= 0 ? static_cast(lo) : kLanesPerHalf + FindKnownFirstTrue(dh, mask.m1); } template > HWY_API intptr_t FindFirstTrue(D d, const Mask256 mask) { const Half dh; const intptr_t lo = FindFirstTrue(dh, mask.m0); constexpr int kLanesPerHalf = 16 / sizeof(T); if (lo >= 0) return lo; const intptr_t hi = FindFirstTrue(dh, mask.m1); return hi + (hi >= 0 ? kLanesPerHalf : 0); } template > HWY_API size_t FindKnownLastTrue(D d, const Mask256 mask) { const Half dh; const intptr_t hi = FindLastTrue(dh, mask.m1); // not known constexpr size_t kLanesPerHalf = 16 / sizeof(T); return hi >= 0 ? kLanesPerHalf + static_cast(hi) : FindKnownLastTrue(dh, mask.m0); } template > HWY_API intptr_t FindLastTrue(D d, const Mask256 mask) { const Half dh; constexpr int kLanesPerHalf = 16 / sizeof(T); const intptr_t hi = FindLastTrue(dh, mask.m1); return hi >= 0 ? kLanesPerHalf + hi : FindLastTrue(dh, mask.m0); } // ------------------------------ CompressStore template > HWY_API size_t CompressStore(Vec256 v, const Mask256 mask, D d, T* HWY_RESTRICT unaligned) { const Half dh; const size_t count = CompressStore(v.v0, mask.m0, dh, unaligned); const size_t count2 = CompressStore(v.v1, mask.m1, dh, unaligned + count); return count + count2; } // ------------------------------ CompressBlendedStore template > HWY_API size_t CompressBlendedStore(Vec256 v, const Mask256 m, D d, T* HWY_RESTRICT unaligned) { const Half dh; const size_t count = CompressBlendedStore(v.v0, m.m0, dh, unaligned); const size_t count2 = CompressBlendedStore(v.v1, m.m1, dh, unaligned + count); return count + count2; } // ------------------------------ CompressBitsStore template > HWY_API size_t CompressBitsStore(Vec256 v, const uint8_t* HWY_RESTRICT bits, D d, T* HWY_RESTRICT unaligned) { const Mask256 m = LoadMaskBits(d, bits); return CompressStore(v, m, d, unaligned); } // ------------------------------ Compress template HWY_API Vec256 Compress(const Vec256 v, const Mask256 mask) { const DFromV d; alignas(32) T lanes[32 / sizeof(T)] = {}; (void)CompressStore(v, mask, d, lanes); return Load(d, lanes); } // ------------------------------ CompressNot template HWY_API Vec256 CompressNot(Vec256 v, const Mask256 mask) { return Compress(v, Not(mask)); } // ------------------------------ CompressBlocksNot HWY_API Vec256 CompressBlocksNot(Vec256 v, Mask256 mask) { const Full128 dh; // Because the non-selected (mask=1) blocks are undefined, we can return the // input unless mask = 01, in which case we must bring down the upper block. return AllTrue(dh, AndNot(mask.m1, mask.m0)) ? SwapAdjacentBlocks(v) : v; } // ------------------------------ CompressBits template HWY_API Vec256 CompressBits(Vec256 v, const uint8_t* HWY_RESTRICT bits) { const Mask256 m = LoadMaskBits(DFromV(), bits); return Compress(v, m); } // ------------------------------ Expand template HWY_API Vec256 Expand(const Vec256 v, const Mask256 mask) { Vec256 ret; const Full256 d; const Half dh; alignas(32) T lanes[32 / sizeof(T)] = {}; Store(v, d, lanes); ret.v0 = Expand(v.v0, mask.m0); ret.v1 = Expand(LoadU(dh, lanes + CountTrue(dh, mask.m0)), mask.m1); return ret; } // ------------------------------ LoadExpand template HWY_API VFromD LoadExpand(MFromD mask, D d, const TFromD* HWY_RESTRICT unaligned) { return Expand(LoadU(d, unaligned), mask); } // ------------------------------ LoadInterleaved3/4 // Implemented in generic_ops, we just overload LoadTransposedBlocks3/4. namespace detail { // Input: // 1 0 (<- first block of unaligned) // 3 2 // 5 4 // Output: // 3 0 // 4 1 // 5 2 template > HWY_API void LoadTransposedBlocks3(D d, const T* HWY_RESTRICT unaligned, Vec256& A, Vec256& B, Vec256& C) { const Vec256 v10 = LoadU(d, unaligned + 0 * MaxLanes(d)); const Vec256 v32 = LoadU(d, unaligned + 1 * MaxLanes(d)); const Vec256 v54 = LoadU(d, unaligned + 2 * MaxLanes(d)); A = ConcatUpperLower(d, v32, v10); B = ConcatLowerUpper(d, v54, v10); C = ConcatUpperLower(d, v54, v32); } // Input (128-bit blocks): // 1 0 (first block of unaligned) // 3 2 // 5 4 // 7 6 // Output: // 4 0 (LSB of A) // 5 1 // 6 2 // 7 3 template > HWY_API void LoadTransposedBlocks4(D d, const T* HWY_RESTRICT unaligned, Vec256& vA, Vec256& vB, Vec256& vC, Vec256& vD) { const Vec256 v10 = LoadU(d, unaligned + 0 * MaxLanes(d)); const Vec256 v32 = LoadU(d, unaligned + 1 * MaxLanes(d)); const Vec256 v54 = LoadU(d, unaligned + 2 * MaxLanes(d)); const Vec256 v76 = LoadU(d, unaligned + 3 * MaxLanes(d)); vA = ConcatLowerLower(d, v54, v10); vB = ConcatUpperUpper(d, v54, v10); vC = ConcatLowerLower(d, v76, v32); vD = ConcatUpperUpper(d, v76, v32); } } // namespace detail // ------------------------------ StoreInterleaved2/3/4 (ConcatUpperLower) // Implemented in generic_ops, we just overload StoreTransposedBlocks2/3/4. namespace detail { // Input (128-bit blocks): // 2 0 (LSB of i) // 3 1 // Output: // 1 0 // 3 2 template > HWY_API void StoreTransposedBlocks2(Vec256 i, Vec256 j, D d, T* HWY_RESTRICT unaligned) { const Vec256 out0 = ConcatLowerLower(d, j, i); const Vec256 out1 = ConcatUpperUpper(d, j, i); StoreU(out0, d, unaligned + 0 * MaxLanes(d)); StoreU(out1, d, unaligned + 1 * MaxLanes(d)); } // Input (128-bit blocks): // 3 0 (LSB of i) // 4 1 // 5 2 // Output: // 1 0 // 3 2 // 5 4 template > HWY_API void StoreTransposedBlocks3(Vec256 i, Vec256 j, Vec256 k, D d, T* HWY_RESTRICT unaligned) { const Vec256 out0 = ConcatLowerLower(d, j, i); const Vec256 out1 = ConcatUpperLower(d, i, k); const Vec256 out2 = ConcatUpperUpper(d, k, j); StoreU(out0, d, unaligned + 0 * MaxLanes(d)); StoreU(out1, d, unaligned + 1 * MaxLanes(d)); StoreU(out2, d, unaligned + 2 * MaxLanes(d)); } // Input (128-bit blocks): // 4 0 (LSB of i) // 5 1 // 6 2 // 7 3 // Output: // 1 0 // 3 2 // 5 4 // 7 6 template > HWY_API void StoreTransposedBlocks4(Vec256 i, Vec256 j, Vec256 k, Vec256 l, D d, T* HWY_RESTRICT unaligned) { // Write lower halves, then upper. const Vec256 out0 = ConcatLowerLower(d, j, i); const Vec256 out1 = ConcatLowerLower(d, l, k); StoreU(out0, d, unaligned + 0 * MaxLanes(d)); StoreU(out1, d, unaligned + 1 * MaxLanes(d)); const Vec256 out2 = ConcatUpperUpper(d, j, i); const Vec256 out3 = ConcatUpperUpper(d, l, k); StoreU(out2, d, unaligned + 2 * MaxLanes(d)); StoreU(out3, d, unaligned + 3 * MaxLanes(d)); } } // namespace detail // ------------------------------ Additional mask logical operations template HWY_API Mask256 SetAtOrAfterFirst(Mask256 mask) { const Full256 d; const Half dh; const Repartition dh_i64; Mask256 result; result.m0 = SetAtOrAfterFirst(mask.m0); result.m1 = SetAtOrAfterFirst(mask.m1); // Copy the sign bit of the lower 128-bit half to the upper 128-bit half const auto vmask_lo = BitCast(dh_i64, VecFromMask(dh, result.m0)); result.m1 = Or(result.m1, MaskFromVec(BitCast(dh, BroadcastSignBit(InterleaveUpper( dh_i64, vmask_lo, vmask_lo))))); return result; } template HWY_API Mask256 SetBeforeFirst(Mask256 mask) { return Not(SetAtOrAfterFirst(mask)); } template HWY_API Mask256 SetOnlyFirst(Mask256 mask) { const Full256 d; const RebindToSigned di; const Repartition di64; const Half dh_i64; const auto zero = Zero(di64); const auto vmask = BitCast(di64, VecFromMask(d, mask)); const auto vmask_eq_0 = VecFromMask(di64, vmask == zero); auto vmask2_lo = LowerHalf(dh_i64, vmask_eq_0); auto vmask2_hi = UpperHalf(dh_i64, vmask_eq_0); vmask2_lo = And(vmask2_lo, InterleaveLower(vmask2_lo, vmask2_lo)); vmask2_hi = And(ConcatLowerUpper(dh_i64, vmask2_hi, vmask2_lo), InterleaveUpper(dh_i64, vmask2_lo, vmask2_lo)); vmask2_lo = InterleaveLower(Set(dh_i64, int64_t{-1}), vmask2_lo); const auto vmask2 = Combine(di64, vmask2_hi, vmask2_lo); const auto only_first_vmask = Neg(BitCast(di, And(vmask, Neg(vmask)))); return MaskFromVec(BitCast(d, And(only_first_vmask, BitCast(di, vmask2)))); } template HWY_API Mask256 SetAtOrBeforeFirst(Mask256 mask) { const Full256 d; constexpr size_t kLanesPerBlock = MaxLanes(d) / 2; const auto vmask = VecFromMask(d, mask); const auto vmask_lo = ConcatLowerLower(d, vmask, Zero(d)); return SetBeforeFirst( MaskFromVec(CombineShiftRightBytes<(kLanesPerBlock - 1) * sizeof(T)>( d, vmask, vmask_lo))); } // ------------------------------ WidenMulPairwiseAdd template > HWY_API Vec256 WidenMulPairwiseAdd(D32 d32, Vec256 a, Vec256 b) { const Half d32h; Vec256 result; result.v0 = WidenMulPairwiseAdd(d32h, a.v0, b.v0); result.v1 = WidenMulPairwiseAdd(d32h, a.v1, b.v1); return result; } // ------------------------------ ReorderWidenMulAccumulate template > HWY_API Vec256 ReorderWidenMulAccumulate(D32 d32, Vec256 a, Vec256 b, Vec256 sum0, Vec256& sum1) { const Half d32h; sum0.v0 = ReorderWidenMulAccumulate(d32h, a.v0, b.v0, sum0.v0, sum1.v0); sum0.v1 = ReorderWidenMulAccumulate(d32h, a.v1, b.v1, sum0.v1, sum1.v1); return sum0; } // ------------------------------ RearrangeToOddPlusEven template HWY_API Vec256 RearrangeToOddPlusEven(Vec256 sum0, Vec256 sum1) { sum0.v0 = RearrangeToOddPlusEven(sum0.v0, sum1.v0); sum0.v1 = RearrangeToOddPlusEven(sum0.v1, sum1.v1); return sum0; } // ------------------------------ Reductions in generic_ops // ------------------------------ Lt128 template > HWY_INLINE Mask256 Lt128(D d, Vec256 a, Vec256 b) { const Half dh; Mask256 ret; ret.m0 = Lt128(dh, a.v0, b.v0); ret.m1 = Lt128(dh, a.v1, b.v1); return ret; } template > HWY_INLINE Mask256 Lt128Upper(D d, Vec256 a, Vec256 b) { const Half dh; Mask256 ret; ret.m0 = Lt128Upper(dh, a.v0, b.v0); ret.m1 = Lt128Upper(dh, a.v1, b.v1); return ret; } template > HWY_INLINE Mask256 Eq128(D d, Vec256 a, Vec256 b) { const Half dh; Mask256 ret; ret.m0 = Eq128(dh, a.v0, b.v0); ret.m1 = Eq128(dh, a.v1, b.v1); return ret; } template > HWY_INLINE Mask256 Eq128Upper(D d, Vec256 a, Vec256 b) { const Half dh; Mask256 ret; ret.m0 = Eq128Upper(dh, a.v0, b.v0); ret.m1 = Eq128Upper(dh, a.v1, b.v1); return ret; } template > HWY_INLINE Mask256 Ne128(D d, Vec256 a, Vec256 b) { const Half dh; Mask256 ret; ret.m0 = Ne128(dh, a.v0, b.v0); ret.m1 = Ne128(dh, a.v1, b.v1); return ret; } template > HWY_INLINE Mask256 Ne128Upper(D d, Vec256 a, Vec256 b) { const Half dh; Mask256 ret; ret.m0 = Ne128Upper(dh, a.v0, b.v0); ret.m1 = Ne128Upper(dh, a.v1, b.v1); return ret; } template > HWY_INLINE Vec256 Min128(D d, Vec256 a, Vec256 b) { const Half dh; Vec256 ret; ret.v0 = Min128(dh, a.v0, b.v0); ret.v1 = Min128(dh, a.v1, b.v1); return ret; } template > HWY_INLINE Vec256 Max128(D d, Vec256 a, Vec256 b) { const Half dh; Vec256 ret; ret.v0 = Max128(dh, a.v0, b.v0); ret.v1 = Max128(dh, a.v1, b.v1); return ret; } template > HWY_INLINE Vec256 Min128Upper(D d, Vec256 a, Vec256 b) { const Half dh; Vec256 ret; ret.v0 = Min128Upper(dh, a.v0, b.v0); ret.v1 = Min128Upper(dh, a.v1, b.v1); return ret; } template > HWY_INLINE Vec256 Max128Upper(D d, Vec256 a, Vec256 b) { const Half dh; Vec256 ret; ret.v0 = Max128Upper(dh, a.v0, b.v0); ret.v1 = Max128Upper(dh, a.v1, b.v1); return ret; } // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace hwy HWY_AFTER_NAMESPACE();