// Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // 256-bit WASM vectors and operations. Experimental. // External include guard in highway.h - see comment there. #include #include #include #include "hwy/base.h" #include "hwy/ops/shared-inl.h" #include "hwy/ops/wasm_128-inl.h" HWY_BEFORE_NAMESPACE(); namespace hwy { namespace HWY_NAMESPACE { template using Full256 = Simd; template using Full128 = Simd; // TODO(richardwinterton): add this to DeduceD in wasm_128 similar to x86_128. template class Vec256 { public: // Compound assignment. Only usable if there is a corresponding non-member // binary operator overload. For example, only f32 and f64 support division. HWY_INLINE Vec256& operator*=(const Vec256 other) { return *this = (*this * other); } HWY_INLINE Vec256& operator/=(const Vec256 other) { return *this = (*this / other); } HWY_INLINE Vec256& operator+=(const Vec256 other) { return *this = (*this + other); } HWY_INLINE Vec256& operator-=(const Vec256 other) { return *this = (*this - other); } HWY_INLINE Vec256& operator&=(const Vec256 other) { return *this = (*this & other); } HWY_INLINE Vec256& operator|=(const Vec256 other) { return *this = (*this | other); } HWY_INLINE Vec256& operator^=(const Vec256 other) { return *this = (*this ^ other); } Vec128 v0; Vec128 v1; }; template struct Mask256 { Mask128 m0; Mask128 m1; }; // ------------------------------ BitCast template HWY_API Vec256 BitCast(Full256 d, Vec256 v) { const Half dh; Vec256 ret; ret.v0 = BitCast(dh, v.v0); ret.v1 = BitCast(dh, v.v1); return ret; // TODO(richardwinterton): implement other ops like this } // ------------------------------ Zero // Returns an all-zero vector/part. template HWY_API Vec256 Zero(Full256 /* tag */) { return Vec256{wasm_i32x4_splat(0)}; } HWY_API Vec256 Zero(Full256 /* tag */) { return Vec256{wasm_f32x4_splat(0.0f)}; } template using VFromD = decltype(Zero(D())); // ------------------------------ Set // Returns a vector/part with all lanes set to "t". HWY_API Vec256 Set(Full256 /* tag */, const uint8_t t) { return Vec256{wasm_i8x16_splat(static_cast(t))}; } HWY_API Vec256 Set(Full256 /* tag */, const uint16_t t) { return Vec256{wasm_i16x8_splat(static_cast(t))}; } HWY_API Vec256 Set(Full256 /* tag */, const uint32_t t) { return Vec256{wasm_i32x4_splat(static_cast(t))}; } HWY_API Vec256 Set(Full256 /* tag */, const uint64_t t) { return Vec256{wasm_i64x2_splat(static_cast(t))}; } HWY_API Vec256 Set(Full256 /* tag */, const int8_t t) { return Vec256{wasm_i8x16_splat(t)}; } HWY_API Vec256 Set(Full256 /* tag */, const int16_t t) { return Vec256{wasm_i16x8_splat(t)}; } HWY_API Vec256 Set(Full256 /* tag */, const int32_t t) { return Vec256{wasm_i32x4_splat(t)}; } HWY_API Vec256 Set(Full256 /* tag */, const int64_t t) { return Vec256{wasm_i64x2_splat(t)}; } HWY_API Vec256 Set(Full256 /* tag */, const float t) { return Vec256{wasm_f32x4_splat(t)}; } HWY_DIAGNOSTICS(push) HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") // Returns a vector with uninitialized elements. template HWY_API Vec256 Undefined(Full256 d) { return Zero(d); } HWY_DIAGNOSTICS(pop) // Returns a vector with lane i=[0, N) set to "first" + i. template Vec256 Iota(const Full256 d, const T2 first) { HWY_ALIGN T lanes[16 / sizeof(T)]; for (size_t i = 0; i < 16 / sizeof(T); ++i) { lanes[i] = static_cast(first + static_cast(i)); } return Load(d, lanes); } // ================================================== ARITHMETIC // ------------------------------ Addition // Unsigned HWY_API Vec256 operator+(const Vec256 a, const Vec256 b) { return Vec256{wasm_i8x16_add(a.raw, b.raw)}; } HWY_API Vec256 operator+(const Vec256 a, const Vec256 b) { return Vec256{wasm_i16x8_add(a.raw, b.raw)}; } HWY_API Vec256 operator+(const Vec256 a, const Vec256 b) { return Vec256{wasm_i32x4_add(a.raw, b.raw)}; } // Signed HWY_API Vec256 operator+(const Vec256 a, const Vec256 b) { return Vec256{wasm_i8x16_add(a.raw, b.raw)}; } HWY_API Vec256 operator+(const Vec256 a, const Vec256 b) { return Vec256{wasm_i16x8_add(a.raw, b.raw)}; } HWY_API Vec256 operator+(const Vec256 a, const Vec256 b) { return Vec256{wasm_i32x4_add(a.raw, b.raw)}; } // Float HWY_API Vec256 operator+(const Vec256 a, const Vec256 b) { return Vec256{wasm_f32x4_add(a.raw, b.raw)}; } // ------------------------------ Subtraction // Unsigned HWY_API Vec256 operator-(const Vec256 a, const Vec256 b) { return Vec256{wasm_i8x16_sub(a.raw, b.raw)}; } HWY_API Vec256 operator-(Vec256 a, Vec256 b) { return Vec256{wasm_i16x8_sub(a.raw, b.raw)}; } HWY_API Vec256 operator-(const Vec256 a, const Vec256 b) { return Vec256{wasm_i32x4_sub(a.raw, b.raw)}; } // Signed HWY_API Vec256 operator-(const Vec256 a, const Vec256 b) { return Vec256{wasm_i8x16_sub(a.raw, b.raw)}; } HWY_API Vec256 operator-(const Vec256 a, const Vec256 b) { return Vec256{wasm_i16x8_sub(a.raw, b.raw)}; } HWY_API Vec256 operator-(const Vec256 a, const Vec256 b) { return Vec256{wasm_i32x4_sub(a.raw, b.raw)}; } // Float HWY_API Vec256 operator-(const Vec256 a, const Vec256 b) { return Vec256{wasm_f32x4_sub(a.raw, b.raw)}; } // ------------------------------ SumsOf8 HWY_API Vec256 SumsOf8(const Vec256 v) { HWY_ABORT("not implemented"); } // ------------------------------ SaturatedAdd // Returns a + b clamped to the destination range. // Unsigned HWY_API Vec256 SaturatedAdd(const Vec256 a, const Vec256 b) { return Vec256{wasm_u8x16_add_sat(a.raw, b.raw)}; } HWY_API Vec256 SaturatedAdd(const Vec256 a, const Vec256 b) { return Vec256{wasm_u16x8_add_sat(a.raw, b.raw)}; } // Signed HWY_API Vec256 SaturatedAdd(const Vec256 a, const Vec256 b) { return Vec256{wasm_i8x16_add_sat(a.raw, b.raw)}; } HWY_API Vec256 SaturatedAdd(const Vec256 a, const Vec256 b) { return Vec256{wasm_i16x8_add_sat(a.raw, b.raw)}; } // ------------------------------ SaturatedSub // Returns a - b clamped to the destination range. // Unsigned HWY_API Vec256 SaturatedSub(const Vec256 a, const Vec256 b) { return Vec256{wasm_u8x16_sub_sat(a.raw, b.raw)}; } HWY_API Vec256 SaturatedSub(const Vec256 a, const Vec256 b) { return Vec256{wasm_u16x8_sub_sat(a.raw, b.raw)}; } // Signed HWY_API Vec256 SaturatedSub(const Vec256 a, const Vec256 b) { return Vec256{wasm_i8x16_sub_sat(a.raw, b.raw)}; } HWY_API Vec256 SaturatedSub(const Vec256 a, const Vec256 b) { return Vec256{wasm_i16x8_sub_sat(a.raw, b.raw)}; } // ------------------------------ Average // Returns (a + b + 1) / 2 // Unsigned HWY_API Vec256 AverageRound(const Vec256 a, const Vec256 b) { return Vec256{wasm_u8x16_avgr(a.raw, b.raw)}; } HWY_API Vec256 AverageRound(const Vec256 a, const Vec256 b) { return Vec256{wasm_u16x8_avgr(a.raw, b.raw)}; } // ------------------------------ Absolute value // Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1. HWY_API Vec256 Abs(const Vec256 v) { return Vec256{wasm_i8x16_abs(v.raw)}; } HWY_API Vec256 Abs(const Vec256 v) { return Vec256{wasm_i16x8_abs(v.raw)}; } HWY_API Vec256 Abs(const Vec256 v) { return Vec256{wasm_i32x4_abs(v.raw)}; } HWY_API Vec256 Abs(const Vec256 v) { return Vec256{wasm_i62x2_abs(v.raw)}; } HWY_API Vec256 Abs(const Vec256 v) { return Vec256{wasm_f32x4_abs(v.raw)}; } // ------------------------------ Shift lanes by constant #bits // Unsigned template HWY_API Vec256 ShiftLeft(const Vec256 v) { return Vec256{wasm_i16x8_shl(v.raw, kBits)}; } template HWY_API Vec256 ShiftRight(const Vec256 v) { return Vec256{wasm_u16x8_shr(v.raw, kBits)}; } template HWY_API Vec256 ShiftLeft(const Vec256 v) { return Vec256{wasm_i32x4_shl(v.raw, kBits)}; } template HWY_API Vec256 ShiftRight(const Vec256 v) { return Vec256{wasm_u32x4_shr(v.raw, kBits)}; } // Signed template HWY_API Vec256 ShiftLeft(const Vec256 v) { return Vec256{wasm_i16x8_shl(v.raw, kBits)}; } template HWY_API Vec256 ShiftRight(const Vec256 v) { return Vec256{wasm_i16x8_shr(v.raw, kBits)}; } template HWY_API Vec256 ShiftLeft(const Vec256 v) { return Vec256{wasm_i32x4_shl(v.raw, kBits)}; } template HWY_API Vec256 ShiftRight(const Vec256 v) { return Vec256{wasm_i32x4_shr(v.raw, kBits)}; } // 8-bit template HWY_API Vec256 ShiftLeft(const Vec256 v) { const Full256 d8; // Use raw instead of BitCast to support N=1. const Vec256 shifted{ShiftLeft(Vec128>{v.raw}).raw}; return kBits == 1 ? (v + v) : (shifted & Set(d8, static_cast((0xFF << kBits) & 0xFF))); } template HWY_API Vec256 ShiftRight(const Vec256 v) { const Full256 d8; // Use raw instead of BitCast to support N=1. const Vec256 shifted{ShiftRight(Vec128{v.raw}).raw}; return shifted & Set(d8, 0xFF >> kBits); } template HWY_API Vec256 ShiftRight(const Vec256 v) { const Full256 di; const Full256 du; const auto shifted = BitCast(di, ShiftRight(BitCast(du, v))); const auto shifted_sign = BitCast(di, Set(du, 0x80 >> kBits)); return (shifted ^ shifted_sign) - shifted_sign; } // ------------------------------ RotateRight (ShiftRight, Or) template HWY_API Vec256 RotateRight(const Vec256 v) { constexpr size_t kSizeInBits = sizeof(T) * 8; static_assert(0 <= kBits && kBits < kSizeInBits, "Invalid shift count"); if (kBits == 0) return v; return Or(ShiftRight(v), ShiftLeft(v)); } // ------------------------------ Shift lanes by same variable #bits // Unsigned HWY_API Vec256 ShiftLeftSame(const Vec256 v, const int bits) { return Vec256{wasm_i16x8_shl(v.raw, bits)}; } HWY_API Vec256 ShiftRightSame(const Vec256 v, const int bits) { return Vec256{wasm_u16x8_shr(v.raw, bits)}; } HWY_API Vec256 ShiftLeftSame(const Vec256 v, const int bits) { return Vec256{wasm_i32x4_shl(v.raw, bits)}; } HWY_API Vec256 ShiftRightSame(const Vec256 v, const int bits) { return Vec256{wasm_u32x4_shr(v.raw, bits)}; } // Signed HWY_API Vec256 ShiftLeftSame(const Vec256 v, const int bits) { return Vec256{wasm_i16x8_shl(v.raw, bits)}; } HWY_API Vec256 ShiftRightSame(const Vec256 v, const int bits) { return Vec256{wasm_i16x8_shr(v.raw, bits)}; } HWY_API Vec256 ShiftLeftSame(const Vec256 v, const int bits) { return Vec256{wasm_i32x4_shl(v.raw, bits)}; } HWY_API Vec256 ShiftRightSame(const Vec256 v, const int bits) { return Vec256{wasm_i32x4_shr(v.raw, bits)}; } // 8-bit template HWY_API Vec256 ShiftLeftSame(const Vec256 v, const int bits) { const Full256 d8; // Use raw instead of BitCast to support N=1. const Vec256 shifted{ShiftLeftSame(Vec128>{v.raw}, bits).raw}; return shifted & Set(d8, (0xFF << bits) & 0xFF); } HWY_API Vec256 ShiftRightSame(Vec256 v, const int bits) { const Full256 d8; // Use raw instead of BitCast to support N=1. const Vec256 shifted{ ShiftRightSame(Vec128{v.raw}, bits).raw}; return shifted & Set(d8, 0xFF >> bits); } HWY_API Vec256 ShiftRightSame(Vec256 v, const int bits) { const Full256 di; const Full256 du; const auto shifted = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); const auto shifted_sign = BitCast(di, Set(du, 0x80 >> bits)); return (shifted ^ shifted_sign) - shifted_sign; } // ------------------------------ Minimum // Unsigned HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { return Vec256{wasm_u8x16_min(a.raw, b.raw)}; } HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { return Vec256{wasm_u16x8_min(a.raw, b.raw)}; } HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { return Vec256{wasm_u32x4_min(a.raw, b.raw)}; } HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { alignas(32) float min[4]; min[0] = HWY_MIN(wasm_u64x2_extract_lane(a, 0), wasm_u64x2_extract_lane(b, 0)); min[1] = HWY_MIN(wasm_u64x2_extract_lane(a, 1), wasm_u64x2_extract_lane(b, 1)); return Vec256{wasm_v128_load(min)}; } // Signed HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { return Vec256{wasm_i8x16_min(a.raw, b.raw)}; } HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { return Vec256{wasm_i16x8_min(a.raw, b.raw)}; } HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { return Vec256{wasm_i32x4_min(a.raw, b.raw)}; } HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { alignas(32) float min[4]; min[0] = HWY_MIN(wasm_i64x2_extract_lane(a, 0), wasm_i64x2_extract_lane(b, 0)); min[1] = HWY_MIN(wasm_i64x2_extract_lane(a, 1), wasm_i64x2_extract_lane(b, 1)); return Vec256{wasm_v128_load(min)}; } // Float HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { return Vec256{wasm_f32x4_min(a.raw, b.raw)}; } // ------------------------------ Maximum // Unsigned HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { return Vec256{wasm_u8x16_max(a.raw, b.raw)}; } HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { return Vec256{wasm_u16x8_max(a.raw, b.raw)}; } HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { return Vec256{wasm_u32x4_max(a.raw, b.raw)}; } HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { alignas(32) float max[4]; max[0] = HWY_MAX(wasm_u64x2_extract_lane(a, 0), wasm_u64x2_extract_lane(b, 0)); max[1] = HWY_MAX(wasm_u64x2_extract_lane(a, 1), wasm_u64x2_extract_lane(b, 1)); return Vec256{wasm_v128_load(max)}; } // Signed HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { return Vec256{wasm_i8x16_max(a.raw, b.raw)}; } HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { return Vec256{wasm_i16x8_max(a.raw, b.raw)}; } HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { return Vec256{wasm_i32x4_max(a.raw, b.raw)}; } HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { alignas(32) float max[4]; max[0] = HWY_MAX(wasm_i64x2_extract_lane(a, 0), wasm_i64x2_extract_lane(b, 0)); max[1] = HWY_MAX(wasm_i64x2_extract_lane(a, 1), wasm_i64x2_extract_lane(b, 1)); return Vec256{wasm_v128_load(max)}; } // Float HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { return Vec256{wasm_f32x4_max(a.raw, b.raw)}; } // ------------------------------ Integer multiplication // Unsigned HWY_API Vec256 operator*(const Vec256 a, const Vec256 b) { return Vec256{wasm_i16x8_mul(a.raw, b.raw)}; } HWY_API Vec256 operator*(const Vec256 a, const Vec256 b) { return Vec256{wasm_i32x4_mul(a.raw, b.raw)}; } // Signed HWY_API Vec256 operator*(const Vec256 a, const Vec256 b) { return Vec256{wasm_i16x8_mul(a.raw, b.raw)}; } HWY_API Vec256 operator*(const Vec256 a, const Vec256 b) { return Vec256{wasm_i32x4_mul(a.raw, b.raw)}; } // Returns the upper 16 bits of a * b in each lane. HWY_API Vec256 MulHigh(const Vec256 a, const Vec256 b) { // TODO(eustas): replace, when implemented in WASM. const auto al = wasm_u32x4_extend_low_u16x8(a.raw); const auto ah = wasm_u32x4_extend_high_u16x8(a.raw); const auto bl = wasm_u32x4_extend_low_u16x8(b.raw); const auto bh = wasm_u32x4_extend_high_u16x8(b.raw); const auto l = wasm_i32x4_mul(al, bl); const auto h = wasm_i32x4_mul(ah, bh); // TODO(eustas): shift-right + narrow? return Vec256{wasm_i16x8_shuffle(l, h, 1, 3, 5, 7, 9, 11, 13, 15)}; } HWY_API Vec256 MulHigh(const Vec256 a, const Vec256 b) { // TODO(eustas): replace, when implemented in WASM. const auto al = wasm_i32x4_extend_low_i16x8(a.raw); const auto ah = wasm_i32x4_extend_high_i16x8(a.raw); const auto bl = wasm_i32x4_extend_low_i16x8(b.raw); const auto bh = wasm_i32x4_extend_high_i16x8(b.raw); const auto l = wasm_i32x4_mul(al, bl); const auto h = wasm_i32x4_mul(ah, bh); // TODO(eustas): shift-right + narrow? return Vec256{wasm_i16x8_shuffle(l, h, 1, 3, 5, 7, 9, 11, 13, 15)}; } // Multiplies even lanes (0, 2 ..) and returns the double-width result. HWY_API Vec256 MulEven(const Vec256 a, const Vec256 b) { // TODO(eustas): replace, when implemented in WASM. const auto kEvenMask = wasm_i32x4_make(-1, 0, -1, 0); const auto ae = wasm_v128_and(a.raw, kEvenMask); const auto be = wasm_v128_and(b.raw, kEvenMask); return Vec256{wasm_i64x2_mul(ae, be)}; } HWY_API Vec256 MulEven(const Vec256 a, const Vec256 b) { // TODO(eustas): replace, when implemented in WASM. const auto kEvenMask = wasm_i32x4_make(-1, 0, -1, 0); const auto ae = wasm_v128_and(a.raw, kEvenMask); const auto be = wasm_v128_and(b.raw, kEvenMask); return Vec256{wasm_i64x2_mul(ae, be)}; } // ------------------------------ Negate template HWY_API Vec256 Neg(const Vec256 v) { return Xor(v, SignBit(Full256())); } HWY_API Vec256 Neg(const Vec256 v) { return Vec256{wasm_i8x16_neg(v.raw)}; } HWY_API Vec256 Neg(const Vec256 v) { return Vec256{wasm_i16x8_neg(v.raw)}; } HWY_API Vec256 Neg(const Vec256 v) { return Vec256{wasm_i32x4_neg(v.raw)}; } HWY_API Vec256 Neg(const Vec256 v) { return Vec256{wasm_i64x2_neg(v.raw)}; } // ------------------------------ Floating-point mul / div HWY_API Vec256 operator*(Vec256 a, Vec256 b) { return Vec256{wasm_f32x4_mul(a.raw, b.raw)}; } HWY_API Vec256 operator/(const Vec256 a, const Vec256 b) { return Vec256{wasm_f32x4_div(a.raw, b.raw)}; } // Approximate reciprocal HWY_API Vec256 ApproximateReciprocal(const Vec256 v) { const Vec256 one = Vec256{wasm_f32x4_splat(1.0f)}; return one / v; } // Absolute value of difference. HWY_API Vec256 AbsDiff(const Vec256 a, const Vec256 b) { return Abs(a - b); } // ------------------------------ Floating-point multiply-add variants // Returns mul * x + add HWY_API Vec256 MulAdd(const Vec256 mul, const Vec256 x, const Vec256 add) { // TODO(eustas): replace, when implemented in WASM. // TODO(eustas): is it wasm_f32x4_qfma? return mul * x + add; } // Returns add - mul * x HWY_API Vec256 NegMulAdd(const Vec256 mul, const Vec256 x, const Vec256 add) { // TODO(eustas): replace, when implemented in WASM. return add - mul * x; } // Returns mul * x - sub HWY_API Vec256 MulSub(const Vec256 mul, const Vec256 x, const Vec256 sub) { // TODO(eustas): replace, when implemented in WASM. // TODO(eustas): is it wasm_f32x4_qfms? return mul * x - sub; } // Returns -mul * x - sub HWY_API Vec256 NegMulSub(const Vec256 mul, const Vec256 x, const Vec256 sub) { // TODO(eustas): replace, when implemented in WASM. return Neg(mul) * x - sub; } // ------------------------------ Floating-point square root // Full precision square root HWY_API Vec256 Sqrt(const Vec256 v) { return Vec256{wasm_f32x4_sqrt(v.raw)}; } // Approximate reciprocal square root HWY_API Vec256 ApproximateReciprocalSqrt(const Vec256 v) { // TODO(eustas): find cheaper a way to calculate this. const Vec256 one = Vec256{wasm_f32x4_splat(1.0f)}; return one / Sqrt(v); } // ------------------------------ Floating-point rounding // Toward nearest integer, ties to even HWY_API Vec256 Round(const Vec256 v) { return Vec256{wasm_f32x4_nearest(v.raw)}; } // Toward zero, aka truncate HWY_API Vec256 Trunc(const Vec256 v) { return Vec256{wasm_f32x4_trunc(v.raw)}; } // Toward +infinity, aka ceiling HWY_API Vec256 Ceil(const Vec256 v) { return Vec256{wasm_f32x4_ceil(v.raw)}; } // Toward -infinity, aka floor HWY_API Vec256 Floor(const Vec256 v) { return Vec256{wasm_f32x4_floor(v.raw)}; } // ================================================== COMPARE // Comparisons fill a lane with 1-bits if the condition is true, else 0. template HWY_API Mask256 RebindMask(Full256 /*tag*/, Mask256 m) { static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); return Mask256{m.raw}; } template HWY_API Mask256 TestBit(Vec256 v, Vec256 bit) { static_assert(!hwy::IsFloat(), "Only integer vectors supported"); return (v & bit) == bit; } // ------------------------------ Equality // Unsigned HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { return Mask256{wasm_i8x16_eq(a.raw, b.raw)}; } HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { return Mask256{wasm_i16x8_eq(a.raw, b.raw)}; } HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { return Mask256{wasm_i32x4_eq(a.raw, b.raw)}; } // Signed HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { return Mask256{wasm_i8x16_eq(a.raw, b.raw)}; } HWY_API Mask256 operator==(Vec256 a, Vec256 b) { return Mask256{wasm_i16x8_eq(a.raw, b.raw)}; } HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { return Mask256{wasm_i32x4_eq(a.raw, b.raw)}; } // Float HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { return Mask256{wasm_f32x4_eq(a.raw, b.raw)}; } // ------------------------------ Inequality // Unsigned HWY_API Mask256 operator!=(const Vec256 a, const Vec256 b) { return Mask256{wasm_i8x16_ne(a.raw, b.raw)}; } HWY_API Mask256 operator!=(const Vec256 a, const Vec256 b) { return Mask256{wasm_i16x8_ne(a.raw, b.raw)}; } HWY_API Mask256 operator!=(const Vec256 a, const Vec256 b) { return Mask256{wasm_i32x4_ne(a.raw, b.raw)}; } // Signed HWY_API Mask256 operator!=(const Vec256 a, const Vec256 b) { return Mask256{wasm_i8x16_ne(a.raw, b.raw)}; } HWY_API Mask256 operator!=(Vec256 a, Vec256 b) { return Mask256{wasm_i16x8_ne(a.raw, b.raw)}; } HWY_API Mask256 operator!=(const Vec256 a, const Vec256 b) { return Mask256{wasm_i32x4_ne(a.raw, b.raw)}; } // Float HWY_API Mask256 operator!=(const Vec256 a, const Vec256 b) { return Mask256{wasm_f32x4_ne(a.raw, b.raw)}; } // ------------------------------ Strict inequality HWY_API Mask256 operator>(const Vec256 a, const Vec256 b) { return Mask256{wasm_i8x16_gt(a.raw, b.raw)}; } HWY_API Mask256 operator>(const Vec256 a, const Vec256 b) { return Mask256{wasm_i16x8_gt(a.raw, b.raw)}; } HWY_API Mask256 operator>(const Vec256 a, const Vec256 b) { return Mask256{wasm_i32x4_gt(a.raw, b.raw)}; } HWY_API Mask256 operator>(const Vec256 a, const Vec256 b) { const Rebind < int32_t, DFromV d32; const auto a32 = BitCast(d32, a); const auto b32 = BitCast(d32, b); // If the upper half is less than or greater, this is the answer. const auto m_gt = a32 < b32; // Otherwise, the lower half decides. const auto m_eq = a32 == b32; const auto lo_in_hi = wasm_i32x4_shuffle(m_gt, m_gt, 2, 2, 0, 0); const auto lo_gt = And(m_eq, lo_in_hi); const auto gt = Or(lo_gt, m_gt); // Copy result in upper 32 bits to lower 32 bits. return Mask256{wasm_i32x4_shuffle(gt, gt, 3, 3, 1, 1)}; } template HWY_API Mask256 operator>(Vec256 a, Vec256 b) { const Full256 du; const RebindToSigned di; const Vec256 msb = Set(du, (LimitsMax() >> 1) + 1); return RebindMask(du, BitCast(di, Xor(a, msb)) > BitCast(di, Xor(b, msb))); } HWY_API Mask256 operator>(const Vec256 a, const Vec256 b) { return Mask256{wasm_f32x4_gt(a.raw, b.raw)}; } template HWY_API Mask256 operator<(const Vec256 a, const Vec256 b) { return operator>(b, a); } // ------------------------------ Weak inequality // Float <= >= HWY_API Mask256 operator<=(const Vec256 a, const Vec256 b) { return Mask256{wasm_f32x4_le(a.raw, b.raw)}; } HWY_API Mask256 operator>=(const Vec256 a, const Vec256 b) { return Mask256{wasm_f32x4_ge(a.raw, b.raw)}; } // ------------------------------ FirstN (Iota, Lt) template HWY_API Mask256 FirstN(const Full256 d, size_t num) { const RebindToSigned di; // Signed comparisons may be cheaper. return RebindMask(d, Iota(di, 0) < Set(di, static_cast>(num))); } // ================================================== LOGICAL // ------------------------------ Not template HWY_API Vec256 Not(Vec256 v) { return Vec256{wasm_v128_not(v.raw)}; } // ------------------------------ And template HWY_API Vec256 And(Vec256 a, Vec256 b) { return Vec256{wasm_v128_and(a.raw, b.raw)}; } // ------------------------------ AndNot // Returns ~not_mask & mask. template HWY_API Vec256 AndNot(Vec256 not_mask, Vec256 mask) { return Vec256{wasm_v128_andnot(mask.raw, not_mask.raw)}; } // ------------------------------ Or template HWY_API Vec256 Or(Vec256 a, Vec256 b) { return Vec256{wasm_v128_or(a.raw, b.raw)}; } // ------------------------------ Xor template HWY_API Vec256 Xor(Vec256 a, Vec256 b) { return Vec256{wasm_v128_xor(a.raw, b.raw)}; } // ------------------------------ OrAnd template HWY_API Vec256 OrAnd(Vec256 o, Vec256 a1, Vec256 a2) { return Or(o, And(a1, a2)); } // ------------------------------ IfVecThenElse template HWY_API Vec256 IfVecThenElse(Vec256 mask, Vec256 yes, Vec256 no) { return IfThenElse(MaskFromVec(mask), yes, no); } // ------------------------------ Operator overloads (internal-only if float) template HWY_API Vec256 operator&(const Vec256 a, const Vec256 b) { return And(a, b); } template HWY_API Vec256 operator|(const Vec256 a, const Vec256 b) { return Or(a, b); } template HWY_API Vec256 operator^(const Vec256 a, const Vec256 b) { return Xor(a, b); } // ------------------------------ CopySign template HWY_API Vec256 CopySign(const Vec256 magn, const Vec256 sign) { static_assert(IsFloat(), "Only makes sense for floating-point"); const auto msb = SignBit(Full256()); return Or(AndNot(msb, magn), And(msb, sign)); } template HWY_API Vec256 CopySignToAbs(const Vec256 abs, const Vec256 sign) { static_assert(IsFloat(), "Only makes sense for floating-point"); return Or(abs, And(SignBit(Full256()), sign)); } // ------------------------------ BroadcastSignBit (compare) template HWY_API Vec256 BroadcastSignBit(const Vec256 v) { return ShiftRight(v); } HWY_API Vec256 BroadcastSignBit(const Vec256 v) { return VecFromMask(Full256(), v < Zero(Full256())); } // ------------------------------ Mask // Mask and Vec are the same (true = FF..FF). template HWY_API Mask256 MaskFromVec(const Vec256 v) { return Mask256{v.raw}; } template HWY_API Vec256 VecFromMask(Full256 /* tag */, Mask256 v) { return Vec256{v.raw}; } // mask ? yes : no template HWY_API Vec256 IfThenElse(Mask256 mask, Vec256 yes, Vec256 no) { return Vec256{wasm_v128_bitselect(yes.raw, no.raw, mask.raw)}; } // mask ? yes : 0 template HWY_API Vec256 IfThenElseZero(Mask256 mask, Vec256 yes) { return yes & VecFromMask(Full256(), mask); } // mask ? 0 : no template HWY_API Vec256 IfThenZeroElse(Mask256 mask, Vec256 no) { return AndNot(VecFromMask(Full256(), mask), no); } template HWY_API Vec256 < T IfNegativeThenElse(Vec256 v, Vec256 yes, Vec256 no) { HWY_ASSERT(0); } template HWY_API Vec256 ZeroIfNegative(Vec256 v) { const Full256 d; const auto zero = Zero(d); return IfThenElse(Mask256{(v > zero).raw}, v, zero); } // ------------------------------ Mask logical template HWY_API Mask256 Not(const Mask256 m) { return MaskFromVec(Not(VecFromMask(Full256(), m))); } template HWY_API Mask256 And(const Mask256 a, Mask256 b) { const Full256 d; return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); } template HWY_API Mask256 AndNot(const Mask256 a, Mask256 b) { const Full256 d; return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); } template HWY_API Mask256 Or(const Mask256 a, Mask256 b) { const Full256 d; return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); } template HWY_API Mask256 Xor(const Mask256 a, Mask256 b) { const Full256 d; return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); } // ------------------------------ Shl (BroadcastSignBit, IfThenElse) // The x86 multiply-by-Pow2() trick will not work because WASM saturates // float->int correctly to 2^31-1 (not 2^31). Because WASM's shifts take a // scalar count operand, per-lane shift instructions would require extract_lane // for each lane, and hoping that shuffle is correctly mapped to a native // instruction. Using non-vector shifts would incur a store-load forwarding // stall when loading the result vector. We instead test bits of the shift // count to "predicate" a shift of the entire vector by a constant. template HWY_API Vec256 operator<<(Vec256 v, const Vec256 bits) { const Full256 d; Mask256 mask; // Need a signed type for BroadcastSignBit. auto test = BitCast(RebindToSigned(), bits); // Move the highest valid bit of the shift count into the sign bit. test = ShiftLeft<12>(test); mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); test = ShiftLeft<1>(test); // next bit (descending order) v = IfThenElse(mask, ShiftLeft<8>(v), v); mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); test = ShiftLeft<1>(test); // next bit (descending order) v = IfThenElse(mask, ShiftLeft<4>(v), v); mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); test = ShiftLeft<1>(test); // next bit (descending order) v = IfThenElse(mask, ShiftLeft<2>(v), v); mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); return IfThenElse(mask, ShiftLeft<1>(v), v); } template HWY_API Vec256 operator<<(Vec256 v, const Vec256 bits) { const Full256 d; Mask256 mask; // Need a signed type for BroadcastSignBit. auto test = BitCast(RebindToSigned(), bits); // Move the highest valid bit of the shift count into the sign bit. test = ShiftLeft<27>(test); mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); test = ShiftLeft<1>(test); // next bit (descending order) v = IfThenElse(mask, ShiftLeft<16>(v), v); mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); test = ShiftLeft<1>(test); // next bit (descending order) v = IfThenElse(mask, ShiftLeft<8>(v), v); mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); test = ShiftLeft<1>(test); // next bit (descending order) v = IfThenElse(mask, ShiftLeft<4>(v), v); mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); test = ShiftLeft<1>(test); // next bit (descending order) v = IfThenElse(mask, ShiftLeft<2>(v), v); mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); return IfThenElse(mask, ShiftLeft<1>(v), v); } // ------------------------------ Shr (BroadcastSignBit, IfThenElse) template HWY_API Vec256 operator>>(Vec256 v, const Vec256 bits) { const Full256 d; Mask256 mask; // Need a signed type for BroadcastSignBit. auto test = BitCast(RebindToSigned(), bits); // Move the highest valid bit of the shift count into the sign bit. test = ShiftLeft<12>(test); mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); test = ShiftLeft<1>(test); // next bit (descending order) v = IfThenElse(mask, ShiftRight<8>(v), v); mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); test = ShiftLeft<1>(test); // next bit (descending order) v = IfThenElse(mask, ShiftRight<4>(v), v); mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); test = ShiftLeft<1>(test); // next bit (descending order) v = IfThenElse(mask, ShiftRight<2>(v), v); mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); return IfThenElse(mask, ShiftRight<1>(v), v); } template HWY_API Vec256 operator>>(Vec256 v, const Vec256 bits) { const Full256 d; Mask256 mask; // Need a signed type for BroadcastSignBit. auto test = BitCast(RebindToSigned(), bits); // Move the highest valid bit of the shift count into the sign bit. test = ShiftLeft<27>(test); mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); test = ShiftLeft<1>(test); // next bit (descending order) v = IfThenElse(mask, ShiftRight<16>(v), v); mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); test = ShiftLeft<1>(test); // next bit (descending order) v = IfThenElse(mask, ShiftRight<8>(v), v); mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); test = ShiftLeft<1>(test); // next bit (descending order) v = IfThenElse(mask, ShiftRight<4>(v), v); mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); test = ShiftLeft<1>(test); // next bit (descending order) v = IfThenElse(mask, ShiftRight<2>(v), v); mask = RebindMask(d, MaskFromVec(BroadcastSignBit(test))); return IfThenElse(mask, ShiftRight<1>(v), v); } // ================================================== MEMORY // ------------------------------ Load template HWY_API Vec256 Load(Full256 /* tag */, const T* HWY_RESTRICT aligned) { return Vec256{wasm_v128_load(aligned)}; } template HWY_API Vec256 MaskedLoad(Mask256 m, Full256 d, const T* HWY_RESTRICT aligned) { return IfThenElseZero(m, Load(d, aligned)); } // LoadU == Load. template HWY_API Vec256 LoadU(Full256 d, const T* HWY_RESTRICT p) { return Load(d, p); } // 128-bit SIMD => nothing to duplicate, same as an unaligned load. template HWY_API Vec256 LoadDup128(Full256 d, const T* HWY_RESTRICT p) { return Load(d, p); } // ------------------------------ Store template HWY_API void Store(Vec256 v, Full256 /* tag */, T* HWY_RESTRICT aligned) { wasm_v128_store(aligned, v.raw); } // StoreU == Store. template HWY_API void StoreU(Vec256 v, Full256 d, T* HWY_RESTRICT p) { Store(v, d, p); } // ------------------------------ Non-temporal stores // Same as aligned stores on non-x86. template HWY_API void Stream(Vec256 v, Full256 /* tag */, T* HWY_RESTRICT aligned) { wasm_v128_store(aligned, v.raw); } // ------------------------------ Scatter (Store) template HWY_API void ScatterOffset(Vec256 v, Full256 d, T* HWY_RESTRICT base, const Vec256 offset) { static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); alignas(32) T lanes[32 / sizeof(T)]; Store(v, d, lanes); alignas(32) Offset offset_lanes[32 / sizeof(T)]; Store(offset, Full256(), offset_lanes); uint8_t* base_bytes = reinterpret_cast(base); for (size_t i = 0; i < N; ++i) { CopyBytes(&lanes[i], base_bytes + offset_lanes[i]); } } template HWY_API void ScatterIndex(Vec256 v, Full256 d, T* HWY_RESTRICT base, const Vec256 index) { static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); alignas(32) T lanes[32 / sizeof(T)]; Store(v, d, lanes); alignas(32) Index index_lanes[32 / sizeof(T)]; Store(index, Full256(), index_lanes); for (size_t i = 0; i < N; ++i) { base[index_lanes[i]] = lanes[i]; } } // ------------------------------ Gather (Load/Store) template HWY_API Vec256 GatherOffset(const Full256 d, const T* HWY_RESTRICT base, const Vec256 offset) { static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); alignas(32) Offset offset_lanes[32 / sizeof(T)]; Store(offset, Full256(), offset_lanes); alignas(32) T lanes[32 / sizeof(T)]; const uint8_t* base_bytes = reinterpret_cast(base); for (size_t i = 0; i < N; ++i) { CopyBytes(base_bytes + offset_lanes[i], &lanes[i]); } return Load(d, lanes); } template HWY_API Vec256 GatherIndex(const Full256 d, const T* HWY_RESTRICT base, const Vec256 index) { static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); alignas(32) Index index_lanes[32 / sizeof(T)]; Store(index, Full256(), index_lanes); alignas(32) T lanes[32 / sizeof(T)]; for (size_t i = 0; i < N; ++i) { lanes[i] = base[index_lanes[i]]; } return Load(d, lanes); } // ================================================== SWIZZLE // ------------------------------ Extract lane // Gets the single value stored in a vector/part. HWY_API uint8_t GetLane(const Vec256 v) { return wasm_i8x16_extract_lane(v.raw, 0); } HWY_API int8_t GetLane(const Vec256 v) { return wasm_i8x16_extract_lane(v.raw, 0); } HWY_API uint16_t GetLane(const Vec256 v) { return wasm_i16x8_extract_lane(v.raw, 0); } HWY_API int16_t GetLane(const Vec256 v) { return wasm_i16x8_extract_lane(v.raw, 0); } HWY_API uint32_t GetLane(const Vec256 v) { return wasm_i32x4_extract_lane(v.raw, 0); } HWY_API int32_t GetLane(const Vec256 v) { return wasm_i32x4_extract_lane(v.raw, 0); } HWY_API uint64_t GetLane(const Vec256 v) { return wasm_i64x2_extract_lane(v.raw, 0); } HWY_API int64_t GetLane(const Vec256 v) { return wasm_i64x2_extract_lane(v.raw, 0); } HWY_API float GetLane(const Vec256 v) { return wasm_f32x4_extract_lane(v.raw, 0); } // ------------------------------ LowerHalf template HWY_API Vec128 LowerHalf(Full128 /* tag */, Vec256 v) { return Vec128{v.raw}; } template HWY_API Vec128 LowerHalf(Vec256 v) { return LowerHalf(Full128(), v); } // ------------------------------ ShiftLeftBytes // 0x01..0F, kBytes = 1 => 0x02..0F00 template HWY_API Vec256 ShiftLeftBytes(Full256 /* tag */, Vec256 v) { static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); const __i8x16 zero = wasm_i8x16_splat(0); switch (kBytes) { case 0: return v; case 1: return Vec256{wasm_i8x16_shuffle(v.raw, zero, 16, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14)}; case 2: return Vec256{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13)}; case 3: return Vec256{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12)}; case 4: return Vec256{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11)}; case 5: return Vec256{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)}; case 6: return Vec256{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9)}; case 7: return Vec256{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, 16, 0, 1, 2, 3, 4, 5, 6, 7, 8)}; case 8: return Vec256{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, 16, 16, 0, 1, 2, 3, 4, 5, 6, 7)}; case 9: return Vec256{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, 16, 16, 16, 0, 1, 2, 3, 4, 5, 6)}; case 10: return Vec256{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 0, 1, 2, 3, 4, 5)}; case 11: return Vec256{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 0, 1, 2, 3, 4)}; case 12: return Vec256{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 0, 1, 2, 3)}; case 13: return Vec256{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 0, 1, 2)}; case 14: return Vec256{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 0, 1)}; case 15: return Vec256{wasm_i8x16_shuffle(v.raw, zero, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 0)}; } return Vec256{zero}; } template HWY_API Vec256 ShiftLeftBytes(Vec256 v) { return ShiftLeftBytes(Full256(), v); } // ------------------------------ ShiftLeftLanes template HWY_API Vec256 ShiftLeftLanes(Full256 d, const Vec256 v) { const Repartition d8; return BitCast(d, ShiftLeftBytes(BitCast(d8, v))); } template HWY_API Vec256 ShiftLeftLanes(const Vec256 v) { return ShiftLeftLanes(Full256(), v); } // ------------------------------ ShiftRightBytes namespace detail { // Helper function allows zeroing invalid lanes in caller. template HWY_API __i8x16 ShrBytes(const Vec256 v) { static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); const __i8x16 zero = wasm_i8x16_splat(0); switch (kBytes) { case 0: return v.raw; case 1: return wasm_i8x16_shuffle(v.raw, zero, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16); case 2: return wasm_i8x16_shuffle(v.raw, zero, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 16); case 3: return wasm_i8x16_shuffle(v.raw, zero, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 16, 16); case 4: return wasm_i8x16_shuffle(v.raw, zero, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 16, 16, 16); case 5: return wasm_i8x16_shuffle(v.raw, zero, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 16, 16, 16, 16); case 6: return wasm_i8x16_shuffle(v.raw, zero, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 16, 16, 16, 16, 16); case 7: return wasm_i8x16_shuffle(v.raw, zero, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 16, 16, 16, 16, 16, 16); case 8: return wasm_i8x16_shuffle(v.raw, zero, 8, 9, 10, 11, 12, 13, 14, 15, 16, 16, 16, 16, 16, 16, 16, 16); case 9: return wasm_i8x16_shuffle(v.raw, zero, 9, 10, 11, 12, 13, 14, 15, 16, 16, 16, 16, 16, 16, 16, 16, 16); case 10: return wasm_i8x16_shuffle(v.raw, zero, 10, 11, 12, 13, 14, 15, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16); case 11: return wasm_i8x16_shuffle(v.raw, zero, 11, 12, 13, 14, 15, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16); case 12: return wasm_i8x16_shuffle(v.raw, zero, 12, 13, 14, 15, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16); case 13: return wasm_i8x16_shuffle(v.raw, zero, 13, 14, 15, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16); case 14: return wasm_i8x16_shuffle(v.raw, zero, 14, 15, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16); case 15: return wasm_i8x16_shuffle(v.raw, zero, 15, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16); case 16: return zero; } } } // namespace detail // 0x01..0F, kBytes = 1 => 0x0001..0E template HWY_API Vec256 ShiftRightBytes(Full256 /* tag */, Vec256 v) { return Vec256{detail::ShrBytes(v)}; } // ------------------------------ ShiftRightLanes template HWY_API Vec256 ShiftRightLanes(Full256 d, const Vec256 v) { const Repartition d8; return BitCast(d, ShiftRightBytes(BitCast(d8, v))); } // ------------------------------ UpperHalf (ShiftRightBytes) // Full input: copy hi into lo (smaller instruction encoding than shifts). template HWY_API Vec128 UpperHalf(Full128 /* tag */, const Vec256 v) { return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 2, 3, 2, 3)}; } HWY_API Vec128 UpperHalf(Full128 /* tag */, const Vec128 v) { return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 2, 3, 2, 3)}; } // ------------------------------ CombineShiftRightBytes template > HWY_API V CombineShiftRightBytes(Full256 /* tag */, V hi, V lo) { static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); switch (kBytes) { case 0: return lo; case 1: return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16)}; case 2: return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17)}; case 3: return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18)}; case 4: return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19)}; case 5: return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)}; case 6: return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21)}; case 7: return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22)}; case 8: return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23)}; case 9: return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24)}; case 10: return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25)}; case 11: return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26)}; case 12: return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27)}; case 13: return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28)}; case 14: return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29)}; case 15: return V{wasm_i8x16_shuffle(lo.raw, hi.raw, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30)}; } return hi; } // ------------------------------ Broadcast/splat any lane // Unsigned template HWY_API Vec256 Broadcast(const Vec256 v) { static_assert(0 <= kLane && kLane < N, "Invalid lane"); return Vec256{wasm_i16x8_shuffle( v.raw, v.raw, kLane, kLane, kLane, kLane, kLane, kLane, kLane, kLane)}; } template HWY_API Vec256 Broadcast(const Vec256 v) { static_assert(0 <= kLane && kLane < N, "Invalid lane"); return Vec256{ wasm_i32x4_shuffle(v.raw, v.raw, kLane, kLane, kLane, kLane)}; } // Signed template HWY_API Vec256 Broadcast(const Vec256 v) { static_assert(0 <= kLane && kLane < N, "Invalid lane"); return Vec256{wasm_i16x8_shuffle(v.raw, v.raw, kLane, kLane, kLane, kLane, kLane, kLane, kLane, kLane)}; } template HWY_API Vec256 Broadcast(const Vec256 v) { static_assert(0 <= kLane && kLane < N, "Invalid lane"); return Vec256{ wasm_i32x4_shuffle(v.raw, v.raw, kLane, kLane, kLane, kLane)}; } // Float template HWY_API Vec256 Broadcast(const Vec256 v) { static_assert(0 <= kLane && kLane < N, "Invalid lane"); return Vec256{ wasm_i32x4_shuffle(v.raw, v.raw, kLane, kLane, kLane, kLane)}; } // ------------------------------ TableLookupBytes // Returns vector of bytes[from[i]]. "from" is also interpreted as bytes, i.e. // lane indices in [0, 16). template HWY_API Vec256 TableLookupBytes(const Vec256 bytes, const Vec256 from) { // Not yet available in all engines, see // https://github.com/WebAssembly/simd/blob/bdcc304b2d379f4601c2c44ea9b44ed9484fde7e/proposals/simd/ImplementationStatus.md // V8 implementation of this had a bug, fixed on 2021-04-03: // https://chromium-review.googlesource.com/c/v8/v8/+/2822951 #if 0 return Vec256{wasm_i8x16_swizzle(bytes.raw, from.raw)}; #else alignas(32) uint8_t control[16]; alignas(32) uint8_t input[16]; alignas(32) uint8_t output[16]; wasm_v128_store(control, from.raw); wasm_v128_store(input, bytes.raw); for (size_t i = 0; i < 16; ++i) { output[i] = control[i] < 16 ? input[control[i]] : 0; } return Vec256{wasm_v128_load(output)}; #endif } template HWY_API Vec256 TableLookupBytesOr0(const Vec256 bytes, const Vec256 from) { const Full256 d; // Mask size must match vector type, so cast everything to this type. Repartition di8; Repartition> d_bytes8; const auto msb = BitCast(di8, from) < Zero(di8); const auto lookup = TableLookupBytes(BitCast(d_bytes8, bytes), BitCast(di8, from)); return BitCast(d, IfThenZeroElse(msb, lookup)); } // ------------------------------ Hard-coded shuffles // Notation: let Vec128 have lanes 3,2,1,0 (0 is least-significant). // Shuffle0321 rotates one lane to the right (the previous least-significant // lane is now most-significant). These could also be implemented via // CombineShiftRightBytes but the shuffle_abcd notation is more convenient. // Swap 32-bit halves in 64-bit halves. HWY_API Vec128 Shuffle2301(const Vec128 v) { return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 1, 0, 3, 2)}; } HWY_API Vec128 Shuffle2301(const Vec128 v) { return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 1, 0, 3, 2)}; } HWY_API Vec128 Shuffle2301(const Vec128 v) { return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 1, 0, 3, 2)}; } // Swap 64-bit halves HWY_API Vec128 Shuffle1032(const Vec128 v) { return Vec128{wasm_i64x2_shuffle(v.raw, v.raw, 1, 0)}; } HWY_API Vec128 Shuffle1032(const Vec128 v) { return Vec128{wasm_i64x2_shuffle(v.raw, v.raw, 1, 0)}; } HWY_API Vec128 Shuffle1032(const Vec128 v) { return Vec128{wasm_i64x2_shuffle(v.raw, v.raw, 1, 0)}; } // Rotate right 32 bits HWY_API Vec128 Shuffle0321(const Vec128 v) { return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 1, 2, 3, 0)}; } HWY_API Vec128 Shuffle0321(const Vec128 v) { return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 1, 2, 3, 0)}; } HWY_API Vec128 Shuffle0321(const Vec128 v) { return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 1, 2, 3, 0)}; } // Rotate left 32 bits HWY_API Vec128 Shuffle2103(const Vec128 v) { return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 3, 0, 1, 2)}; } HWY_API Vec128 Shuffle2103(const Vec128 v) { return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 3, 0, 1, 2)}; } HWY_API Vec128 Shuffle2103(const Vec128 v) { return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 3, 0, 1, 2)}; } // Reverse HWY_API Vec128 Shuffle0123(const Vec128 v) { return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 3, 2, 1, 0)}; } HWY_API Vec128 Shuffle0123(const Vec128 v) { return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 3, 2, 1, 0)}; } HWY_API Vec128 Shuffle0123(const Vec128 v) { return Vec128{wasm_i32x4_shuffle(v.raw, v.raw, 3, 2, 1, 0)}; } // ------------------------------ TableLookupLanes // Returned by SetTableIndices for use by TableLookupLanes. template struct Indices256 { __v128_u raw; }; template HWY_API Indices256 IndicesFromVec(Full256 d, Vec256 vec) { static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); return Indices256{}; } template HWY_API Indices256 SetTableIndices(Full256 d, const TI* idx) { const Rebind di; return IndicesFromVec(d, LoadU(di, idx)); } template HWY_API Vec256 TableLookupLanes(Vec256 v, Indices256 idx) { using TI = MakeSigned; const Full256 d; const Full256 di; return BitCast(d, TableLookupBytes(BitCast(di, v), Vec256{idx.raw})); } // ------------------------------ Reverse (Shuffle0123, Shuffle2301, Shuffle01) template HWY_API Vec256 Reverse(Full256 /* tag */, const Vec256 v) { return Shuffle01(v); } // Four lanes: shuffle template HWY_API Vec256 Reverse(Full256 /* tag */, const Vec256 v) { return Shuffle0123(v); } // 16-bit template HWY_API Vec256 Reverse(Full256 d, const Vec256 v) { const RepartitionToWide> du32; return BitCast(d, RotateRight<16>(Reverse(du32, BitCast(du32, v)))); } // ------------------------------ Reverse2 template HWY_API Vec256 Reverse2(Full256 d, const Vec256 v) { HWY_ASSERT(0); } // ------------------------------ Reverse4 template HWY_API Vec256 Reverse4(Full256 d, const Vec256 v) { HWY_ASSERT(0); } // ------------------------------ Reverse8 template HWY_API Vec256 Reverse8(Full256 d, const Vec256 v) { HWY_ASSERT(0); } // ------------------------------ InterleaveLower HWY_API Vec256 InterleaveLower(Vec256 a, Vec256 b) { return Vec256{wasm_i8x16_shuffle(a.raw, b.raw, 0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23)}; } HWY_API Vec256 InterleaveLower(Vec256 a, Vec256 b) { return Vec256{ wasm_i16x8_shuffle(a.raw, b.raw, 0, 8, 1, 9, 2, 10, 3, 11)}; } HWY_API Vec256 InterleaveLower(Vec256 a, Vec256 b) { return Vec256{wasm_i32x4_shuffle(a.raw, b.raw, 0, 4, 1, 5)}; } HWY_API Vec256 InterleaveLower(Vec256 a, Vec256 b) { return Vec256{wasm_i64x2_shuffle(a.raw, b.raw, 0, 2)}; } HWY_API Vec256 InterleaveLower(Vec256 a, Vec256 b) { return Vec256{wasm_i8x16_shuffle(a.raw, b.raw, 0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23)}; } HWY_API Vec256 InterleaveLower(Vec256 a, Vec256 b) { return Vec256{ wasm_i16x8_shuffle(a.raw, b.raw, 0, 8, 1, 9, 2, 10, 3, 11)}; } HWY_API Vec256 InterleaveLower(Vec256 a, Vec256 b) { return Vec256{wasm_i32x4_shuffle(a.raw, b.raw, 0, 4, 1, 5)}; } HWY_API Vec256 InterleaveLower(Vec256 a, Vec256 b) { return Vec256{wasm_i64x2_shuffle(a.raw, b.raw, 0, 2)}; } HWY_API Vec256 InterleaveLower(Vec256 a, Vec256 b) { return Vec256{wasm_i32x4_shuffle(a.raw, b.raw, 0, 4, 1, 5)}; } // Additional overload for the optional tag. template > HWY_API V InterleaveLower(Full256 /* tag */, V a, V b) { return InterleaveLower(a, b); } // ------------------------------ InterleaveUpper (UpperHalf) // All functions inside detail lack the required D parameter. namespace detail { HWY_API Vec256 InterleaveUpper(Vec256 a, Vec256 b) { return Vec256{wasm_i8x16_shuffle(a.raw, b.raw, 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31)}; } HWY_API Vec256 InterleaveUpper(Vec256 a, Vec256 b) { return Vec256{ wasm_i16x8_shuffle(a.raw, b.raw, 4, 12, 5, 13, 6, 14, 7, 15)}; } HWY_API Vec256 InterleaveUpper(Vec256 a, Vec256 b) { return Vec256{wasm_i32x4_shuffle(a.raw, b.raw, 2, 6, 3, 7)}; } HWY_API Vec256 InterleaveUpper(Vec256 a, Vec256 b) { return Vec256{wasm_i64x2_shuffle(a.raw, b.raw, 1, 3)}; } HWY_API Vec256 InterleaveUpper(Vec256 a, Vec256 b) { return Vec256{wasm_i8x16_shuffle(a.raw, b.raw, 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31)}; } HWY_API Vec256 InterleaveUpper(Vec256 a, Vec256 b) { return Vec256{ wasm_i16x8_shuffle(a.raw, b.raw, 4, 12, 5, 13, 6, 14, 7, 15)}; } HWY_API Vec256 InterleaveUpper(Vec256 a, Vec256 b) { return Vec256{wasm_i32x4_shuffle(a.raw, b.raw, 2, 6, 3, 7)}; } HWY_API Vec256 InterleaveUpper(Vec256 a, Vec256 b) { return Vec256{wasm_i64x2_shuffle(a.raw, b.raw, 1, 3)}; } HWY_API Vec256 InterleaveUpper(Vec256 a, Vec256 b) { return Vec256{wasm_i32x4_shuffle(a.raw, b.raw, 2, 6, 3, 7)}; } } // namespace detail template > HWY_API V InterleaveUpper(Full256 /* tag */, V a, V b) { return detail::InterleaveUpper(a, b); } // ------------------------------ ZipLower/ZipUpper (InterleaveLower) // Same as Interleave*, except that the return lanes are double-width integers; // this is necessary because the single-lane scalar cannot return two values. template >> HWY_API VFromD ZipLower(Vec256 a, Vec256 b) { return BitCast(DW(), InterleaveLower(a, b)); } template , class DW = RepartitionToWide> HWY_API VFromD ZipLower(DW dw, Vec256 a, Vec256 b) { return BitCast(dw, InterleaveLower(D(), a, b)); } template , class DW = RepartitionToWide> HWY_API VFromD ZipUpper(DW dw, Vec256 a, Vec256 b) { return BitCast(dw, InterleaveUpper(D(), a, b)); } // ================================================== COMBINE // ------------------------------ Combine (InterleaveLower) // N = N/2 + N/2 (upper half undefined) template HWY_API Vec256 Combine(Full256 d, Vec128 hi_half, Vec128 lo_half) { const Half d2; const RebindToUnsigned du2; // Treat half-width input as one lane, and expand to two lanes. using VU = Vec128, 2>; const VU lo{BitCast(du2, lo_half).raw}; const VU hi{BitCast(du2, hi_half).raw}; return BitCast(d, InterleaveLower(lo, hi)); } // ------------------------------ ZeroExtendVector (Combine, IfThenElseZero) template HWY_API Vec256 ZeroExtendVector(Full256 d, Vec128 lo) { return IfThenElseZero(FirstN(d, 16 / sizeof(T)), Vec256{lo.raw}); } // ------------------------------ ConcatLowerLower // hiH,hiL loH,loL |-> hiL,loL (= lower halves) template HWY_API Vec256 ConcatLowerLower(Full256 /* tag */, const Vec256 hi, const Vec256 lo) { return Vec256{wasm_i64x2_shuffle(lo.raw, hi.raw, 0, 2)}; } // ------------------------------ ConcatUpperUpper template HWY_API Vec256 ConcatUpperUpper(Full256 /* tag */, const Vec256 hi, const Vec256 lo) { return Vec256{wasm_i64x2_shuffle(lo.raw, hi.raw, 1, 3)}; } // ------------------------------ ConcatLowerUpper template HWY_API Vec256 ConcatLowerUpper(Full256 d, const Vec256 hi, const Vec256 lo) { return CombineShiftRightBytes<8>(d, hi, lo); } // ------------------------------ ConcatUpperLower template HWY_API Vec256 ConcatUpperLower(Full256 d, const Vec256 hi, const Vec256 lo) { return IfThenElse(FirstN(d, Lanes(d) / 2), lo, hi); } // ------------------------------ ConcatOdd // 32-bit template HWY_API Vec256 ConcatOdd(Full256 /* tag */, Vec256 hi, Vec256 lo) { return Vec256{wasm_i32x4_shuffle(lo.raw, hi.raw, 1, 3, 5, 7)}; } // 64-bit full - no partial because we need at least two inputs to have // even/odd. template HWY_API Vec256 ConcatOdd(Full256 /* tag */, Vec256 hi, Vec256 lo) { return InterleaveUpper(Full256(), lo, hi); } // ------------------------------ ConcatEven (InterleaveLower) // 32-bit full template HWY_API Vec256 ConcatEven(Full256 /* tag */, Vec256 hi, Vec256 lo) { return Vec256{wasm_i32x4_shuffle(lo.raw, hi.raw, 0, 2, 4, 6)}; } // 64-bit full - no partial because we need at least two inputs to have // even/odd. template HWY_API Vec256 ConcatEven(Full256 /* tag */, Vec256 hi, Vec256 lo) { return InterleaveLower(Full256(), lo, hi); } // ------------------------------ DupEven template HWY_API Vec256 DupEven(Vec256 v) { HWY_ASSERT(0); } // ------------------------------ DupOdd template HWY_API Vec256 DupOdd(Vec256 v) { HWY_ASSERT(0); } // ------------------------------ OddEven namespace detail { template HWY_INLINE Vec256 OddEven(hwy::SizeTag<1> /* tag */, const Vec256 a, const Vec256 b) { const Full256 d; const Repartition d8; alignas(32) constexpr uint8_t mask[16] = {0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0}; return IfThenElse(MaskFromVec(BitCast(d, Load(d8, mask))), b, a); } template HWY_INLINE Vec256 OddEven(hwy::SizeTag<2> /* tag */, const Vec256 a, const Vec256 b) { return Vec256{wasm_i16x8_shuffle(a.raw, b.raw, 8, 1, 10, 3, 12, 5, 14, 7)}; } template HWY_INLINE Vec256 OddEven(hwy::SizeTag<4> /* tag */, const Vec256 a, const Vec256 b) { return Vec256{wasm_i32x4_shuffle(a.raw, b.raw, 4, 1, 6, 3)}; } template HWY_INLINE Vec256 OddEven(hwy::SizeTag<8> /* tag */, const Vec256 a, const Vec256 b) { return Vec256{wasm_i64x2_shuffle(a.raw, b.raw, 2, 1)}; } } // namespace detail template HWY_API Vec256 OddEven(const Vec256 a, const Vec256 b) { return detail::OddEven(hwy::SizeTag(), a, b); } HWY_API Vec256 OddEven(const Vec256 a, const Vec256 b) { return Vec256{wasm_i32x4_shuffle(a.raw, b.raw, 4, 1, 6, 3)}; } // ------------------------------ OddEvenBlocks template HWY_API Vec256 OddEvenBlocks(Vec256 /* odd */, Vec256 even) { return even; } // ------------------------------ SwapAdjacentBlocks template HWY_API Vec256 SwapAdjacentBlocks(Vec256 v) { return v; } // ------------------------------ ReverseBlocks template HWY_API Vec256 ReverseBlocks(Full256 /* tag */, const Vec256 v) { return v; } // ================================================== CONVERT // ------------------------------ Promotions (part w/ narrow lanes -> full) // Unsigned: zero-extend. HWY_API Vec256 PromoteTo(Full256 /* tag */, const Vec128 v) { return Vec256{wasm_u16x8_extend_low_u8x16(v.raw)}; } HWY_API Vec256 PromoteTo(Full256 /* tag */, const Vec128 v) { return Vec256{ wasm_u32x4_extend_low_u16x8(wasm_u16x8_extend_low_u8x16(v.raw))}; } HWY_API Vec256 PromoteTo(Full256 /* tag */, const Vec128 v) { return Vec256{wasm_u16x8_extend_low_u8x16(v.raw)}; } HWY_API Vec256 PromoteTo(Full256 /* tag */, const Vec128 v) { return Vec256{ wasm_u32x4_extend_low_u16x8(wasm_u16x8_extend_low_u8x16(v.raw))}; } HWY_API Vec256 PromoteTo(Full256 /* tag */, const Vec128 v) { return Vec256{wasm_u32x4_extend_low_u16x8(v.raw)}; } HWY_API Vec256 PromoteTo(Full256 /* tag */, const Vec128 v) { return Vec256{wasm_u32x4_extend_low_u16x8(v.raw)}; } // Signed: replicate sign bit. HWY_API Vec256 PromoteTo(Full256 /* tag */, const Vec128 v) { return Vec256{wasm_i16x8_extend_low_i8x16(v.raw)}; } HWY_API Vec256 PromoteTo(Full256 /* tag */, const Vec128 v) { return Vec256{ wasm_i32x4_extend_low_i16x8(wasm_i16x8_extend_low_i8x16(v.raw))}; } HWY_API Vec256 PromoteTo(Full256 /* tag */, const Vec128 v) { return Vec256{wasm_i32x4_extend_low_i16x8(v.raw)}; } HWY_API Vec256 PromoteTo(Full256 /* tag */, const Vec128 v) { return Vec256{wasm_f64x2_convert_low_i32x4(v.raw)}; } HWY_API Vec256 PromoteTo(Full256 /* tag */, const Vec128 v) { const Full256 di32; const Full256 du32; const Full256 df32; // Expand to u32 so we can shift. const auto bits16 = PromoteTo(du32, Vec256{v.raw}); const auto sign = ShiftRight<15>(bits16); const auto biased_exp = ShiftRight<10>(bits16) & Set(du32, 0x1F); const auto mantissa = bits16 & Set(du32, 0x3FF); const auto subnormal = BitCast(du32, ConvertTo(df32, BitCast(di32, mantissa)) * Set(df32, 1.0f / 16384 / 1024)); const auto biased_exp32 = biased_exp + Set(du32, 127 - 15); const auto mantissa32 = ShiftLeft<23 - 10>(mantissa); const auto normal = ShiftLeft<23>(biased_exp32) | mantissa32; const auto bits32 = IfThenElse(biased_exp == Zero(du32), subnormal, normal); return BitCast(df32, ShiftLeft<31>(sign) | bits32); } HWY_API Vec256 PromoteTo(Full256 df32, const Vec128 v) { const Rebind du16; const RebindToSigned di32; return BitCast(df32, ShiftLeft<16>(PromoteTo(di32, BitCast(du16, v)))); } // ------------------------------ Demotions (full -> part w/ narrow lanes) HWY_API Vec128 DemoteTo(Full128 /* tag */, const Vec256 v) { return Vec128{wasm_u16x8_narrow_i32x4(v.raw, v.raw)}; } HWY_API Vec128 DemoteTo(Full128 /* tag */, const Vec256 v) { return Vec128{wasm_i16x8_narrow_i32x4(v.raw, v.raw)}; } HWY_API Vec128 DemoteTo(Full128 /* tag */, const Vec256 v) { const auto intermediate = wasm_i16x8_narrow_i32x4(v.raw, v.raw); return Vec128{wasm_u8x16_narrow_i16x8(intermediate, intermediate)}; } HWY_API Vec128 DemoteTo(Full128 /* tag */, const Vec256 v) { return Vec128{wasm_u8x16_narrow_i16x8(v.raw, v.raw)}; } HWY_API Vec128 DemoteTo(Full128 /* tag */, const Vec256 v) { const auto intermediate = wasm_i16x8_narrow_i32x4(v.raw, v.raw); return Vec128{wasm_i8x16_narrow_i16x8(intermediate, intermediate)}; } HWY_API Vec128 DemoteTo(Full128 /* tag */, const Vec256 v) { return Vec128{wasm_i8x16_narrow_i16x8(v.raw, v.raw)}; } HWY_API Vec128 DemoteTo(Full128 /* di */, const Vec256 v) { return Vec128{wasm_i32x4_trunc_sat_f64x2_zero(v.raw)}; } HWY_API Vec128 DemoteTo(Full128 /* tag */, const Vec256 v) { const Full256 di; const Full256 du; const Full256 du16; const auto bits32 = BitCast(du, v); const auto sign = ShiftRight<31>(bits32); const auto biased_exp32 = ShiftRight<23>(bits32) & Set(du, 0xFF); const auto mantissa32 = bits32 & Set(du, 0x7FFFFF); const auto k15 = Set(di, 15); const auto exp = Min(BitCast(di, biased_exp32) - Set(di, 127), k15); const auto is_tiny = exp < Set(di, -24); const auto is_subnormal = exp < Set(di, -14); const auto biased_exp16 = BitCast(du, IfThenZeroElse(is_subnormal, exp + k15)); const auto sub_exp = BitCast(du, Set(di, -14) - exp); // [1, 11) const auto sub_m = (Set(du, 1) << (Set(du, 10) - sub_exp)) + (mantissa32 >> (Set(du, 13) + sub_exp)); const auto mantissa16 = IfThenElse(RebindMask(du, is_subnormal), sub_m, ShiftRight<13>(mantissa32)); // <1024 const auto sign16 = ShiftLeft<15>(sign); const auto normal16 = sign16 | ShiftLeft<10>(biased_exp16) | mantissa16; const auto bits16 = IfThenZeroElse(is_tiny, BitCast(di, normal16)); return Vec128{DemoteTo(du16, bits16).raw}; } HWY_API Vec128 DemoteTo(Full128 dbf16, const Vec256 v) { const Rebind di32; const Rebind du32; // for logical shift right const Rebind du16; const auto bits_in_32 = BitCast(di32, ShiftRight<16>(BitCast(du32, v))); return BitCast(dbf16, DemoteTo(du16, bits_in_32)); } HWY_API Vec128 ReorderDemote2To(Full128 dbf16, Vec256 a, Vec256 b) { const RebindToUnsigned du16; const Repartition du32; const Vec256 b_in_even = ShiftRight<16>(BitCast(du32, b)); return BitCast(dbf16, OddEven(BitCast(du16, a), BitCast(du16, b_in_even))); } // For already range-limited input [0, 255]. HWY_API Vec256 U8FromU32(const Vec256 v) { const auto intermediate = wasm_i16x8_narrow_i32x4(v.raw, v.raw); return Vec256{wasm_u8x16_narrow_i16x8(intermediate, intermediate)}; } // ------------------------------ Convert i32 <=> f32 (Round) HWY_API Vec256 ConvertTo(Full256 /* tag */, const Vec256 v) { return Vec256{wasm_f32x4_convert_i32x4(v.raw)}; } // Truncates (rounds toward zero). HWY_API Vec256 ConvertTo(Full256 /* tag */, const Vec256 v) { return Vec256{wasm_i32x4_trunc_sat_f32x4(v.raw)}; } HWY_API Vec256 NearestInt(const Vec256 v) { return ConvertTo(Full256(), Round(v)); } // ================================================== MISC // ------------------------------ LoadMaskBits (TestBit) namespace detail { template HWY_INLINE Mask256 LoadMaskBits(Full256 d, uint64_t bits) { const RebindToUnsigned du; // Easier than Set(), which would require an >8-bit type, which would not // compile for T=uint8_t, N=1. const Vec256 vbits{wasm_i32x4_splat(static_cast(bits))}; // Replicate bytes 8x such that each byte contains the bit that governs it. alignas(32) constexpr uint8_t kRep8[16] = {0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1}; const auto rep8 = TableLookupBytes(vbits, Load(du, kRep8)); alignas(32) constexpr uint8_t kBit[16] = {1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128}; return RebindMask(d, TestBit(rep8, LoadDup128(du, kBit))); } template HWY_INLINE Mask256 LoadMaskBits(Full256 d, uint64_t bits) { const RebindToUnsigned du; alignas(32) constexpr uint16_t kBit[8] = {1, 2, 4, 8, 16, 32, 64, 128}; return RebindMask(d, TestBit(Set(du, bits), Load(du, kBit))); } template HWY_INLINE Mask256 LoadMaskBits(Full256 d, uint64_t bits) { const RebindToUnsigned du; alignas(32) constexpr uint32_t kBit[8] = {1, 2, 4, 8}; return RebindMask(d, TestBit(Set(du, bits), Load(du, kBit))); } template HWY_INLINE Mask256 LoadMaskBits(Full256 d, uint64_t bits) { const RebindToUnsigned du; alignas(32) constexpr uint64_t kBit[8] = {1, 2}; return RebindMask(d, TestBit(Set(du, bits), Load(du, kBit))); } } // namespace detail // `p` points to at least 8 readable bytes, not all of which need be valid. template HWY_API Mask256 LoadMaskBits(Full256 d, const uint8_t* HWY_RESTRICT bits) { uint64_t mask_bits = 0; CopyBytes<(N + 7) / 8>(bits, &mask_bits); return detail::LoadMaskBits(d, mask_bits); } // ------------------------------ Mask namespace detail { // Full template HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<1> /*tag*/, const Mask128 mask) { alignas(32) uint64_t lanes[2]; wasm_v128_store(lanes, mask.raw); constexpr uint64_t kMagic = 0x103070F1F3F80ULL; const uint64_t lo = ((lanes[0] * kMagic) >> 56); const uint64_t hi = ((lanes[1] * kMagic) >> 48) & 0xFF00; return (hi + lo); } template HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<2> /*tag*/, const Mask256 mask) { // Remove useless lower half of each u16 while preserving the sign bit. const __i16x8 zero = wasm_i16x8_splat(0); const Mask256 mask8{wasm_i8x16_narrow_i16x8(mask.raw, zero)}; return BitsFromMask(hwy::SizeTag<1>(), mask8); } template HWY_INLINE uint64_t BitsFromMask(hwy::SizeTag<4> /*tag*/, const Mask256 mask) { const __i32x4 mask_i = static_cast<__i32x4>(mask.raw); const __i32x4 slice = wasm_i32x4_make(1, 2, 4, 8); const __i32x4 sliced_mask = wasm_v128_and(mask_i, slice); alignas(32) uint32_t lanes[4]; wasm_v128_store(lanes, sliced_mask); return lanes[0] | lanes[1] | lanes[2] | lanes[3]; } // Returns 0xFF for bytes with index >= N, otherwise 0. constexpr __i8x16 BytesAbove() { return /**/ (N == 0) ? wasm_i32x4_make(-1, -1, -1, -1) : (N == 4) ? wasm_i32x4_make(0, -1, -1, -1) : (N == 8) ? wasm_i32x4_make(0, 0, -1, -1) : (N == 12) ? wasm_i32x4_make(0, 0, 0, -1) : (N == 16) ? wasm_i32x4_make(0, 0, 0, 0) : (N == 2) ? wasm_i16x8_make(0, -1, -1, -1, -1, -1, -1, -1) : (N == 6) ? wasm_i16x8_make(0, 0, 0, -1, -1, -1, -1, -1) : (N == 10) ? wasm_i16x8_make(0, 0, 0, 0, 0, -1, -1, -1) : (N == 14) ? wasm_i16x8_make(0, 0, 0, 0, 0, 0, 0, -1) : (N == 1) ? wasm_i8x16_make(0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1) : (N == 3) ? wasm_i8x16_make(0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1) : (N == 5) ? wasm_i8x16_make(0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1) : (N == 7) ? wasm_i8x16_make(0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1) : (N == 9) ? wasm_i8x16_make(0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1, -1, -1) : (N == 11) ? wasm_i8x16_make(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1, -1, -1) : (N == 13) ? wasm_i8x16_make(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, -1) : wasm_i8x16_make(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1); } template HWY_INLINE uint64_t BitsFromMask(const Mask256 mask) { return BitsFromMask(hwy::SizeTag(), mask); } template HWY_INLINE size_t CountTrue(hwy::SizeTag<1> tag, const Mask128 m) { return PopCount(BitsFromMask(tag, m)); } template HWY_INLINE size_t CountTrue(hwy::SizeTag<2> tag, const Mask128 m) { return PopCount(BitsFromMask(tag, m)); } template HWY_INLINE size_t CountTrue(hwy::SizeTag<4> /*tag*/, const Mask128 m) { const __i32x4 var_shift = wasm_i32x4_make(1, 2, 4, 8); const __i32x4 shifted_bits = wasm_v128_and(m.raw, var_shift); alignas(32) uint64_t lanes[2]; wasm_v128_store(lanes, shifted_bits); return PopCount(lanes[0] | lanes[1]); } } // namespace detail // `p` points to at least 8 writable bytes. template HWY_API size_t StoreMaskBits(const Full256 /* tag */, const Mask256 mask, uint8_t* bits) { const uint64_t mask_bits = detail::BitsFromMask(mask); const size_t kNumBytes = (N + 7) / 8; CopyBytes(&mask_bits, bits); return kNumBytes; } template HWY_API size_t CountTrue(const Full256 /* tag */, const Mask128 m) { return detail::CountTrue(hwy::SizeTag(), m); } template HWY_API bool AllFalse(const Full256 d, const Mask128 m) { #if 0 // Casting followed by wasm_i8x16_any_true results in wasm error: // i32.eqz[0] expected type i32, found i8x16.popcnt of type s128 const auto v8 = BitCast(Full256(), VecFromMask(d, m)); return !wasm_i8x16_any_true(v8.raw); #else (void)d; return (wasm_i64x2_extract_lane(m.raw, 0) | wasm_i64x2_extract_lane(m.raw, 1)) == 0; #endif } // Full vector namespace detail { template HWY_INLINE bool AllTrue(hwy::SizeTag<1> /*tag*/, const Mask128 m) { return wasm_i8x16_all_true(m.raw); } template HWY_INLINE bool AllTrue(hwy::SizeTag<2> /*tag*/, const Mask128 m) { return wasm_i16x8_all_true(m.raw); } template HWY_INLINE bool AllTrue(hwy::SizeTag<4> /*tag*/, const Mask128 m) { return wasm_i32x4_all_true(m.raw); } } // namespace detail template HWY_API bool AllTrue(const Full256 /* tag */, const Mask128 m) { return detail::AllTrue(hwy::SizeTag(), m); } template HWY_API intptr_t FindFirstTrue(const Full256 /* tag */, const Mask256 mask) { const uint64_t bits = detail::BitsFromMask(mask); return bits ? Num0BitsBelowLS1Bit_Nonzero64(bits) : -1; } // ------------------------------ Compress namespace detail { template HWY_INLINE Vec256 Idx16x8FromBits(const uint64_t mask_bits) { HWY_DASSERT(mask_bits < 256); const Full256 d; const Rebind d8; const Full256 du; // We need byte indices for TableLookupBytes (one vector's worth for each of // 256 combinations of 8 mask bits). Loading them directly requires 4 KiB. We // can instead store lane indices and convert to byte indices (2*lane + 0..1), // with the doubling baked into the table. Unpacking nibbles is likely more // costly than the higher cache footprint from storing bytes. alignas(32) constexpr uint8_t table[256 * 8] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 2, 4, 0, 0, 0, 0, 0, 0, 0, 2, 4, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 2, 6, 0, 0, 0, 0, 0, 0, 0, 2, 6, 0, 0, 0, 0, 0, 4, 6, 0, 0, 0, 0, 0, 0, 0, 4, 6, 0, 0, 0, 0, 0, 2, 4, 6, 0, 0, 0, 0, 0, 0, 2, 4, 6, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 2, 8, 0, 0, 0, 0, 0, 0, 0, 2, 8, 0, 0, 0, 0, 0, 4, 8, 0, 0, 0, 0, 0, 0, 0, 4, 8, 0, 0, 0, 0, 0, 2, 4, 8, 0, 0, 0, 0, 0, 0, 2, 4, 8, 0, 0, 0, 0, 6, 8, 0, 0, 0, 0, 0, 0, 0, 6, 8, 0, 0, 0, 0, 0, 2, 6, 8, 0, 0, 0, 0, 0, 0, 2, 6, 8, 0, 0, 0, 0, 4, 6, 8, 0, 0, 0, 0, 0, 0, 4, 6, 8, 0, 0, 0, 0, 2, 4, 6, 8, 0, 0, 0, 0, 0, 2, 4, 6, 8, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 2, 10, 0, 0, 0, 0, 0, 0, 0, 2, 10, 0, 0, 0, 0, 0, 4, 10, 0, 0, 0, 0, 0, 0, 0, 4, 10, 0, 0, 0, 0, 0, 2, 4, 10, 0, 0, 0, 0, 0, 0, 2, 4, 10, 0, 0, 0, 0, 6, 10, 0, 0, 0, 0, 0, 0, 0, 6, 10, 0, 0, 0, 0, 0, 2, 6, 10, 0, 0, 0, 0, 0, 0, 2, 6, 10, 0, 0, 0, 0, 4, 6, 10, 0, 0, 0, 0, 0, 0, 4, 6, 10, 0, 0, 0, 0, 2, 4, 6, 10, 0, 0, 0, 0, 0, 2, 4, 6, 10, 0, 0, 0, 8, 10, 0, 0, 0, 0, 0, 0, 0, 8, 10, 0, 0, 0, 0, 0, 2, 8, 10, 0, 0, 0, 0, 0, 0, 2, 8, 10, 0, 0, 0, 0, 4, 8, 10, 0, 0, 0, 0, 0, 0, 4, 8, 10, 0, 0, 0, 0, 2, 4, 8, 10, 0, 0, 0, 0, 0, 2, 4, 8, 10, 0, 0, 0, 6, 8, 10, 0, 0, 0, 0, 0, 0, 6, 8, 10, 0, 0, 0, 0, 2, 6, 8, 10, 0, 0, 0, 0, 0, 2, 6, 8, 10, 0, 0, 0, 4, 6, 8, 10, 0, 0, 0, 0, 0, 4, 6, 8, 10, 0, 0, 0, 2, 4, 6, 8, 10, 0, 0, 0, 0, 2, 4, 6, 8, 10, 0, 0, 12, 0, 0, 0, 0, 0, 0, 0, 0, 12, 0, 0, 0, 0, 0, 0, 2, 12, 0, 0, 0, 0, 0, 0, 0, 2, 12, 0, 0, 0, 0, 0, 4, 12, 0, 0, 0, 0, 0, 0, 0, 4, 12, 0, 0, 0, 0, 0, 2, 4, 12, 0, 0, 0, 0, 0, 0, 2, 4, 12, 0, 0, 0, 0, 6, 12, 0, 0, 0, 0, 0, 0, 0, 6, 12, 0, 0, 0, 0, 0, 2, 6, 12, 0, 0, 0, 0, 0, 0, 2, 6, 12, 0, 0, 0, 0, 4, 6, 12, 0, 0, 0, 0, 0, 0, 4, 6, 12, 0, 0, 0, 0, 2, 4, 6, 12, 0, 0, 0, 0, 0, 2, 4, 6, 12, 0, 0, 0, 8, 12, 0, 0, 0, 0, 0, 0, 0, 8, 12, 0, 0, 0, 0, 0, 2, 8, 12, 0, 0, 0, 0, 0, 0, 2, 8, 12, 0, 0, 0, 0, 4, 8, 12, 0, 0, 0, 0, 0, 0, 4, 8, 12, 0, 0, 0, 0, 2, 4, 8, 12, 0, 0, 0, 0, 0, 2, 4, 8, 12, 0, 0, 0, 6, 8, 12, 0, 0, 0, 0, 0, 0, 6, 8, 12, 0, 0, 0, 0, 2, 6, 8, 12, 0, 0, 0, 0, 0, 2, 6, 8, 12, 0, 0, 0, 4, 6, 8, 12, 0, 0, 0, 0, 0, 4, 6, 8, 12, 0, 0, 0, 2, 4, 6, 8, 12, 0, 0, 0, 0, 2, 4, 6, 8, 12, 0, 0, 10, 12, 0, 0, 0, 0, 0, 0, 0, 10, 12, 0, 0, 0, 0, 0, 2, 10, 12, 0, 0, 0, 0, 0, 0, 2, 10, 12, 0, 0, 0, 0, 4, 10, 12, 0, 0, 0, 0, 0, 0, 4, 10, 12, 0, 0, 0, 0, 2, 4, 10, 12, 0, 0, 0, 0, 0, 2, 4, 10, 12, 0, 0, 0, 6, 10, 12, 0, 0, 0, 0, 0, 0, 6, 10, 12, 0, 0, 0, 0, 2, 6, 10, 12, 0, 0, 0, 0, 0, 2, 6, 10, 12, 0, 0, 0, 4, 6, 10, 12, 0, 0, 0, 0, 0, 4, 6, 10, 12, 0, 0, 0, 2, 4, 6, 10, 12, 0, 0, 0, 0, 2, 4, 6, 10, 12, 0, 0, 8, 10, 12, 0, 0, 0, 0, 0, 0, 8, 10, 12, 0, 0, 0, 0, 2, 8, 10, 12, 0, 0, 0, 0, 0, 2, 8, 10, 12, 0, 0, 0, 4, 8, 10, 12, 0, 0, 0, 0, 0, 4, 8, 10, 12, 0, 0, 0, 2, 4, 8, 10, 12, 0, 0, 0, 0, 2, 4, 8, 10, 12, 0, 0, 6, 8, 10, 12, 0, 0, 0, 0, 0, 6, 8, 10, 12, 0, 0, 0, 2, 6, 8, 10, 12, 0, 0, 0, 0, 2, 6, 8, 10, 12, 0, 0, 4, 6, 8, 10, 12, 0, 0, 0, 0, 4, 6, 8, 10, 12, 0, 0, 2, 4, 6, 8, 10, 12, 0, 0, 0, 2, 4, 6, 8, 10, 12, 0, 14, 0, 0, 0, 0, 0, 0, 0, 0, 14, 0, 0, 0, 0, 0, 0, 2, 14, 0, 0, 0, 0, 0, 0, 0, 2, 14, 0, 0, 0, 0, 0, 4, 14, 0, 0, 0, 0, 0, 0, 0, 4, 14, 0, 0, 0, 0, 0, 2, 4, 14, 0, 0, 0, 0, 0, 0, 2, 4, 14, 0, 0, 0, 0, 6, 14, 0, 0, 0, 0, 0, 0, 0, 6, 14, 0, 0, 0, 0, 0, 2, 6, 14, 0, 0, 0, 0, 0, 0, 2, 6, 14, 0, 0, 0, 0, 4, 6, 14, 0, 0, 0, 0, 0, 0, 4, 6, 14, 0, 0, 0, 0, 2, 4, 6, 14, 0, 0, 0, 0, 0, 2, 4, 6, 14, 0, 0, 0, 8, 14, 0, 0, 0, 0, 0, 0, 0, 8, 14, 0, 0, 0, 0, 0, 2, 8, 14, 0, 0, 0, 0, 0, 0, 2, 8, 14, 0, 0, 0, 0, 4, 8, 14, 0, 0, 0, 0, 0, 0, 4, 8, 14, 0, 0, 0, 0, 2, 4, 8, 14, 0, 0, 0, 0, 0, 2, 4, 8, 14, 0, 0, 0, 6, 8, 14, 0, 0, 0, 0, 0, 0, 6, 8, 14, 0, 0, 0, 0, 2, 6, 8, 14, 0, 0, 0, 0, 0, 2, 6, 8, 14, 0, 0, 0, 4, 6, 8, 14, 0, 0, 0, 0, 0, 4, 6, 8, 14, 0, 0, 0, 2, 4, 6, 8, 14, 0, 0, 0, 0, 2, 4, 6, 8, 14, 0, 0, 10, 14, 0, 0, 0, 0, 0, 0, 0, 10, 14, 0, 0, 0, 0, 0, 2, 10, 14, 0, 0, 0, 0, 0, 0, 2, 10, 14, 0, 0, 0, 0, 4, 10, 14, 0, 0, 0, 0, 0, 0, 4, 10, 14, 0, 0, 0, 0, 2, 4, 10, 14, 0, 0, 0, 0, 0, 2, 4, 10, 14, 0, 0, 0, 6, 10, 14, 0, 0, 0, 0, 0, 0, 6, 10, 14, 0, 0, 0, 0, 2, 6, 10, 14, 0, 0, 0, 0, 0, 2, 6, 10, 14, 0, 0, 0, 4, 6, 10, 14, 0, 0, 0, 0, 0, 4, 6, 10, 14, 0, 0, 0, 2, 4, 6, 10, 14, 0, 0, 0, 0, 2, 4, 6, 10, 14, 0, 0, 8, 10, 14, 0, 0, 0, 0, 0, 0, 8, 10, 14, 0, 0, 0, 0, 2, 8, 10, 14, 0, 0, 0, 0, 0, 2, 8, 10, 14, 0, 0, 0, 4, 8, 10, 14, 0, 0, 0, 0, 0, 4, 8, 10, 14, 0, 0, 0, 2, 4, 8, 10, 14, 0, 0, 0, 0, 2, 4, 8, 10, 14, 0, 0, 6, 8, 10, 14, 0, 0, 0, 0, 0, 6, 8, 10, 14, 0, 0, 0, 2, 6, 8, 10, 14, 0, 0, 0, 0, 2, 6, 8, 10, 14, 0, 0, 4, 6, 8, 10, 14, 0, 0, 0, 0, 4, 6, 8, 10, 14, 0, 0, 2, 4, 6, 8, 10, 14, 0, 0, 0, 2, 4, 6, 8, 10, 14, 0, 12, 14, 0, 0, 0, 0, 0, 0, 0, 12, 14, 0, 0, 0, 0, 0, 2, 12, 14, 0, 0, 0, 0, 0, 0, 2, 12, 14, 0, 0, 0, 0, 4, 12, 14, 0, 0, 0, 0, 0, 0, 4, 12, 14, 0, 0, 0, 0, 2, 4, 12, 14, 0, 0, 0, 0, 0, 2, 4, 12, 14, 0, 0, 0, 6, 12, 14, 0, 0, 0, 0, 0, 0, 6, 12, 14, 0, 0, 0, 0, 2, 6, 12, 14, 0, 0, 0, 0, 0, 2, 6, 12, 14, 0, 0, 0, 4, 6, 12, 14, 0, 0, 0, 0, 0, 4, 6, 12, 14, 0, 0, 0, 2, 4, 6, 12, 14, 0, 0, 0, 0, 2, 4, 6, 12, 14, 0, 0, 8, 12, 14, 0, 0, 0, 0, 0, 0, 8, 12, 14, 0, 0, 0, 0, 2, 8, 12, 14, 0, 0, 0, 0, 0, 2, 8, 12, 14, 0, 0, 0, 4, 8, 12, 14, 0, 0, 0, 0, 0, 4, 8, 12, 14, 0, 0, 0, 2, 4, 8, 12, 14, 0, 0, 0, 0, 2, 4, 8, 12, 14, 0, 0, 6, 8, 12, 14, 0, 0, 0, 0, 0, 6, 8, 12, 14, 0, 0, 0, 2, 6, 8, 12, 14, 0, 0, 0, 0, 2, 6, 8, 12, 14, 0, 0, 4, 6, 8, 12, 14, 0, 0, 0, 0, 4, 6, 8, 12, 14, 0, 0, 2, 4, 6, 8, 12, 14, 0, 0, 0, 2, 4, 6, 8, 12, 14, 0, 10, 12, 14, 0, 0, 0, 0, 0, 0, 10, 12, 14, 0, 0, 0, 0, 2, 10, 12, 14, 0, 0, 0, 0, 0, 2, 10, 12, 14, 0, 0, 0, 4, 10, 12, 14, 0, 0, 0, 0, 0, 4, 10, 12, 14, 0, 0, 0, 2, 4, 10, 12, 14, 0, 0, 0, 0, 2, 4, 10, 12, 14, 0, 0, 6, 10, 12, 14, 0, 0, 0, 0, 0, 6, 10, 12, 14, 0, 0, 0, 2, 6, 10, 12, 14, 0, 0, 0, 0, 2, 6, 10, 12, 14, 0, 0, 4, 6, 10, 12, 14, 0, 0, 0, 0, 4, 6, 10, 12, 14, 0, 0, 2, 4, 6, 10, 12, 14, 0, 0, 0, 2, 4, 6, 10, 12, 14, 0, 8, 10, 12, 14, 0, 0, 0, 0, 0, 8, 10, 12, 14, 0, 0, 0, 2, 8, 10, 12, 14, 0, 0, 0, 0, 2, 8, 10, 12, 14, 0, 0, 4, 8, 10, 12, 14, 0, 0, 0, 0, 4, 8, 10, 12, 14, 0, 0, 2, 4, 8, 10, 12, 14, 0, 0, 0, 2, 4, 8, 10, 12, 14, 0, 6, 8, 10, 12, 14, 0, 0, 0, 0, 6, 8, 10, 12, 14, 0, 0, 2, 6, 8, 10, 12, 14, 0, 0, 0, 2, 6, 8, 10, 12, 14, 0, 4, 6, 8, 10, 12, 14, 0, 0, 0, 4, 6, 8, 10, 12, 14, 0, 2, 4, 6, 8, 10, 12, 14, 0, 0, 2, 4, 6, 8, 10, 12, 14}; const Vec256 byte_idx{Load(d8, table + mask_bits * 8).raw}; const Vec256 pairs = ZipLower(byte_idx, byte_idx); return BitCast(d, pairs + Set(du, 0x0100)); } template HWY_INLINE Vec256 Idx32x4FromBits(const uint64_t mask_bits) { HWY_DASSERT(mask_bits < 16); // There are only 4 lanes, so we can afford to load the index vector directly. alignas(32) constexpr uint8_t packed_array[16 * 16] = { 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, // 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, // 4, 5, 6, 7, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, // 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 0, 1, 2, 3, // 8, 9, 10, 11, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, // 0, 1, 2, 3, 8, 9, 10, 11, 0, 1, 2, 3, 0, 1, 2, 3, // 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 0, 1, 2, 3, // 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, // 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, // 0, 1, 2, 3, 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3, // 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3, // 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, // 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3, // 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, // 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, // 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; const Full256 d; const Repartition d8; return BitCast(d, Load(d8, packed_array + 16 * mask_bits)); } #if HWY_HAVE_INTEGER64 || HWY_HAVE_FLOAT64 template HWY_INLINE Vec256 Idx64x2FromBits(const uint64_t mask_bits) { HWY_DASSERT(mask_bits < 4); // There are only 2 lanes, so we can afford to load the index vector directly. alignas(32) constexpr uint8_t packed_array[4 * 16] = { 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, // 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, // 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, // 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; const Full256 d; const Repartition d8; return BitCast(d, Load(d8, packed_array + 16 * mask_bits)); } #endif // Helper functions called by both Compress and CompressStore - avoids a // redundant BitsFromMask in the latter. template HWY_INLINE Vec256 Compress(hwy::SizeTag<2> /*tag*/, Vec256 v, const uint64_t mask_bits) { const auto idx = detail::Idx16x8FromBits(mask_bits); using D = Full256; const RebindToSigned di; return BitCast(D(), TableLookupBytes(BitCast(di, v), BitCast(di, idx))); } template HWY_INLINE Vec256 Compress(hwy::SizeTag<4> /*tag*/, Vec256 v, const uint64_t mask_bits) { const auto idx = detail::Idx32x4FromBits(mask_bits); using D = Full256; const RebindToSigned di; return BitCast(D(), TableLookupBytes(BitCast(di, v), BitCast(di, idx))); } #if HWY_HAVE_INTEGER64 || HWY_HAVE_FLOAT64 template HWY_INLINE Vec256 Compress(hwy::SizeTag<8> /*tag*/, Vec256 v, const uint64_t mask_bits) { const auto idx = detail::Idx64x2FromBits(mask_bits); using D = Full256; const RebindToSigned di; return BitCast(D(), TableLookupBytes(BitCast(di, v), BitCast(di, idx))); } #endif } // namespace detail template HWY_API Vec256 Compress(Vec256 v, const Mask256 mask) { const uint64_t mask_bits = detail::BitsFromMask(mask); return detail::Compress(hwy::SizeTag(), v, mask_bits); } // ------------------------------ CompressBits template HWY_API Vec256 CompressBits(Vec256 v, const uint8_t* HWY_RESTRICT bits) { uint64_t mask_bits = 0; constexpr size_t kNumBytes = (N + 7) / 8; CopyBytes(bits, &mask_bits); if (N < 8) { mask_bits &= (1ull << N) - 1; } return detail::Compress(hwy::SizeTag(), v, mask_bits); } // ------------------------------ CompressStore template HWY_API size_t CompressStore(Vec256 v, const Mask256 mask, Full256 d, T* HWY_RESTRICT unaligned) { const uint64_t mask_bits = detail::BitsFromMask(mask); const auto c = detail::Compress(hwy::SizeTag(), v, mask_bits); StoreU(c, d, unaligned); return PopCount(mask_bits); } // ------------------------------ CompressBlendedStore template HWY_API size_t CompressBlendedStore(Vec256 v, Mask256 m, Full256 d, T* HWY_RESTRICT unaligned) { const RebindToUnsigned du; // so we can support fp16/bf16 using TU = TFromD; const uint64_t mask_bits = detail::BitsFromMask(m); const size_t count = PopCount(mask_bits); const Mask256 store_mask = FirstN(du, count); const Vec256 compressed = detail::Compress(hwy::SizeTag(), BitCast(du, v), mask_bits); const Vec256 prev = BitCast(du, LoadU(d, unaligned)); StoreU(BitCast(d, IfThenElse(store_mask, compressed, prev)), d, unaligned); return count; } // ------------------------------ CompressBitsStore template HWY_API size_t CompressBitsStore(Vec256 v, const uint8_t* HWY_RESTRICT bits, Full256 d, T* HWY_RESTRICT unaligned) { uint64_t mask_bits = 0; constexpr size_t kNumBytes = (N + 7) / 8; CopyBytes(bits, &mask_bits); if (N < 8) { mask_bits &= (1ull << N) - 1; } const auto c = detail::Compress(hwy::SizeTag(), v, mask_bits); StoreU(c, d, unaligned); return PopCount(mask_bits); } // ------------------------------ StoreInterleaved3 (CombineShiftRightBytes, // TableLookupBytes) HWY_API void StoreInterleaved3(const Vec256 a, const Vec256 b, const Vec256 c, Full256 d, uint8_t* HWY_RESTRICT unaligned) { const auto k5 = Set(d, 5); const auto k6 = Set(d, 6); // Shuffle (a,b,c) vector bytes to (MSB on left): r5, bgr[4:0]. // 0x80 so lanes to be filled from other vectors are 0 for blending. alignas(32) static constexpr uint8_t tbl_r0[16] = { 0, 0x80, 0x80, 1, 0x80, 0x80, 2, 0x80, 0x80, // 3, 0x80, 0x80, 4, 0x80, 0x80, 5}; alignas(32) static constexpr uint8_t tbl_g0[16] = { 0x80, 0, 0x80, 0x80, 1, 0x80, // 0x80, 2, 0x80, 0x80, 3, 0x80, 0x80, 4, 0x80, 0x80}; const auto shuf_r0 = Load(d, tbl_r0); const auto shuf_g0 = Load(d, tbl_g0); // cannot reuse r0 due to 5 in MSB const auto shuf_b0 = CombineShiftRightBytes<15>(d, shuf_g0, shuf_g0); const auto r0 = TableLookupBytes(a, shuf_r0); // 5..4..3..2..1..0 const auto g0 = TableLookupBytes(b, shuf_g0); // ..4..3..2..1..0. const auto b0 = TableLookupBytes(c, shuf_b0); // .4..3..2..1..0.. const auto int0 = r0 | g0 | b0; StoreU(int0, d, unaligned + 0 * 16); // Second vector: g10,r10, bgr[9:6], b5,g5 const auto shuf_r1 = shuf_b0 + k6; // .A..9..8..7..6.. const auto shuf_g1 = shuf_r0 + k5; // A..9..8..7..6..5 const auto shuf_b1 = shuf_g0 + k5; // ..9..8..7..6..5. const auto r1 = TableLookupBytes(a, shuf_r1); const auto g1 = TableLookupBytes(b, shuf_g1); const auto b1 = TableLookupBytes(c, shuf_b1); const auto int1 = r1 | g1 | b1; StoreU(int1, d, unaligned + 1 * 16); // Third vector: bgr[15:11], b10 const auto shuf_r2 = shuf_b1 + k6; // ..F..E..D..C..B. const auto shuf_g2 = shuf_r1 + k5; // .F..E..D..C..B.. const auto shuf_b2 = shuf_g1 + k5; // F..E..D..C..B..A const auto r2 = TableLookupBytes(a, shuf_r2); const auto g2 = TableLookupBytes(b, shuf_g2); const auto b2 = TableLookupBytes(c, shuf_b2); const auto int2 = r2 | g2 | b2; StoreU(int2, d, unaligned + 2 * 16); } // ------------------------------ StoreInterleaved4 HWY_API void StoreInterleaved4(const Vec256 v0, const Vec256 v1, const Vec256 v2, const Vec256 v3, Full256 d8, uint8_t* HWY_RESTRICT unaligned) { const RepartitionToWide d16; const RepartitionToWide d32; // let a,b,c,d denote v0..3. const auto ba0 = ZipLower(d16, v0, v1); // b7 a7 .. b0 a0 const auto dc0 = ZipLower(d16, v2, v3); // d7 c7 .. d0 c0 const auto ba8 = ZipUpper(d16, v0, v1); const auto dc8 = ZipUpper(d16, v2, v3); const auto dcba_0 = ZipLower(d32, ba0, dc0); // d..a3 d..a0 const auto dcba_4 = ZipUpper(d32, ba0, dc0); // d..a7 d..a4 const auto dcba_8 = ZipLower(d32, ba8, dc8); // d..aB d..a8 const auto dcba_C = ZipUpper(d32, ba8, dc8); // d..aF d..aC StoreU(BitCast(d8, dcba_0), d8, unaligned + 0 * 16); StoreU(BitCast(d8, dcba_4), d8, unaligned + 1 * 16); StoreU(BitCast(d8, dcba_8), d8, unaligned + 2 * 16); StoreU(BitCast(d8, dcba_C), d8, unaligned + 3 * 16); } // ------------------------------ MulEven/Odd (Load) HWY_INLINE Vec256 MulEven(const Vec256 a, const Vec256 b) { alignas(32) uint64_t mul[2]; mul[0] = Mul128(static_cast(wasm_i64x2_extract_lane(a.raw, 0)), static_cast(wasm_i64x2_extract_lane(b.raw, 0)), &mul[1]); return Load(Full256(), mul); } HWY_INLINE Vec256 MulOdd(const Vec256 a, const Vec256 b) { alignas(32) uint64_t mul[2]; mul[0] = Mul128(static_cast(wasm_i64x2_extract_lane(a.raw, 1)), static_cast(wasm_i64x2_extract_lane(b.raw, 1)), &mul[1]); return Load(Full256(), mul); } // ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) HWY_API Vec256 ReorderWidenMulAccumulate(Full256 df32, Vec256 a, Vec256 b, const Vec256 sum0, Vec256& sum1) { const Repartition du16; const RebindToUnsigned du32; const Vec256 zero = Zero(du16); const Vec256 a0 = ZipLower(du32, zero, BitCast(du16, a)); const Vec256 a1 = ZipUpper(du32, zero, BitCast(du16, a)); const Vec256 b0 = ZipLower(du32, zero, BitCast(du16, b)); const Vec256 b1 = ZipUpper(du32, zero, BitCast(du16, b)); sum1 = MulAdd(BitCast(df32, a1), BitCast(df32, b1), sum1); return MulAdd(BitCast(df32, a0), BitCast(df32, b0), sum0); } // ------------------------------ Reductions namespace detail { // u32/i32/f32: template HWY_INLINE Vec256 SumOfLanes(hwy::SizeTag<4> /* tag */, const Vec256 v3210) { const Vec256 v1032 = Shuffle1032(v3210); const Vec256 v31_20_31_20 = v3210 + v1032; const Vec256 v20_31_20_31 = Shuffle0321(v31_20_31_20); return v20_31_20_31 + v31_20_31_20; } template HWY_INLINE Vec256 MinOfLanes(hwy::SizeTag<4> /* tag */, const Vec256 v3210) { const Vec256 v1032 = Shuffle1032(v3210); const Vec256 v31_20_31_20 = Min(v3210, v1032); const Vec256 v20_31_20_31 = Shuffle0321(v31_20_31_20); return Min(v20_31_20_31, v31_20_31_20); } template HWY_INLINE Vec256 MaxOfLanes(hwy::SizeTag<4> /* tag */, const Vec256 v3210) { const Vec256 v1032 = Shuffle1032(v3210); const Vec256 v31_20_31_20 = Max(v3210, v1032); const Vec256 v20_31_20_31 = Shuffle0321(v31_20_31_20); return Max(v20_31_20_31, v31_20_31_20); } // u64/i64/f64: template HWY_INLINE Vec256 SumOfLanes(hwy::SizeTag<8> /* tag */, const Vec256 v10) { const Vec256 v01 = Shuffle01(v10); return v10 + v01; } template HWY_INLINE Vec256 MinOfLanes(hwy::SizeTag<8> /* tag */, const Vec256 v10) { const Vec256 v01 = Shuffle01(v10); return Min(v10, v01); } template HWY_INLINE Vec256 MaxOfLanes(hwy::SizeTag<8> /* tag */, const Vec256 v10) { const Vec256 v01 = Shuffle01(v10); return Max(v10, v01); } // u16/i16 template HWY_API Vec256 MinOfLanes(hwy::SizeTag<2> /* tag */, Vec256 v) { const Repartition> d32; const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); const auto odd = ShiftRight<16>(BitCast(d32, v)); const auto min = MinOfLanes(d32, Min(even, odd)); // Also broadcast into odd lanes. return BitCast(Full256(), Or(min, ShiftLeft<16>(min))); } template HWY_API Vec256 MaxOfLanes(hwy::SizeTag<2> /* tag */, Vec256 v) { const Repartition> d32; const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); const auto odd = ShiftRight<16>(BitCast(d32, v)); const auto min = MaxOfLanes(d32, Max(even, odd)); // Also broadcast into odd lanes. return BitCast(Full256(), Or(min, ShiftLeft<16>(min))); } } // namespace detail // Supported for u/i/f 32/64. Returns the same value in each lane. template HWY_API Vec256 SumOfLanes(Full256 /* tag */, const Vec256 v) { return detail::SumOfLanes(hwy::SizeTag(), v); } template HWY_API Vec256 MinOfLanes(Full256 /* tag */, const Vec256 v) { return detail::MinOfLanes(hwy::SizeTag(), v); } template HWY_API Vec256 MaxOfLanes(Full256 /* tag */, const Vec256 v) { return detail::MaxOfLanes(hwy::SizeTag(), v); } // ------------------------------ Lt128 template HWY_INLINE Mask256 Lt128(Full256 d, Vec256 a, Vec256 b) {} template HWY_INLINE Vec256 Min128(Full256 d, Vec256 a, Vec256 b) {} template HWY_INLINE Vec256 Max128(Full256 d, Vec256 a, Vec256 b) {} // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace hwy HWY_AFTER_NAMESPACE();