// Copyright 2019 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // 256-bit vectors and AVX2 instructions, plus some AVX512-VL operations when // compiling for that target. // External include guard in highway.h - see comment there. // WARNING: most operations do not cross 128-bit block boundaries. In // particular, "Broadcast", pack and zip behavior may be surprising. #include // AVX2+ #include "hwy/base.h" #if defined(_MSC_VER) && defined(__clang__) // Including should be enough, but Clang's headers helpfully skip // including these headers when _MSC_VER is defined, like when using clang-cl. // Include these directly here. #include // avxintrin defines __m256i and must come before avx2intrin. #include #include // _pext_u64 #include #include #include #endif #include #include // For half-width vectors. Already includes base.h and shared-inl.h. #include "hwy/ops/x86_128-inl.h" HWY_BEFORE_NAMESPACE(); namespace hwy { namespace HWY_NAMESPACE { namespace detail { template struct Raw256 { using type = __m256i; }; template <> struct Raw256 { using type = __m256; }; template <> struct Raw256 { using type = __m256d; }; } // namespace detail template class Vec256 { using Raw = typename detail::Raw256::type; public: // Compound assignment. Only usable if there is a corresponding non-member // binary operator overload. For example, only f32 and f64 support division. HWY_INLINE Vec256& operator*=(const Vec256 other) { return *this = (*this * other); } HWY_INLINE Vec256& operator/=(const Vec256 other) { return *this = (*this / other); } HWY_INLINE Vec256& operator+=(const Vec256 other) { return *this = (*this + other); } HWY_INLINE Vec256& operator-=(const Vec256 other) { return *this = (*this - other); } HWY_INLINE Vec256& operator&=(const Vec256 other) { return *this = (*this & other); } HWY_INLINE Vec256& operator|=(const Vec256 other) { return *this = (*this | other); } HWY_INLINE Vec256& operator^=(const Vec256 other) { return *this = (*this ^ other); } Raw raw; }; #if HWY_TARGET <= HWY_AVX3 namespace detail { // Template arg: sizeof(lane type) template struct RawMask256 {}; template <> struct RawMask256<1> { using type = __mmask32; }; template <> struct RawMask256<2> { using type = __mmask16; }; template <> struct RawMask256<4> { using type = __mmask8; }; template <> struct RawMask256<8> { using type = __mmask8; }; } // namespace detail template struct Mask256 { using Raw = typename detail::RawMask256::type; static Mask256 FromBits(uint64_t mask_bits) { return Mask256{static_cast(mask_bits)}; } Raw raw; }; #else // AVX2 // FF..FF or 0. template struct Mask256 { typename detail::Raw256::type raw; }; #endif // HWY_TARGET <= HWY_AVX3 // ------------------------------ BitCast namespace detail { HWY_INLINE __m256i BitCastToInteger(__m256i v) { return v; } HWY_INLINE __m256i BitCastToInteger(__m256 v) { return _mm256_castps_si256(v); } HWY_INLINE __m256i BitCastToInteger(__m256d v) { return _mm256_castpd_si256(v); } template HWY_INLINE Vec256 BitCastToByte(Vec256 v) { return Vec256{BitCastToInteger(v.raw)}; } // Cannot rely on function overloading because return types differ. template struct BitCastFromInteger256 { HWY_INLINE __m256i operator()(__m256i v) { return v; } }; template <> struct BitCastFromInteger256 { HWY_INLINE __m256 operator()(__m256i v) { return _mm256_castsi256_ps(v); } }; template <> struct BitCastFromInteger256 { HWY_INLINE __m256d operator()(__m256i v) { return _mm256_castsi256_pd(v); } }; template HWY_INLINE Vec256 BitCastFromByte(Full256 /* tag */, Vec256 v) { return Vec256{BitCastFromInteger256()(v.raw)}; } } // namespace detail template HWY_API Vec256 BitCast(Full256 d, Vec256 v) { return detail::BitCastFromByte(d, detail::BitCastToByte(v)); } // ------------------------------ Set // Returns an all-zero vector. template HWY_API Vec256 Zero(Full256 /* tag */) { return Vec256{_mm256_setzero_si256()}; } HWY_API Vec256 Zero(Full256 /* tag */) { return Vec256{_mm256_setzero_ps()}; } HWY_API Vec256 Zero(Full256 /* tag */) { return Vec256{_mm256_setzero_pd()}; } // Returns a vector with all lanes set to "t". HWY_API Vec256 Set(Full256 /* tag */, const uint8_t t) { return Vec256{_mm256_set1_epi8(static_cast(t))}; // NOLINT } HWY_API Vec256 Set(Full256 /* tag */, const uint16_t t) { return Vec256{_mm256_set1_epi16(static_cast(t))}; // NOLINT } HWY_API Vec256 Set(Full256 /* tag */, const uint32_t t) { return Vec256{_mm256_set1_epi32(static_cast(t))}; } HWY_API Vec256 Set(Full256 /* tag */, const uint64_t t) { return Vec256{ _mm256_set1_epi64x(static_cast(t))}; // NOLINT } HWY_API Vec256 Set(Full256 /* tag */, const int8_t t) { return Vec256{_mm256_set1_epi8(static_cast(t))}; // NOLINT } HWY_API Vec256 Set(Full256 /* tag */, const int16_t t) { return Vec256{_mm256_set1_epi16(static_cast(t))}; // NOLINT } HWY_API Vec256 Set(Full256 /* tag */, const int32_t t) { return Vec256{_mm256_set1_epi32(t)}; } HWY_API Vec256 Set(Full256 /* tag */, const int64_t t) { return Vec256{ _mm256_set1_epi64x(static_cast(t))}; // NOLINT } HWY_API Vec256 Set(Full256 /* tag */, const float t) { return Vec256{_mm256_set1_ps(t)}; } HWY_API Vec256 Set(Full256 /* tag */, const double t) { return Vec256{_mm256_set1_pd(t)}; } HWY_DIAGNOSTICS(push) HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") // Returns a vector with uninitialized elements. template HWY_API Vec256 Undefined(Full256 /* tag */) { // Available on Clang 6.0, GCC 6.2, ICC 16.03, MSVC 19.14. All but ICC // generate an XOR instruction. return Vec256{_mm256_undefined_si256()}; } HWY_API Vec256 Undefined(Full256 /* tag */) { return Vec256{_mm256_undefined_ps()}; } HWY_API Vec256 Undefined(Full256 /* tag */) { return Vec256{_mm256_undefined_pd()}; } HWY_DIAGNOSTICS(pop) // ================================================== LOGICAL // ------------------------------ And template HWY_API Vec256 And(Vec256 a, Vec256 b) { return Vec256{_mm256_and_si256(a.raw, b.raw)}; } HWY_API Vec256 And(const Vec256 a, const Vec256 b) { return Vec256{_mm256_and_ps(a.raw, b.raw)}; } HWY_API Vec256 And(const Vec256 a, const Vec256 b) { return Vec256{_mm256_and_pd(a.raw, b.raw)}; } // ------------------------------ AndNot // Returns ~not_mask & mask. template HWY_API Vec256 AndNot(Vec256 not_mask, Vec256 mask) { return Vec256{_mm256_andnot_si256(not_mask.raw, mask.raw)}; } HWY_API Vec256 AndNot(const Vec256 not_mask, const Vec256 mask) { return Vec256{_mm256_andnot_ps(not_mask.raw, mask.raw)}; } HWY_API Vec256 AndNot(const Vec256 not_mask, const Vec256 mask) { return Vec256{_mm256_andnot_pd(not_mask.raw, mask.raw)}; } // ------------------------------ Or template HWY_API Vec256 Or(Vec256 a, Vec256 b) { return Vec256{_mm256_or_si256(a.raw, b.raw)}; } HWY_API Vec256 Or(const Vec256 a, const Vec256 b) { return Vec256{_mm256_or_ps(a.raw, b.raw)}; } HWY_API Vec256 Or(const Vec256 a, const Vec256 b) { return Vec256{_mm256_or_pd(a.raw, b.raw)}; } // ------------------------------ Xor template HWY_API Vec256 Xor(Vec256 a, Vec256 b) { return Vec256{_mm256_xor_si256(a.raw, b.raw)}; } HWY_API Vec256 Xor(const Vec256 a, const Vec256 b) { return Vec256{_mm256_xor_ps(a.raw, b.raw)}; } HWY_API Vec256 Xor(const Vec256 a, const Vec256 b) { return Vec256{_mm256_xor_pd(a.raw, b.raw)}; } // ------------------------------ Not template HWY_API Vec256 Not(const Vec256 v) { using TU = MakeUnsigned; #if HWY_TARGET <= HWY_AVX3 const __m256i vu = BitCast(Full256(), v).raw; return BitCast(Full256(), Vec256{_mm256_ternarylogic_epi32(vu, vu, vu, 0x55)}); #else return Xor(v, BitCast(Full256(), Vec256{_mm256_set1_epi32(-1)})); #endif } // ------------------------------ OrAnd template HWY_API Vec256 OrAnd(Vec256 o, Vec256 a1, Vec256 a2) { #if HWY_TARGET <= HWY_AVX3 const Full256 d; const RebindToUnsigned du; using VU = VFromD; const __m256i ret = _mm256_ternarylogic_epi64( BitCast(du, o).raw, BitCast(du, a1).raw, BitCast(du, a2).raw, 0xF8); return BitCast(d, VU{ret}); #else return Or(o, And(a1, a2)); #endif } // ------------------------------ IfVecThenElse template HWY_API Vec256 IfVecThenElse(Vec256 mask, Vec256 yes, Vec256 no) { #if HWY_TARGET <= HWY_AVX3 const Full256 d; const RebindToUnsigned du; using VU = VFromD; return BitCast(d, VU{_mm256_ternarylogic_epi64(BitCast(du, mask).raw, BitCast(du, yes).raw, BitCast(du, no).raw, 0xCA)}); #else return IfThenElse(MaskFromVec(mask), yes, no); #endif } // ------------------------------ Operator overloads (internal-only if float) template HWY_API Vec256 operator&(const Vec256 a, const Vec256 b) { return And(a, b); } template HWY_API Vec256 operator|(const Vec256 a, const Vec256 b) { return Or(a, b); } template HWY_API Vec256 operator^(const Vec256 a, const Vec256 b) { return Xor(a, b); } // ------------------------------ PopulationCount // 8/16 require BITALG, 32/64 require VPOPCNTDQ. #if HWY_TARGET == HWY_AVX3_DL #ifdef HWY_NATIVE_POPCNT #undef HWY_NATIVE_POPCNT #else #define HWY_NATIVE_POPCNT #endif namespace detail { template HWY_INLINE Vec256 PopulationCount(hwy::SizeTag<1> /* tag */, Vec256 v) { return Vec256{_mm256_popcnt_epi8(v.raw)}; } template HWY_INLINE Vec256 PopulationCount(hwy::SizeTag<2> /* tag */, Vec256 v) { return Vec256{_mm256_popcnt_epi16(v.raw)}; } template HWY_INLINE Vec256 PopulationCount(hwy::SizeTag<4> /* tag */, Vec256 v) { return Vec256{_mm256_popcnt_epi32(v.raw)}; } template HWY_INLINE Vec256 PopulationCount(hwy::SizeTag<8> /* tag */, Vec256 v) { return Vec256{_mm256_popcnt_epi64(v.raw)}; } } // namespace detail template HWY_API Vec256 PopulationCount(Vec256 v) { return detail::PopulationCount(hwy::SizeTag(), v); } #endif // HWY_TARGET == HWY_AVX3_DL // ================================================== SIGN // ------------------------------ CopySign template HWY_API Vec256 CopySign(const Vec256 magn, const Vec256 sign) { static_assert(IsFloat(), "Only makes sense for floating-point"); const Full256 d; const auto msb = SignBit(d); #if HWY_TARGET <= HWY_AVX3 const Rebind, decltype(d)> du; // Truth table for msb, magn, sign | bitwise msb ? sign : mag // 0 0 0 | 0 // 0 0 1 | 0 // 0 1 0 | 1 // 0 1 1 | 1 // 1 0 0 | 0 // 1 0 1 | 1 // 1 1 0 | 0 // 1 1 1 | 1 // The lane size does not matter because we are not using predication. const __m256i out = _mm256_ternarylogic_epi32( BitCast(du, msb).raw, BitCast(du, magn).raw, BitCast(du, sign).raw, 0xAC); return BitCast(d, decltype(Zero(du)){out}); #else return Or(AndNot(msb, magn), And(msb, sign)); #endif } template HWY_API Vec256 CopySignToAbs(const Vec256 abs, const Vec256 sign) { #if HWY_TARGET <= HWY_AVX3 // AVX3 can also handle abs < 0, so no extra action needed. return CopySign(abs, sign); #else return Or(abs, And(SignBit(Full256()), sign)); #endif } // ================================================== MASK #if HWY_TARGET <= HWY_AVX3 // ------------------------------ IfThenElse // Returns mask ? b : a. namespace detail { // Templates for signed/unsigned integer of a particular size. template HWY_INLINE Vec256 IfThenElse(hwy::SizeTag<1> /* tag */, Mask256 mask, Vec256 yes, Vec256 no) { return Vec256{_mm256_mask_mov_epi8(no.raw, mask.raw, yes.raw)}; } template HWY_INLINE Vec256 IfThenElse(hwy::SizeTag<2> /* tag */, Mask256 mask, Vec256 yes, Vec256 no) { return Vec256{_mm256_mask_mov_epi16(no.raw, mask.raw, yes.raw)}; } template HWY_INLINE Vec256 IfThenElse(hwy::SizeTag<4> /* tag */, Mask256 mask, Vec256 yes, Vec256 no) { return Vec256{_mm256_mask_mov_epi32(no.raw, mask.raw, yes.raw)}; } template HWY_INLINE Vec256 IfThenElse(hwy::SizeTag<8> /* tag */, Mask256 mask, Vec256 yes, Vec256 no) { return Vec256{_mm256_mask_mov_epi64(no.raw, mask.raw, yes.raw)}; } } // namespace detail template HWY_API Vec256 IfThenElse(Mask256 mask, Vec256 yes, Vec256 no) { return detail::IfThenElse(hwy::SizeTag(), mask, yes, no); } HWY_API Vec256 IfThenElse(Mask256 mask, Vec256 yes, Vec256 no) { return Vec256{_mm256_mask_mov_ps(no.raw, mask.raw, yes.raw)}; } HWY_API Vec256 IfThenElse(Mask256 mask, Vec256 yes, Vec256 no) { return Vec256{_mm256_mask_mov_pd(no.raw, mask.raw, yes.raw)}; } namespace detail { template HWY_INLINE Vec256 IfThenElseZero(hwy::SizeTag<1> /* tag */, Mask256 mask, Vec256 yes) { return Vec256{_mm256_maskz_mov_epi8(mask.raw, yes.raw)}; } template HWY_INLINE Vec256 IfThenElseZero(hwy::SizeTag<2> /* tag */, Mask256 mask, Vec256 yes) { return Vec256{_mm256_maskz_mov_epi16(mask.raw, yes.raw)}; } template HWY_INLINE Vec256 IfThenElseZero(hwy::SizeTag<4> /* tag */, Mask256 mask, Vec256 yes) { return Vec256{_mm256_maskz_mov_epi32(mask.raw, yes.raw)}; } template HWY_INLINE Vec256 IfThenElseZero(hwy::SizeTag<8> /* tag */, Mask256 mask, Vec256 yes) { return Vec256{_mm256_maskz_mov_epi64(mask.raw, yes.raw)}; } } // namespace detail template HWY_API Vec256 IfThenElseZero(Mask256 mask, Vec256 yes) { return detail::IfThenElseZero(hwy::SizeTag(), mask, yes); } HWY_API Vec256 IfThenElseZero(Mask256 mask, Vec256 yes) { return Vec256{_mm256_maskz_mov_ps(mask.raw, yes.raw)}; } HWY_API Vec256 IfThenElseZero(Mask256 mask, Vec256 yes) { return Vec256{_mm256_maskz_mov_pd(mask.raw, yes.raw)}; } namespace detail { template HWY_INLINE Vec256 IfThenZeroElse(hwy::SizeTag<1> /* tag */, Mask256 mask, Vec256 no) { // xor_epi8/16 are missing, but we have sub, which is just as fast for u8/16. return Vec256{_mm256_mask_sub_epi8(no.raw, mask.raw, no.raw, no.raw)}; } template HWY_INLINE Vec256 IfThenZeroElse(hwy::SizeTag<2> /* tag */, Mask256 mask, Vec256 no) { return Vec256{_mm256_mask_sub_epi16(no.raw, mask.raw, no.raw, no.raw)}; } template HWY_INLINE Vec256 IfThenZeroElse(hwy::SizeTag<4> /* tag */, Mask256 mask, Vec256 no) { return Vec256{_mm256_mask_xor_epi32(no.raw, mask.raw, no.raw, no.raw)}; } template HWY_INLINE Vec256 IfThenZeroElse(hwy::SizeTag<8> /* tag */, Mask256 mask, Vec256 no) { return Vec256{_mm256_mask_xor_epi64(no.raw, mask.raw, no.raw, no.raw)}; } } // namespace detail template HWY_API Vec256 IfThenZeroElse(Mask256 mask, Vec256 no) { return detail::IfThenZeroElse(hwy::SizeTag(), mask, no); } HWY_API Vec256 IfThenZeroElse(Mask256 mask, Vec256 no) { return Vec256{_mm256_mask_xor_ps(no.raw, mask.raw, no.raw, no.raw)}; } HWY_API Vec256 IfThenZeroElse(Mask256 mask, Vec256 no) { return Vec256{_mm256_mask_xor_pd(no.raw, mask.raw, no.raw, no.raw)}; } template HWY_API Vec256 ZeroIfNegative(const Vec256 v) { // AVX3 MaskFromVec only looks at the MSB return IfThenZeroElse(MaskFromVec(v), v); } // ------------------------------ Mask logical namespace detail { template HWY_INLINE Mask256 And(hwy::SizeTag<1> /*tag*/, const Mask256 a, const Mask256 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask256{_kand_mask32(a.raw, b.raw)}; #else return Mask256{static_cast<__mmask32>(a.raw & b.raw)}; #endif } template HWY_INLINE Mask256 And(hwy::SizeTag<2> /*tag*/, const Mask256 a, const Mask256 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask256{_kand_mask16(a.raw, b.raw)}; #else return Mask256{static_cast<__mmask16>(a.raw & b.raw)}; #endif } template HWY_INLINE Mask256 And(hwy::SizeTag<4> /*tag*/, const Mask256 a, const Mask256 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask256{_kand_mask8(a.raw, b.raw)}; #else return Mask256{static_cast<__mmask8>(a.raw & b.raw)}; #endif } template HWY_INLINE Mask256 And(hwy::SizeTag<8> /*tag*/, const Mask256 a, const Mask256 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask256{_kand_mask8(a.raw, b.raw)}; #else return Mask256{static_cast<__mmask8>(a.raw & b.raw)}; #endif } template HWY_INLINE Mask256 AndNot(hwy::SizeTag<1> /*tag*/, const Mask256 a, const Mask256 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask256{_kandn_mask32(a.raw, b.raw)}; #else return Mask256{static_cast<__mmask32>(~a.raw & b.raw)}; #endif } template HWY_INLINE Mask256 AndNot(hwy::SizeTag<2> /*tag*/, const Mask256 a, const Mask256 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask256{_kandn_mask16(a.raw, b.raw)}; #else return Mask256{static_cast<__mmask16>(~a.raw & b.raw)}; #endif } template HWY_INLINE Mask256 AndNot(hwy::SizeTag<4> /*tag*/, const Mask256 a, const Mask256 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask256{_kandn_mask8(a.raw, b.raw)}; #else return Mask256{static_cast<__mmask8>(~a.raw & b.raw)}; #endif } template HWY_INLINE Mask256 AndNot(hwy::SizeTag<8> /*tag*/, const Mask256 a, const Mask256 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask256{_kandn_mask8(a.raw, b.raw)}; #else return Mask256{static_cast<__mmask8>(~a.raw & b.raw)}; #endif } template HWY_INLINE Mask256 Or(hwy::SizeTag<1> /*tag*/, const Mask256 a, const Mask256 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask256{_kor_mask32(a.raw, b.raw)}; #else return Mask256{static_cast<__mmask32>(a.raw | b.raw)}; #endif } template HWY_INLINE Mask256 Or(hwy::SizeTag<2> /*tag*/, const Mask256 a, const Mask256 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask256{_kor_mask16(a.raw, b.raw)}; #else return Mask256{static_cast<__mmask16>(a.raw | b.raw)}; #endif } template HWY_INLINE Mask256 Or(hwy::SizeTag<4> /*tag*/, const Mask256 a, const Mask256 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask256{_kor_mask8(a.raw, b.raw)}; #else return Mask256{static_cast<__mmask8>(a.raw | b.raw)}; #endif } template HWY_INLINE Mask256 Or(hwy::SizeTag<8> /*tag*/, const Mask256 a, const Mask256 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask256{_kor_mask8(a.raw, b.raw)}; #else return Mask256{static_cast<__mmask8>(a.raw | b.raw)}; #endif } template HWY_INLINE Mask256 Xor(hwy::SizeTag<1> /*tag*/, const Mask256 a, const Mask256 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask256{_kxor_mask32(a.raw, b.raw)}; #else return Mask256{static_cast<__mmask32>(a.raw ^ b.raw)}; #endif } template HWY_INLINE Mask256 Xor(hwy::SizeTag<2> /*tag*/, const Mask256 a, const Mask256 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask256{_kxor_mask16(a.raw, b.raw)}; #else return Mask256{static_cast<__mmask16>(a.raw ^ b.raw)}; #endif } template HWY_INLINE Mask256 Xor(hwy::SizeTag<4> /*tag*/, const Mask256 a, const Mask256 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask256{_kxor_mask8(a.raw, b.raw)}; #else return Mask256{static_cast<__mmask8>(a.raw ^ b.raw)}; #endif } template HWY_INLINE Mask256 Xor(hwy::SizeTag<8> /*tag*/, const Mask256 a, const Mask256 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask256{_kxor_mask8(a.raw, b.raw)}; #else return Mask256{static_cast<__mmask8>(a.raw ^ b.raw)}; #endif } } // namespace detail template HWY_API Mask256 And(const Mask256 a, Mask256 b) { return detail::And(hwy::SizeTag(), a, b); } template HWY_API Mask256 AndNot(const Mask256 a, Mask256 b) { return detail::AndNot(hwy::SizeTag(), a, b); } template HWY_API Mask256 Or(const Mask256 a, Mask256 b) { return detail::Or(hwy::SizeTag(), a, b); } template HWY_API Mask256 Xor(const Mask256 a, Mask256 b) { return detail::Xor(hwy::SizeTag(), a, b); } template HWY_API Mask256 Not(const Mask256 m) { // Flip only the valid bits. constexpr size_t N = 32 / sizeof(T); return Xor(m, Mask256::FromBits((1ull << N) - 1)); } #else // AVX2 // ------------------------------ Mask // Mask and Vec are the same (true = FF..FF). template HWY_API Mask256 MaskFromVec(const Vec256 v) { return Mask256{v.raw}; } template HWY_API Vec256 VecFromMask(const Mask256 v) { return Vec256{v.raw}; } template HWY_API Vec256 VecFromMask(Full256 /* tag */, const Mask256 v) { return Vec256{v.raw}; } // ------------------------------ IfThenElse // mask ? yes : no template HWY_API Vec256 IfThenElse(const Mask256 mask, const Vec256 yes, const Vec256 no) { return Vec256{_mm256_blendv_epi8(no.raw, yes.raw, mask.raw)}; } HWY_API Vec256 IfThenElse(const Mask256 mask, const Vec256 yes, const Vec256 no) { return Vec256{_mm256_blendv_ps(no.raw, yes.raw, mask.raw)}; } HWY_API Vec256 IfThenElse(const Mask256 mask, const Vec256 yes, const Vec256 no) { return Vec256{_mm256_blendv_pd(no.raw, yes.raw, mask.raw)}; } // mask ? yes : 0 template HWY_API Vec256 IfThenElseZero(Mask256 mask, Vec256 yes) { return yes & VecFromMask(Full256(), mask); } // mask ? 0 : no template HWY_API Vec256 IfThenZeroElse(Mask256 mask, Vec256 no) { return AndNot(VecFromMask(Full256(), mask), no); } template HWY_API Vec256 ZeroIfNegative(Vec256 v) { const auto zero = Zero(Full256()); // AVX2 IfThenElse only looks at the MSB for 32/64-bit lanes return IfThenElse(MaskFromVec(v), zero, v); } // ------------------------------ Mask logical template HWY_API Mask256 Not(const Mask256 m) { return MaskFromVec(Not(VecFromMask(Full256(), m))); } template HWY_API Mask256 And(const Mask256 a, Mask256 b) { const Full256 d; return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b))); } template HWY_API Mask256 AndNot(const Mask256 a, Mask256 b) { const Full256 d; return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b))); } template HWY_API Mask256 Or(const Mask256 a, Mask256 b) { const Full256 d; return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b))); } template HWY_API Mask256 Xor(const Mask256 a, Mask256 b) { const Full256 d; return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b))); } #endif // HWY_TARGET <= HWY_AVX3 // ================================================== COMPARE #if HWY_TARGET <= HWY_AVX3 // Comparisons set a mask bit to 1 if the condition is true, else 0. template HWY_API Mask256 RebindMask(Full256 /*tag*/, Mask256 m) { static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); return Mask256{m.raw}; } namespace detail { template HWY_INLINE Mask256 TestBit(hwy::SizeTag<1> /*tag*/, const Vec256 v, const Vec256 bit) { return Mask256{_mm256_test_epi8_mask(v.raw, bit.raw)}; } template HWY_INLINE Mask256 TestBit(hwy::SizeTag<2> /*tag*/, const Vec256 v, const Vec256 bit) { return Mask256{_mm256_test_epi16_mask(v.raw, bit.raw)}; } template HWY_INLINE Mask256 TestBit(hwy::SizeTag<4> /*tag*/, const Vec256 v, const Vec256 bit) { return Mask256{_mm256_test_epi32_mask(v.raw, bit.raw)}; } template HWY_INLINE Mask256 TestBit(hwy::SizeTag<8> /*tag*/, const Vec256 v, const Vec256 bit) { return Mask256{_mm256_test_epi64_mask(v.raw, bit.raw)}; } } // namespace detail template HWY_API Mask256 TestBit(const Vec256 v, const Vec256 bit) { static_assert(!hwy::IsFloat(), "Only integer vectors supported"); return detail::TestBit(hwy::SizeTag(), v, bit); } // ------------------------------ Equality template HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { return Mask256{_mm256_cmpeq_epi8_mask(a.raw, b.raw)}; } template HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { return Mask256{_mm256_cmpeq_epi16_mask(a.raw, b.raw)}; } template HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { return Mask256{_mm256_cmpeq_epi32_mask(a.raw, b.raw)}; } template HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { return Mask256{_mm256_cmpeq_epi64_mask(a.raw, b.raw)}; } HWY_API Mask256 operator==(Vec256 a, Vec256 b) { return Mask256{_mm256_cmp_ps_mask(a.raw, b.raw, _CMP_EQ_OQ)}; } HWY_API Mask256 operator==(Vec256 a, Vec256 b) { return Mask256{_mm256_cmp_pd_mask(a.raw, b.raw, _CMP_EQ_OQ)}; } // ------------------------------ Inequality template HWY_API Mask256 operator!=(const Vec256 a, const Vec256 b) { return Mask256{_mm256_cmpneq_epi8_mask(a.raw, b.raw)}; } template HWY_API Mask256 operator!=(const Vec256 a, const Vec256 b) { return Mask256{_mm256_cmpneq_epi16_mask(a.raw, b.raw)}; } template HWY_API Mask256 operator!=(const Vec256 a, const Vec256 b) { return Mask256{_mm256_cmpneq_epi32_mask(a.raw, b.raw)}; } template HWY_API Mask256 operator!=(const Vec256 a, const Vec256 b) { return Mask256{_mm256_cmpneq_epi64_mask(a.raw, b.raw)}; } HWY_API Mask256 operator!=(Vec256 a, Vec256 b) { return Mask256{_mm256_cmp_ps_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; } HWY_API Mask256 operator!=(Vec256 a, Vec256 b) { return Mask256{_mm256_cmp_pd_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; } // ------------------------------ Strict inequality HWY_API Mask256 operator>(Vec256 a, Vec256 b) { return Mask256{_mm256_cmpgt_epi8_mask(a.raw, b.raw)}; } HWY_API Mask256 operator>(Vec256 a, Vec256 b) { return Mask256{_mm256_cmpgt_epi16_mask(a.raw, b.raw)}; } HWY_API Mask256 operator>(Vec256 a, Vec256 b) { return Mask256{_mm256_cmpgt_epi32_mask(a.raw, b.raw)}; } HWY_API Mask256 operator>(Vec256 a, Vec256 b) { return Mask256{_mm256_cmpgt_epi64_mask(a.raw, b.raw)}; } HWY_API Mask256 operator>(Vec256 a, Vec256 b) { return Mask256{_mm256_cmpgt_epu8_mask(a.raw, b.raw)}; } HWY_API Mask256 operator>(const Vec256 a, const Vec256 b) { return Mask256{_mm256_cmpgt_epu16_mask(a.raw, b.raw)}; } HWY_API Mask256 operator>(const Vec256 a, const Vec256 b) { return Mask256{_mm256_cmpgt_epu32_mask(a.raw, b.raw)}; } HWY_API Mask256 operator>(const Vec256 a, const Vec256 b) { return Mask256{_mm256_cmpgt_epu64_mask(a.raw, b.raw)}; } HWY_API Mask256 operator>(Vec256 a, Vec256 b) { return Mask256{_mm256_cmp_ps_mask(a.raw, b.raw, _CMP_GT_OQ)}; } HWY_API Mask256 operator>(Vec256 a, Vec256 b) { return Mask256{_mm256_cmp_pd_mask(a.raw, b.raw, _CMP_GT_OQ)}; } // ------------------------------ Weak inequality HWY_API Mask256 operator>=(Vec256 a, Vec256 b) { return Mask256{_mm256_cmp_ps_mask(a.raw, b.raw, _CMP_GE_OQ)}; } HWY_API Mask256 operator>=(Vec256 a, Vec256 b) { return Mask256{_mm256_cmp_pd_mask(a.raw, b.raw, _CMP_GE_OQ)}; } // ------------------------------ Mask namespace detail { template HWY_INLINE Mask256 MaskFromVec(hwy::SizeTag<1> /*tag*/, const Vec256 v) { return Mask256{_mm256_movepi8_mask(v.raw)}; } template HWY_INLINE Mask256 MaskFromVec(hwy::SizeTag<2> /*tag*/, const Vec256 v) { return Mask256{_mm256_movepi16_mask(v.raw)}; } template HWY_INLINE Mask256 MaskFromVec(hwy::SizeTag<4> /*tag*/, const Vec256 v) { return Mask256{_mm256_movepi32_mask(v.raw)}; } template HWY_INLINE Mask256 MaskFromVec(hwy::SizeTag<8> /*tag*/, const Vec256 v) { return Mask256{_mm256_movepi64_mask(v.raw)}; } } // namespace detail template HWY_API Mask256 MaskFromVec(const Vec256 v) { return detail::MaskFromVec(hwy::SizeTag(), v); } // There do not seem to be native floating-point versions of these instructions. HWY_API Mask256 MaskFromVec(const Vec256 v) { return Mask256{MaskFromVec(BitCast(Full256(), v)).raw}; } HWY_API Mask256 MaskFromVec(const Vec256 v) { return Mask256{MaskFromVec(BitCast(Full256(), v)).raw}; } template HWY_API Vec256 VecFromMask(const Mask256 v) { return Vec256{_mm256_movm_epi8(v.raw)}; } template HWY_API Vec256 VecFromMask(const Mask256 v) { return Vec256{_mm256_movm_epi16(v.raw)}; } template HWY_API Vec256 VecFromMask(const Mask256 v) { return Vec256{_mm256_movm_epi32(v.raw)}; } template HWY_API Vec256 VecFromMask(const Mask256 v) { return Vec256{_mm256_movm_epi64(v.raw)}; } HWY_API Vec256 VecFromMask(const Mask256 v) { return Vec256{_mm256_castsi256_ps(_mm256_movm_epi32(v.raw))}; } HWY_API Vec256 VecFromMask(const Mask256 v) { return Vec256{_mm256_castsi256_pd(_mm256_movm_epi64(v.raw))}; } template HWY_API Vec256 VecFromMask(Full256 /* tag */, const Mask256 v) { return VecFromMask(v); } #else // AVX2 // Comparisons fill a lane with 1-bits if the condition is true, else 0. template HWY_API Mask256 RebindMask(Full256 d_to, Mask256 m) { static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); return MaskFromVec(BitCast(d_to, VecFromMask(Full256(), m))); } template HWY_API Mask256 TestBit(const Vec256 v, const Vec256 bit) { static_assert(!hwy::IsFloat(), "Only integer vectors supported"); return (v & bit) == bit; } // ------------------------------ Equality template HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { return Mask256{_mm256_cmpeq_epi8(a.raw, b.raw)}; } template HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { return Mask256{_mm256_cmpeq_epi16(a.raw, b.raw)}; } template HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { return Mask256{_mm256_cmpeq_epi32(a.raw, b.raw)}; } template HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { return Mask256{_mm256_cmpeq_epi64(a.raw, b.raw)}; } HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { return Mask256{_mm256_cmp_ps(a.raw, b.raw, _CMP_EQ_OQ)}; } HWY_API Mask256 operator==(const Vec256 a, const Vec256 b) { return Mask256{_mm256_cmp_pd(a.raw, b.raw, _CMP_EQ_OQ)}; } // ------------------------------ Inequality template HWY_API Mask256 operator!=(const Vec256 a, const Vec256 b) { return Not(a == b); } HWY_API Mask256 operator!=(const Vec256 a, const Vec256 b) { return Mask256{_mm256_cmp_ps(a.raw, b.raw, _CMP_NEQ_OQ)}; } HWY_API Mask256 operator!=(const Vec256 a, const Vec256 b) { return Mask256{_mm256_cmp_pd(a.raw, b.raw, _CMP_NEQ_OQ)}; } // ------------------------------ Strict inequality // Pre-9.3 GCC immintrin.h uses char, which may be unsigned, causing cmpgt_epi8 // to perform an unsigned comparison instead of the intended signed. Workaround // is to cast to an explicitly signed type. See https://godbolt.org/z/PL7Ujy #if HWY_COMPILER_GCC != 0 && HWY_COMPILER_GCC < 930 #define HWY_AVX2_GCC_CMPGT8_WORKAROUND 1 #else #define HWY_AVX2_GCC_CMPGT8_WORKAROUND 0 #endif HWY_API Mask256 operator>(Vec256 a, Vec256 b) { #if HWY_AVX2_GCC_CMPGT8_WORKAROUND using i8x32 = signed char __attribute__((__vector_size__(32))); return Mask256{static_cast<__m256i>(reinterpret_cast(a.raw) > reinterpret_cast(b.raw))}; #else return Mask256{_mm256_cmpgt_epi8(a.raw, b.raw)}; #endif } HWY_API Mask256 operator>(const Vec256 a, const Vec256 b) { return Mask256{_mm256_cmpgt_epi16(a.raw, b.raw)}; } HWY_API Mask256 operator>(const Vec256 a, const Vec256 b) { return Mask256{_mm256_cmpgt_epi32(a.raw, b.raw)}; } HWY_API Mask256 operator>(const Vec256 a, const Vec256 b) { return Mask256{_mm256_cmpgt_epi64(a.raw, b.raw)}; } template HWY_API Mask256 operator>(const Vec256 a, const Vec256 b) { const Full256 du; const RebindToSigned di; const Vec256 msb = Set(du, (LimitsMax() >> 1) + 1); return RebindMask(du, BitCast(di, Xor(a, msb)) > BitCast(di, Xor(b, msb))); } HWY_API Mask256 operator>(const Vec256 a, const Vec256 b) { return Mask256{_mm256_cmp_ps(a.raw, b.raw, _CMP_GT_OQ)}; } HWY_API Mask256 operator>(Vec256 a, Vec256 b) { return Mask256{_mm256_cmp_pd(a.raw, b.raw, _CMP_GT_OQ)}; } // ------------------------------ Weak inequality HWY_API Mask256 operator>=(const Vec256 a, const Vec256 b) { return Mask256{_mm256_cmp_ps(a.raw, b.raw, _CMP_GE_OQ)}; } HWY_API Mask256 operator>=(const Vec256 a, const Vec256 b) { return Mask256{_mm256_cmp_pd(a.raw, b.raw, _CMP_GE_OQ)}; } #endif // HWY_TARGET <= HWY_AVX3 // ------------------------------ Reversed comparisons template HWY_API Mask256 operator<(const Vec256 a, const Vec256 b) { return b > a; } template HWY_API Mask256 operator<=(const Vec256 a, const Vec256 b) { return b >= a; } // ------------------------------ Min (Gt, IfThenElse) // Unsigned HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { return Vec256{_mm256_min_epu8(a.raw, b.raw)}; } HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { return Vec256{_mm256_min_epu16(a.raw, b.raw)}; } HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { return Vec256{_mm256_min_epu32(a.raw, b.raw)}; } HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { #if HWY_TARGET <= HWY_AVX3 return Vec256{_mm256_min_epu64(a.raw, b.raw)}; #else const Full256 du; const Full256 di; const auto msb = Set(du, 1ull << 63); const auto gt = RebindMask(du, BitCast(di, a ^ msb) > BitCast(di, b ^ msb)); return IfThenElse(gt, b, a); #endif } // Signed HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { return Vec256{_mm256_min_epi8(a.raw, b.raw)}; } HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { return Vec256{_mm256_min_epi16(a.raw, b.raw)}; } HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { return Vec256{_mm256_min_epi32(a.raw, b.raw)}; } HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { #if HWY_TARGET <= HWY_AVX3 return Vec256{_mm256_min_epi64(a.raw, b.raw)}; #else return IfThenElse(a < b, a, b); #endif } // Float HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { return Vec256{_mm256_min_ps(a.raw, b.raw)}; } HWY_API Vec256 Min(const Vec256 a, const Vec256 b) { return Vec256{_mm256_min_pd(a.raw, b.raw)}; } // ------------------------------ Max (Gt, IfThenElse) // Unsigned HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { return Vec256{_mm256_max_epu8(a.raw, b.raw)}; } HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { return Vec256{_mm256_max_epu16(a.raw, b.raw)}; } HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { return Vec256{_mm256_max_epu32(a.raw, b.raw)}; } HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { #if HWY_TARGET <= HWY_AVX3 return Vec256{_mm256_max_epu64(a.raw, b.raw)}; #else const Full256 du; const Full256 di; const auto msb = Set(du, 1ull << 63); const auto gt = RebindMask(du, BitCast(di, a ^ msb) > BitCast(di, b ^ msb)); return IfThenElse(gt, a, b); #endif } // Signed HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { return Vec256{_mm256_max_epi8(a.raw, b.raw)}; } HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { return Vec256{_mm256_max_epi16(a.raw, b.raw)}; } HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { return Vec256{_mm256_max_epi32(a.raw, b.raw)}; } HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { #if HWY_TARGET <= HWY_AVX3 return Vec256{_mm256_max_epi64(a.raw, b.raw)}; #else return IfThenElse(a < b, b, a); #endif } // Float HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { return Vec256{_mm256_max_ps(a.raw, b.raw)}; } HWY_API Vec256 Max(const Vec256 a, const Vec256 b) { return Vec256{_mm256_max_pd(a.raw, b.raw)}; } // ------------------------------ FirstN (Iota, Lt) template HWY_API Mask256 FirstN(const Full256 d, size_t n) { #if HWY_TARGET <= HWY_AVX3 (void)d; constexpr size_t N = 32 / sizeof(T); #if HWY_ARCH_X86_64 const uint64_t all = (1ull << N) - 1; // BZHI only looks at the lower 8 bits of n! return Mask256::FromBits((n > 255) ? all : _bzhi_u64(all, n)); #else const uint32_t all = static_cast((1ull << N) - 1); // BZHI only looks at the lower 8 bits of n! return Mask256::FromBits( (n > 255) ? all : _bzhi_u32(all, static_cast(n))); #endif // HWY_ARCH_X86_64 #else const RebindToSigned di; // Signed comparisons are cheaper. return RebindMask(d, Iota(di, 0) < Set(di, static_cast>(n))); #endif } // ================================================== ARITHMETIC // ------------------------------ Addition // Unsigned HWY_API Vec256 operator+(const Vec256 a, const Vec256 b) { return Vec256{_mm256_add_epi8(a.raw, b.raw)}; } HWY_API Vec256 operator+(const Vec256 a, const Vec256 b) { return Vec256{_mm256_add_epi16(a.raw, b.raw)}; } HWY_API Vec256 operator+(const Vec256 a, const Vec256 b) { return Vec256{_mm256_add_epi32(a.raw, b.raw)}; } HWY_API Vec256 operator+(const Vec256 a, const Vec256 b) { return Vec256{_mm256_add_epi64(a.raw, b.raw)}; } // Signed HWY_API Vec256 operator+(const Vec256 a, const Vec256 b) { return Vec256{_mm256_add_epi8(a.raw, b.raw)}; } HWY_API Vec256 operator+(const Vec256 a, const Vec256 b) { return Vec256{_mm256_add_epi16(a.raw, b.raw)}; } HWY_API Vec256 operator+(const Vec256 a, const Vec256 b) { return Vec256{_mm256_add_epi32(a.raw, b.raw)}; } HWY_API Vec256 operator+(const Vec256 a, const Vec256 b) { return Vec256{_mm256_add_epi64(a.raw, b.raw)}; } // Float HWY_API Vec256 operator+(const Vec256 a, const Vec256 b) { return Vec256{_mm256_add_ps(a.raw, b.raw)}; } HWY_API Vec256 operator+(const Vec256 a, const Vec256 b) { return Vec256{_mm256_add_pd(a.raw, b.raw)}; } // ------------------------------ Subtraction // Unsigned HWY_API Vec256 operator-(const Vec256 a, const Vec256 b) { return Vec256{_mm256_sub_epi8(a.raw, b.raw)}; } HWY_API Vec256 operator-(const Vec256 a, const Vec256 b) { return Vec256{_mm256_sub_epi16(a.raw, b.raw)}; } HWY_API Vec256 operator-(const Vec256 a, const Vec256 b) { return Vec256{_mm256_sub_epi32(a.raw, b.raw)}; } HWY_API Vec256 operator-(const Vec256 a, const Vec256 b) { return Vec256{_mm256_sub_epi64(a.raw, b.raw)}; } // Signed HWY_API Vec256 operator-(const Vec256 a, const Vec256 b) { return Vec256{_mm256_sub_epi8(a.raw, b.raw)}; } HWY_API Vec256 operator-(const Vec256 a, const Vec256 b) { return Vec256{_mm256_sub_epi16(a.raw, b.raw)}; } HWY_API Vec256 operator-(const Vec256 a, const Vec256 b) { return Vec256{_mm256_sub_epi32(a.raw, b.raw)}; } HWY_API Vec256 operator-(const Vec256 a, const Vec256 b) { return Vec256{_mm256_sub_epi64(a.raw, b.raw)}; } // Float HWY_API Vec256 operator-(const Vec256 a, const Vec256 b) { return Vec256{_mm256_sub_ps(a.raw, b.raw)}; } HWY_API Vec256 operator-(const Vec256 a, const Vec256 b) { return Vec256{_mm256_sub_pd(a.raw, b.raw)}; } // ------------------------------ SumsOf8 HWY_API Vec256 SumsOf8(const Vec256 v) { return Vec256{_mm256_sad_epu8(v.raw, _mm256_setzero_si256())}; } // ------------------------------ SaturatedAdd // Returns a + b clamped to the destination range. // Unsigned HWY_API Vec256 SaturatedAdd(const Vec256 a, const Vec256 b) { return Vec256{_mm256_adds_epu8(a.raw, b.raw)}; } HWY_API Vec256 SaturatedAdd(const Vec256 a, const Vec256 b) { return Vec256{_mm256_adds_epu16(a.raw, b.raw)}; } // Signed HWY_API Vec256 SaturatedAdd(const Vec256 a, const Vec256 b) { return Vec256{_mm256_adds_epi8(a.raw, b.raw)}; } HWY_API Vec256 SaturatedAdd(const Vec256 a, const Vec256 b) { return Vec256{_mm256_adds_epi16(a.raw, b.raw)}; } // ------------------------------ SaturatedSub // Returns a - b clamped to the destination range. // Unsigned HWY_API Vec256 SaturatedSub(const Vec256 a, const Vec256 b) { return Vec256{_mm256_subs_epu8(a.raw, b.raw)}; } HWY_API Vec256 SaturatedSub(const Vec256 a, const Vec256 b) { return Vec256{_mm256_subs_epu16(a.raw, b.raw)}; } // Signed HWY_API Vec256 SaturatedSub(const Vec256 a, const Vec256 b) { return Vec256{_mm256_subs_epi8(a.raw, b.raw)}; } HWY_API Vec256 SaturatedSub(const Vec256 a, const Vec256 b) { return Vec256{_mm256_subs_epi16(a.raw, b.raw)}; } // ------------------------------ Average // Returns (a + b + 1) / 2 // Unsigned HWY_API Vec256 AverageRound(const Vec256 a, const Vec256 b) { return Vec256{_mm256_avg_epu8(a.raw, b.raw)}; } HWY_API Vec256 AverageRound(const Vec256 a, const Vec256 b) { return Vec256{_mm256_avg_epu16(a.raw, b.raw)}; } // ------------------------------ Abs (Sub) // Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1. HWY_API Vec256 Abs(const Vec256 v) { #if HWY_COMPILER_MSVC // Workaround for incorrect codegen? (wrong result) const auto zero = Zero(Full256()); return Vec256{_mm256_max_epi8(v.raw, (zero - v).raw)}; #else return Vec256{_mm256_abs_epi8(v.raw)}; #endif } HWY_API Vec256 Abs(const Vec256 v) { return Vec256{_mm256_abs_epi16(v.raw)}; } HWY_API Vec256 Abs(const Vec256 v) { return Vec256{_mm256_abs_epi32(v.raw)}; } // i64 is implemented after BroadcastSignBit. HWY_API Vec256 Abs(const Vec256 v) { const Vec256 mask{_mm256_set1_epi32(0x7FFFFFFF)}; return v & BitCast(Full256(), mask); } HWY_API Vec256 Abs(const Vec256 v) { const Vec256 mask{_mm256_set1_epi64x(0x7FFFFFFFFFFFFFFFLL)}; return v & BitCast(Full256(), mask); } // ------------------------------ Integer multiplication // Unsigned HWY_API Vec256 operator*(const Vec256 a, const Vec256 b) { return Vec256{_mm256_mullo_epi16(a.raw, b.raw)}; } HWY_API Vec256 operator*(const Vec256 a, const Vec256 b) { return Vec256{_mm256_mullo_epi32(a.raw, b.raw)}; } // Signed HWY_API Vec256 operator*(const Vec256 a, const Vec256 b) { return Vec256{_mm256_mullo_epi16(a.raw, b.raw)}; } HWY_API Vec256 operator*(const Vec256 a, const Vec256 b) { return Vec256{_mm256_mullo_epi32(a.raw, b.raw)}; } // Returns the upper 16 bits of a * b in each lane. HWY_API Vec256 MulHigh(const Vec256 a, const Vec256 b) { return Vec256{_mm256_mulhi_epu16(a.raw, b.raw)}; } HWY_API Vec256 MulHigh(const Vec256 a, const Vec256 b) { return Vec256{_mm256_mulhi_epi16(a.raw, b.raw)}; } // Multiplies even lanes (0, 2 ..) and places the double-wide result into // even and the upper half into its odd neighbor lane. HWY_API Vec256 MulEven(const Vec256 a, const Vec256 b) { return Vec256{_mm256_mul_epi32(a.raw, b.raw)}; } HWY_API Vec256 MulEven(const Vec256 a, const Vec256 b) { return Vec256{_mm256_mul_epu32(a.raw, b.raw)}; } // ------------------------------ ShiftLeft template HWY_API Vec256 ShiftLeft(const Vec256 v) { return Vec256{_mm256_slli_epi16(v.raw, kBits)}; } template HWY_API Vec256 ShiftLeft(const Vec256 v) { return Vec256{_mm256_slli_epi32(v.raw, kBits)}; } template HWY_API Vec256 ShiftLeft(const Vec256 v) { return Vec256{_mm256_slli_epi64(v.raw, kBits)}; } template HWY_API Vec256 ShiftLeft(const Vec256 v) { return Vec256{_mm256_slli_epi16(v.raw, kBits)}; } template HWY_API Vec256 ShiftLeft(const Vec256 v) { return Vec256{_mm256_slli_epi32(v.raw, kBits)}; } template HWY_API Vec256 ShiftLeft(const Vec256 v) { return Vec256{_mm256_slli_epi64(v.raw, kBits)}; } template HWY_API Vec256 ShiftLeft(const Vec256 v) { const Full256 d8; const RepartitionToWide d16; const auto shifted = BitCast(d8, ShiftLeft(BitCast(d16, v))); return kBits == 1 ? (v + v) : (shifted & Set(d8, static_cast((0xFF << kBits) & 0xFF))); } // ------------------------------ ShiftRight template HWY_API Vec256 ShiftRight(const Vec256 v) { return Vec256{_mm256_srli_epi16(v.raw, kBits)}; } template HWY_API Vec256 ShiftRight(const Vec256 v) { return Vec256{_mm256_srli_epi32(v.raw, kBits)}; } template HWY_API Vec256 ShiftRight(const Vec256 v) { return Vec256{_mm256_srli_epi64(v.raw, kBits)}; } template HWY_API Vec256 ShiftRight(const Vec256 v) { const Full256 d8; // Use raw instead of BitCast to support N=1. const Vec256 shifted{ShiftRight(Vec256{v.raw}).raw}; return shifted & Set(d8, 0xFF >> kBits); } template HWY_API Vec256 ShiftRight(const Vec256 v) { return Vec256{_mm256_srai_epi16(v.raw, kBits)}; } template HWY_API Vec256 ShiftRight(const Vec256 v) { return Vec256{_mm256_srai_epi32(v.raw, kBits)}; } template HWY_API Vec256 ShiftRight(const Vec256 v) { const Full256 di; const Full256 du; const auto shifted = BitCast(di, ShiftRight(BitCast(du, v))); const auto shifted_sign = BitCast(di, Set(du, 0x80 >> kBits)); return (shifted ^ shifted_sign) - shifted_sign; } // i64 is implemented after BroadcastSignBit. // ------------------------------ RotateRight template HWY_API Vec256 RotateRight(const Vec256 v) { static_assert(0 <= kBits && kBits < 32, "Invalid shift count"); #if HWY_TARGET <= HWY_AVX3 return Vec256{_mm256_ror_epi32(v.raw, kBits)}; #else if (kBits == 0) return v; return Or(ShiftRight(v), ShiftLeft(v)); #endif } template HWY_API Vec256 RotateRight(const Vec256 v) { static_assert(0 <= kBits && kBits < 64, "Invalid shift count"); #if HWY_TARGET <= HWY_AVX3 return Vec256{_mm256_ror_epi64(v.raw, kBits)}; #else if (kBits == 0) return v; return Or(ShiftRight(v), ShiftLeft(v)); #endif } // ------------------------------ BroadcastSignBit (ShiftRight, compare, mask) HWY_API Vec256 BroadcastSignBit(const Vec256 v) { return VecFromMask(v < Zero(Full256())); } HWY_API Vec256 BroadcastSignBit(const Vec256 v) { return ShiftRight<15>(v); } HWY_API Vec256 BroadcastSignBit(const Vec256 v) { return ShiftRight<31>(v); } HWY_API Vec256 BroadcastSignBit(const Vec256 v) { #if HWY_TARGET == HWY_AVX2 return VecFromMask(v < Zero(Full256())); #else return Vec256{_mm256_srai_epi64(v.raw, 63)}; #endif } template HWY_API Vec256 ShiftRight(const Vec256 v) { #if HWY_TARGET <= HWY_AVX3 return Vec256{_mm256_srai_epi64(v.raw, kBits)}; #else const Full256 di; const Full256 du; const auto right = BitCast(di, ShiftRight(BitCast(du, v))); const auto sign = ShiftLeft<64 - kBits>(BroadcastSignBit(v)); return right | sign; #endif } HWY_API Vec256 Abs(const Vec256 v) { #if HWY_TARGET <= HWY_AVX3 return Vec256{_mm256_abs_epi64(v.raw)}; #else const auto zero = Zero(Full256()); return IfThenElse(MaskFromVec(BroadcastSignBit(v)), zero - v, v); #endif } // ------------------------------ IfNegativeThenElse (BroadcastSignBit) HWY_API Vec256 IfNegativeThenElse(Vec256 v, Vec256 yes, Vec256 no) { // int8: AVX2 IfThenElse only looks at the MSB. return IfThenElse(MaskFromVec(v), yes, no); } template HWY_API Vec256 IfNegativeThenElse(Vec256 v, Vec256 yes, Vec256 no) { static_assert(IsSigned(), "Only works for signed/float"); const Full256 d; const RebindToSigned di; // 16-bit: no native blendv, so copy sign to lower byte's MSB. v = BitCast(d, BroadcastSignBit(BitCast(di, v))); return IfThenElse(MaskFromVec(v), yes, no); } template HWY_API Vec256 IfNegativeThenElse(Vec256 v, Vec256 yes, Vec256 no) { static_assert(IsSigned(), "Only works for signed/float"); const Full256 d; const RebindToFloat df; // 32/64-bit: use float IfThenElse, which only looks at the MSB. const MFromD msb = MaskFromVec(BitCast(df, v)); return BitCast(d, IfThenElse(msb, BitCast(df, yes), BitCast(df, no))); } // ------------------------------ ShiftLeftSame HWY_API Vec256 ShiftLeftSame(const Vec256 v, const int bits) { return Vec256{_mm256_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; } HWY_API Vec256 ShiftLeftSame(const Vec256 v, const int bits) { return Vec256{_mm256_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; } HWY_API Vec256 ShiftLeftSame(const Vec256 v, const int bits) { return Vec256{_mm256_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; } HWY_API Vec256 ShiftLeftSame(const Vec256 v, const int bits) { return Vec256{_mm256_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; } HWY_API Vec256 ShiftLeftSame(const Vec256 v, const int bits) { return Vec256{_mm256_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; } HWY_API Vec256 ShiftLeftSame(const Vec256 v, const int bits) { return Vec256{_mm256_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; } template HWY_API Vec256 ShiftLeftSame(const Vec256 v, const int bits) { const Full256 d8; const RepartitionToWide d16; const auto shifted = BitCast(d8, ShiftLeftSame(BitCast(d16, v), bits)); return shifted & Set(d8, static_cast((0xFF << bits) & 0xFF)); } // ------------------------------ ShiftRightSame (BroadcastSignBit) HWY_API Vec256 ShiftRightSame(const Vec256 v, const int bits) { return Vec256{_mm256_srl_epi16(v.raw, _mm_cvtsi32_si128(bits))}; } HWY_API Vec256 ShiftRightSame(const Vec256 v, const int bits) { return Vec256{_mm256_srl_epi32(v.raw, _mm_cvtsi32_si128(bits))}; } HWY_API Vec256 ShiftRightSame(const Vec256 v, const int bits) { return Vec256{_mm256_srl_epi64(v.raw, _mm_cvtsi32_si128(bits))}; } HWY_API Vec256 ShiftRightSame(Vec256 v, const int bits) { const Full256 d8; const RepartitionToWide d16; const auto shifted = BitCast(d8, ShiftRightSame(BitCast(d16, v), bits)); return shifted & Set(d8, static_cast(0xFF >> bits)); } HWY_API Vec256 ShiftRightSame(const Vec256 v, const int bits) { return Vec256{_mm256_sra_epi16(v.raw, _mm_cvtsi32_si128(bits))}; } HWY_API Vec256 ShiftRightSame(const Vec256 v, const int bits) { return Vec256{_mm256_sra_epi32(v.raw, _mm_cvtsi32_si128(bits))}; } HWY_API Vec256 ShiftRightSame(const Vec256 v, const int bits) { #if HWY_TARGET <= HWY_AVX3 return Vec256{_mm256_sra_epi64(v.raw, _mm_cvtsi32_si128(bits))}; #else const Full256 di; const Full256 du; const auto right = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); const auto sign = ShiftLeftSame(BroadcastSignBit(v), 64 - bits); return right | sign; #endif } HWY_API Vec256 ShiftRightSame(Vec256 v, const int bits) { const Full256 di; const Full256 du; const auto shifted = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); const auto shifted_sign = BitCast(di, Set(du, static_cast(0x80 >> bits))); return (shifted ^ shifted_sign) - shifted_sign; } // ------------------------------ Neg (Xor, Sub) template HWY_API Vec256 Neg(const Vec256 v) { return Xor(v, SignBit(Full256())); } template HWY_API Vec256 Neg(const Vec256 v) { return Zero(Full256()) - v; } // ------------------------------ Floating-point mul / div HWY_API Vec256 operator*(const Vec256 a, const Vec256 b) { return Vec256{_mm256_mul_ps(a.raw, b.raw)}; } HWY_API Vec256 operator*(const Vec256 a, const Vec256 b) { return Vec256{_mm256_mul_pd(a.raw, b.raw)}; } HWY_API Vec256 operator/(const Vec256 a, const Vec256 b) { return Vec256{_mm256_div_ps(a.raw, b.raw)}; } HWY_API Vec256 operator/(const Vec256 a, const Vec256 b) { return Vec256{_mm256_div_pd(a.raw, b.raw)}; } // Approximate reciprocal HWY_API Vec256 ApproximateReciprocal(const Vec256 v) { return Vec256{_mm256_rcp_ps(v.raw)}; } // Absolute value of difference. HWY_API Vec256 AbsDiff(const Vec256 a, const Vec256 b) { return Abs(a - b); } // ------------------------------ Floating-point multiply-add variants // Returns mul * x + add HWY_API Vec256 MulAdd(const Vec256 mul, const Vec256 x, const Vec256 add) { #ifdef HWY_DISABLE_BMI2_FMA return mul * x + add; #else return Vec256{_mm256_fmadd_ps(mul.raw, x.raw, add.raw)}; #endif } HWY_API Vec256 MulAdd(const Vec256 mul, const Vec256 x, const Vec256 add) { #ifdef HWY_DISABLE_BMI2_FMA return mul * x + add; #else return Vec256{_mm256_fmadd_pd(mul.raw, x.raw, add.raw)}; #endif } // Returns add - mul * x HWY_API Vec256 NegMulAdd(const Vec256 mul, const Vec256 x, const Vec256 add) { #ifdef HWY_DISABLE_BMI2_FMA return add - mul * x; #else return Vec256{_mm256_fnmadd_ps(mul.raw, x.raw, add.raw)}; #endif } HWY_API Vec256 NegMulAdd(const Vec256 mul, const Vec256 x, const Vec256 add) { #ifdef HWY_DISABLE_BMI2_FMA return add - mul * x; #else return Vec256{_mm256_fnmadd_pd(mul.raw, x.raw, add.raw)}; #endif } // Returns mul * x - sub HWY_API Vec256 MulSub(const Vec256 mul, const Vec256 x, const Vec256 sub) { #ifdef HWY_DISABLE_BMI2_FMA return mul * x - sub; #else return Vec256{_mm256_fmsub_ps(mul.raw, x.raw, sub.raw)}; #endif } HWY_API Vec256 MulSub(const Vec256 mul, const Vec256 x, const Vec256 sub) { #ifdef HWY_DISABLE_BMI2_FMA return mul * x - sub; #else return Vec256{_mm256_fmsub_pd(mul.raw, x.raw, sub.raw)}; #endif } // Returns -mul * x - sub HWY_API Vec256 NegMulSub(const Vec256 mul, const Vec256 x, const Vec256 sub) { #ifdef HWY_DISABLE_BMI2_FMA return Neg(mul * x) - sub; #else return Vec256{_mm256_fnmsub_ps(mul.raw, x.raw, sub.raw)}; #endif } HWY_API Vec256 NegMulSub(const Vec256 mul, const Vec256 x, const Vec256 sub) { #ifdef HWY_DISABLE_BMI2_FMA return Neg(mul * x) - sub; #else return Vec256{_mm256_fnmsub_pd(mul.raw, x.raw, sub.raw)}; #endif } // ------------------------------ Floating-point square root // Full precision square root HWY_API Vec256 Sqrt(const Vec256 v) { return Vec256{_mm256_sqrt_ps(v.raw)}; } HWY_API Vec256 Sqrt(const Vec256 v) { return Vec256{_mm256_sqrt_pd(v.raw)}; } // Approximate reciprocal square root HWY_API Vec256 ApproximateReciprocalSqrt(const Vec256 v) { return Vec256{_mm256_rsqrt_ps(v.raw)}; } // ------------------------------ Floating-point rounding // Toward nearest integer, tie to even HWY_API Vec256 Round(const Vec256 v) { return Vec256{ _mm256_round_ps(v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; } HWY_API Vec256 Round(const Vec256 v) { return Vec256{ _mm256_round_pd(v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; } // Toward zero, aka truncate HWY_API Vec256 Trunc(const Vec256 v) { return Vec256{ _mm256_round_ps(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; } HWY_API Vec256 Trunc(const Vec256 v) { return Vec256{ _mm256_round_pd(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; } // Toward +infinity, aka ceiling HWY_API Vec256 Ceil(const Vec256 v) { return Vec256{ _mm256_round_ps(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; } HWY_API Vec256 Ceil(const Vec256 v) { return Vec256{ _mm256_round_pd(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; } // Toward -infinity, aka floor HWY_API Vec256 Floor(const Vec256 v) { return Vec256{ _mm256_round_ps(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; } HWY_API Vec256 Floor(const Vec256 v) { return Vec256{ _mm256_round_pd(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; } // ================================================== MEMORY // ------------------------------ Load template HWY_API Vec256 Load(Full256 /* tag */, const T* HWY_RESTRICT aligned) { return Vec256{ _mm256_load_si256(reinterpret_cast(aligned))}; } HWY_API Vec256 Load(Full256 /* tag */, const float* HWY_RESTRICT aligned) { return Vec256{_mm256_load_ps(aligned)}; } HWY_API Vec256 Load(Full256 /* tag */, const double* HWY_RESTRICT aligned) { return Vec256{_mm256_load_pd(aligned)}; } template HWY_API Vec256 LoadU(Full256 /* tag */, const T* HWY_RESTRICT p) { return Vec256{_mm256_loadu_si256(reinterpret_cast(p))}; } HWY_API Vec256 LoadU(Full256 /* tag */, const float* HWY_RESTRICT p) { return Vec256{_mm256_loadu_ps(p)}; } HWY_API Vec256 LoadU(Full256 /* tag */, const double* HWY_RESTRICT p) { return Vec256{_mm256_loadu_pd(p)}; } // ------------------------------ MaskedLoad #if HWY_TARGET <= HWY_AVX3 template HWY_API Vec256 MaskedLoad(Mask256 m, Full256 /* tag */, const T* HWY_RESTRICT aligned) { return Vec256{_mm256_maskz_load_epi32(m.raw, aligned)}; } template HWY_API Vec256 MaskedLoad(Mask256 m, Full256 /* tag */, const T* HWY_RESTRICT aligned) { return Vec256{_mm256_maskz_load_epi64(m.raw, aligned)}; } HWY_API Vec256 MaskedLoad(Mask256 m, Full256 /* tag */, const float* HWY_RESTRICT aligned) { return Vec256{_mm256_maskz_load_ps(m.raw, aligned)}; } HWY_API Vec256 MaskedLoad(Mask256 m, Full256 /* tag */, const double* HWY_RESTRICT aligned) { return Vec256{_mm256_maskz_load_pd(m.raw, aligned)}; } // There is no load_epi8/16, so use loadu instead. template HWY_API Vec256 MaskedLoad(Mask256 m, Full256 /* tag */, const T* HWY_RESTRICT aligned) { return Vec256{_mm256_maskz_loadu_epi8(m.raw, aligned)}; } template HWY_API Vec256 MaskedLoad(Mask256 m, Full256 /* tag */, const T* HWY_RESTRICT aligned) { return Vec256{_mm256_maskz_loadu_epi16(m.raw, aligned)}; } #else // AVX2 template HWY_API Vec256 MaskedLoad(Mask256 m, Full256 /* tag */, const T* HWY_RESTRICT aligned) { auto aligned_p = reinterpret_cast(aligned); // NOLINT return Vec256{_mm256_maskload_epi32(aligned_p, m.raw)}; } template HWY_API Vec256 MaskedLoad(Mask256 m, Full256 /* tag */, const T* HWY_RESTRICT aligned) { auto aligned_p = reinterpret_cast(aligned); // NOLINT return Vec256{_mm256_maskload_epi64(aligned_p, m.raw)}; } HWY_API Vec256 MaskedLoad(Mask256 m, Full256 d, const float* HWY_RESTRICT aligned) { const Vec256 mi = BitCast(RebindToSigned(), VecFromMask(d, m)); return Vec256{_mm256_maskload_ps(aligned, mi.raw)}; } HWY_API Vec256 MaskedLoad(Mask256 m, Full256 d, const double* HWY_RESTRICT aligned) { const Vec256 mi = BitCast(RebindToSigned(), VecFromMask(d, m)); return Vec256{_mm256_maskload_pd(aligned, mi.raw)}; } // There is no maskload_epi8/16, so blend instead. template * = nullptr> HWY_API Vec256 MaskedLoad(Mask256 m, Full256 d, const T* HWY_RESTRICT aligned) { return IfThenElseZero(m, Load(d, aligned)); } #endif // ------------------------------ LoadDup128 // Loads 128 bit and duplicates into both 128-bit halves. This avoids the // 3-cycle cost of moving data between 128-bit halves and avoids port 5. template HWY_API Vec256 LoadDup128(Full256 /* tag */, const T* HWY_RESTRICT p) { #if HWY_LOADDUP_ASM __m256i out; asm("vbroadcasti128 %1, %[reg]" : [ reg ] "=x"(out) : "m"(p[0])); return Vec256{out}; #elif HWY_COMPILER_MSVC && !HWY_COMPILER_CLANG // Workaround for incorrect results with _mm256_broadcastsi128_si256. Note // that MSVC also lacks _mm256_zextsi128_si256, but cast (which leaves the // upper half undefined) is fine because we're overwriting that anyway. const __m128i v128 = LoadU(Full128(), p).raw; return Vec256{ _mm256_inserti128_si256(_mm256_castsi128_si256(v128), v128, 1)}; #else return Vec256{_mm256_broadcastsi128_si256(LoadU(Full128(), p).raw)}; #endif } HWY_API Vec256 LoadDup128(Full256 /* tag */, const float* const HWY_RESTRICT p) { #if HWY_LOADDUP_ASM __m256 out; asm("vbroadcastf128 %1, %[reg]" : [ reg ] "=x"(out) : "m"(p[0])); return Vec256{out}; #elif HWY_COMPILER_MSVC && !HWY_COMPILER_CLANG const __m128 v128 = LoadU(Full128(), p).raw; return Vec256{ _mm256_insertf128_ps(_mm256_castps128_ps256(v128), v128, 1)}; #else return Vec256{_mm256_broadcast_ps(reinterpret_cast(p))}; #endif } HWY_API Vec256 LoadDup128(Full256 /* tag */, const double* const HWY_RESTRICT p) { #if HWY_LOADDUP_ASM __m256d out; asm("vbroadcastf128 %1, %[reg]" : [ reg ] "=x"(out) : "m"(p[0])); return Vec256{out}; #elif HWY_COMPILER_MSVC && !HWY_COMPILER_CLANG const __m128d v128 = LoadU(Full128(), p).raw; return Vec256{ _mm256_insertf128_pd(_mm256_castpd128_pd256(v128), v128, 1)}; #else return Vec256{ _mm256_broadcast_pd(reinterpret_cast(p))}; #endif } // ------------------------------ Store template HWY_API void Store(Vec256 v, Full256 /* tag */, T* HWY_RESTRICT aligned) { _mm256_store_si256(reinterpret_cast<__m256i*>(aligned), v.raw); } HWY_API void Store(const Vec256 v, Full256 /* tag */, float* HWY_RESTRICT aligned) { _mm256_store_ps(aligned, v.raw); } HWY_API void Store(const Vec256 v, Full256 /* tag */, double* HWY_RESTRICT aligned) { _mm256_store_pd(aligned, v.raw); } template HWY_API void StoreU(Vec256 v, Full256 /* tag */, T* HWY_RESTRICT p) { _mm256_storeu_si256(reinterpret_cast<__m256i*>(p), v.raw); } HWY_API void StoreU(const Vec256 v, Full256 /* tag */, float* HWY_RESTRICT p) { _mm256_storeu_ps(p, v.raw); } HWY_API void StoreU(const Vec256 v, Full256 /* tag */, double* HWY_RESTRICT p) { _mm256_storeu_pd(p, v.raw); } // ------------------------------ Non-temporal stores template HWY_API void Stream(Vec256 v, Full256 /* tag */, T* HWY_RESTRICT aligned) { _mm256_stream_si256(reinterpret_cast<__m256i*>(aligned), v.raw); } HWY_API void Stream(const Vec256 v, Full256 /* tag */, float* HWY_RESTRICT aligned) { _mm256_stream_ps(aligned, v.raw); } HWY_API void Stream(const Vec256 v, Full256 /* tag */, double* HWY_RESTRICT aligned) { _mm256_stream_pd(aligned, v.raw); } // ------------------------------ Scatter // Work around warnings in the intrinsic definitions (passing -1 as a mask). HWY_DIAGNOSTICS(push) HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") #if HWY_TARGET <= HWY_AVX3 namespace detail { template HWY_INLINE void ScatterOffset(hwy::SizeTag<4> /* tag */, Vec256 v, Full256 /* tag */, T* HWY_RESTRICT base, const Vec256 offset) { _mm256_i32scatter_epi32(base, offset.raw, v.raw, 1); } template HWY_INLINE void ScatterIndex(hwy::SizeTag<4> /* tag */, Vec256 v, Full256 /* tag */, T* HWY_RESTRICT base, const Vec256 index) { _mm256_i32scatter_epi32(base, index.raw, v.raw, 4); } template HWY_INLINE void ScatterOffset(hwy::SizeTag<8> /* tag */, Vec256 v, Full256 /* tag */, T* HWY_RESTRICT base, const Vec256 offset) { _mm256_i64scatter_epi64(base, offset.raw, v.raw, 1); } template HWY_INLINE void ScatterIndex(hwy::SizeTag<8> /* tag */, Vec256 v, Full256 /* tag */, T* HWY_RESTRICT base, const Vec256 index) { _mm256_i64scatter_epi64(base, index.raw, v.raw, 8); } } // namespace detail template HWY_API void ScatterOffset(Vec256 v, Full256 d, T* HWY_RESTRICT base, const Vec256 offset) { static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); return detail::ScatterOffset(hwy::SizeTag(), v, d, base, offset); } template HWY_API void ScatterIndex(Vec256 v, Full256 d, T* HWY_RESTRICT base, const Vec256 index) { static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); return detail::ScatterIndex(hwy::SizeTag(), v, d, base, index); } HWY_API void ScatterOffset(Vec256 v, Full256 /* tag */, float* HWY_RESTRICT base, const Vec256 offset) { _mm256_i32scatter_ps(base, offset.raw, v.raw, 1); } HWY_API void ScatterIndex(Vec256 v, Full256 /* tag */, float* HWY_RESTRICT base, const Vec256 index) { _mm256_i32scatter_ps(base, index.raw, v.raw, 4); } HWY_API void ScatterOffset(Vec256 v, Full256 /* tag */, double* HWY_RESTRICT base, const Vec256 offset) { _mm256_i64scatter_pd(base, offset.raw, v.raw, 1); } HWY_API void ScatterIndex(Vec256 v, Full256 /* tag */, double* HWY_RESTRICT base, const Vec256 index) { _mm256_i64scatter_pd(base, index.raw, v.raw, 8); } #else template HWY_API void ScatterOffset(Vec256 v, Full256 d, T* HWY_RESTRICT base, const Vec256 offset) { static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); constexpr size_t N = 32 / sizeof(T); alignas(32) T lanes[N]; Store(v, d, lanes); alignas(32) Offset offset_lanes[N]; Store(offset, Full256(), offset_lanes); uint8_t* base_bytes = reinterpret_cast(base); for (size_t i = 0; i < N; ++i) { CopyBytes(&lanes[i], base_bytes + offset_lanes[i]); } } template HWY_API void ScatterIndex(Vec256 v, Full256 d, T* HWY_RESTRICT base, const Vec256 index) { static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); constexpr size_t N = 32 / sizeof(T); alignas(32) T lanes[N]; Store(v, d, lanes); alignas(32) Index index_lanes[N]; Store(index, Full256(), index_lanes); for (size_t i = 0; i < N; ++i) { base[index_lanes[i]] = lanes[i]; } } #endif // ------------------------------ Gather namespace detail { template HWY_INLINE Vec256 GatherOffset(hwy::SizeTag<4> /* tag */, Full256 /* tag */, const T* HWY_RESTRICT base, const Vec256 offset) { return Vec256{_mm256_i32gather_epi32( reinterpret_cast(base), offset.raw, 1)}; } template HWY_INLINE Vec256 GatherIndex(hwy::SizeTag<4> /* tag */, Full256 /* tag */, const T* HWY_RESTRICT base, const Vec256 index) { return Vec256{_mm256_i32gather_epi32( reinterpret_cast(base), index.raw, 4)}; } template HWY_INLINE Vec256 GatherOffset(hwy::SizeTag<8> /* tag */, Full256 /* tag */, const T* HWY_RESTRICT base, const Vec256 offset) { return Vec256{_mm256_i64gather_epi64( reinterpret_cast(base), offset.raw, 1)}; } template HWY_INLINE Vec256 GatherIndex(hwy::SizeTag<8> /* tag */, Full256 /* tag */, const T* HWY_RESTRICT base, const Vec256 index) { return Vec256{_mm256_i64gather_epi64( reinterpret_cast(base), index.raw, 8)}; } } // namespace detail template HWY_API Vec256 GatherOffset(Full256 d, const T* HWY_RESTRICT base, const Vec256 offset) { static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); return detail::GatherOffset(hwy::SizeTag(), d, base, offset); } template HWY_API Vec256 GatherIndex(Full256 d, const T* HWY_RESTRICT base, const Vec256 index) { static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); return detail::GatherIndex(hwy::SizeTag(), d, base, index); } HWY_API Vec256 GatherOffset(Full256 /* tag */, const float* HWY_RESTRICT base, const Vec256 offset) { return Vec256{_mm256_i32gather_ps(base, offset.raw, 1)}; } HWY_API Vec256 GatherIndex(Full256 /* tag */, const float* HWY_RESTRICT base, const Vec256 index) { return Vec256{_mm256_i32gather_ps(base, index.raw, 4)}; } HWY_API Vec256 GatherOffset(Full256 /* tag */, const double* HWY_RESTRICT base, const Vec256 offset) { return Vec256{_mm256_i64gather_pd(base, offset.raw, 1)}; } HWY_API Vec256 GatherIndex(Full256 /* tag */, const double* HWY_RESTRICT base, const Vec256 index) { return Vec256{_mm256_i64gather_pd(base, index.raw, 8)}; } HWY_DIAGNOSTICS(pop) // ================================================== SWIZZLE // ------------------------------ LowerHalf template HWY_API Vec128 LowerHalf(Full128 /* tag */, Vec256 v) { return Vec128{_mm256_castsi256_si128(v.raw)}; } HWY_API Vec128 LowerHalf(Full128 /* tag */, Vec256 v) { return Vec128{_mm256_castps256_ps128(v.raw)}; } HWY_API Vec128 LowerHalf(Full128 /* tag */, Vec256 v) { return Vec128{_mm256_castpd256_pd128(v.raw)}; } template HWY_API Vec128 LowerHalf(Vec256 v) { return LowerHalf(Full128(), v); } // ------------------------------ UpperHalf template HWY_API Vec128 UpperHalf(Full128 /* tag */, Vec256 v) { return Vec128{_mm256_extracti128_si256(v.raw, 1)}; } HWY_API Vec128 UpperHalf(Full128 /* tag */, Vec256 v) { return Vec128{_mm256_extractf128_ps(v.raw, 1)}; } HWY_API Vec128 UpperHalf(Full128 /* tag */, Vec256 v) { return Vec128{_mm256_extractf128_pd(v.raw, 1)}; } // ------------------------------ GetLane (LowerHalf) template HWY_API T GetLane(const Vec256 v) { return GetLane(LowerHalf(v)); } // ------------------------------ ZeroExtendVector // Unfortunately the initial _mm256_castsi128_si256 intrinsic leaves the upper // bits undefined. Although it makes sense for them to be zero (VEX encoded // 128-bit instructions zero the upper lanes to avoid large penalties), a // compiler could decide to optimize out code that relies on this. // // The newer _mm256_zextsi128_si256 intrinsic fixes this by specifying the // zeroing, but it is not available on MSVC nor GCC until 10.1. For older GCC, // we can still obtain the desired code thanks to pattern recognition; note that // the expensive insert instruction is not actually generated, see // https://gcc.godbolt.org/z/1MKGaP. template HWY_API Vec256 ZeroExtendVector(Full256 /* tag */, Vec128 lo) { #if !HWY_COMPILER_CLANG && HWY_COMPILER_GCC && (HWY_COMPILER_GCC < 1000) return Vec256{_mm256_inserti128_si256(_mm256_setzero_si256(), lo.raw, 0)}; #else return Vec256{_mm256_zextsi128_si256(lo.raw)}; #endif } HWY_API Vec256 ZeroExtendVector(Full256 /* tag */, Vec128 lo) { #if !HWY_COMPILER_CLANG && HWY_COMPILER_GCC && (HWY_COMPILER_GCC < 1000) return Vec256{_mm256_insertf128_ps(_mm256_setzero_ps(), lo.raw, 0)}; #else return Vec256{_mm256_zextps128_ps256(lo.raw)}; #endif } HWY_API Vec256 ZeroExtendVector(Full256 /* tag */, Vec128 lo) { #if !HWY_COMPILER_CLANG && HWY_COMPILER_GCC && (HWY_COMPILER_GCC < 1000) return Vec256{_mm256_insertf128_pd(_mm256_setzero_pd(), lo.raw, 0)}; #else return Vec256{_mm256_zextpd128_pd256(lo.raw)}; #endif } // ------------------------------ Combine template HWY_API Vec256 Combine(Full256 d, Vec128 hi, Vec128 lo) { const auto lo256 = ZeroExtendVector(d, lo); return Vec256{_mm256_inserti128_si256(lo256.raw, hi.raw, 1)}; } HWY_API Vec256 Combine(Full256 d, Vec128 hi, Vec128 lo) { const auto lo256 = ZeroExtendVector(d, lo); return Vec256{_mm256_insertf128_ps(lo256.raw, hi.raw, 1)}; } HWY_API Vec256 Combine(Full256 d, Vec128 hi, Vec128 lo) { const auto lo256 = ZeroExtendVector(d, lo); return Vec256{_mm256_insertf128_pd(lo256.raw, hi.raw, 1)}; } // ------------------------------ ShiftLeftBytes template HWY_API Vec256 ShiftLeftBytes(Full256 /* tag */, const Vec256 v) { static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); // This is the same operation as _mm256_bslli_epi128. return Vec256{_mm256_slli_si256(v.raw, kBytes)}; } template HWY_API Vec256 ShiftLeftBytes(const Vec256 v) { return ShiftLeftBytes(Full256(), v); } // ------------------------------ ShiftLeftLanes template HWY_API Vec256 ShiftLeftLanes(Full256 d, const Vec256 v) { const Repartition d8; return BitCast(d, ShiftLeftBytes(BitCast(d8, v))); } template HWY_API Vec256 ShiftLeftLanes(const Vec256 v) { return ShiftLeftLanes(Full256(), v); } // ------------------------------ ShiftRightBytes template HWY_API Vec256 ShiftRightBytes(Full256 /* tag */, const Vec256 v) { static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); // This is the same operation as _mm256_bsrli_epi128. return Vec256{_mm256_srli_si256(v.raw, kBytes)}; } // ------------------------------ ShiftRightLanes template HWY_API Vec256 ShiftRightLanes(Full256 d, const Vec256 v) { const Repartition d8; return BitCast(d, ShiftRightBytes(d8, BitCast(d8, v))); } // ------------------------------ CombineShiftRightBytes // Extracts 128 bits from by skipping the least-significant kBytes. template > HWY_API V CombineShiftRightBytes(Full256 d, V hi, V lo) { const Repartition d8; return BitCast(d, Vec256{_mm256_alignr_epi8( BitCast(d8, hi).raw, BitCast(d8, lo).raw, kBytes)}); } // ------------------------------ Broadcast/splat any lane // Unsigned template HWY_API Vec256 Broadcast(const Vec256 v) { static_assert(0 <= kLane && kLane < 8, "Invalid lane"); if (kLane < 4) { const __m256i lo = _mm256_shufflelo_epi16(v.raw, (0x55 * kLane) & 0xFF); return Vec256{_mm256_unpacklo_epi64(lo, lo)}; } else { const __m256i hi = _mm256_shufflehi_epi16(v.raw, (0x55 * (kLane - 4)) & 0xFF); return Vec256{_mm256_unpackhi_epi64(hi, hi)}; } } template HWY_API Vec256 Broadcast(const Vec256 v) { static_assert(0 <= kLane && kLane < 4, "Invalid lane"); return Vec256{_mm256_shuffle_epi32(v.raw, 0x55 * kLane)}; } template HWY_API Vec256 Broadcast(const Vec256 v) { static_assert(0 <= kLane && kLane < 2, "Invalid lane"); return Vec256{_mm256_shuffle_epi32(v.raw, kLane ? 0xEE : 0x44)}; } // Signed template HWY_API Vec256 Broadcast(const Vec256 v) { static_assert(0 <= kLane && kLane < 8, "Invalid lane"); if (kLane < 4) { const __m256i lo = _mm256_shufflelo_epi16(v.raw, (0x55 * kLane) & 0xFF); return Vec256{_mm256_unpacklo_epi64(lo, lo)}; } else { const __m256i hi = _mm256_shufflehi_epi16(v.raw, (0x55 * (kLane - 4)) & 0xFF); return Vec256{_mm256_unpackhi_epi64(hi, hi)}; } } template HWY_API Vec256 Broadcast(const Vec256 v) { static_assert(0 <= kLane && kLane < 4, "Invalid lane"); return Vec256{_mm256_shuffle_epi32(v.raw, 0x55 * kLane)}; } template HWY_API Vec256 Broadcast(const Vec256 v) { static_assert(0 <= kLane && kLane < 2, "Invalid lane"); return Vec256{_mm256_shuffle_epi32(v.raw, kLane ? 0xEE : 0x44)}; } // Float template HWY_API Vec256 Broadcast(Vec256 v) { static_assert(0 <= kLane && kLane < 4, "Invalid lane"); return Vec256{_mm256_shuffle_ps(v.raw, v.raw, 0x55 * kLane)}; } template HWY_API Vec256 Broadcast(const Vec256 v) { static_assert(0 <= kLane && kLane < 2, "Invalid lane"); return Vec256{_mm256_shuffle_pd(v.raw, v.raw, 15 * kLane)}; } // ------------------------------ Hard-coded shuffles // Notation: let Vec256 have lanes 7,6,5,4,3,2,1,0 (0 is // least-significant). Shuffle0321 rotates four-lane blocks one lane to the // right (the previous least-significant lane is now most-significant => // 47650321). These could also be implemented via CombineShiftRightBytes but // the shuffle_abcd notation is more convenient. // Swap 32-bit halves in 64-bit halves. HWY_API Vec256 Shuffle2301(const Vec256 v) { return Vec256{_mm256_shuffle_epi32(v.raw, 0xB1)}; } HWY_API Vec256 Shuffle2301(const Vec256 v) { return Vec256{_mm256_shuffle_epi32(v.raw, 0xB1)}; } HWY_API Vec256 Shuffle2301(const Vec256 v) { return Vec256{_mm256_shuffle_ps(v.raw, v.raw, 0xB1)}; } // Swap 64-bit halves HWY_API Vec256 Shuffle1032(const Vec256 v) { return Vec256{_mm256_shuffle_epi32(v.raw, 0x4E)}; } HWY_API Vec256 Shuffle1032(const Vec256 v) { return Vec256{_mm256_shuffle_epi32(v.raw, 0x4E)}; } HWY_API Vec256 Shuffle1032(const Vec256 v) { // Shorter encoding than _mm256_permute_ps. return Vec256{_mm256_shuffle_ps(v.raw, v.raw, 0x4E)}; } HWY_API Vec256 Shuffle01(const Vec256 v) { return Vec256{_mm256_shuffle_epi32(v.raw, 0x4E)}; } HWY_API Vec256 Shuffle01(const Vec256 v) { return Vec256{_mm256_shuffle_epi32(v.raw, 0x4E)}; } HWY_API Vec256 Shuffle01(const Vec256 v) { // Shorter encoding than _mm256_permute_pd. return Vec256{_mm256_shuffle_pd(v.raw, v.raw, 5)}; } // Rotate right 32 bits HWY_API Vec256 Shuffle0321(const Vec256 v) { return Vec256{_mm256_shuffle_epi32(v.raw, 0x39)}; } HWY_API Vec256 Shuffle0321(const Vec256 v) { return Vec256{_mm256_shuffle_epi32(v.raw, 0x39)}; } HWY_API Vec256 Shuffle0321(const Vec256 v) { return Vec256{_mm256_shuffle_ps(v.raw, v.raw, 0x39)}; } // Rotate left 32 bits HWY_API Vec256 Shuffle2103(const Vec256 v) { return Vec256{_mm256_shuffle_epi32(v.raw, 0x93)}; } HWY_API Vec256 Shuffle2103(const Vec256 v) { return Vec256{_mm256_shuffle_epi32(v.raw, 0x93)}; } HWY_API Vec256 Shuffle2103(const Vec256 v) { return Vec256{_mm256_shuffle_ps(v.raw, v.raw, 0x93)}; } // Reverse HWY_API Vec256 Shuffle0123(const Vec256 v) { return Vec256{_mm256_shuffle_epi32(v.raw, 0x1B)}; } HWY_API Vec256 Shuffle0123(const Vec256 v) { return Vec256{_mm256_shuffle_epi32(v.raw, 0x1B)}; } HWY_API Vec256 Shuffle0123(const Vec256 v) { return Vec256{_mm256_shuffle_ps(v.raw, v.raw, 0x1B)}; } // ------------------------------ TableLookupLanes // Returned by SetTableIndices/IndicesFromVec for use by TableLookupLanes. template struct Indices256 { __m256i raw; }; // Native 8x32 instruction: indices remain unchanged template HWY_API Indices256 IndicesFromVec(Full256 /* tag */, Vec256 vec) { static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); #if HWY_IS_DEBUG_BUILD const Full256 di; HWY_DASSERT(AllFalse(di, Lt(vec, Zero(di))) && AllTrue(di, Lt(vec, Set(di, static_cast(32 / sizeof(T)))))); #endif return Indices256{vec.raw}; } // 64-bit lanes: convert indices to 8x32 unless AVX3 is available template HWY_API Indices256 IndicesFromVec(Full256 d, Vec256 idx64) { static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); const Rebind di; (void)di; // potentially unused #if HWY_IS_DEBUG_BUILD HWY_DASSERT(AllFalse(di, Lt(idx64, Zero(di))) && AllTrue(di, Lt(idx64, Set(di, static_cast(32 / sizeof(T)))))); #endif #if HWY_TARGET <= HWY_AVX3 (void)d; return Indices256{idx64.raw}; #else const Repartition df; // 32-bit! // Replicate 64-bit index into upper 32 bits const Vec256 dup = BitCast(di, Vec256{_mm256_moveldup_ps(BitCast(df, idx64).raw)}); // For each idx64 i, idx32 are 2*i and 2*i+1. const Vec256 idx32 = dup + dup + Set(di, TI(1) << 32); return Indices256{idx32.raw}; #endif } template HWY_API Indices256 SetTableIndices(const Full256 d, const TI* idx) { const Rebind di; return IndicesFromVec(d, LoadU(di, idx)); } template HWY_API Vec256 TableLookupLanes(Vec256 v, Indices256 idx) { return Vec256{_mm256_permutevar8x32_epi32(v.raw, idx.raw)}; } template HWY_API Vec256 TableLookupLanes(Vec256 v, Indices256 idx) { #if HWY_TARGET <= HWY_AVX3 return Vec256{_mm256_permutexvar_epi64(idx.raw, v.raw)}; #else return Vec256{_mm256_permutevar8x32_epi32(v.raw, idx.raw)}; #endif } HWY_API Vec256 TableLookupLanes(const Vec256 v, const Indices256 idx) { return Vec256{_mm256_permutevar8x32_ps(v.raw, idx.raw)}; } HWY_API Vec256 TableLookupLanes(const Vec256 v, const Indices256 idx) { #if HWY_TARGET <= HWY_AVX3 return Vec256{_mm256_permutexvar_pd(idx.raw, v.raw)}; #else const Full256 df; const Full256 du; return BitCast(df, Vec256{_mm256_permutevar8x32_epi32( BitCast(du, v).raw, idx.raw)}); #endif } // ------------------------------ Reverse (RotateRight) template HWY_API Vec256 Reverse(Full256 d, const Vec256 v) { alignas(32) constexpr int32_t kReverse[8] = {7, 6, 5, 4, 3, 2, 1, 0}; return TableLookupLanes(v, SetTableIndices(d, kReverse)); } template HWY_API Vec256 Reverse(Full256 d, const Vec256 v) { alignas(32) constexpr int64_t kReverse[4] = {3, 2, 1, 0}; return TableLookupLanes(v, SetTableIndices(d, kReverse)); } template HWY_API Vec256 Reverse(Full256 d, const Vec256 v) { #if HWY_TARGET <= HWY_AVX3 const RebindToSigned di; alignas(32) constexpr int16_t kReverse[16] = {15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}; const Vec256 idx = Load(di, kReverse); return BitCast(d, Vec256{ _mm256_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); #else const RepartitionToWide> du32; const Vec256 rev32 = Reverse(du32, BitCast(du32, v)); return BitCast(d, RotateRight<16>(rev32)); #endif } // ------------------------------ Reverse2 template HWY_API Vec256 Reverse2(Full256 d, const Vec256 v) { const Full256 du32; return BitCast(d, RotateRight<16>(BitCast(du32, v))); } template HWY_API Vec256 Reverse2(Full256 /* tag */, const Vec256 v) { return Shuffle2301(v); } template HWY_API Vec256 Reverse2(Full256 /* tag */, const Vec256 v) { return Shuffle01(v); } // ------------------------------ Reverse4 template HWY_API Vec256 Reverse4(Full256 d, const Vec256 v) { #if HWY_TARGET <= HWY_AVX3 const RebindToSigned di; alignas(32) constexpr int16_t kReverse4[16] = {3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8, 15, 14, 13, 12}; const Vec256 idx = Load(di, kReverse4); return BitCast(d, Vec256{ _mm256_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); #else const RepartitionToWide dw; return Reverse2(d, BitCast(d, Shuffle2301(BitCast(dw, v)))); #endif } template HWY_API Vec256 Reverse4(Full256 /* tag */, const Vec256 v) { return Shuffle0123(v); } template HWY_API Vec256 Reverse4(Full256 /* tag */, const Vec256 v) { return Vec256{_mm256_permute4x64_epi64(v.raw, _MM_SHUFFLE(0, 1, 2, 3))}; } HWY_API Vec256 Reverse4(Full256 /* tag */, Vec256 v) { return Vec256{_mm256_permute4x64_pd(v.raw, _MM_SHUFFLE(0, 1, 2, 3))}; } // ------------------------------ Reverse8 template HWY_API Vec256 Reverse8(Full256 d, const Vec256 v) { #if HWY_TARGET <= HWY_AVX3 const RebindToSigned di; alignas(32) constexpr int16_t kReverse8[16] = {7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8}; const Vec256 idx = Load(di, kReverse8); return BitCast(d, Vec256{ _mm256_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); #else const RepartitionToWide dw; return Reverse2(d, BitCast(d, Shuffle0123(BitCast(dw, v)))); #endif } template HWY_API Vec256 Reverse8(Full256 d, const Vec256 v) { return Reverse(d, v); } template HWY_API Vec256 Reverse8(Full256 /* tag */, const Vec256 /* v */) { HWY_ASSERT(0); // AVX2 does not have 8 64-bit lanes } // ------------------------------ InterleaveLower // Interleaves lanes from halves of the 128-bit blocks of "a" (which provides // the least-significant lane) and "b". To concatenate two half-width integers // into one, use ZipLower/Upper instead (also works with scalar). HWY_API Vec256 InterleaveLower(const Vec256 a, const Vec256 b) { return Vec256{_mm256_unpacklo_epi8(a.raw, b.raw)}; } HWY_API Vec256 InterleaveLower(const Vec256 a, const Vec256 b) { return Vec256{_mm256_unpacklo_epi16(a.raw, b.raw)}; } HWY_API Vec256 InterleaveLower(const Vec256 a, const Vec256 b) { return Vec256{_mm256_unpacklo_epi32(a.raw, b.raw)}; } HWY_API Vec256 InterleaveLower(const Vec256 a, const Vec256 b) { return Vec256{_mm256_unpacklo_epi64(a.raw, b.raw)}; } HWY_API Vec256 InterleaveLower(const Vec256 a, const Vec256 b) { return Vec256{_mm256_unpacklo_epi8(a.raw, b.raw)}; } HWY_API Vec256 InterleaveLower(const Vec256 a, const Vec256 b) { return Vec256{_mm256_unpacklo_epi16(a.raw, b.raw)}; } HWY_API Vec256 InterleaveLower(const Vec256 a, const Vec256 b) { return Vec256{_mm256_unpacklo_epi32(a.raw, b.raw)}; } HWY_API Vec256 InterleaveLower(const Vec256 a, const Vec256 b) { return Vec256{_mm256_unpacklo_epi64(a.raw, b.raw)}; } HWY_API Vec256 InterleaveLower(const Vec256 a, const Vec256 b) { return Vec256{_mm256_unpacklo_ps(a.raw, b.raw)}; } HWY_API Vec256 InterleaveLower(const Vec256 a, const Vec256 b) { return Vec256{_mm256_unpacklo_pd(a.raw, b.raw)}; } // ------------------------------ InterleaveUpper // All functions inside detail lack the required D parameter. namespace detail { HWY_API Vec256 InterleaveUpper(const Vec256 a, const Vec256 b) { return Vec256{_mm256_unpackhi_epi8(a.raw, b.raw)}; } HWY_API Vec256 InterleaveUpper(const Vec256 a, const Vec256 b) { return Vec256{_mm256_unpackhi_epi16(a.raw, b.raw)}; } HWY_API Vec256 InterleaveUpper(const Vec256 a, const Vec256 b) { return Vec256{_mm256_unpackhi_epi32(a.raw, b.raw)}; } HWY_API Vec256 InterleaveUpper(const Vec256 a, const Vec256 b) { return Vec256{_mm256_unpackhi_epi64(a.raw, b.raw)}; } HWY_API Vec256 InterleaveUpper(const Vec256 a, const Vec256 b) { return Vec256{_mm256_unpackhi_epi8(a.raw, b.raw)}; } HWY_API Vec256 InterleaveUpper(const Vec256 a, const Vec256 b) { return Vec256{_mm256_unpackhi_epi16(a.raw, b.raw)}; } HWY_API Vec256 InterleaveUpper(const Vec256 a, const Vec256 b) { return Vec256{_mm256_unpackhi_epi32(a.raw, b.raw)}; } HWY_API Vec256 InterleaveUpper(const Vec256 a, const Vec256 b) { return Vec256{_mm256_unpackhi_epi64(a.raw, b.raw)}; } HWY_API Vec256 InterleaveUpper(const Vec256 a, const Vec256 b) { return Vec256{_mm256_unpackhi_ps(a.raw, b.raw)}; } HWY_API Vec256 InterleaveUpper(const Vec256 a, const Vec256 b) { return Vec256{_mm256_unpackhi_pd(a.raw, b.raw)}; } } // namespace detail template > HWY_API V InterleaveUpper(Full256 /* tag */, V a, V b) { return detail::InterleaveUpper(a, b); } // ------------------------------ ZipLower/ZipUpper (InterleaveLower) // Same as Interleave*, except that the return lanes are double-width integers; // this is necessary because the single-lane scalar cannot return two values. template > HWY_API Vec256 ZipLower(Vec256 a, Vec256 b) { return BitCast(Full256(), InterleaveLower(a, b)); } template > HWY_API Vec256 ZipLower(Full256 dw, Vec256 a, Vec256 b) { return BitCast(dw, InterleaveLower(a, b)); } template > HWY_API Vec256 ZipUpper(Full256 dw, Vec256 a, Vec256 b) { return BitCast(dw, InterleaveUpper(Full256(), a, b)); } // ------------------------------ Blocks (LowerHalf, ZeroExtendVector) // _mm256_broadcastsi128_si256 has 7 cycle latency. _mm256_permute2x128_si256 is // slow on Zen1 (8 uops); we can avoid it for LowerLower and UpperLower, and on // UpperUpper at the cost of one extra cycle/instruction. // hiH,hiL loH,loL |-> hiL,loL (= lower halves) template HWY_API Vec256 ConcatLowerLower(Full256 d, const Vec256 hi, const Vec256 lo) { const Half d2; return Vec256{_mm256_inserti128_si256(lo.raw, LowerHalf(d2, hi).raw, 1)}; } HWY_API Vec256 ConcatLowerLower(Full256 d, const Vec256 hi, const Vec256 lo) { const Half d2; return Vec256{_mm256_insertf128_ps(lo.raw, LowerHalf(d2, hi).raw, 1)}; } HWY_API Vec256 ConcatLowerLower(Full256 d, const Vec256 hi, const Vec256 lo) { const Half d2; return Vec256{_mm256_insertf128_pd(lo.raw, LowerHalf(d2, hi).raw, 1)}; } // hiH,hiL loH,loL |-> hiL,loH (= inner halves / swap blocks) template HWY_API Vec256 ConcatLowerUpper(Full256 /* tag */, const Vec256 hi, const Vec256 lo) { return Vec256{_mm256_permute2x128_si256(lo.raw, hi.raw, 0x21)}; } HWY_API Vec256 ConcatLowerUpper(Full256 /* tag */, const Vec256 hi, const Vec256 lo) { return Vec256{_mm256_permute2f128_ps(lo.raw, hi.raw, 0x21)}; } HWY_API Vec256 ConcatLowerUpper(Full256 /* tag */, const Vec256 hi, const Vec256 lo) { return Vec256{_mm256_permute2f128_pd(lo.raw, hi.raw, 0x21)}; } // hiH,hiL loH,loL |-> hiH,loL (= outer halves) template HWY_API Vec256 ConcatUpperLower(Full256 /* tag */, const Vec256 hi, const Vec256 lo) { return Vec256{_mm256_blend_epi32(hi.raw, lo.raw, 0x0F)}; } HWY_API Vec256 ConcatUpperLower(Full256 /* tag */, const Vec256 hi, const Vec256 lo) { return Vec256{_mm256_blend_ps(hi.raw, lo.raw, 0x0F)}; } HWY_API Vec256 ConcatUpperLower(Full256 /* tag */, const Vec256 hi, const Vec256 lo) { return Vec256{_mm256_blend_pd(hi.raw, lo.raw, 3)}; } // hiH,hiL loH,loL |-> hiH,loH (= upper halves) template HWY_API Vec256 ConcatUpperUpper(Full256 d, const Vec256 hi, const Vec256 lo) { const Half d2; return ConcatUpperLower(d, hi, ZeroExtendVector(d, UpperHalf(d2, lo))); } // ------------------------------ ConcatOdd template HWY_API Vec256 ConcatOdd(Full256 d, Vec256 hi, Vec256 lo) { const RebindToUnsigned du; #if HWY_TARGET <= HWY_AVX3 alignas(32) constexpr uint32_t kIdx[8] = {1, 3, 5, 7, 9, 11, 13, 15}; return BitCast(d, Vec256{_mm256_mask2_permutex2var_epi32( BitCast(du, lo).raw, Load(du, kIdx).raw, __mmask8{0xFF}, BitCast(du, hi).raw)}); #else const RebindToFloat df; const Vec256 v3131{_mm256_shuffle_ps( BitCast(df, lo).raw, BitCast(df, hi).raw, _MM_SHUFFLE(3, 1, 3, 1))}; return Vec256{_mm256_permute4x64_epi64(BitCast(du, v3131).raw, _MM_SHUFFLE(3, 1, 2, 0))}; #endif } HWY_API Vec256 ConcatOdd(Full256 d, Vec256 hi, Vec256 lo) { const RebindToUnsigned du; #if HWY_TARGET <= HWY_AVX3 alignas(32) constexpr uint32_t kIdx[8] = {1, 3, 5, 7, 9, 11, 13, 15}; return Vec256{_mm256_mask2_permutex2var_ps(lo.raw, Load(du, kIdx).raw, __mmask8{0xFF}, hi.raw)}; #else const Vec256 v3131{ _mm256_shuffle_ps(lo.raw, hi.raw, _MM_SHUFFLE(3, 1, 3, 1))}; return BitCast(d, Vec256{_mm256_permute4x64_epi64( BitCast(du, v3131).raw, _MM_SHUFFLE(3, 1, 2, 0))}); #endif } template HWY_API Vec256 ConcatOdd(Full256 d, Vec256 hi, Vec256 lo) { const RebindToUnsigned du; #if HWY_TARGET <= HWY_AVX3 alignas(64) constexpr uint64_t kIdx[4] = {1, 3, 5, 7}; return BitCast(d, Vec256{_mm256_mask2_permutex2var_epi64( BitCast(du, lo).raw, Load(du, kIdx).raw, __mmask8{0xFF}, BitCast(du, hi).raw)}); #else const RebindToFloat df; const Vec256 v31{ _mm256_shuffle_pd(BitCast(df, lo).raw, BitCast(df, hi).raw, 15)}; return Vec256{ _mm256_permute4x64_epi64(BitCast(du, v31).raw, _MM_SHUFFLE(3, 1, 2, 0))}; #endif } HWY_API Vec256 ConcatOdd(Full256 d, Vec256 hi, Vec256 lo) { #if HWY_TARGET <= HWY_AVX3 const RebindToUnsigned du; alignas(64) constexpr uint64_t kIdx[4] = {1, 3, 5, 7}; return Vec256{_mm256_mask2_permutex2var_pd(lo.raw, Load(du, kIdx).raw, __mmask8{0xFF}, hi.raw)}; #else (void)d; const Vec256 v31{_mm256_shuffle_pd(lo.raw, hi.raw, 15)}; return Vec256{ _mm256_permute4x64_pd(v31.raw, _MM_SHUFFLE(3, 1, 2, 0))}; #endif } // ------------------------------ ConcatEven template HWY_API Vec256 ConcatEven(Full256 d, Vec256 hi, Vec256 lo) { const RebindToUnsigned du; #if HWY_TARGET <= HWY_AVX3 alignas(64) constexpr uint32_t kIdx[8] = {0, 2, 4, 6, 8, 10, 12, 14}; return BitCast(d, Vec256{_mm256_mask2_permutex2var_epi32( BitCast(du, lo).raw, Load(du, kIdx).raw, __mmask8{0xFF}, BitCast(du, hi).raw)}); #else const RebindToFloat df; const Vec256 v2020{_mm256_shuffle_ps( BitCast(df, lo).raw, BitCast(df, hi).raw, _MM_SHUFFLE(2, 0, 2, 0))}; return Vec256{_mm256_permute4x64_epi64(BitCast(du, v2020).raw, _MM_SHUFFLE(3, 1, 2, 0))}; #endif } HWY_API Vec256 ConcatEven(Full256 d, Vec256 hi, Vec256 lo) { const RebindToUnsigned du; #if HWY_TARGET <= HWY_AVX3 alignas(64) constexpr uint32_t kIdx[8] = {0, 2, 4, 6, 8, 10, 12, 14}; return Vec256{_mm256_mask2_permutex2var_ps(lo.raw, Load(du, kIdx).raw, __mmask8{0xFF}, hi.raw)}; #else const Vec256 v2020{ _mm256_shuffle_ps(lo.raw, hi.raw, _MM_SHUFFLE(2, 0, 2, 0))}; return BitCast(d, Vec256{_mm256_permute4x64_epi64( BitCast(du, v2020).raw, _MM_SHUFFLE(3, 1, 2, 0))}); #endif } template HWY_API Vec256 ConcatEven(Full256 d, Vec256 hi, Vec256 lo) { const RebindToUnsigned du; #if HWY_TARGET <= HWY_AVX3 alignas(64) constexpr uint64_t kIdx[4] = {0, 2, 4, 6}; return BitCast(d, Vec256{_mm256_mask2_permutex2var_epi64( BitCast(du, lo).raw, Load(du, kIdx).raw, __mmask8{0xFF}, BitCast(du, hi).raw)}); #else const RebindToFloat df; const Vec256 v20{ _mm256_shuffle_pd(BitCast(df, lo).raw, BitCast(df, hi).raw, 0)}; return Vec256{ _mm256_permute4x64_epi64(BitCast(du, v20).raw, _MM_SHUFFLE(3, 1, 2, 0))}; #endif } HWY_API Vec256 ConcatEven(Full256 d, Vec256 hi, Vec256 lo) { #if HWY_TARGET <= HWY_AVX3 const RebindToUnsigned du; alignas(64) constexpr uint64_t kIdx[4] = {0, 2, 4, 6}; return Vec256{_mm256_mask2_permutex2var_pd(lo.raw, Load(du, kIdx).raw, __mmask8{0xFF}, hi.raw)}; #else (void)d; const Vec256 v20{_mm256_shuffle_pd(lo.raw, hi.raw, 0)}; return Vec256{ _mm256_permute4x64_pd(v20.raw, _MM_SHUFFLE(3, 1, 2, 0))}; #endif } // ------------------------------ DupEven (InterleaveLower) template HWY_API Vec256 DupEven(Vec256 v) { return Vec256{_mm256_shuffle_epi32(v.raw, _MM_SHUFFLE(2, 2, 0, 0))}; } HWY_API Vec256 DupEven(Vec256 v) { return Vec256{ _mm256_shuffle_ps(v.raw, v.raw, _MM_SHUFFLE(2, 2, 0, 0))}; } template HWY_API Vec256 DupEven(const Vec256 v) { return InterleaveLower(Full256(), v, v); } // ------------------------------ DupOdd (InterleaveUpper) template HWY_API Vec256 DupOdd(Vec256 v) { return Vec256{_mm256_shuffle_epi32(v.raw, _MM_SHUFFLE(3, 3, 1, 1))}; } HWY_API Vec256 DupOdd(Vec256 v) { return Vec256{ _mm256_shuffle_ps(v.raw, v.raw, _MM_SHUFFLE(3, 3, 1, 1))}; } template HWY_API Vec256 DupOdd(const Vec256 v) { return InterleaveUpper(Full256(), v, v); } // ------------------------------ OddEven namespace detail { template HWY_INLINE Vec256 OddEven(hwy::SizeTag<1> /* tag */, const Vec256 a, const Vec256 b) { const Full256 d; const Full256 d8; alignas(32) constexpr uint8_t mask[16] = {0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0, 0xFF, 0}; return IfThenElse(MaskFromVec(BitCast(d, LoadDup128(d8, mask))), b, a); } template HWY_INLINE Vec256 OddEven(hwy::SizeTag<2> /* tag */, const Vec256 a, const Vec256 b) { return Vec256{_mm256_blend_epi16(a.raw, b.raw, 0x55)}; } template HWY_INLINE Vec256 OddEven(hwy::SizeTag<4> /* tag */, const Vec256 a, const Vec256 b) { return Vec256{_mm256_blend_epi32(a.raw, b.raw, 0x55)}; } template HWY_INLINE Vec256 OddEven(hwy::SizeTag<8> /* tag */, const Vec256 a, const Vec256 b) { return Vec256{_mm256_blend_epi32(a.raw, b.raw, 0x33)}; } } // namespace detail template HWY_API Vec256 OddEven(const Vec256 a, const Vec256 b) { return detail::OddEven(hwy::SizeTag(), a, b); } HWY_API Vec256 OddEven(const Vec256 a, const Vec256 b) { return Vec256{_mm256_blend_ps(a.raw, b.raw, 0x55)}; } HWY_API Vec256 OddEven(const Vec256 a, const Vec256 b) { return Vec256{_mm256_blend_pd(a.raw, b.raw, 5)}; } // ------------------------------ OddEvenBlocks template Vec256 OddEvenBlocks(Vec256 odd, Vec256 even) { return Vec256{_mm256_blend_epi32(odd.raw, even.raw, 0xFu)}; } HWY_API Vec256 OddEvenBlocks(Vec256 odd, Vec256 even) { return Vec256{_mm256_blend_ps(odd.raw, even.raw, 0xFu)}; } HWY_API Vec256 OddEvenBlocks(Vec256 odd, Vec256 even) { return Vec256{_mm256_blend_pd(odd.raw, even.raw, 0x3u)}; } // ------------------------------ SwapAdjacentBlocks template HWY_API Vec256 SwapAdjacentBlocks(Vec256 v) { return Vec256{_mm256_permute4x64_epi64(v.raw, _MM_SHUFFLE(1, 0, 3, 2))}; } HWY_API Vec256 SwapAdjacentBlocks(Vec256 v) { const Full256 df; const Full256 di; // Avoid _mm256_permute2f128_ps - slow on AMD. return BitCast(df, Vec256{_mm256_permute4x64_epi64( BitCast(di, v).raw, _MM_SHUFFLE(1, 0, 3, 2))}); } HWY_API Vec256 SwapAdjacentBlocks(Vec256 v) { return Vec256{_mm256_permute4x64_pd(v.raw, _MM_SHUFFLE(1, 0, 3, 2))}; } // ------------------------------ ReverseBlocks (ConcatLowerUpper) template HWY_API Vec256 ReverseBlocks(Full256 d, Vec256 v) { return ConcatLowerUpper(d, v, v); } // ------------------------------ TableLookupBytes (ZeroExtendVector) // Both full template HWY_API Vec256 TableLookupBytes(const Vec256 bytes, const Vec256 from) { return Vec256{_mm256_shuffle_epi8(bytes.raw, from.raw)}; } // Partial index vector template HWY_API Vec128 TableLookupBytes(const Vec256 bytes, const Vec128 from) { // First expand to full 128, then 256. const auto from_256 = ZeroExtendVector(Full256(), Vec128{from.raw}); const auto tbl_full = TableLookupBytes(bytes, from_256); // Shrink to 128, then partial. return Vec128{LowerHalf(Full128(), tbl_full).raw}; } // Partial table vector template HWY_API Vec256 TableLookupBytes(const Vec128 bytes, const Vec256 from) { // First expand to full 128, then 256. const auto bytes_256 = ZeroExtendVector(Full256(), Vec128{bytes.raw}); return TableLookupBytes(bytes_256, from); } // Partial both are handled by x86_128. // ------------------------------ Shl (Mul, ZipLower) #if HWY_TARGET > HWY_AVX3 // AVX2 or older namespace detail { // Returns 2^v for use as per-lane multipliers to emulate 16-bit shifts. template HWY_INLINE Vec256> Pow2(const Vec256 v) { const Full256 d; const RepartitionToWide dw; const Rebind df; const auto zero = Zero(d); // Move into exponent (this u16 will become the upper half of an f32) const auto exp = ShiftLeft<23 - 16>(v); const auto upper = exp + Set(d, 0x3F80); // upper half of 1.0f // Insert 0 into lower halves for reinterpreting as binary32. const auto f0 = ZipLower(dw, zero, upper); const auto f1 = ZipUpper(dw, zero, upper); // Do not use ConvertTo because it checks for overflow, which is redundant // because we only care about v in [0, 16). const Vec256 bits0{_mm256_cvttps_epi32(BitCast(df, f0).raw)}; const Vec256 bits1{_mm256_cvttps_epi32(BitCast(df, f1).raw)}; return Vec256>{_mm256_packus_epi32(bits0.raw, bits1.raw)}; } } // namespace detail #endif // HWY_TARGET > HWY_AVX3 HWY_API Vec256 operator<<(const Vec256 v, const Vec256 bits) { #if HWY_TARGET <= HWY_AVX3 return Vec256{_mm256_sllv_epi16(v.raw, bits.raw)}; #else return v * detail::Pow2(bits); #endif } HWY_API Vec256 operator<<(const Vec256 v, const Vec256 bits) { return Vec256{_mm256_sllv_epi32(v.raw, bits.raw)}; } HWY_API Vec256 operator<<(const Vec256 v, const Vec256 bits) { return Vec256{_mm256_sllv_epi64(v.raw, bits.raw)}; } // Signed left shift is the same as unsigned. template HWY_API Vec256 operator<<(const Vec256 v, const Vec256 bits) { const Full256 di; const Full256> du; return BitCast(di, BitCast(du, v) << BitCast(du, bits)); } // ------------------------------ Shr (MulHigh, IfThenElse, Not) HWY_API Vec256 operator>>(const Vec256 v, const Vec256 bits) { #if HWY_TARGET <= HWY_AVX3 return Vec256{_mm256_srlv_epi16(v.raw, bits.raw)}; #else const Full256 d; // For bits=0, we cannot mul by 2^16, so fix the result later. const auto out = MulHigh(v, detail::Pow2(Set(d, 16) - bits)); // Replace output with input where bits == 0. return IfThenElse(bits == Zero(d), v, out); #endif } HWY_API Vec256 operator>>(const Vec256 v, const Vec256 bits) { return Vec256{_mm256_srlv_epi32(v.raw, bits.raw)}; } HWY_API Vec256 operator>>(const Vec256 v, const Vec256 bits) { return Vec256{_mm256_srlv_epi64(v.raw, bits.raw)}; } HWY_API Vec256 operator>>(const Vec256 v, const Vec256 bits) { #if HWY_TARGET <= HWY_AVX3 return Vec256{_mm256_srav_epi16(v.raw, bits.raw)}; #else return detail::SignedShr(Full256(), v, bits); #endif } HWY_API Vec256 operator>>(const Vec256 v, const Vec256 bits) { return Vec256{_mm256_srav_epi32(v.raw, bits.raw)}; } HWY_API Vec256 operator>>(const Vec256 v, const Vec256 bits) { #if HWY_TARGET <= HWY_AVX3 return Vec256{_mm256_srav_epi64(v.raw, bits.raw)}; #else return detail::SignedShr(Full256(), v, bits); #endif } HWY_INLINE Vec256 MulEven(const Vec256 a, const Vec256 b) { const DFromV du64; const RepartitionToNarrow du32; const auto maskL = Set(du64, 0xFFFFFFFFULL); const auto a32 = BitCast(du32, a); const auto b32 = BitCast(du32, b); // Inputs for MulEven: we only need the lower 32 bits const auto aH = Shuffle2301(a32); const auto bH = Shuffle2301(b32); // Knuth double-word multiplication. We use 32x32 = 64 MulEven and only need // the even (lower 64 bits of every 128-bit block) results. See // https://github.com/hcs0/Hackers-Delight/blob/master/muldwu.c.tat const auto aLbL = MulEven(a32, b32); const auto w3 = aLbL & maskL; const auto t2 = MulEven(aH, b32) + ShiftRight<32>(aLbL); const auto w2 = t2 & maskL; const auto w1 = ShiftRight<32>(t2); const auto t = MulEven(a32, bH) + w2; const auto k = ShiftRight<32>(t); const auto mulH = MulEven(aH, bH) + w1 + k; const auto mulL = ShiftLeft<32>(t) + w3; return InterleaveLower(mulL, mulH); } HWY_INLINE Vec256 MulOdd(const Vec256 a, const Vec256 b) { const DFromV du64; const RepartitionToNarrow du32; const auto maskL = Set(du64, 0xFFFFFFFFULL); const auto a32 = BitCast(du32, a); const auto b32 = BitCast(du32, b); // Inputs for MulEven: we only need bits [95:64] (= upper half of input) const auto aH = Shuffle2301(a32); const auto bH = Shuffle2301(b32); // Same as above, but we're using the odd results (upper 64 bits per block). const auto aLbL = MulEven(a32, b32); const auto w3 = aLbL & maskL; const auto t2 = MulEven(aH, b32) + ShiftRight<32>(aLbL); const auto w2 = t2 & maskL; const auto w1 = ShiftRight<32>(t2); const auto t = MulEven(a32, bH) + w2; const auto k = ShiftRight<32>(t); const auto mulH = MulEven(aH, bH) + w1 + k; const auto mulL = ShiftLeft<32>(t) + w3; return InterleaveUpper(du64, mulL, mulH); } // ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) HWY_API Vec256 ReorderWidenMulAccumulate(Full256 df32, Vec256 a, Vec256 b, const Vec256 sum0, Vec256& sum1) { // TODO(janwas): _mm256_dpbf16_ps when available const Repartition du16; const RebindToUnsigned du32; const Vec256 zero = Zero(du16); // Lane order within sum0/1 is undefined, hence we can avoid the // longer-latency lane-crossing PromoteTo. const Vec256 a0 = ZipLower(du32, zero, BitCast(du16, a)); const Vec256 a1 = ZipUpper(du32, zero, BitCast(du16, a)); const Vec256 b0 = ZipLower(du32, zero, BitCast(du16, b)); const Vec256 b1 = ZipUpper(du32, zero, BitCast(du16, b)); sum1 = MulAdd(BitCast(df32, a1), BitCast(df32, b1), sum1); return MulAdd(BitCast(df32, a0), BitCast(df32, b0), sum0); } // ================================================== CONVERT // ------------------------------ Promotions (part w/ narrow lanes -> full) HWY_API Vec256 PromoteTo(Full256 /* tag */, const Vec128 v) { return Vec256{_mm256_cvtps_pd(v.raw)}; } HWY_API Vec256 PromoteTo(Full256 /* tag */, const Vec128 v) { return Vec256{_mm256_cvtepi32_pd(v.raw)}; } // Unsigned: zero-extend. // Note: these have 3 cycle latency; if inputs are already split across the // 128 bit blocks (in their upper/lower halves), then Zip* would be faster. HWY_API Vec256 PromoteTo(Full256 /* tag */, Vec128 v) { return Vec256{_mm256_cvtepu8_epi16(v.raw)}; } HWY_API Vec256 PromoteTo(Full256 /* tag */, Vec128 v) { return Vec256{_mm256_cvtepu8_epi32(v.raw)}; } HWY_API Vec256 PromoteTo(Full256 /* tag */, Vec128 v) { return Vec256{_mm256_cvtepu8_epi16(v.raw)}; } HWY_API Vec256 PromoteTo(Full256 /* tag */, Vec128 v) { return Vec256{_mm256_cvtepu8_epi32(v.raw)}; } HWY_API Vec256 PromoteTo(Full256 /* tag */, Vec128 v) { return Vec256{_mm256_cvtepu16_epi32(v.raw)}; } HWY_API Vec256 PromoteTo(Full256 /* tag */, Vec128 v) { return Vec256{_mm256_cvtepu16_epi32(v.raw)}; } HWY_API Vec256 PromoteTo(Full256 /* tag */, Vec128 v) { return Vec256{_mm256_cvtepu32_epi64(v.raw)}; } // Signed: replicate sign bit. // Note: these have 3 cycle latency; if inputs are already split across the // 128 bit blocks (in their upper/lower halves), then ZipUpper/lo followed by // signed shift would be faster. HWY_API Vec256 PromoteTo(Full256 /* tag */, Vec128 v) { return Vec256{_mm256_cvtepi8_epi16(v.raw)}; } HWY_API Vec256 PromoteTo(Full256 /* tag */, Vec128 v) { return Vec256{_mm256_cvtepi8_epi32(v.raw)}; } HWY_API Vec256 PromoteTo(Full256 /* tag */, Vec128 v) { return Vec256{_mm256_cvtepi16_epi32(v.raw)}; } HWY_API Vec256 PromoteTo(Full256 /* tag */, Vec128 v) { return Vec256{_mm256_cvtepi32_epi64(v.raw)}; } // ------------------------------ Demotions (full -> part w/ narrow lanes) HWY_API Vec128 DemoteTo(Full128 /* tag */, const Vec256 v) { const __m256i u16 = _mm256_packus_epi32(v.raw, v.raw); // Concatenating lower halves of both 128-bit blocks afterward is more // efficient than an extra input with low block = high block of v. return Vec128{ _mm256_castsi256_si128(_mm256_permute4x64_epi64(u16, 0x88))}; } HWY_API Vec128 DemoteTo(Full128 /* tag */, const Vec256 v) { const __m256i i16 = _mm256_packs_epi32(v.raw, v.raw); return Vec128{ _mm256_castsi256_si128(_mm256_permute4x64_epi64(i16, 0x88))}; } HWY_API Vec128 DemoteTo(Full64 /* tag */, const Vec256 v) { const __m256i u16_blocks = _mm256_packus_epi32(v.raw, v.raw); // Concatenate lower 64 bits of each 128-bit block const __m256i u16_concat = _mm256_permute4x64_epi64(u16_blocks, 0x88); const __m128i u16 = _mm256_castsi256_si128(u16_concat); // packus treats the input as signed; we want unsigned. Clear the MSB to get // unsigned saturation to u8. const __m128i i16 = _mm_and_si128(u16, _mm_set1_epi16(0x7FFF)); return Vec128{_mm_packus_epi16(i16, i16)}; } HWY_API Vec128 DemoteTo(Full128 /* tag */, const Vec256 v) { const __m256i u8 = _mm256_packus_epi16(v.raw, v.raw); return Vec128{ _mm256_castsi256_si128(_mm256_permute4x64_epi64(u8, 0x88))}; } HWY_API Vec128 DemoteTo(Full64 /* tag */, const Vec256 v) { const __m256i i16_blocks = _mm256_packs_epi32(v.raw, v.raw); // Concatenate lower 64 bits of each 128-bit block const __m256i i16_concat = _mm256_permute4x64_epi64(i16_blocks, 0x88); const __m128i i16 = _mm256_castsi256_si128(i16_concat); return Vec128{_mm_packs_epi16(i16, i16)}; } HWY_API Vec128 DemoteTo(Full128 /* tag */, const Vec256 v) { const __m256i i8 = _mm256_packs_epi16(v.raw, v.raw); return Vec128{ _mm256_castsi256_si128(_mm256_permute4x64_epi64(i8, 0x88))}; } // Avoid "value of intrinsic immediate argument '8' is out of range '0 - 7'". // 8 is the correct value of _MM_FROUND_NO_EXC, which is allowed here. HWY_DIAGNOSTICS(push) HWY_DIAGNOSTICS_OFF(disable : 4556, ignored "-Wsign-conversion") HWY_API Vec128 DemoteTo(Full128 df16, const Vec256 v) { #ifdef HWY_DISABLE_F16C const RebindToUnsigned du16; const Rebind du; const RebindToSigned di; const auto bits32 = BitCast(du, v); const auto sign = ShiftRight<31>(bits32); const auto biased_exp32 = ShiftRight<23>(bits32) & Set(du, 0xFF); const auto mantissa32 = bits32 & Set(du, 0x7FFFFF); const auto k15 = Set(di, 15); const auto exp = Min(BitCast(di, biased_exp32) - Set(di, 127), k15); const auto is_tiny = exp < Set(di, -24); const auto is_subnormal = exp < Set(di, -14); const auto biased_exp16 = BitCast(du, IfThenZeroElse(is_subnormal, exp + k15)); const auto sub_exp = BitCast(du, Set(di, -14) - exp); // [1, 11) const auto sub_m = (Set(du, 1) << (Set(du, 10) - sub_exp)) + (mantissa32 >> (Set(du, 13) + sub_exp)); const auto mantissa16 = IfThenElse(RebindMask(du, is_subnormal), sub_m, ShiftRight<13>(mantissa32)); // <1024 const auto sign16 = ShiftLeft<15>(sign); const auto normal16 = sign16 | ShiftLeft<10>(biased_exp16) | mantissa16; const auto bits16 = IfThenZeroElse(is_tiny, BitCast(di, normal16)); return BitCast(df16, DemoteTo(du16, bits16)); #else (void)df16; return Vec128{_mm256_cvtps_ph(v.raw, _MM_FROUND_NO_EXC)}; #endif } HWY_DIAGNOSTICS(pop) HWY_API Vec128 DemoteTo(Full128 dbf16, const Vec256 v) { // TODO(janwas): _mm256_cvtneps_pbh once we have avx512bf16. const Rebind di32; const Rebind du32; // for logical shift right const Rebind du16; const auto bits_in_32 = BitCast(di32, ShiftRight<16>(BitCast(du32, v))); return BitCast(dbf16, DemoteTo(du16, bits_in_32)); } HWY_API Vec256 ReorderDemote2To(Full256 dbf16, Vec256 a, Vec256 b) { // TODO(janwas): _mm256_cvtne2ps_pbh once we have avx512bf16. const RebindToUnsigned du16; const Repartition du32; const Vec256 b_in_even = ShiftRight<16>(BitCast(du32, b)); return BitCast(dbf16, OddEven(BitCast(du16, a), BitCast(du16, b_in_even))); } HWY_API Vec128 DemoteTo(Full128 /* tag */, const Vec256 v) { return Vec128{_mm256_cvtpd_ps(v.raw)}; } HWY_API Vec128 DemoteTo(Full128 /* tag */, const Vec256 v) { const auto clamped = detail::ClampF64ToI32Max(Full256(), v); return Vec128{_mm256_cvttpd_epi32(clamped.raw)}; } // For already range-limited input [0, 255]. HWY_API Vec128 U8FromU32(const Vec256 v) { const Full256 d32; alignas(32) static constexpr uint32_t k8From32[8] = { 0x0C080400u, ~0u, ~0u, ~0u, ~0u, 0x0C080400u, ~0u, ~0u}; // Place first four bytes in lo[0], remaining 4 in hi[1]. const auto quad = TableLookupBytes(v, Load(d32, k8From32)); // Interleave both quadruplets - OR instead of unpack reduces port5 pressure. const auto lo = LowerHalf(quad); const auto hi = UpperHalf(Full128(), quad); const auto pair = LowerHalf(lo | hi); return BitCast(Full64(), pair); } // ------------------------------ Integer <=> fp (ShiftRight, OddEven) HWY_API Vec256 ConvertTo(Full256 /* tag */, const Vec256 v) { return Vec256{_mm256_cvtepi32_ps(v.raw)}; } HWY_API Vec256 ConvertTo(Full256 dd, const Vec256 v) { #if HWY_TARGET <= HWY_AVX3 (void)dd; return Vec256{_mm256_cvtepi64_pd(v.raw)}; #else // Based on wim's approach (https://stackoverflow.com/questions/41144668/) const Repartition d32; const Repartition d64; // Toggle MSB of lower 32-bits and insert exponent for 2^84 + 2^63 const auto k84_63 = Set(d64, 0x4530000080000000ULL); const auto v_upper = BitCast(dd, ShiftRight<32>(BitCast(d64, v)) ^ k84_63); // Exponent is 2^52, lower 32 bits from v (=> 32-bit OddEven) const auto k52 = Set(d32, 0x43300000); const auto v_lower = BitCast(dd, OddEven(k52, BitCast(d32, v))); const auto k84_63_52 = BitCast(dd, Set(d64, 0x4530000080100000ULL)); return (v_upper - k84_63_52) + v_lower; // order matters! #endif } // Truncates (rounds toward zero). HWY_API Vec256 ConvertTo(Full256 d, const Vec256 v) { return detail::FixConversionOverflow(d, v, _mm256_cvttps_epi32(v.raw)); } HWY_API Vec256 ConvertTo(Full256 di, const Vec256 v) { #if HWY_TARGET <= HWY_AVX3 return detail::FixConversionOverflow(di, v, _mm256_cvttpd_epi64(v.raw)); #else using VI = decltype(Zero(di)); const VI k0 = Zero(di); const VI k1 = Set(di, 1); const VI k51 = Set(di, 51); // Exponent indicates whether the number can be represented as int64_t. const VI biased_exp = ShiftRight<52>(BitCast(di, v)) & Set(di, 0x7FF); const VI exp = biased_exp - Set(di, 0x3FF); const auto in_range = exp < Set(di, 63); // If we were to cap the exponent at 51 and add 2^52, the number would be in // [2^52, 2^53) and mantissa bits could be read out directly. We need to // round-to-0 (truncate), but changing rounding mode in MXCSR hits a // compiler reordering bug: https://gcc.godbolt.org/z/4hKj6c6qc . We instead // manually shift the mantissa into place (we already have many of the // inputs anyway). const VI shift_mnt = Max(k51 - exp, k0); const VI shift_int = Max(exp - k51, k0); const VI mantissa = BitCast(di, v) & Set(di, (1ULL << 52) - 1); // Include implicit 1-bit; shift by one more to ensure it's in the mantissa. const VI int52 = (mantissa | Set(di, 1ULL << 52)) >> (shift_mnt + k1); // For inputs larger than 2^52, insert zeros at the bottom. const VI shifted = int52 << shift_int; // Restore the one bit lost when shifting in the implicit 1-bit. const VI restored = shifted | ((mantissa & k1) << (shift_int - k1)); // Saturate to LimitsMin (unchanged when negating below) or LimitsMax. const VI sign_mask = BroadcastSignBit(BitCast(di, v)); const VI limit = Set(di, LimitsMax()) - sign_mask; const VI magnitude = IfThenElse(in_range, restored, limit); // If the input was negative, negate the integer (two's complement). return (magnitude ^ sign_mask) - sign_mask; #endif } HWY_API Vec256 NearestInt(const Vec256 v) { const Full256 di; return detail::FixConversionOverflow(di, v, _mm256_cvtps_epi32(v.raw)); } HWY_API Vec256 PromoteTo(Full256 df32, const Vec128 v) { #ifdef HWY_DISABLE_F16C const RebindToSigned di32; const RebindToUnsigned du32; // Expand to u32 so we can shift. const auto bits16 = PromoteTo(du32, Vec128{v.raw}); const auto sign = ShiftRight<15>(bits16); const auto biased_exp = ShiftRight<10>(bits16) & Set(du32, 0x1F); const auto mantissa = bits16 & Set(du32, 0x3FF); const auto subnormal = BitCast(du32, ConvertTo(df32, BitCast(di32, mantissa)) * Set(df32, 1.0f / 16384 / 1024)); const auto biased_exp32 = biased_exp + Set(du32, 127 - 15); const auto mantissa32 = ShiftLeft<23 - 10>(mantissa); const auto normal = ShiftLeft<23>(biased_exp32) | mantissa32; const auto bits32 = IfThenElse(biased_exp == Zero(du32), subnormal, normal); return BitCast(df32, ShiftLeft<31>(sign) | bits32); #else (void)df32; return Vec256{_mm256_cvtph_ps(v.raw)}; #endif } HWY_API Vec256 PromoteTo(Full256 df32, const Vec128 v) { const Rebind du16; const RebindToSigned di32; return BitCast(df32, ShiftLeft<16>(PromoteTo(di32, BitCast(du16, v)))); } // ================================================== CRYPTO #if !defined(HWY_DISABLE_PCLMUL_AES) // Per-target flag to prevent generic_ops-inl.h from defining AESRound. #ifdef HWY_NATIVE_AES #undef HWY_NATIVE_AES #else #define HWY_NATIVE_AES #endif HWY_API Vec256 AESRound(Vec256 state, Vec256 round_key) { #if HWY_TARGET == HWY_AVX3_DL return Vec256{_mm256_aesenc_epi128(state.raw, round_key.raw)}; #else const Full256 d; const Half d2; return Combine(d, AESRound(UpperHalf(d2, state), UpperHalf(d2, round_key)), AESRound(LowerHalf(state), LowerHalf(round_key))); #endif } HWY_API Vec256 AESLastRound(Vec256 state, Vec256 round_key) { #if HWY_TARGET == HWY_AVX3_DL return Vec256{_mm256_aesenclast_epi128(state.raw, round_key.raw)}; #else const Full256 d; const Half d2; return Combine(d, AESLastRound(UpperHalf(d2, state), UpperHalf(d2, round_key)), AESLastRound(LowerHalf(state), LowerHalf(round_key))); #endif } HWY_API Vec256 CLMulLower(Vec256 a, Vec256 b) { #if HWY_TARGET == HWY_AVX3_DL return Vec256{_mm256_clmulepi64_epi128(a.raw, b.raw, 0x00)}; #else const Full256 d; const Half d2; return Combine(d, CLMulLower(UpperHalf(d2, a), UpperHalf(d2, b)), CLMulLower(LowerHalf(a), LowerHalf(b))); #endif } HWY_API Vec256 CLMulUpper(Vec256 a, Vec256 b) { #if HWY_TARGET == HWY_AVX3_DL return Vec256{_mm256_clmulepi64_epi128(a.raw, b.raw, 0x11)}; #else const Full256 d; const Half d2; return Combine(d, CLMulUpper(UpperHalf(d2, a), UpperHalf(d2, b)), CLMulUpper(LowerHalf(a), LowerHalf(b))); #endif } #endif // HWY_DISABLE_PCLMUL_AES // ================================================== MISC // Returns a vector with lane i=[0, N) set to "first" + i. template HWY_API Vec256 Iota(const Full256 d, const T2 first) { HWY_ALIGN T lanes[32 / sizeof(T)]; for (size_t i = 0; i < 32 / sizeof(T); ++i) { lanes[i] = static_cast(first + static_cast(i)); } return Load(d, lanes); } #if HWY_TARGET <= HWY_AVX3 // ------------------------------ LoadMaskBits // `p` points to at least 8 readable bytes, not all of which need be valid. template HWY_API Mask256 LoadMaskBits(const Full256 /* tag */, const uint8_t* HWY_RESTRICT bits) { constexpr size_t N = 32 / sizeof(T); constexpr size_t kNumBytes = (N + 7) / 8; uint64_t mask_bits = 0; CopyBytes(bits, &mask_bits); if (N < 8) { mask_bits &= (1ull << N) - 1; } return Mask256::FromBits(mask_bits); } // ------------------------------ StoreMaskBits // `p` points to at least 8 writable bytes. template HWY_API size_t StoreMaskBits(const Full256 /* tag */, const Mask256 mask, uint8_t* bits) { constexpr size_t N = 32 / sizeof(T); constexpr size_t kNumBytes = (N + 7) / 8; CopyBytes(&mask.raw, bits); // Non-full byte, need to clear the undefined upper bits. if (N < 8) { const int mask = static_cast((1ull << N) - 1); bits[0] = static_cast(bits[0] & mask); } return kNumBytes; } // ------------------------------ Mask testing template HWY_API size_t CountTrue(const Full256 /* tag */, const Mask256 mask) { return PopCount(static_cast(mask.raw)); } template HWY_API intptr_t FindFirstTrue(const Full256 /* tag */, const Mask256 mask) { return mask.raw ? intptr_t(Num0BitsBelowLS1Bit_Nonzero32(mask.raw)) : -1; } // Beware: the suffix indicates the number of mask bits, not lane size! namespace detail { template HWY_INLINE bool AllFalse(hwy::SizeTag<1> /*tag*/, const Mask256 mask) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return _kortestz_mask32_u8(mask.raw, mask.raw); #else return mask.raw == 0; #endif } template HWY_INLINE bool AllFalse(hwy::SizeTag<2> /*tag*/, const Mask256 mask) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return _kortestz_mask16_u8(mask.raw, mask.raw); #else return mask.raw == 0; #endif } template HWY_INLINE bool AllFalse(hwy::SizeTag<4> /*tag*/, const Mask256 mask) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return _kortestz_mask8_u8(mask.raw, mask.raw); #else return mask.raw == 0; #endif } template HWY_INLINE bool AllFalse(hwy::SizeTag<8> /*tag*/, const Mask256 mask) { return (uint64_t{mask.raw} & 0xF) == 0; } } // namespace detail template HWY_API bool AllFalse(const Full256 /* tag */, const Mask256 mask) { return detail::AllFalse(hwy::SizeTag(), mask); } namespace detail { template HWY_INLINE bool AllTrue(hwy::SizeTag<1> /*tag*/, const Mask256 mask) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return _kortestc_mask32_u8(mask.raw, mask.raw); #else return mask.raw == 0xFFFFFFFFu; #endif } template HWY_INLINE bool AllTrue(hwy::SizeTag<2> /*tag*/, const Mask256 mask) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return _kortestc_mask16_u8(mask.raw, mask.raw); #else return mask.raw == 0xFFFFu; #endif } template HWY_INLINE bool AllTrue(hwy::SizeTag<4> /*tag*/, const Mask256 mask) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return _kortestc_mask8_u8(mask.raw, mask.raw); #else return mask.raw == 0xFFu; #endif } template HWY_INLINE bool AllTrue(hwy::SizeTag<8> /*tag*/, const Mask256 mask) { // Cannot use _kortestc because we have less than 8 mask bits. return mask.raw == 0xFu; } } // namespace detail template HWY_API bool AllTrue(const Full256 /* tag */, const Mask256 mask) { return detail::AllTrue(hwy::SizeTag(), mask); } // ------------------------------ Compress // 16-bit is defined in x86_512 so we can use 512-bit vectors. template HWY_API Vec256 Compress(Vec256 v, Mask256 mask) { return Vec256{_mm256_maskz_compress_epi32(mask.raw, v.raw)}; } template HWY_API Vec256 Compress(Vec256 v, Mask256 mask) { return Vec256{_mm256_maskz_compress_epi64(mask.raw, v.raw)}; } HWY_API Vec256 Compress(Vec256 v, Mask256 mask) { return Vec256{_mm256_maskz_compress_ps(mask.raw, v.raw)}; } HWY_API Vec256 Compress(Vec256 v, Mask256 mask) { return Vec256{_mm256_maskz_compress_pd(mask.raw, v.raw)}; } // ------------------------------ CompressBits (LoadMaskBits) template HWY_API Vec256 CompressBits(Vec256 v, const uint8_t* HWY_RESTRICT bits) { return Compress(v, LoadMaskBits(Full256(), bits)); } // ------------------------------ CompressStore template HWY_API size_t CompressStore(Vec256 v, Mask256 mask, Full256 d, T* HWY_RESTRICT unaligned) { const Rebind du; const auto vu = BitCast(du, v); // (required for float16_t inputs) const uint64_t mask_bits{mask.raw}; #if HWY_TARGET == HWY_AVX3_DL // VBMI2 _mm256_mask_compressstoreu_epi16(unaligned, mask.raw, v.raw); #else // Split into halves to keep the table size manageable. const Half duh; const auto vL = LowerHalf(duh, vu); const auto vH = UpperHalf(duh, vu); const uint64_t mask_bitsL = mask_bits & 0xFF; const uint64_t mask_bitsH = mask_bits >> 8; const auto idxL = detail::IndicesForCompress16(mask_bitsL); const auto idxH = detail::IndicesForCompress16(mask_bitsH); // Compress and 128-bit halves. const Vec128 cL{_mm_permutexvar_epi16(idxL.raw, vL.raw)}; const Vec128 cH{_mm_permutexvar_epi16(idxH.raw, vH.raw)}; const Half dh; StoreU(BitCast(dh, cL), dh, unaligned); StoreU(BitCast(dh, cH), dh, unaligned + PopCount(mask_bitsL)); #endif // HWY_TARGET == HWY_AVX3_DL return PopCount(mask_bits); } template HWY_API size_t CompressStore(Vec256 v, Mask256 mask, Full256 /* tag */, T* HWY_RESTRICT unaligned) { _mm256_mask_compressstoreu_epi32(unaligned, mask.raw, v.raw); return PopCount(uint64_t{mask.raw}); } template HWY_API size_t CompressStore(Vec256 v, Mask256 mask, Full256 /* tag */, T* HWY_RESTRICT unaligned) { _mm256_mask_compressstoreu_epi64(unaligned, mask.raw, v.raw); return PopCount(uint64_t{mask.raw} & 0xFull); } HWY_API size_t CompressStore(Vec256 v, Mask256 mask, Full256 /* tag */, float* HWY_RESTRICT unaligned) { _mm256_mask_compressstoreu_ps(unaligned, mask.raw, v.raw); return PopCount(uint64_t{mask.raw}); } HWY_API size_t CompressStore(Vec256 v, Mask256 mask, Full256 /* tag */, double* HWY_RESTRICT unaligned) { _mm256_mask_compressstoreu_pd(unaligned, mask.raw, v.raw); return PopCount(uint64_t{mask.raw} & 0xFull); } // ------------------------------ CompressBlendedStore (CompressStore) #if HWY_TARGET == HWY_AVX2 namespace detail { // Intel SDM says "No AC# reported for any mask bit combinations". However, AMD // allows AC# if "Alignment checking enabled and: 256-bit memory operand not // 32-byte aligned". Fortunately AC# is not enabled by default and requires both // OS support (CR0) and the application to set rflags.AC. We assume these remain // disabled because x86/x64 code and compiler output often contain misaligned // scalar accesses, which would also fault. // // Caveat: these are slow on AMD Jaguar/Bulldozer. template HWY_API void MaskedStore(Mask256 m, Vec256 v, Full256 /* tag */, T* HWY_RESTRICT unaligned) { auto unaligned_p = reinterpret_cast(aligned); // NOLINT _mm256_maskstore_epi32(unaligned_p, m.raw, v.raw); } template HWY_API void MaskedStore(Mask256 m, Vec256 v, Full256 /* tag */, T* HWY_RESTRICT unaligned) { auto unaligned_p = reinterpret_cast(aligned); // NOLINT _mm256_maskstore_epi64(unaligned_p, m.raw, v.raw); } HWY_API void MaskedStore(Mask256 m, Vec256 v, Full256 d, float* HWY_RESTRICT unaligned) { const Vec256 mi = BitCast(RebindToSigned(), VecFromMask(d, m)); _mm256_maskstore_ps(unaligned, mi.raw, v.raw); } HWY_API void MaskedStore(Mask256 m, Vec256 v, Full256 d, double* HWY_RESTRICT unaligned) { const Vec256 mi = BitCast(RebindToSigned(), VecFromMask(d, m)); _mm256_maskstore_pd(unaligned, mi.raw, v.raw); } // There is no maskstore_epi8/16, so blend instead. template * = nullptr> HWY_API void MaskedStore(Mask256 m, Vec256 v, Full256 d, T* HWY_RESTRICT unaligned) { StoreU(IfThenElse(m, v, LoadU(d, unaligned)), d, unaligned); } } // namespace detail #endif // HWY_TARGET == HWY_AVX2 #if HWY_TARGET <= HWY_AVX3 template HWY_API size_t CompressBlendedStore(Vec256 v, Mask256 m, Full256 d, T* HWY_RESTRICT unaligned) { // Native (32 or 64-bit) AVX-512 instruction already does the blending at no // extra cost (latency 11, rthroughput 2 - same as compress plus store). return CompressStore(v, m, d, unaligned); } template HWY_API size_t CompressBlendedStore(Vec256 v, Mask256 m, Full256 d, T* HWY_RESTRICT unaligned) { #if HWY_TARGET <= HWY_AVX3_DL return CompressStore(v, m, d, unaligned); // also native #else const size_t count = CountTrue(d, m); const Vec256 compressed = Compress(v, m); // There is no 16-bit MaskedStore, so blend. const Vec256 prev = LoadU(d, unaligned); StoreU(IfThenElse(FirstN(d, count), compressed, prev), d, unaligned); return count; #endif } #else // AVX2 template HWY_API size_t CompressBlendedStore(Vec256 v, Mask256 m, Full256 d, T* HWY_RESTRICT unaligned) { const size_t count = CountTrue(m); detail::MaskedStore(FirstN(d, count), d, Compress(v, m)); } template HWY_API size_t CompressBlendedStore(Vec256 v, Mask256 m, Full256 d, T* HWY_RESTRICT unaligned) { // There is no 16-bit MaskedStore, so blend. const size_t count = CountTrue(m); const Vec256 compressed = Compress(v, m); const Vec256 prev = LoadU(d, unaligned); StoreU(IfThenElse(FirstN(d, count), compressed, prev), d, unaligned); return count; } #endif // AVX2 // ------------------------------ CompressBitsStore (LoadMaskBits) template HWY_API size_t CompressBitsStore(Vec256 v, const uint8_t* HWY_RESTRICT bits, Full256 d, T* HWY_RESTRICT unaligned) { return CompressStore(v, LoadMaskBits(d, bits), d, unaligned); } #else // AVX2 // ------------------------------ LoadMaskBits (TestBit) namespace detail { // 256 suffix avoids ambiguity with x86_128 without needing HWY_IF_LE128 there. template HWY_INLINE Mask256 LoadMaskBits256(Full256 d, uint64_t mask_bits) { const RebindToUnsigned du; const Repartition du32; const auto vbits = BitCast(du, Set(du32, static_cast(mask_bits))); // Replicate bytes 8x such that each byte contains the bit that governs it. const Repartition du64; alignas(32) constexpr uint64_t kRep8[4] = { 0x0000000000000000ull, 0x0101010101010101ull, 0x0202020202020202ull, 0x0303030303030303ull}; const auto rep8 = TableLookupBytes(vbits, BitCast(du, Load(du64, kRep8))); alignas(32) constexpr uint8_t kBit[16] = {1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128}; return RebindMask(d, TestBit(rep8, LoadDup128(du, kBit))); } template HWY_INLINE Mask256 LoadMaskBits256(Full256 d, uint64_t mask_bits) { const RebindToUnsigned du; alignas(32) constexpr uint16_t kBit[16] = { 1, 2, 4, 8, 16, 32, 64, 128, 0x100, 0x200, 0x400, 0x800, 0x1000, 0x2000, 0x4000, 0x8000}; const auto vmask_bits = Set(du, static_cast(mask_bits)); return RebindMask(d, TestBit(vmask_bits, Load(du, kBit))); } template HWY_INLINE Mask256 LoadMaskBits256(Full256 d, uint64_t mask_bits) { const RebindToUnsigned du; alignas(32) constexpr uint32_t kBit[8] = {1, 2, 4, 8, 16, 32, 64, 128}; const auto vmask_bits = Set(du, static_cast(mask_bits)); return RebindMask(d, TestBit(vmask_bits, Load(du, kBit))); } template HWY_INLINE Mask256 LoadMaskBits256(Full256 d, uint64_t mask_bits) { const RebindToUnsigned du; alignas(32) constexpr uint64_t kBit[8] = {1, 2, 4, 8}; return RebindMask(d, TestBit(Set(du, mask_bits), Load(du, kBit))); } } // namespace detail // `p` points to at least 8 readable bytes, not all of which need be valid. template HWY_API Mask256 LoadMaskBits(Full256 d, const uint8_t* HWY_RESTRICT bits) { constexpr size_t N = 32 / sizeof(T); constexpr size_t kNumBytes = (N + 7) / 8; uint64_t mask_bits = 0; CopyBytes(bits, &mask_bits); if (N < 8) { mask_bits &= (1ull << N) - 1; } return detail::LoadMaskBits256(d, mask_bits); } // ------------------------------ StoreMaskBits namespace detail { template HWY_INLINE uint64_t BitsFromMask(const Mask256 mask) { const Full256 d; const Full256 d8; const auto sign_bits = BitCast(d8, VecFromMask(d, mask)).raw; // Prevent sign-extension of 32-bit masks because the intrinsic returns int. return static_cast(_mm256_movemask_epi8(sign_bits)); } template HWY_INLINE uint64_t BitsFromMask(const Mask256 mask) { #if HWY_ARCH_X86_64 const Full256 d; const Full256 d8; const Mask256 mask8 = MaskFromVec(BitCast(d8, VecFromMask(d, mask))); const uint64_t sign_bits8 = BitsFromMask(mask8); // Skip the bits from the lower byte of each u16 (better not to use the // same packs_epi16 as SSE4, because that requires an extra swizzle here). return _pext_u64(sign_bits8, 0xAAAAAAAAull); #else // Slow workaround for 32-bit builds, which lack _pext_u64. // Remove useless lower half of each u16 while preserving the sign bit. // Bytes [0, 8) and [16, 24) have the same sign bits as the input lanes. const auto sign_bits = _mm256_packs_epi16(mask.raw, _mm256_setzero_si256()); // Move odd qwords (value zero) to top so they don't affect the mask value. const auto compressed = _mm256_permute4x64_epi64(sign_bits, _MM_SHUFFLE(3, 1, 2, 0)); return static_cast(_mm256_movemask_epi8(compressed)); #endif // HWY_ARCH_X86_64 } template HWY_INLINE uint64_t BitsFromMask(const Mask256 mask) { const Full256 d; const Full256 df; const auto sign_bits = BitCast(df, VecFromMask(d, mask)).raw; return static_cast(_mm256_movemask_ps(sign_bits)); } template HWY_INLINE uint64_t BitsFromMask(const Mask256 mask) { const Full256 d; const Full256 df; const auto sign_bits = BitCast(df, VecFromMask(d, mask)).raw; return static_cast(_mm256_movemask_pd(sign_bits)); } } // namespace detail // `p` points to at least 8 writable bytes. template HWY_API size_t StoreMaskBits(const Full256 /* tag */, const Mask256 mask, uint8_t* bits) { constexpr size_t N = 32 / sizeof(T); constexpr size_t kNumBytes = (N + 7) / 8; const uint64_t mask_bits = detail::BitsFromMask(mask); CopyBytes(&mask_bits, bits); return kNumBytes; } // ------------------------------ Mask testing // Specialize for 16-bit lanes to avoid unnecessary pext. This assumes each mask // lane is 0 or ~0. template HWY_API bool AllFalse(const Full256 d, const Mask256 mask) { const Repartition d8; const Mask256 mask8 = MaskFromVec(BitCast(d8, VecFromMask(d, mask))); return detail::BitsFromMask(mask8) == 0; } template HWY_API bool AllFalse(const Full256 /* tag */, const Mask256 mask) { // Cheaper than PTEST, which is 2 uop / 3L. return detail::BitsFromMask(mask) == 0; } template HWY_API bool AllTrue(const Full256 d, const Mask256 mask) { const Repartition d8; const Mask256 mask8 = MaskFromVec(BitCast(d8, VecFromMask(d, mask))); return detail::BitsFromMask(mask8) == (1ull << 32) - 1; } template HWY_API bool AllTrue(const Full256 /* tag */, const Mask256 mask) { constexpr uint64_t kAllBits = (1ull << (32 / sizeof(T))) - 1; return detail::BitsFromMask(mask) == kAllBits; } template HWY_API size_t CountTrue(const Full256 d, const Mask256 mask) { const Repartition d8; const Mask256 mask8 = MaskFromVec(BitCast(d8, VecFromMask(d, mask))); return PopCount(detail::BitsFromMask(mask8)) >> 1; } template HWY_API size_t CountTrue(const Full256 /* tag */, const Mask256 mask) { return PopCount(detail::BitsFromMask(mask)); } template HWY_API intptr_t FindFirstTrue(const Full256 /* tag */, const Mask256 mask) { const uint64_t mask_bits = detail::BitsFromMask(mask); return mask_bits ? intptr_t(Num0BitsBelowLS1Bit_Nonzero64(mask_bits)) : -1; } // ------------------------------ Compress, CompressBits namespace detail { template HWY_INLINE Indices256 IndicesFromBits(Full256 d, uint64_t mask_bits) { const RebindToUnsigned d32; // We need a masked Iota(). With 8 lanes, there are 256 combinations and a LUT // of SetTableIndices would require 8 KiB, a large part of L1D. The other // alternative is _pext_u64, but this is extremely slow on Zen2 (18 cycles) // and unavailable in 32-bit builds. We instead compress each index into 4 // bits, for a total of 1 KiB. alignas(16) constexpr uint32_t packed_array[256] = { 0x00000000, 0x00000000, 0x00000001, 0x00000010, 0x00000002, 0x00000020, 0x00000021, 0x00000210, 0x00000003, 0x00000030, 0x00000031, 0x00000310, 0x00000032, 0x00000320, 0x00000321, 0x00003210, 0x00000004, 0x00000040, 0x00000041, 0x00000410, 0x00000042, 0x00000420, 0x00000421, 0x00004210, 0x00000043, 0x00000430, 0x00000431, 0x00004310, 0x00000432, 0x00004320, 0x00004321, 0x00043210, 0x00000005, 0x00000050, 0x00000051, 0x00000510, 0x00000052, 0x00000520, 0x00000521, 0x00005210, 0x00000053, 0x00000530, 0x00000531, 0x00005310, 0x00000532, 0x00005320, 0x00005321, 0x00053210, 0x00000054, 0x00000540, 0x00000541, 0x00005410, 0x00000542, 0x00005420, 0x00005421, 0x00054210, 0x00000543, 0x00005430, 0x00005431, 0x00054310, 0x00005432, 0x00054320, 0x00054321, 0x00543210, 0x00000006, 0x00000060, 0x00000061, 0x00000610, 0x00000062, 0x00000620, 0x00000621, 0x00006210, 0x00000063, 0x00000630, 0x00000631, 0x00006310, 0x00000632, 0x00006320, 0x00006321, 0x00063210, 0x00000064, 0x00000640, 0x00000641, 0x00006410, 0x00000642, 0x00006420, 0x00006421, 0x00064210, 0x00000643, 0x00006430, 0x00006431, 0x00064310, 0x00006432, 0x00064320, 0x00064321, 0x00643210, 0x00000065, 0x00000650, 0x00000651, 0x00006510, 0x00000652, 0x00006520, 0x00006521, 0x00065210, 0x00000653, 0x00006530, 0x00006531, 0x00065310, 0x00006532, 0x00065320, 0x00065321, 0x00653210, 0x00000654, 0x00006540, 0x00006541, 0x00065410, 0x00006542, 0x00065420, 0x00065421, 0x00654210, 0x00006543, 0x00065430, 0x00065431, 0x00654310, 0x00065432, 0x00654320, 0x00654321, 0x06543210, 0x00000007, 0x00000070, 0x00000071, 0x00000710, 0x00000072, 0x00000720, 0x00000721, 0x00007210, 0x00000073, 0x00000730, 0x00000731, 0x00007310, 0x00000732, 0x00007320, 0x00007321, 0x00073210, 0x00000074, 0x00000740, 0x00000741, 0x00007410, 0x00000742, 0x00007420, 0x00007421, 0x00074210, 0x00000743, 0x00007430, 0x00007431, 0x00074310, 0x00007432, 0x00074320, 0x00074321, 0x00743210, 0x00000075, 0x00000750, 0x00000751, 0x00007510, 0x00000752, 0x00007520, 0x00007521, 0x00075210, 0x00000753, 0x00007530, 0x00007531, 0x00075310, 0x00007532, 0x00075320, 0x00075321, 0x00753210, 0x00000754, 0x00007540, 0x00007541, 0x00075410, 0x00007542, 0x00075420, 0x00075421, 0x00754210, 0x00007543, 0x00075430, 0x00075431, 0x00754310, 0x00075432, 0x00754320, 0x00754321, 0x07543210, 0x00000076, 0x00000760, 0x00000761, 0x00007610, 0x00000762, 0x00007620, 0x00007621, 0x00076210, 0x00000763, 0x00007630, 0x00007631, 0x00076310, 0x00007632, 0x00076320, 0x00076321, 0x00763210, 0x00000764, 0x00007640, 0x00007641, 0x00076410, 0x00007642, 0x00076420, 0x00076421, 0x00764210, 0x00007643, 0x00076430, 0x00076431, 0x00764310, 0x00076432, 0x00764320, 0x00764321, 0x07643210, 0x00000765, 0x00007650, 0x00007651, 0x00076510, 0x00007652, 0x00076520, 0x00076521, 0x00765210, 0x00007653, 0x00076530, 0x00076531, 0x00765310, 0x00076532, 0x00765320, 0x00765321, 0x07653210, 0x00007654, 0x00076540, 0x00076541, 0x00765410, 0x00076542, 0x00765420, 0x00765421, 0x07654210, 0x00076543, 0x00765430, 0x00765431, 0x07654310, 0x00765432, 0x07654320, 0x07654321, 0x76543210}; // No need to mask because _mm256_permutevar8x32_epi32 ignores bits 3..31. // Just shift each copy of the 32 bit LUT to extract its 4-bit fields. // If broadcasting 32-bit from memory incurs the 3-cycle block-crossing // latency, it may be faster to use LoadDup128 and PSHUFB. const auto packed = Set(d32, packed_array[mask_bits]); alignas(32) constexpr uint32_t shifts[8] = {0, 4, 8, 12, 16, 20, 24, 28}; return Indices256{(packed >> Load(d32, shifts)).raw}; } template HWY_INLINE Indices256 IndicesFromBits(Full256 d, uint64_t mask_bits) { const Repartition d32; // For 64-bit, we still need 32-bit indices because there is no 64-bit // permutevar, but there are only 4 lanes, so we can afford to skip the // unpacking and load the entire index vector directly. alignas(32) constexpr uint32_t packed_array[128] = { 0, 1, 0, 1, 0, 1, 0, 1, /**/ 0, 1, 0, 1, 0, 1, 0, 1, // 2, 3, 0, 1, 0, 1, 0, 1, /**/ 0, 1, 2, 3, 0, 1, 0, 1, // 4, 5, 0, 1, 0, 1, 0, 1, /**/ 0, 1, 4, 5, 0, 1, 0, 1, // 2, 3, 4, 5, 0, 1, 0, 1, /**/ 0, 1, 2, 3, 4, 5, 0, 1, // 6, 7, 0, 1, 0, 1, 0, 1, /**/ 0, 1, 6, 7, 0, 1, 0, 1, // 2, 3, 6, 7, 0, 1, 0, 1, /**/ 0, 1, 2, 3, 6, 7, 0, 1, // 4, 5, 6, 7, 0, 1, 0, 1, /**/ 0, 1, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, /**/ 0, 1, 2, 3, 4, 5, 6, 7}; return Indices256{Load(d32, packed_array + 8 * mask_bits).raw}; } template HWY_INLINE Vec256 Compress(Vec256 v, const uint64_t mask_bits) { const Full256 d; const Repartition du32; HWY_DASSERT(mask_bits < (1ull << (32 / sizeof(T)))); const auto indices = IndicesFromBits(d, mask_bits); return BitCast(d, TableLookupLanes(BitCast(du32, v), indices)); } // LUTs are infeasible for 2^16 possible masks. Promoting to 32-bit and using // the native Compress is probably more efficient than 2 LUTs. template HWY_INLINE Vec256 Compress(Vec256 v, const uint64_t mask_bits) { using D = Full256; const Rebind du; const Repartition dw; const auto vu16 = BitCast(du, v); // (required for float16_t inputs) const auto promoted0 = PromoteTo(dw, LowerHalf(vu16)); const auto promoted1 = PromoteTo(dw, UpperHalf(Half(), vu16)); const uint64_t mask_bits0 = mask_bits & 0xFF; const uint64_t mask_bits1 = mask_bits >> 8; const auto compressed0 = Compress(promoted0, mask_bits0); const auto compressed1 = Compress(promoted1, mask_bits1); const Half dh; const auto demoted0 = ZeroExtendVector(du, DemoteTo(dh, compressed0)); const auto demoted1 = ZeroExtendVector(du, DemoteTo(dh, compressed1)); const size_t count0 = PopCount(mask_bits0); // Now combine by shifting demoted1 up. AVX2 lacks VPERMW, so start with // VPERMD for shifting at 4 byte granularity. alignas(32) constexpr int32_t iota4[16] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7}; const auto indices = SetTableIndices(dw, iota4 + 8 - count0 / 2); const auto shift1_multiple4 = BitCast(du, TableLookupLanes(BitCast(dw, demoted1), indices)); // Whole-register unconditional shift by 2 bytes. // TODO(janwas): slow on AMD, use 2 shifts + permq + OR instead? const __m256i lo_zz = _mm256_permute2x128_si256(shift1_multiple4.raw, shift1_multiple4.raw, 0x08); const auto shift1_multiple2 = Vec256{_mm256_alignr_epi8(shift1_multiple4.raw, lo_zz, 14)}; // Make the shift conditional on the lower bit of count0. const auto m_odd = TestBit(Set(du, static_cast(count0)), Set(du, 1)); const auto shifted1 = IfThenElse(m_odd, shift1_multiple2, shift1_multiple4); // Blend the lower and shifted upper parts. constexpr uint16_t on = 0xFFFF; alignas(32) constexpr uint16_t lower_lanes[32] = {HWY_REP4(on), HWY_REP4(on), HWY_REP4(on), HWY_REP4(on)}; const auto m_lower = MaskFromVec(LoadU(du, lower_lanes + 16 - count0)); return BitCast(D(), IfThenElse(m_lower, demoted0, shifted1)); } } // namespace detail template HWY_API Vec256 Compress(Vec256 v, Mask256 m) { const uint64_t mask_bits = detail::BitsFromMask(m); return detail::Compress(v, mask_bits); } template HWY_API Vec256 CompressBits(Vec256 v, const uint8_t* HWY_RESTRICT bits) { constexpr size_t N = 32 / sizeof(T); constexpr size_t kNumBytes = (N + 7) / 8; uint64_t mask_bits = 0; CopyBytes(bits, &mask_bits); if (N < 8) { mask_bits &= (1ull << N) - 1; } return detail::Compress(v, mask_bits); } // ------------------------------ CompressStore, CompressBitsStore template HWY_API size_t CompressStore(Vec256 v, Mask256 m, Full256 d, T* HWY_RESTRICT unaligned) { const uint64_t mask_bits = detail::BitsFromMask(m); StoreU(detail::Compress(v, mask_bits), d, unaligned); return PopCount(mask_bits); } template HWY_API size_t CompressBlendedStore(Vec256 v, Mask256 m, Full256 d, T* HWY_RESTRICT unaligned) { const uint64_t mask_bits = detail::BitsFromMask(m); const size_t count = PopCount(mask_bits); const Vec256 compress = detail::Compress(v, mask_bits); const Vec256 prev = LoadU(d, unaligned); StoreU(IfThenElse(FirstN(d, count), compress, prev), d, unaligned); return count; } template HWY_API size_t CompressBitsStore(Vec256 v, const uint8_t* HWY_RESTRICT bits, Full256 d, T* HWY_RESTRICT unaligned) { constexpr size_t N = 32 / sizeof(T); constexpr size_t kNumBytes = (N + 7) / 8; uint64_t mask_bits = 0; CopyBytes(bits, &mask_bits); if (N < 8) { mask_bits &= (1ull << N) - 1; } StoreU(detail::Compress(v, mask_bits), d, unaligned); return PopCount(mask_bits); } #endif // HWY_TARGET <= HWY_AVX3 // ------------------------------ StoreInterleaved3 (CombineShiftRightBytes, // TableLookupBytes, ConcatUpperLower) HWY_API void StoreInterleaved3(const Vec256 v0, const Vec256 v1, const Vec256 v2, Full256 d, uint8_t* HWY_RESTRICT unaligned) { const auto k5 = Set(d, 5); const auto k6 = Set(d, 6); // Shuffle (v0,v1,v2) vector bytes to (MSB on left): r5, bgr[4:0]. // 0x80 so lanes to be filled from other vectors are 0 for blending. alignas(16) static constexpr uint8_t tbl_r0[16] = { 0, 0x80, 0x80, 1, 0x80, 0x80, 2, 0x80, 0x80, // 3, 0x80, 0x80, 4, 0x80, 0x80, 5}; alignas(16) static constexpr uint8_t tbl_g0[16] = { 0x80, 0, 0x80, 0x80, 1, 0x80, // 0x80, 2, 0x80, 0x80, 3, 0x80, 0x80, 4, 0x80, 0x80}; const auto shuf_r0 = LoadDup128(d, tbl_r0); const auto shuf_g0 = LoadDup128(d, tbl_g0); // cannot reuse r0 due to 5 const auto shuf_b0 = CombineShiftRightBytes<15>(d, shuf_g0, shuf_g0); const auto r0 = TableLookupBytes(v0, shuf_r0); // 5..4..3..2..1..0 const auto g0 = TableLookupBytes(v1, shuf_g0); // ..4..3..2..1..0. const auto b0 = TableLookupBytes(v2, shuf_b0); // .4..3..2..1..0.. const auto interleaved_10_00 = r0 | g0 | b0; // Second vector: g10,r10, bgr[9:6], b5,g5 const auto shuf_r1 = shuf_b0 + k6; // .A..9..8..7..6.. const auto shuf_g1 = shuf_r0 + k5; // A..9..8..7..6..5 const auto shuf_b1 = shuf_g0 + k5; // ..9..8..7..6..5. const auto r1 = TableLookupBytes(v0, shuf_r1); const auto g1 = TableLookupBytes(v1, shuf_g1); const auto b1 = TableLookupBytes(v2, shuf_b1); const auto interleaved_15_05 = r1 | g1 | b1; // We want to write the lower halves of the interleaved vectors, then the // upper halves. We could obtain 10_05 and 15_0A via ConcatUpperLower, but // that would require two ununaligned stores. For the lower halves, we can // merge two 128-bit stores for the same swizzling cost: const auto out0 = ConcatLowerLower(d, interleaved_15_05, interleaved_10_00); StoreU(out0, d, unaligned + 0 * 32); // Third vector: bgr[15:11], b10 const auto shuf_r2 = shuf_b1 + k6; // ..F..E..D..C..B. const auto shuf_g2 = shuf_r1 + k5; // .F..E..D..C..B.. const auto shuf_b2 = shuf_g1 + k5; // F..E..D..C..B..A const auto r2 = TableLookupBytes(v0, shuf_r2); const auto g2 = TableLookupBytes(v1, shuf_g2); const auto b2 = TableLookupBytes(v2, shuf_b2); const auto interleaved_1A_0A = r2 | g2 | b2; const auto out1 = ConcatUpperLower(d, interleaved_10_00, interleaved_1A_0A); StoreU(out1, d, unaligned + 1 * 32); const auto out2 = ConcatUpperUpper(d, interleaved_1A_0A, interleaved_15_05); StoreU(out2, d, unaligned + 2 * 32); } // ------------------------------ StoreInterleaved4 HWY_API void StoreInterleaved4(const Vec256 v0, const Vec256 v1, const Vec256 v2, const Vec256 v3, Full256 d8, uint8_t* HWY_RESTRICT unaligned) { const RepartitionToWide d16; const RepartitionToWide d32; // let a,b,c,d denote v0..3. const auto ba0 = ZipLower(d16, v0, v1); // b7 a7 .. b0 a0 const auto dc0 = ZipLower(d16, v2, v3); // d7 c7 .. d0 c0 const auto ba8 = ZipUpper(d16, v0, v1); const auto dc8 = ZipUpper(d16, v2, v3); const auto dcba_0 = ZipLower(d32, ba0, dc0); // d..a13 d..a10 | d..a03 d..a00 const auto dcba_4 = ZipUpper(d32, ba0, dc0); // d..a17 d..a14 | d..a07 d..a04 const auto dcba_8 = ZipLower(d32, ba8, dc8); // d..a1B d..a18 | d..a0B d..a08 const auto dcba_C = ZipUpper(d32, ba8, dc8); // d..a1F d..a1C | d..a0F d..a0C // Write lower halves, then upper. vperm2i128 is slow on Zen1 but we can // efficiently combine two lower halves into 256 bits: const auto out0 = BitCast(d8, ConcatLowerLower(d32, dcba_4, dcba_0)); const auto out1 = BitCast(d8, ConcatLowerLower(d32, dcba_C, dcba_8)); StoreU(out0, d8, unaligned + 0 * 32); StoreU(out1, d8, unaligned + 1 * 32); const auto out2 = BitCast(d8, ConcatUpperUpper(d32, dcba_4, dcba_0)); const auto out3 = BitCast(d8, ConcatUpperUpper(d32, dcba_C, dcba_8)); StoreU(out2, d8, unaligned + 2 * 32); StoreU(out3, d8, unaligned + 3 * 32); } // ------------------------------ Reductions namespace detail { // Returns sum{lane[i]} in each lane. "v3210" is a replicated 128-bit block. // Same logic as x86/128.h, but with Vec256 arguments. template HWY_INLINE Vec256 SumOfLanes(hwy::SizeTag<4> /* tag */, const Vec256 v3210) { const auto v1032 = Shuffle1032(v3210); const auto v31_20_31_20 = v3210 + v1032; const auto v20_31_20_31 = Shuffle0321(v31_20_31_20); return v20_31_20_31 + v31_20_31_20; } template HWY_INLINE Vec256 MinOfLanes(hwy::SizeTag<4> /* tag */, const Vec256 v3210) { const auto v1032 = Shuffle1032(v3210); const auto v31_20_31_20 = Min(v3210, v1032); const auto v20_31_20_31 = Shuffle0321(v31_20_31_20); return Min(v20_31_20_31, v31_20_31_20); } template HWY_INLINE Vec256 MaxOfLanes(hwy::SizeTag<4> /* tag */, const Vec256 v3210) { const auto v1032 = Shuffle1032(v3210); const auto v31_20_31_20 = Max(v3210, v1032); const auto v20_31_20_31 = Shuffle0321(v31_20_31_20); return Max(v20_31_20_31, v31_20_31_20); } template HWY_INLINE Vec256 SumOfLanes(hwy::SizeTag<8> /* tag */, const Vec256 v10) { const auto v01 = Shuffle01(v10); return v10 + v01; } template HWY_INLINE Vec256 MinOfLanes(hwy::SizeTag<8> /* tag */, const Vec256 v10) { const auto v01 = Shuffle01(v10); return Min(v10, v01); } template HWY_INLINE Vec256 MaxOfLanes(hwy::SizeTag<8> /* tag */, const Vec256 v10) { const auto v01 = Shuffle01(v10); return Max(v10, v01); } // u16/i16 template HWY_API Vec256 MinOfLanes(hwy::SizeTag<2> /* tag */, Vec256 v) { const Repartition> d32; const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); const auto odd = ShiftRight<16>(BitCast(d32, v)); const auto min = MinOfLanes(d32, Min(even, odd)); // Also broadcast into odd lanes. return BitCast(Full256(), Or(min, ShiftLeft<16>(min))); } template HWY_API Vec256 MaxOfLanes(hwy::SizeTag<2> /* tag */, Vec256 v) { const Repartition> d32; const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); const auto odd = ShiftRight<16>(BitCast(d32, v)); const auto min = MaxOfLanes(d32, Max(even, odd)); // Also broadcast into odd lanes. return BitCast(Full256(), Or(min, ShiftLeft<16>(min))); } } // namespace detail // Supported for {uif}32x8, {uif}64x4. Returns the sum in each lane. template HWY_API Vec256 SumOfLanes(Full256 d, const Vec256 vHL) { const Vec256 vLH = ConcatLowerUpper(d, vHL, vHL); return detail::SumOfLanes(hwy::SizeTag(), vLH + vHL); } template HWY_API Vec256 MinOfLanes(Full256 d, const Vec256 vHL) { const Vec256 vLH = ConcatLowerUpper(d, vHL, vHL); return detail::MinOfLanes(hwy::SizeTag(), Min(vLH, vHL)); } template HWY_API Vec256 MaxOfLanes(Full256 d, const Vec256 vHL) { const Vec256 vLH = ConcatLowerUpper(d, vHL, vHL); return detail::MaxOfLanes(hwy::SizeTag(), Max(vLH, vHL)); } // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace hwy HWY_AFTER_NAMESPACE();