// Copyright 2019 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // 512-bit AVX512 vectors and operations. // External include guard in highway.h - see comment there. // WARNING: most operations do not cross 128-bit block boundaries. In // particular, "Broadcast", pack and zip behavior may be surprising. #include // AVX2+ #include "hwy/base.h" #if defined(_MSC_VER) && defined(__clang__) // Including should be enough, but Clang's headers helpfully skip // including these headers when _MSC_VER is defined, like when using clang-cl. // Include these directly here. // clang-format off #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include // clang-format on #endif #include #include // For half-width vectors. Already includes base.h and shared-inl.h. #include "hwy/ops/x86_256-inl.h" HWY_BEFORE_NAMESPACE(); namespace hwy { namespace HWY_NAMESPACE { namespace detail { template struct Raw512 { using type = __m512i; }; template <> struct Raw512 { using type = __m512; }; template <> struct Raw512 { using type = __m512d; }; // Template arg: sizeof(lane type) template struct RawMask512 {}; template <> struct RawMask512<1> { using type = __mmask64; }; template <> struct RawMask512<2> { using type = __mmask32; }; template <> struct RawMask512<4> { using type = __mmask16; }; template <> struct RawMask512<8> { using type = __mmask8; }; } // namespace detail template class Vec512 { using Raw = typename detail::Raw512::type; public: // Compound assignment. Only usable if there is a corresponding non-member // binary operator overload. For example, only f32 and f64 support division. HWY_INLINE Vec512& operator*=(const Vec512 other) { return *this = (*this * other); } HWY_INLINE Vec512& operator/=(const Vec512 other) { return *this = (*this / other); } HWY_INLINE Vec512& operator+=(const Vec512 other) { return *this = (*this + other); } HWY_INLINE Vec512& operator-=(const Vec512 other) { return *this = (*this - other); } HWY_INLINE Vec512& operator&=(const Vec512 other) { return *this = (*this & other); } HWY_INLINE Vec512& operator|=(const Vec512 other) { return *this = (*this | other); } HWY_INLINE Vec512& operator^=(const Vec512 other) { return *this = (*this ^ other); } Raw raw; }; // Mask register: one bit per lane. template struct Mask512 { typename detail::RawMask512::type raw; }; // ------------------------------ BitCast namespace detail { HWY_INLINE __m512i BitCastToInteger(__m512i v) { return v; } HWY_INLINE __m512i BitCastToInteger(__m512 v) { return _mm512_castps_si512(v); } HWY_INLINE __m512i BitCastToInteger(__m512d v) { return _mm512_castpd_si512(v); } template HWY_INLINE Vec512 BitCastToByte(Vec512 v) { return Vec512{BitCastToInteger(v.raw)}; } // Cannot rely on function overloading because return types differ. template struct BitCastFromInteger512 { HWY_INLINE __m512i operator()(__m512i v) { return v; } }; template <> struct BitCastFromInteger512 { HWY_INLINE __m512 operator()(__m512i v) { return _mm512_castsi512_ps(v); } }; template <> struct BitCastFromInteger512 { HWY_INLINE __m512d operator()(__m512i v) { return _mm512_castsi512_pd(v); } }; template HWY_INLINE Vec512 BitCastFromByte(Full512 /* tag */, Vec512 v) { return Vec512{BitCastFromInteger512()(v.raw)}; } } // namespace detail template HWY_API Vec512 BitCast(Full512 d, Vec512 v) { return detail::BitCastFromByte(d, detail::BitCastToByte(v)); } // ------------------------------ Set // Returns an all-zero vector. template HWY_API Vec512 Zero(Full512 /* tag */) { return Vec512{_mm512_setzero_si512()}; } HWY_API Vec512 Zero(Full512 /* tag */) { return Vec512{_mm512_setzero_ps()}; } HWY_API Vec512 Zero(Full512 /* tag */) { return Vec512{_mm512_setzero_pd()}; } // Returns a vector with all lanes set to "t". HWY_API Vec512 Set(Full512 /* tag */, const uint8_t t) { return Vec512{_mm512_set1_epi8(static_cast(t))}; // NOLINT } HWY_API Vec512 Set(Full512 /* tag */, const uint16_t t) { return Vec512{_mm512_set1_epi16(static_cast(t))}; // NOLINT } HWY_API Vec512 Set(Full512 /* tag */, const uint32_t t) { return Vec512{_mm512_set1_epi32(static_cast(t))}; } HWY_API Vec512 Set(Full512 /* tag */, const uint64_t t) { return Vec512{ _mm512_set1_epi64(static_cast(t))}; // NOLINT } HWY_API Vec512 Set(Full512 /* tag */, const int8_t t) { return Vec512{_mm512_set1_epi8(static_cast(t))}; // NOLINT } HWY_API Vec512 Set(Full512 /* tag */, const int16_t t) { return Vec512{_mm512_set1_epi16(static_cast(t))}; // NOLINT } HWY_API Vec512 Set(Full512 /* tag */, const int32_t t) { return Vec512{_mm512_set1_epi32(t)}; } HWY_API Vec512 Set(Full512 /* tag */, const int64_t t) { return Vec512{ _mm512_set1_epi64(static_cast(t))}; // NOLINT } HWY_API Vec512 Set(Full512 /* tag */, const float t) { return Vec512{_mm512_set1_ps(t)}; } HWY_API Vec512 Set(Full512 /* tag */, const double t) { return Vec512{_mm512_set1_pd(t)}; } HWY_DIAGNOSTICS(push) HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized") // Returns a vector with uninitialized elements. template HWY_API Vec512 Undefined(Full512 /* tag */) { // Available on Clang 6.0, GCC 6.2, ICC 16.03, MSVC 19.14. All but ICC // generate an XOR instruction. return Vec512{_mm512_undefined_epi32()}; } HWY_API Vec512 Undefined(Full512 /* tag */) { return Vec512{_mm512_undefined_ps()}; } HWY_API Vec512 Undefined(Full512 /* tag */) { return Vec512{_mm512_undefined_pd()}; } HWY_DIAGNOSTICS(pop) // ================================================== LOGICAL // ------------------------------ Not template HWY_API Vec512 Not(const Vec512 v) { using TU = MakeUnsigned; const __m512i vu = BitCast(Full512(), v).raw; return BitCast(Full512(), Vec512{_mm512_ternarylogic_epi32(vu, vu, vu, 0x55)}); } // ------------------------------ And template HWY_API Vec512 And(const Vec512 a, const Vec512 b) { return Vec512{_mm512_and_si512(a.raw, b.raw)}; } HWY_API Vec512 And(const Vec512 a, const Vec512 b) { return Vec512{_mm512_and_ps(a.raw, b.raw)}; } HWY_API Vec512 And(const Vec512 a, const Vec512 b) { return Vec512{_mm512_and_pd(a.raw, b.raw)}; } // ------------------------------ AndNot // Returns ~not_mask & mask. template HWY_API Vec512 AndNot(const Vec512 not_mask, const Vec512 mask) { return Vec512{_mm512_andnot_si512(not_mask.raw, mask.raw)}; } HWY_API Vec512 AndNot(const Vec512 not_mask, const Vec512 mask) { return Vec512{_mm512_andnot_ps(not_mask.raw, mask.raw)}; } HWY_API Vec512 AndNot(const Vec512 not_mask, const Vec512 mask) { return Vec512{_mm512_andnot_pd(not_mask.raw, mask.raw)}; } // ------------------------------ Or template HWY_API Vec512 Or(const Vec512 a, const Vec512 b) { return Vec512{_mm512_or_si512(a.raw, b.raw)}; } HWY_API Vec512 Or(const Vec512 a, const Vec512 b) { return Vec512{_mm512_or_ps(a.raw, b.raw)}; } HWY_API Vec512 Or(const Vec512 a, const Vec512 b) { return Vec512{_mm512_or_pd(a.raw, b.raw)}; } // ------------------------------ Xor template HWY_API Vec512 Xor(const Vec512 a, const Vec512 b) { return Vec512{_mm512_xor_si512(a.raw, b.raw)}; } HWY_API Vec512 Xor(const Vec512 a, const Vec512 b) { return Vec512{_mm512_xor_ps(a.raw, b.raw)}; } HWY_API Vec512 Xor(const Vec512 a, const Vec512 b) { return Vec512{_mm512_xor_pd(a.raw, b.raw)}; } // ------------------------------ OrAnd template HWY_API Vec512 OrAnd(Vec512 o, Vec512 a1, Vec512 a2) { const Full512 d; const RebindToUnsigned du; using VU = VFromD; const __m512i ret = _mm512_ternarylogic_epi64( BitCast(du, o).raw, BitCast(du, a1).raw, BitCast(du, a2).raw, 0xF8); return BitCast(d, VU{ret}); } // ------------------------------ IfVecThenElse template HWY_API Vec512 IfVecThenElse(Vec512 mask, Vec512 yes, Vec512 no) { const Full512 d; const RebindToUnsigned du; using VU = VFromD; return BitCast(d, VU{_mm512_ternarylogic_epi64(BitCast(du, mask).raw, BitCast(du, yes).raw, BitCast(du, no).raw, 0xCA)}); } // ------------------------------ Operator overloads (internal-only if float) template HWY_API Vec512 operator&(const Vec512 a, const Vec512 b) { return And(a, b); } template HWY_API Vec512 operator|(const Vec512 a, const Vec512 b) { return Or(a, b); } template HWY_API Vec512 operator^(const Vec512 a, const Vec512 b) { return Xor(a, b); } // ------------------------------ PopulationCount // 8/16 require BITALG, 32/64 require VPOPCNTDQ. #if HWY_TARGET == HWY_AVX3_DL #ifdef HWY_NATIVE_POPCNT #undef HWY_NATIVE_POPCNT #else #define HWY_NATIVE_POPCNT #endif namespace detail { template HWY_INLINE Vec512 PopulationCount(hwy::SizeTag<1> /* tag */, Vec512 v) { return Vec512{_mm512_popcnt_epi8(v.raw)}; } template HWY_INLINE Vec512 PopulationCount(hwy::SizeTag<2> /* tag */, Vec512 v) { return Vec512{_mm512_popcnt_epi16(v.raw)}; } template HWY_INLINE Vec512 PopulationCount(hwy::SizeTag<4> /* tag */, Vec512 v) { return Vec512{_mm512_popcnt_epi32(v.raw)}; } template HWY_INLINE Vec512 PopulationCount(hwy::SizeTag<8> /* tag */, Vec512 v) { return Vec512{_mm512_popcnt_epi64(v.raw)}; } } // namespace detail template HWY_API Vec512 PopulationCount(Vec512 v) { return detail::PopulationCount(hwy::SizeTag(), v); } #endif // HWY_TARGET == HWY_AVX3_DL // ================================================== SIGN // ------------------------------ CopySign template HWY_API Vec512 CopySign(const Vec512 magn, const Vec512 sign) { static_assert(IsFloat(), "Only makes sense for floating-point"); const Full512 d; const auto msb = SignBit(d); const Rebind, decltype(d)> du; // Truth table for msb, magn, sign | bitwise msb ? sign : mag // 0 0 0 | 0 // 0 0 1 | 0 // 0 1 0 | 1 // 0 1 1 | 1 // 1 0 0 | 0 // 1 0 1 | 1 // 1 1 0 | 0 // 1 1 1 | 1 // The lane size does not matter because we are not using predication. const __m512i out = _mm512_ternarylogic_epi32( BitCast(du, msb).raw, BitCast(du, magn).raw, BitCast(du, sign).raw, 0xAC); return BitCast(d, decltype(Zero(du)){out}); } template HWY_API Vec512 CopySignToAbs(const Vec512 abs, const Vec512 sign) { // AVX3 can also handle abs < 0, so no extra action needed. return CopySign(abs, sign); } // ================================================== MASK // ------------------------------ FirstN // Possibilities for constructing a bitmask of N ones: // - kshift* only consider the lowest byte of the shift count, so they would // not correctly handle large n. // - Scalar shifts >= 64 are UB. // - BZHI has the desired semantics; we assume AVX-512 implies BMI2. However, // we need 64-bit masks for sizeof(T) == 1, so special-case 32-bit builds. #if HWY_ARCH_X86_32 namespace detail { // 32 bit mask is sufficient for lane size >= 2. template HWY_INLINE Mask512 FirstN(size_t n) { Mask512 m; const uint32_t all = ~uint32_t(0); // BZHI only looks at the lower 8 bits of n! m.raw = static_cast((n > 255) ? all : _bzhi_u32(all, n)); return m; } template HWY_INLINE Mask512 FirstN(size_t n) { const uint64_t bits = n < 64 ? ((1ULL << n) - 1) : ~uint64_t(0); return Mask512{static_cast<__mmask64>(bits)}; } } // namespace detail #endif // HWY_ARCH_X86_32 template HWY_API Mask512 FirstN(const Full512 /*tag*/, size_t n) { #if HWY_ARCH_X86_64 Mask512 m; const uint64_t all = ~uint64_t(0); // BZHI only looks at the lower 8 bits of n! m.raw = static_cast((n > 255) ? all : _bzhi_u64(all, n)); return m; #else return detail::FirstN(n); #endif // HWY_ARCH_X86_64 } // ------------------------------ IfThenElse // Returns mask ? b : a. namespace detail { // Templates for signed/unsigned integer of a particular size. template HWY_INLINE Vec512 IfThenElse(hwy::SizeTag<1> /* tag */, const Mask512 mask, const Vec512 yes, const Vec512 no) { return Vec512{_mm512_mask_mov_epi8(no.raw, mask.raw, yes.raw)}; } template HWY_INLINE Vec512 IfThenElse(hwy::SizeTag<2> /* tag */, const Mask512 mask, const Vec512 yes, const Vec512 no) { return Vec512{_mm512_mask_mov_epi16(no.raw, mask.raw, yes.raw)}; } template HWY_INLINE Vec512 IfThenElse(hwy::SizeTag<4> /* tag */, const Mask512 mask, const Vec512 yes, const Vec512 no) { return Vec512{_mm512_mask_mov_epi32(no.raw, mask.raw, yes.raw)}; } template HWY_INLINE Vec512 IfThenElse(hwy::SizeTag<8> /* tag */, const Mask512 mask, const Vec512 yes, const Vec512 no) { return Vec512{_mm512_mask_mov_epi64(no.raw, mask.raw, yes.raw)}; } } // namespace detail template HWY_API Vec512 IfThenElse(const Mask512 mask, const Vec512 yes, const Vec512 no) { return detail::IfThenElse(hwy::SizeTag(), mask, yes, no); } HWY_API Vec512 IfThenElse(const Mask512 mask, const Vec512 yes, const Vec512 no) { return Vec512{_mm512_mask_mov_ps(no.raw, mask.raw, yes.raw)}; } HWY_API Vec512 IfThenElse(const Mask512 mask, const Vec512 yes, const Vec512 no) { return Vec512{_mm512_mask_mov_pd(no.raw, mask.raw, yes.raw)}; } namespace detail { template HWY_INLINE Vec512 IfThenElseZero(hwy::SizeTag<1> /* tag */, const Mask512 mask, const Vec512 yes) { return Vec512{_mm512_maskz_mov_epi8(mask.raw, yes.raw)}; } template HWY_INLINE Vec512 IfThenElseZero(hwy::SizeTag<2> /* tag */, const Mask512 mask, const Vec512 yes) { return Vec512{_mm512_maskz_mov_epi16(mask.raw, yes.raw)}; } template HWY_INLINE Vec512 IfThenElseZero(hwy::SizeTag<4> /* tag */, const Mask512 mask, const Vec512 yes) { return Vec512{_mm512_maskz_mov_epi32(mask.raw, yes.raw)}; } template HWY_INLINE Vec512 IfThenElseZero(hwy::SizeTag<8> /* tag */, const Mask512 mask, const Vec512 yes) { return Vec512{_mm512_maskz_mov_epi64(mask.raw, yes.raw)}; } } // namespace detail template HWY_API Vec512 IfThenElseZero(const Mask512 mask, const Vec512 yes) { return detail::IfThenElseZero(hwy::SizeTag(), mask, yes); } HWY_API Vec512 IfThenElseZero(const Mask512 mask, const Vec512 yes) { return Vec512{_mm512_maskz_mov_ps(mask.raw, yes.raw)}; } HWY_API Vec512 IfThenElseZero(const Mask512 mask, const Vec512 yes) { return Vec512{_mm512_maskz_mov_pd(mask.raw, yes.raw)}; } namespace detail { template HWY_INLINE Vec512 IfThenZeroElse(hwy::SizeTag<1> /* tag */, const Mask512 mask, const Vec512 no) { // xor_epi8/16 are missing, but we have sub, which is just as fast for u8/16. return Vec512{_mm512_mask_sub_epi8(no.raw, mask.raw, no.raw, no.raw)}; } template HWY_INLINE Vec512 IfThenZeroElse(hwy::SizeTag<2> /* tag */, const Mask512 mask, const Vec512 no) { return Vec512{_mm512_mask_sub_epi16(no.raw, mask.raw, no.raw, no.raw)}; } template HWY_INLINE Vec512 IfThenZeroElse(hwy::SizeTag<4> /* tag */, const Mask512 mask, const Vec512 no) { return Vec512{_mm512_mask_xor_epi32(no.raw, mask.raw, no.raw, no.raw)}; } template HWY_INLINE Vec512 IfThenZeroElse(hwy::SizeTag<8> /* tag */, const Mask512 mask, const Vec512 no) { return Vec512{_mm512_mask_xor_epi64(no.raw, mask.raw, no.raw, no.raw)}; } } // namespace detail template HWY_API Vec512 IfThenZeroElse(const Mask512 mask, const Vec512 no) { return detail::IfThenZeroElse(hwy::SizeTag(), mask, no); } HWY_API Vec512 IfThenZeroElse(const Mask512 mask, const Vec512 no) { return Vec512{_mm512_mask_xor_ps(no.raw, mask.raw, no.raw, no.raw)}; } HWY_API Vec512 IfThenZeroElse(const Mask512 mask, const Vec512 no) { return Vec512{_mm512_mask_xor_pd(no.raw, mask.raw, no.raw, no.raw)}; } template HWY_API Vec512 IfNegativeThenElse(Vec512 v, Vec512 yes, Vec512 no) { static_assert(IsSigned(), "Only works for signed/float"); // AVX3 MaskFromVec only looks at the MSB return IfThenElse(MaskFromVec(v), yes, no); } template HWY_API Vec512 ZeroIfNegative(const Vec512 v) { // AVX3 MaskFromVec only looks at the MSB return IfThenZeroElse(MaskFromVec(v), v); } // ================================================== ARITHMETIC // ------------------------------ Addition // Unsigned HWY_API Vec512 operator+(const Vec512 a, const Vec512 b) { return Vec512{_mm512_add_epi8(a.raw, b.raw)}; } HWY_API Vec512 operator+(const Vec512 a, const Vec512 b) { return Vec512{_mm512_add_epi16(a.raw, b.raw)}; } HWY_API Vec512 operator+(const Vec512 a, const Vec512 b) { return Vec512{_mm512_add_epi32(a.raw, b.raw)}; } HWY_API Vec512 operator+(const Vec512 a, const Vec512 b) { return Vec512{_mm512_add_epi64(a.raw, b.raw)}; } // Signed HWY_API Vec512 operator+(const Vec512 a, const Vec512 b) { return Vec512{_mm512_add_epi8(a.raw, b.raw)}; } HWY_API Vec512 operator+(const Vec512 a, const Vec512 b) { return Vec512{_mm512_add_epi16(a.raw, b.raw)}; } HWY_API Vec512 operator+(const Vec512 a, const Vec512 b) { return Vec512{_mm512_add_epi32(a.raw, b.raw)}; } HWY_API Vec512 operator+(const Vec512 a, const Vec512 b) { return Vec512{_mm512_add_epi64(a.raw, b.raw)}; } // Float HWY_API Vec512 operator+(const Vec512 a, const Vec512 b) { return Vec512{_mm512_add_ps(a.raw, b.raw)}; } HWY_API Vec512 operator+(const Vec512 a, const Vec512 b) { return Vec512{_mm512_add_pd(a.raw, b.raw)}; } // ------------------------------ Subtraction // Unsigned HWY_API Vec512 operator-(const Vec512 a, const Vec512 b) { return Vec512{_mm512_sub_epi8(a.raw, b.raw)}; } HWY_API Vec512 operator-(const Vec512 a, const Vec512 b) { return Vec512{_mm512_sub_epi16(a.raw, b.raw)}; } HWY_API Vec512 operator-(const Vec512 a, const Vec512 b) { return Vec512{_mm512_sub_epi32(a.raw, b.raw)}; } HWY_API Vec512 operator-(const Vec512 a, const Vec512 b) { return Vec512{_mm512_sub_epi64(a.raw, b.raw)}; } // Signed HWY_API Vec512 operator-(const Vec512 a, const Vec512 b) { return Vec512{_mm512_sub_epi8(a.raw, b.raw)}; } HWY_API Vec512 operator-(const Vec512 a, const Vec512 b) { return Vec512{_mm512_sub_epi16(a.raw, b.raw)}; } HWY_API Vec512 operator-(const Vec512 a, const Vec512 b) { return Vec512{_mm512_sub_epi32(a.raw, b.raw)}; } HWY_API Vec512 operator-(const Vec512 a, const Vec512 b) { return Vec512{_mm512_sub_epi64(a.raw, b.raw)}; } // Float HWY_API Vec512 operator-(const Vec512 a, const Vec512 b) { return Vec512{_mm512_sub_ps(a.raw, b.raw)}; } HWY_API Vec512 operator-(const Vec512 a, const Vec512 b) { return Vec512{_mm512_sub_pd(a.raw, b.raw)}; } // ------------------------------ SumsOf8 HWY_API Vec512 SumsOf8(const Vec512 v) { return Vec512{_mm512_sad_epu8(v.raw, _mm512_setzero_si512())}; } // ------------------------------ SaturatedAdd // Returns a + b clamped to the destination range. // Unsigned HWY_API Vec512 SaturatedAdd(const Vec512 a, const Vec512 b) { return Vec512{_mm512_adds_epu8(a.raw, b.raw)}; } HWY_API Vec512 SaturatedAdd(const Vec512 a, const Vec512 b) { return Vec512{_mm512_adds_epu16(a.raw, b.raw)}; } // Signed HWY_API Vec512 SaturatedAdd(const Vec512 a, const Vec512 b) { return Vec512{_mm512_adds_epi8(a.raw, b.raw)}; } HWY_API Vec512 SaturatedAdd(const Vec512 a, const Vec512 b) { return Vec512{_mm512_adds_epi16(a.raw, b.raw)}; } // ------------------------------ SaturatedSub // Returns a - b clamped to the destination range. // Unsigned HWY_API Vec512 SaturatedSub(const Vec512 a, const Vec512 b) { return Vec512{_mm512_subs_epu8(a.raw, b.raw)}; } HWY_API Vec512 SaturatedSub(const Vec512 a, const Vec512 b) { return Vec512{_mm512_subs_epu16(a.raw, b.raw)}; } // Signed HWY_API Vec512 SaturatedSub(const Vec512 a, const Vec512 b) { return Vec512{_mm512_subs_epi8(a.raw, b.raw)}; } HWY_API Vec512 SaturatedSub(const Vec512 a, const Vec512 b) { return Vec512{_mm512_subs_epi16(a.raw, b.raw)}; } // ------------------------------ Average // Returns (a + b + 1) / 2 // Unsigned HWY_API Vec512 AverageRound(const Vec512 a, const Vec512 b) { return Vec512{_mm512_avg_epu8(a.raw, b.raw)}; } HWY_API Vec512 AverageRound(const Vec512 a, const Vec512 b) { return Vec512{_mm512_avg_epu16(a.raw, b.raw)}; } // ------------------------------ Abs (Sub) // Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1. HWY_API Vec512 Abs(const Vec512 v) { #if HWY_COMPILER_MSVC // Workaround for incorrect codegen? (untested due to internal compiler error) const auto zero = Zero(Full512()); return Vec512{_mm512_max_epi8(v.raw, (zero - v).raw)}; #else return Vec512{_mm512_abs_epi8(v.raw)}; #endif } HWY_API Vec512 Abs(const Vec512 v) { return Vec512{_mm512_abs_epi16(v.raw)}; } HWY_API Vec512 Abs(const Vec512 v) { return Vec512{_mm512_abs_epi32(v.raw)}; } HWY_API Vec512 Abs(const Vec512 v) { return Vec512{_mm512_abs_epi64(v.raw)}; } // These aren't native instructions, they also involve AND with constant. HWY_API Vec512 Abs(const Vec512 v) { return Vec512{_mm512_abs_ps(v.raw)}; } HWY_API Vec512 Abs(const Vec512 v) { return Vec512{_mm512_abs_pd(v.raw)}; } // ------------------------------ ShiftLeft template HWY_API Vec512 ShiftLeft(const Vec512 v) { return Vec512{_mm512_slli_epi16(v.raw, kBits)}; } template HWY_API Vec512 ShiftLeft(const Vec512 v) { return Vec512{_mm512_slli_epi32(v.raw, kBits)}; } template HWY_API Vec512 ShiftLeft(const Vec512 v) { return Vec512{_mm512_slli_epi64(v.raw, kBits)}; } template HWY_API Vec512 ShiftLeft(const Vec512 v) { return Vec512{_mm512_slli_epi16(v.raw, kBits)}; } template HWY_API Vec512 ShiftLeft(const Vec512 v) { return Vec512{_mm512_slli_epi32(v.raw, kBits)}; } template HWY_API Vec512 ShiftLeft(const Vec512 v) { return Vec512{_mm512_slli_epi64(v.raw, kBits)}; } template HWY_API Vec512 ShiftLeft(const Vec512 v) { const Full512 d8; const RepartitionToWide d16; const auto shifted = BitCast(d8, ShiftLeft(BitCast(d16, v))); return kBits == 1 ? (v + v) : (shifted & Set(d8, static_cast((0xFF << kBits) & 0xFF))); } // ------------------------------ ShiftRight template HWY_API Vec512 ShiftRight(const Vec512 v) { return Vec512{_mm512_srli_epi16(v.raw, kBits)}; } template HWY_API Vec512 ShiftRight(const Vec512 v) { return Vec512{_mm512_srli_epi32(v.raw, kBits)}; } template HWY_API Vec512 ShiftRight(const Vec512 v) { return Vec512{_mm512_srli_epi64(v.raw, kBits)}; } template HWY_API Vec512 ShiftRight(const Vec512 v) { const Full512 d8; // Use raw instead of BitCast to support N=1. const Vec512 shifted{ShiftRight(Vec512{v.raw}).raw}; return shifted & Set(d8, 0xFF >> kBits); } template HWY_API Vec512 ShiftRight(const Vec512 v) { return Vec512{_mm512_srai_epi16(v.raw, kBits)}; } template HWY_API Vec512 ShiftRight(const Vec512 v) { return Vec512{_mm512_srai_epi32(v.raw, kBits)}; } template HWY_API Vec512 ShiftRight(const Vec512 v) { return Vec512{_mm512_srai_epi64(v.raw, kBits)}; } template HWY_API Vec512 ShiftRight(const Vec512 v) { const Full512 di; const Full512 du; const auto shifted = BitCast(di, ShiftRight(BitCast(du, v))); const auto shifted_sign = BitCast(di, Set(du, 0x80 >> kBits)); return (shifted ^ shifted_sign) - shifted_sign; } // ------------------------------ RotateRight template HWY_API Vec512 RotateRight(const Vec512 v) { static_assert(0 <= kBits && kBits < 32, "Invalid shift count"); return Vec512{_mm512_ror_epi32(v.raw, kBits)}; } template HWY_API Vec512 RotateRight(const Vec512 v) { static_assert(0 <= kBits && kBits < 64, "Invalid shift count"); return Vec512{_mm512_ror_epi64(v.raw, kBits)}; } // ------------------------------ ShiftLeftSame HWY_API Vec512 ShiftLeftSame(const Vec512 v, const int bits) { return Vec512{_mm512_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; } HWY_API Vec512 ShiftLeftSame(const Vec512 v, const int bits) { return Vec512{_mm512_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; } HWY_API Vec512 ShiftLeftSame(const Vec512 v, const int bits) { return Vec512{_mm512_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; } HWY_API Vec512 ShiftLeftSame(const Vec512 v, const int bits) { return Vec512{_mm512_sll_epi16(v.raw, _mm_cvtsi32_si128(bits))}; } HWY_API Vec512 ShiftLeftSame(const Vec512 v, const int bits) { return Vec512{_mm512_sll_epi32(v.raw, _mm_cvtsi32_si128(bits))}; } HWY_API Vec512 ShiftLeftSame(const Vec512 v, const int bits) { return Vec512{_mm512_sll_epi64(v.raw, _mm_cvtsi32_si128(bits))}; } template HWY_API Vec512 ShiftLeftSame(const Vec512 v, const int bits) { const Full512 d8; const RepartitionToWide d16; const auto shifted = BitCast(d8, ShiftLeftSame(BitCast(d16, v), bits)); return shifted & Set(d8, static_cast((0xFF << bits) & 0xFF)); } // ------------------------------ ShiftRightSame HWY_API Vec512 ShiftRightSame(const Vec512 v, const int bits) { return Vec512{_mm512_srl_epi16(v.raw, _mm_cvtsi32_si128(bits))}; } HWY_API Vec512 ShiftRightSame(const Vec512 v, const int bits) { return Vec512{_mm512_srl_epi32(v.raw, _mm_cvtsi32_si128(bits))}; } HWY_API Vec512 ShiftRightSame(const Vec512 v, const int bits) { return Vec512{_mm512_srl_epi64(v.raw, _mm_cvtsi32_si128(bits))}; } HWY_API Vec512 ShiftRightSame(Vec512 v, const int bits) { const Full512 d8; const RepartitionToWide d16; const auto shifted = BitCast(d8, ShiftRightSame(BitCast(d16, v), bits)); return shifted & Set(d8, static_cast(0xFF >> bits)); } HWY_API Vec512 ShiftRightSame(const Vec512 v, const int bits) { return Vec512{_mm512_sra_epi16(v.raw, _mm_cvtsi32_si128(bits))}; } HWY_API Vec512 ShiftRightSame(const Vec512 v, const int bits) { return Vec512{_mm512_sra_epi32(v.raw, _mm_cvtsi32_si128(bits))}; } HWY_API Vec512 ShiftRightSame(const Vec512 v, const int bits) { return Vec512{_mm512_sra_epi64(v.raw, _mm_cvtsi32_si128(bits))}; } HWY_API Vec512 ShiftRightSame(Vec512 v, const int bits) { const Full512 di; const Full512 du; const auto shifted = BitCast(di, ShiftRightSame(BitCast(du, v), bits)); const auto shifted_sign = BitCast(di, Set(du, static_cast(0x80 >> bits))); return (shifted ^ shifted_sign) - shifted_sign; } // ------------------------------ Shl HWY_API Vec512 operator<<(const Vec512 v, const Vec512 bits) { return Vec512{_mm512_sllv_epi16(v.raw, bits.raw)}; } HWY_API Vec512 operator<<(const Vec512 v, const Vec512 bits) { return Vec512{_mm512_sllv_epi32(v.raw, bits.raw)}; } HWY_API Vec512 operator<<(const Vec512 v, const Vec512 bits) { return Vec512{_mm512_sllv_epi64(v.raw, bits.raw)}; } // Signed left shift is the same as unsigned. template HWY_API Vec512 operator<<(const Vec512 v, const Vec512 bits) { const Full512 di; const Full512> du; return BitCast(di, BitCast(du, v) << BitCast(du, bits)); } // ------------------------------ Shr HWY_API Vec512 operator>>(const Vec512 v, const Vec512 bits) { return Vec512{_mm512_srlv_epi16(v.raw, bits.raw)}; } HWY_API Vec512 operator>>(const Vec512 v, const Vec512 bits) { return Vec512{_mm512_srlv_epi32(v.raw, bits.raw)}; } HWY_API Vec512 operator>>(const Vec512 v, const Vec512 bits) { return Vec512{_mm512_srlv_epi64(v.raw, bits.raw)}; } HWY_API Vec512 operator>>(const Vec512 v, const Vec512 bits) { return Vec512{_mm512_srav_epi16(v.raw, bits.raw)}; } HWY_API Vec512 operator>>(const Vec512 v, const Vec512 bits) { return Vec512{_mm512_srav_epi32(v.raw, bits.raw)}; } HWY_API Vec512 operator>>(const Vec512 v, const Vec512 bits) { return Vec512{_mm512_srav_epi64(v.raw, bits.raw)}; } // ------------------------------ Minimum // Unsigned HWY_API Vec512 Min(const Vec512 a, const Vec512 b) { return Vec512{_mm512_min_epu8(a.raw, b.raw)}; } HWY_API Vec512 Min(const Vec512 a, const Vec512 b) { return Vec512{_mm512_min_epu16(a.raw, b.raw)}; } HWY_API Vec512 Min(const Vec512 a, const Vec512 b) { return Vec512{_mm512_min_epu32(a.raw, b.raw)}; } HWY_API Vec512 Min(const Vec512 a, const Vec512 b) { return Vec512{_mm512_min_epu64(a.raw, b.raw)}; } // Signed HWY_API Vec512 Min(const Vec512 a, const Vec512 b) { return Vec512{_mm512_min_epi8(a.raw, b.raw)}; } HWY_API Vec512 Min(const Vec512 a, const Vec512 b) { return Vec512{_mm512_min_epi16(a.raw, b.raw)}; } HWY_API Vec512 Min(const Vec512 a, const Vec512 b) { return Vec512{_mm512_min_epi32(a.raw, b.raw)}; } HWY_API Vec512 Min(const Vec512 a, const Vec512 b) { return Vec512{_mm512_min_epi64(a.raw, b.raw)}; } // Float HWY_API Vec512 Min(const Vec512 a, const Vec512 b) { return Vec512{_mm512_min_ps(a.raw, b.raw)}; } HWY_API Vec512 Min(const Vec512 a, const Vec512 b) { return Vec512{_mm512_min_pd(a.raw, b.raw)}; } // ------------------------------ Maximum // Unsigned HWY_API Vec512 Max(const Vec512 a, const Vec512 b) { return Vec512{_mm512_max_epu8(a.raw, b.raw)}; } HWY_API Vec512 Max(const Vec512 a, const Vec512 b) { return Vec512{_mm512_max_epu16(a.raw, b.raw)}; } HWY_API Vec512 Max(const Vec512 a, const Vec512 b) { return Vec512{_mm512_max_epu32(a.raw, b.raw)}; } HWY_API Vec512 Max(const Vec512 a, const Vec512 b) { return Vec512{_mm512_max_epu64(a.raw, b.raw)}; } // Signed HWY_API Vec512 Max(const Vec512 a, const Vec512 b) { return Vec512{_mm512_max_epi8(a.raw, b.raw)}; } HWY_API Vec512 Max(const Vec512 a, const Vec512 b) { return Vec512{_mm512_max_epi16(a.raw, b.raw)}; } HWY_API Vec512 Max(const Vec512 a, const Vec512 b) { return Vec512{_mm512_max_epi32(a.raw, b.raw)}; } HWY_API Vec512 Max(const Vec512 a, const Vec512 b) { return Vec512{_mm512_max_epi64(a.raw, b.raw)}; } // Float HWY_API Vec512 Max(const Vec512 a, const Vec512 b) { return Vec512{_mm512_max_ps(a.raw, b.raw)}; } HWY_API Vec512 Max(const Vec512 a, const Vec512 b) { return Vec512{_mm512_max_pd(a.raw, b.raw)}; } // ------------------------------ Integer multiplication // Unsigned HWY_API Vec512 operator*(const Vec512 a, const Vec512 b) { return Vec512{_mm512_mullo_epi16(a.raw, b.raw)}; } HWY_API Vec512 operator*(const Vec512 a, const Vec512 b) { return Vec512{_mm512_mullo_epi32(a.raw, b.raw)}; } // Signed HWY_API Vec512 operator*(const Vec512 a, const Vec512 b) { return Vec512{_mm512_mullo_epi16(a.raw, b.raw)}; } HWY_API Vec512 operator*(const Vec512 a, const Vec512 b) { return Vec512{_mm512_mullo_epi32(a.raw, b.raw)}; } // Returns the upper 16 bits of a * b in each lane. HWY_API Vec512 MulHigh(const Vec512 a, const Vec512 b) { return Vec512{_mm512_mulhi_epu16(a.raw, b.raw)}; } HWY_API Vec512 MulHigh(const Vec512 a, const Vec512 b) { return Vec512{_mm512_mulhi_epi16(a.raw, b.raw)}; } // Multiplies even lanes (0, 2 ..) and places the double-wide result into // even and the upper half into its odd neighbor lane. HWY_API Vec512 MulEven(const Vec512 a, const Vec512 b) { return Vec512{_mm512_mul_epi32(a.raw, b.raw)}; } HWY_API Vec512 MulEven(const Vec512 a, const Vec512 b) { return Vec512{_mm512_mul_epu32(a.raw, b.raw)}; } // ------------------------------ Neg (Sub) template HWY_API Vec512 Neg(const Vec512 v) { return Xor(v, SignBit(Full512())); } template HWY_API Vec512 Neg(const Vec512 v) { return Zero(Full512()) - v; } // ------------------------------ Floating-point mul / div HWY_API Vec512 operator*(const Vec512 a, const Vec512 b) { return Vec512{_mm512_mul_ps(a.raw, b.raw)}; } HWY_API Vec512 operator*(const Vec512 a, const Vec512 b) { return Vec512{_mm512_mul_pd(a.raw, b.raw)}; } HWY_API Vec512 operator/(const Vec512 a, const Vec512 b) { return Vec512{_mm512_div_ps(a.raw, b.raw)}; } HWY_API Vec512 operator/(const Vec512 a, const Vec512 b) { return Vec512{_mm512_div_pd(a.raw, b.raw)}; } // Approximate reciprocal HWY_API Vec512 ApproximateReciprocal(const Vec512 v) { return Vec512{_mm512_rcp14_ps(v.raw)}; } // Absolute value of difference. HWY_API Vec512 AbsDiff(const Vec512 a, const Vec512 b) { return Abs(a - b); } // ------------------------------ Floating-point multiply-add variants // Returns mul * x + add HWY_API Vec512 MulAdd(const Vec512 mul, const Vec512 x, const Vec512 add) { return Vec512{_mm512_fmadd_ps(mul.raw, x.raw, add.raw)}; } HWY_API Vec512 MulAdd(const Vec512 mul, const Vec512 x, const Vec512 add) { return Vec512{_mm512_fmadd_pd(mul.raw, x.raw, add.raw)}; } // Returns add - mul * x HWY_API Vec512 NegMulAdd(const Vec512 mul, const Vec512 x, const Vec512 add) { return Vec512{_mm512_fnmadd_ps(mul.raw, x.raw, add.raw)}; } HWY_API Vec512 NegMulAdd(const Vec512 mul, const Vec512 x, const Vec512 add) { return Vec512{_mm512_fnmadd_pd(mul.raw, x.raw, add.raw)}; } // Returns mul * x - sub HWY_API Vec512 MulSub(const Vec512 mul, const Vec512 x, const Vec512 sub) { return Vec512{_mm512_fmsub_ps(mul.raw, x.raw, sub.raw)}; } HWY_API Vec512 MulSub(const Vec512 mul, const Vec512 x, const Vec512 sub) { return Vec512{_mm512_fmsub_pd(mul.raw, x.raw, sub.raw)}; } // Returns -mul * x - sub HWY_API Vec512 NegMulSub(const Vec512 mul, const Vec512 x, const Vec512 sub) { return Vec512{_mm512_fnmsub_ps(mul.raw, x.raw, sub.raw)}; } HWY_API Vec512 NegMulSub(const Vec512 mul, const Vec512 x, const Vec512 sub) { return Vec512{_mm512_fnmsub_pd(mul.raw, x.raw, sub.raw)}; } // ------------------------------ Floating-point square root // Full precision square root HWY_API Vec512 Sqrt(const Vec512 v) { return Vec512{_mm512_sqrt_ps(v.raw)}; } HWY_API Vec512 Sqrt(const Vec512 v) { return Vec512{_mm512_sqrt_pd(v.raw)}; } // Approximate reciprocal square root HWY_API Vec512 ApproximateReciprocalSqrt(const Vec512 v) { return Vec512{_mm512_rsqrt14_ps(v.raw)}; } // ------------------------------ Floating-point rounding // Work around warnings in the intrinsic definitions (passing -1 as a mask). HWY_DIAGNOSTICS(push) HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") // Toward nearest integer, tie to even HWY_API Vec512 Round(const Vec512 v) { return Vec512{_mm512_roundscale_ps( v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; } HWY_API Vec512 Round(const Vec512 v) { return Vec512{_mm512_roundscale_pd( v.raw, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)}; } // Toward zero, aka truncate HWY_API Vec512 Trunc(const Vec512 v) { return Vec512{ _mm512_roundscale_ps(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; } HWY_API Vec512 Trunc(const Vec512 v) { return Vec512{ _mm512_roundscale_pd(v.raw, _MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC)}; } // Toward +infinity, aka ceiling HWY_API Vec512 Ceil(const Vec512 v) { return Vec512{ _mm512_roundscale_ps(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; } HWY_API Vec512 Ceil(const Vec512 v) { return Vec512{ _mm512_roundscale_pd(v.raw, _MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC)}; } // Toward -infinity, aka floor HWY_API Vec512 Floor(const Vec512 v) { return Vec512{ _mm512_roundscale_ps(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; } HWY_API Vec512 Floor(const Vec512 v) { return Vec512{ _mm512_roundscale_pd(v.raw, _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)}; } HWY_DIAGNOSTICS(pop) // ================================================== COMPARE // Comparisons set a mask bit to 1 if the condition is true, else 0. template HWY_API Mask512 RebindMask(Full512 /*tag*/, Mask512 m) { static_assert(sizeof(TFrom) == sizeof(TTo), "Must have same size"); return Mask512{m.raw}; } namespace detail { template HWY_INLINE Mask512 TestBit(hwy::SizeTag<1> /*tag*/, const Vec512 v, const Vec512 bit) { return Mask512{_mm512_test_epi8_mask(v.raw, bit.raw)}; } template HWY_INLINE Mask512 TestBit(hwy::SizeTag<2> /*tag*/, const Vec512 v, const Vec512 bit) { return Mask512{_mm512_test_epi16_mask(v.raw, bit.raw)}; } template HWY_INLINE Mask512 TestBit(hwy::SizeTag<4> /*tag*/, const Vec512 v, const Vec512 bit) { return Mask512{_mm512_test_epi32_mask(v.raw, bit.raw)}; } template HWY_INLINE Mask512 TestBit(hwy::SizeTag<8> /*tag*/, const Vec512 v, const Vec512 bit) { return Mask512{_mm512_test_epi64_mask(v.raw, bit.raw)}; } } // namespace detail template HWY_API Mask512 TestBit(const Vec512 v, const Vec512 bit) { static_assert(!hwy::IsFloat(), "Only integer vectors supported"); return detail::TestBit(hwy::SizeTag(), v, bit); } // ------------------------------ Equality template HWY_API Mask512 operator==(Vec512 a, Vec512 b) { return Mask512{_mm512_cmpeq_epi8_mask(a.raw, b.raw)}; } template HWY_API Mask512 operator==(Vec512 a, Vec512 b) { return Mask512{_mm512_cmpeq_epi16_mask(a.raw, b.raw)}; } template HWY_API Mask512 operator==(Vec512 a, Vec512 b) { return Mask512{_mm512_cmpeq_epi32_mask(a.raw, b.raw)}; } template HWY_API Mask512 operator==(Vec512 a, Vec512 b) { return Mask512{_mm512_cmpeq_epi64_mask(a.raw, b.raw)}; } HWY_API Mask512 operator==(Vec512 a, Vec512 b) { return Mask512{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_EQ_OQ)}; } HWY_API Mask512 operator==(Vec512 a, Vec512 b) { return Mask512{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_EQ_OQ)}; } // ------------------------------ Inequality template HWY_API Mask512 operator!=(Vec512 a, Vec512 b) { return Mask512{_mm512_cmpneq_epi8_mask(a.raw, b.raw)}; } template HWY_API Mask512 operator!=(Vec512 a, Vec512 b) { return Mask512{_mm512_cmpneq_epi16_mask(a.raw, b.raw)}; } template HWY_API Mask512 operator!=(Vec512 a, Vec512 b) { return Mask512{_mm512_cmpneq_epi32_mask(a.raw, b.raw)}; } template HWY_API Mask512 operator!=(Vec512 a, Vec512 b) { return Mask512{_mm512_cmpneq_epi64_mask(a.raw, b.raw)}; } HWY_API Mask512 operator!=(Vec512 a, Vec512 b) { return Mask512{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; } HWY_API Mask512 operator!=(Vec512 a, Vec512 b) { return Mask512{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_NEQ_OQ)}; } // ------------------------------ Strict inequality HWY_API Mask512 operator>(Vec512 a, Vec512 b) { return Mask512{_mm512_cmpgt_epu8_mask(a.raw, b.raw)}; } HWY_API Mask512 operator>(Vec512 a, Vec512 b) { return Mask512{_mm512_cmpgt_epu16_mask(a.raw, b.raw)}; } HWY_API Mask512 operator>(Vec512 a, Vec512 b) { return Mask512{_mm512_cmpgt_epu32_mask(a.raw, b.raw)}; } HWY_API Mask512 operator>(Vec512 a, Vec512 b) { return Mask512{_mm512_cmpgt_epu64_mask(a.raw, b.raw)}; } HWY_API Mask512 operator>(Vec512 a, Vec512 b) { return Mask512{_mm512_cmpgt_epi8_mask(a.raw, b.raw)}; } HWY_API Mask512 operator>(Vec512 a, Vec512 b) { return Mask512{_mm512_cmpgt_epi16_mask(a.raw, b.raw)}; } HWY_API Mask512 operator>(Vec512 a, Vec512 b) { return Mask512{_mm512_cmpgt_epi32_mask(a.raw, b.raw)}; } HWY_API Mask512 operator>(Vec512 a, Vec512 b) { return Mask512{_mm512_cmpgt_epi64_mask(a.raw, b.raw)}; } HWY_API Mask512 operator>(Vec512 a, Vec512 b) { return Mask512{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_GT_OQ)}; } HWY_API Mask512 operator>(Vec512 a, Vec512 b) { return Mask512{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_GT_OQ)}; } // ------------------------------ Weak inequality HWY_API Mask512 operator>=(Vec512 a, Vec512 b) { return Mask512{_mm512_cmp_ps_mask(a.raw, b.raw, _CMP_GE_OQ)}; } HWY_API Mask512 operator>=(Vec512 a, Vec512 b) { return Mask512{_mm512_cmp_pd_mask(a.raw, b.raw, _CMP_GE_OQ)}; } // ------------------------------ Reversed comparisons template HWY_API Mask512 operator<(Vec512 a, Vec512 b) { return b > a; } template HWY_API Mask512 operator<=(Vec512 a, Vec512 b) { return b >= a; } // ------------------------------ Mask namespace detail { template HWY_INLINE Mask512 MaskFromVec(hwy::SizeTag<1> /*tag*/, const Vec512 v) { return Mask512{_mm512_movepi8_mask(v.raw)}; } template HWY_INLINE Mask512 MaskFromVec(hwy::SizeTag<2> /*tag*/, const Vec512 v) { return Mask512{_mm512_movepi16_mask(v.raw)}; } template HWY_INLINE Mask512 MaskFromVec(hwy::SizeTag<4> /*tag*/, const Vec512 v) { return Mask512{_mm512_movepi32_mask(v.raw)}; } template HWY_INLINE Mask512 MaskFromVec(hwy::SizeTag<8> /*tag*/, const Vec512 v) { return Mask512{_mm512_movepi64_mask(v.raw)}; } } // namespace detail template HWY_API Mask512 MaskFromVec(const Vec512 v) { return detail::MaskFromVec(hwy::SizeTag(), v); } // There do not seem to be native floating-point versions of these instructions. HWY_API Mask512 MaskFromVec(const Vec512 v) { return Mask512{MaskFromVec(BitCast(Full512(), v)).raw}; } HWY_API Mask512 MaskFromVec(const Vec512 v) { return Mask512{MaskFromVec(BitCast(Full512(), v)).raw}; } HWY_API Vec512 VecFromMask(const Mask512 v) { return Vec512{_mm512_movm_epi8(v.raw)}; } HWY_API Vec512 VecFromMask(const Mask512 v) { return Vec512{_mm512_movm_epi8(v.raw)}; } HWY_API Vec512 VecFromMask(const Mask512 v) { return Vec512{_mm512_movm_epi16(v.raw)}; } HWY_API Vec512 VecFromMask(const Mask512 v) { return Vec512{_mm512_movm_epi16(v.raw)}; } HWY_API Vec512 VecFromMask(const Mask512 v) { return Vec512{_mm512_movm_epi32(v.raw)}; } HWY_API Vec512 VecFromMask(const Mask512 v) { return Vec512{_mm512_movm_epi32(v.raw)}; } HWY_API Vec512 VecFromMask(const Mask512 v) { return Vec512{_mm512_castsi512_ps(_mm512_movm_epi32(v.raw))}; } HWY_API Vec512 VecFromMask(const Mask512 v) { return Vec512{_mm512_movm_epi64(v.raw)}; } HWY_API Vec512 VecFromMask(const Mask512 v) { return Vec512{_mm512_movm_epi64(v.raw)}; } HWY_API Vec512 VecFromMask(const Mask512 v) { return Vec512{_mm512_castsi512_pd(_mm512_movm_epi64(v.raw))}; } template HWY_API Vec512 VecFromMask(Full512 /* tag */, const Mask512 v) { return VecFromMask(v); } // ------------------------------ Mask logical namespace detail { template HWY_INLINE Mask512 Not(hwy::SizeTag<1> /*tag*/, const Mask512 m) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask512{_knot_mask64(m.raw)}; #else return Mask512{~m.raw}; #endif } template HWY_INLINE Mask512 Not(hwy::SizeTag<2> /*tag*/, const Mask512 m) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask512{_knot_mask32(m.raw)}; #else return Mask512{~m.raw}; #endif } template HWY_INLINE Mask512 Not(hwy::SizeTag<4> /*tag*/, const Mask512 m) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask512{_knot_mask16(m.raw)}; #else return Mask512{static_cast(~m.raw & 0xFFFF)}; #endif } template HWY_INLINE Mask512 Not(hwy::SizeTag<8> /*tag*/, const Mask512 m) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask512{_knot_mask8(m.raw)}; #else return Mask512{static_cast(~m.raw & 0xFF)}; #endif } template HWY_INLINE Mask512 And(hwy::SizeTag<1> /*tag*/, const Mask512 a, const Mask512 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask512{_kand_mask64(a.raw, b.raw)}; #else return Mask512{a.raw & b.raw}; #endif } template HWY_INLINE Mask512 And(hwy::SizeTag<2> /*tag*/, const Mask512 a, const Mask512 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask512{_kand_mask32(a.raw, b.raw)}; #else return Mask512{a.raw & b.raw}; #endif } template HWY_INLINE Mask512 And(hwy::SizeTag<4> /*tag*/, const Mask512 a, const Mask512 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask512{_kand_mask16(a.raw, b.raw)}; #else return Mask512{static_cast(a.raw & b.raw)}; #endif } template HWY_INLINE Mask512 And(hwy::SizeTag<8> /*tag*/, const Mask512 a, const Mask512 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask512{_kand_mask8(a.raw, b.raw)}; #else return Mask512{static_cast(a.raw & b.raw)}; #endif } template HWY_INLINE Mask512 AndNot(hwy::SizeTag<1> /*tag*/, const Mask512 a, const Mask512 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask512{_kandn_mask64(a.raw, b.raw)}; #else return Mask512{~a.raw & b.raw}; #endif } template HWY_INLINE Mask512 AndNot(hwy::SizeTag<2> /*tag*/, const Mask512 a, const Mask512 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask512{_kandn_mask32(a.raw, b.raw)}; #else return Mask512{~a.raw & b.raw}; #endif } template HWY_INLINE Mask512 AndNot(hwy::SizeTag<4> /*tag*/, const Mask512 a, const Mask512 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask512{_kandn_mask16(a.raw, b.raw)}; #else return Mask512{static_cast(~a.raw & b.raw)}; #endif } template HWY_INLINE Mask512 AndNot(hwy::SizeTag<8> /*tag*/, const Mask512 a, const Mask512 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask512{_kandn_mask8(a.raw, b.raw)}; #else return Mask512{static_cast(~a.raw & b.raw)}; #endif } template HWY_INLINE Mask512 Or(hwy::SizeTag<1> /*tag*/, const Mask512 a, const Mask512 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask512{_kor_mask64(a.raw, b.raw)}; #else return Mask512{a.raw | b.raw}; #endif } template HWY_INLINE Mask512 Or(hwy::SizeTag<2> /*tag*/, const Mask512 a, const Mask512 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask512{_kor_mask32(a.raw, b.raw)}; #else return Mask512{a.raw | b.raw}; #endif } template HWY_INLINE Mask512 Or(hwy::SizeTag<4> /*tag*/, const Mask512 a, const Mask512 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask512{_kor_mask16(a.raw, b.raw)}; #else return Mask512{static_cast(a.raw | b.raw)}; #endif } template HWY_INLINE Mask512 Or(hwy::SizeTag<8> /*tag*/, const Mask512 a, const Mask512 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask512{_kor_mask8(a.raw, b.raw)}; #else return Mask512{static_cast(a.raw | b.raw)}; #endif } template HWY_INLINE Mask512 Xor(hwy::SizeTag<1> /*tag*/, const Mask512 a, const Mask512 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask512{_kxor_mask64(a.raw, b.raw)}; #else return Mask512{a.raw ^ b.raw}; #endif } template HWY_INLINE Mask512 Xor(hwy::SizeTag<2> /*tag*/, const Mask512 a, const Mask512 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask512{_kxor_mask32(a.raw, b.raw)}; #else return Mask512{a.raw ^ b.raw}; #endif } template HWY_INLINE Mask512 Xor(hwy::SizeTag<4> /*tag*/, const Mask512 a, const Mask512 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask512{_kxor_mask16(a.raw, b.raw)}; #else return Mask512{static_cast(a.raw ^ b.raw)}; #endif } template HWY_INLINE Mask512 Xor(hwy::SizeTag<8> /*tag*/, const Mask512 a, const Mask512 b) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return Mask512{_kxor_mask8(a.raw, b.raw)}; #else return Mask512{static_cast(a.raw ^ b.raw)}; #endif } } // namespace detail template HWY_API Mask512 Not(const Mask512 m) { return detail::Not(hwy::SizeTag(), m); } template HWY_API Mask512 And(const Mask512 a, Mask512 b) { return detail::And(hwy::SizeTag(), a, b); } template HWY_API Mask512 AndNot(const Mask512 a, Mask512 b) { return detail::AndNot(hwy::SizeTag(), a, b); } template HWY_API Mask512 Or(const Mask512 a, Mask512 b) { return detail::Or(hwy::SizeTag(), a, b); } template HWY_API Mask512 Xor(const Mask512 a, Mask512 b) { return detail::Xor(hwy::SizeTag(), a, b); } // ------------------------------ BroadcastSignBit (ShiftRight, compare, mask) HWY_API Vec512 BroadcastSignBit(const Vec512 v) { return VecFromMask(v < Zero(Full512())); } HWY_API Vec512 BroadcastSignBit(const Vec512 v) { return ShiftRight<15>(v); } HWY_API Vec512 BroadcastSignBit(const Vec512 v) { return ShiftRight<31>(v); } HWY_API Vec512 BroadcastSignBit(const Vec512 v) { return Vec512{_mm512_srai_epi64(v.raw, 63)}; } // ================================================== MEMORY // ------------------------------ Load template HWY_API Vec512 Load(Full512 /* tag */, const T* HWY_RESTRICT aligned) { return Vec512{_mm512_load_si512(aligned)}; } HWY_API Vec512 Load(Full512 /* tag */, const float* HWY_RESTRICT aligned) { return Vec512{_mm512_load_ps(aligned)}; } HWY_API Vec512 Load(Full512 /* tag */, const double* HWY_RESTRICT aligned) { return Vec512{_mm512_load_pd(aligned)}; } template HWY_API Vec512 LoadU(Full512 /* tag */, const T* HWY_RESTRICT p) { return Vec512{_mm512_loadu_si512(p)}; } HWY_API Vec512 LoadU(Full512 /* tag */, const float* HWY_RESTRICT p) { return Vec512{_mm512_loadu_ps(p)}; } HWY_API Vec512 LoadU(Full512 /* tag */, const double* HWY_RESTRICT p) { return Vec512{_mm512_loadu_pd(p)}; } // ------------------------------ MaskedLoad template HWY_API Vec512 MaskedLoad(Mask512 m, Full512 /* tag */, const T* HWY_RESTRICT aligned) { return Vec512{_mm512_maskz_load_epi32(m.raw, aligned)}; } template HWY_API Vec512 MaskedLoad(Mask512 m, Full512 /* tag */, const T* HWY_RESTRICT aligned) { return Vec512{_mm512_maskz_load_epi64(m.raw, aligned)}; } HWY_API Vec512 MaskedLoad(Mask512 m, Full512 /* tag */, const float* HWY_RESTRICT aligned) { return Vec512{_mm512_maskz_load_ps(m.raw, aligned)}; } HWY_API Vec512 MaskedLoad(Mask512 m, Full512 /* tag */, const double* HWY_RESTRICT aligned) { return Vec512{_mm512_maskz_load_pd(m.raw, aligned)}; } // There is no load_epi8/16, so use loadu instead. template HWY_API Vec512 MaskedLoad(Mask512 m, Full512 /* tag */, const T* HWY_RESTRICT aligned) { return Vec512{_mm512_maskz_loadu_epi8(m.raw, aligned)}; } template HWY_API Vec512 MaskedLoad(Mask512 m, Full512 /* tag */, const T* HWY_RESTRICT aligned) { return Vec512{_mm512_maskz_loadu_epi16(m.raw, aligned)}; } // ------------------------------ LoadDup128 // Loads 128 bit and duplicates into both 128-bit halves. This avoids the // 3-cycle cost of moving data between 128-bit halves and avoids port 5. template HWY_API Vec512 LoadDup128(Full512 /* tag */, const T* const HWY_RESTRICT p) { // Clang 3.9 generates VINSERTF128 which is slower, but inline assembly leads // to "invalid output size for constraint" without -mavx512: // https://gcc.godbolt.org/z/-Jt_-F #if HWY_LOADDUP_ASM __m512i out; asm("vbroadcasti128 %1, %[reg]" : [reg] "=x"(out) : "m"(p[0])); return Vec512{out}; #else const auto x4 = LoadU(Full128(), p); return Vec512{_mm512_broadcast_i32x4(x4.raw)}; #endif } HWY_API Vec512 LoadDup128(Full512 /* tag */, const float* const HWY_RESTRICT p) { #if HWY_LOADDUP_ASM __m512 out; asm("vbroadcastf128 %1, %[reg]" : [reg] "=x"(out) : "m"(p[0])); return Vec512{out}; #else const __m128 x4 = _mm_loadu_ps(p); return Vec512{_mm512_broadcast_f32x4(x4)}; #endif } HWY_API Vec512 LoadDup128(Full512 /* tag */, const double* const HWY_RESTRICT p) { #if HWY_LOADDUP_ASM __m512d out; asm("vbroadcastf128 %1, %[reg]" : [reg] "=x"(out) : "m"(p[0])); return Vec512{out}; #else const __m128d x2 = _mm_loadu_pd(p); return Vec512{_mm512_broadcast_f64x2(x2)}; #endif } // ------------------------------ Store template HWY_API void Store(const Vec512 v, Full512 /* tag */, T* HWY_RESTRICT aligned) { _mm512_store_si512(reinterpret_cast<__m512i*>(aligned), v.raw); } HWY_API void Store(const Vec512 v, Full512 /* tag */, float* HWY_RESTRICT aligned) { _mm512_store_ps(aligned, v.raw); } HWY_API void Store(const Vec512 v, Full512 /* tag */, double* HWY_RESTRICT aligned) { _mm512_store_pd(aligned, v.raw); } template HWY_API void StoreU(const Vec512 v, Full512 /* tag */, T* HWY_RESTRICT p) { _mm512_storeu_si512(reinterpret_cast<__m512i*>(p), v.raw); } HWY_API void StoreU(const Vec512 v, Full512 /* tag */, float* HWY_RESTRICT p) { _mm512_storeu_ps(p, v.raw); } HWY_API void StoreU(const Vec512 v, Full512, double* HWY_RESTRICT p) { _mm512_storeu_pd(p, v.raw); } // ------------------------------ Non-temporal stores template HWY_API void Stream(const Vec512 v, Full512 /* tag */, T* HWY_RESTRICT aligned) { _mm512_stream_si512(reinterpret_cast<__m512i*>(aligned), v.raw); } HWY_API void Stream(const Vec512 v, Full512 /* tag */, float* HWY_RESTRICT aligned) { _mm512_stream_ps(aligned, v.raw); } HWY_API void Stream(const Vec512 v, Full512, double* HWY_RESTRICT aligned) { _mm512_stream_pd(aligned, v.raw); } // ------------------------------ Scatter // Work around warnings in the intrinsic definitions (passing -1 as a mask). HWY_DIAGNOSTICS(push) HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") namespace detail { template HWY_INLINE void ScatterOffset(hwy::SizeTag<4> /* tag */, Vec512 v, Full512 /* tag */, T* HWY_RESTRICT base, const Vec512 offset) { _mm512_i32scatter_epi32(base, offset.raw, v.raw, 1); } template HWY_INLINE void ScatterIndex(hwy::SizeTag<4> /* tag */, Vec512 v, Full512 /* tag */, T* HWY_RESTRICT base, const Vec512 index) { _mm512_i32scatter_epi32(base, index.raw, v.raw, 4); } template HWY_INLINE void ScatterOffset(hwy::SizeTag<8> /* tag */, Vec512 v, Full512 /* tag */, T* HWY_RESTRICT base, const Vec512 offset) { _mm512_i64scatter_epi64(base, offset.raw, v.raw, 1); } template HWY_INLINE void ScatterIndex(hwy::SizeTag<8> /* tag */, Vec512 v, Full512 /* tag */, T* HWY_RESTRICT base, const Vec512 index) { _mm512_i64scatter_epi64(base, index.raw, v.raw, 8); } } // namespace detail template HWY_API void ScatterOffset(Vec512 v, Full512 d, T* HWY_RESTRICT base, const Vec512 offset) { static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); return detail::ScatterOffset(hwy::SizeTag(), v, d, base, offset); } template HWY_API void ScatterIndex(Vec512 v, Full512 d, T* HWY_RESTRICT base, const Vec512 index) { static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); return detail::ScatterIndex(hwy::SizeTag(), v, d, base, index); } HWY_API void ScatterOffset(Vec512 v, Full512 /* tag */, float* HWY_RESTRICT base, const Vec512 offset) { _mm512_i32scatter_ps(base, offset.raw, v.raw, 1); } HWY_API void ScatterIndex(Vec512 v, Full512 /* tag */, float* HWY_RESTRICT base, const Vec512 index) { _mm512_i32scatter_ps(base, index.raw, v.raw, 4); } HWY_API void ScatterOffset(Vec512 v, Full512 /* tag */, double* HWY_RESTRICT base, const Vec512 offset) { _mm512_i64scatter_pd(base, offset.raw, v.raw, 1); } HWY_API void ScatterIndex(Vec512 v, Full512 /* tag */, double* HWY_RESTRICT base, const Vec512 index) { _mm512_i64scatter_pd(base, index.raw, v.raw, 8); } // ------------------------------ Gather namespace detail { template HWY_INLINE Vec512 GatherOffset(hwy::SizeTag<4> /* tag */, Full512 /* tag */, const T* HWY_RESTRICT base, const Vec512 offset) { return Vec512{_mm512_i32gather_epi32(offset.raw, base, 1)}; } template HWY_INLINE Vec512 GatherIndex(hwy::SizeTag<4> /* tag */, Full512 /* tag */, const T* HWY_RESTRICT base, const Vec512 index) { return Vec512{_mm512_i32gather_epi32(index.raw, base, 4)}; } template HWY_INLINE Vec512 GatherOffset(hwy::SizeTag<8> /* tag */, Full512 /* tag */, const T* HWY_RESTRICT base, const Vec512 offset) { return Vec512{_mm512_i64gather_epi64(offset.raw, base, 1)}; } template HWY_INLINE Vec512 GatherIndex(hwy::SizeTag<8> /* tag */, Full512 /* tag */, const T* HWY_RESTRICT base, const Vec512 index) { return Vec512{_mm512_i64gather_epi64(index.raw, base, 8)}; } } // namespace detail template HWY_API Vec512 GatherOffset(Full512 d, const T* HWY_RESTRICT base, const Vec512 offset) { static_assert(sizeof(T) == sizeof(Offset), "Must match for portability"); return detail::GatherOffset(hwy::SizeTag(), d, base, offset); } template HWY_API Vec512 GatherIndex(Full512 d, const T* HWY_RESTRICT base, const Vec512 index) { static_assert(sizeof(T) == sizeof(Index), "Must match for portability"); return detail::GatherIndex(hwy::SizeTag(), d, base, index); } HWY_API Vec512 GatherOffset(Full512 /* tag */, const float* HWY_RESTRICT base, const Vec512 offset) { return Vec512{_mm512_i32gather_ps(offset.raw, base, 1)}; } HWY_API Vec512 GatherIndex(Full512 /* tag */, const float* HWY_RESTRICT base, const Vec512 index) { return Vec512{_mm512_i32gather_ps(index.raw, base, 4)}; } HWY_API Vec512 GatherOffset(Full512 /* tag */, const double* HWY_RESTRICT base, const Vec512 offset) { return Vec512{_mm512_i64gather_pd(offset.raw, base, 1)}; } HWY_API Vec512 GatherIndex(Full512 /* tag */, const double* HWY_RESTRICT base, const Vec512 index) { return Vec512{_mm512_i64gather_pd(index.raw, base, 8)}; } HWY_DIAGNOSTICS(pop) // ================================================== SWIZZLE // ------------------------------ LowerHalf template HWY_API Vec256 LowerHalf(Full256 /* tag */, Vec512 v) { return Vec256{_mm512_castsi512_si256(v.raw)}; } HWY_API Vec256 LowerHalf(Full256 /* tag */, Vec512 v) { return Vec256{_mm512_castps512_ps256(v.raw)}; } HWY_API Vec256 LowerHalf(Full256 /* tag */, Vec512 v) { return Vec256{_mm512_castpd512_pd256(v.raw)}; } template HWY_API Vec256 LowerHalf(Vec512 v) { return LowerHalf(Full256(), v); } // ------------------------------ UpperHalf template HWY_API Vec256 UpperHalf(Full256 /* tag */, Vec512 v) { return Vec256{_mm512_extracti32x8_epi32(v.raw, 1)}; } HWY_API Vec256 UpperHalf(Full256 /* tag */, Vec512 v) { return Vec256{_mm512_extractf32x8_ps(v.raw, 1)}; } HWY_API Vec256 UpperHalf(Full256 /* tag */, Vec512 v) { return Vec256{_mm512_extractf64x4_pd(v.raw, 1)}; } // ------------------------------ GetLane (LowerHalf) template HWY_API T GetLane(const Vec512 v) { return GetLane(LowerHalf(v)); } // ------------------------------ ZeroExtendVector // Unfortunately the initial _mm512_castsi256_si512 intrinsic leaves the upper // bits undefined. Although it makes sense for them to be zero (EVEX encoded // instructions have that effect), a compiler could decide to optimize out code // that relies on this. // // The newer _mm512_zextsi256_si512 intrinsic fixes this by specifying the // zeroing, but it is not available on GCC until 10.1. For older GCC, we can // still obtain the desired code thanks to pattern recognition; note that the // expensive insert instruction is not actually generated, see // https://gcc.godbolt.org/z/1MKGaP. template HWY_API Vec512 ZeroExtendVector(Full512 /* tag */, Vec256 lo) { #if !HWY_COMPILER_CLANG && HWY_COMPILER_GCC && (HWY_COMPILER_GCC < 1000) return Vec512{_mm512_inserti32x8(_mm512_setzero_si512(), lo.raw, 0)}; #else return Vec512{_mm512_zextsi256_si512(lo.raw)}; #endif } HWY_API Vec512 ZeroExtendVector(Full512 /* tag */, Vec256 lo) { #if !HWY_COMPILER_CLANG && HWY_COMPILER_GCC && (HWY_COMPILER_GCC < 1000) return Vec512{_mm512_insertf32x8(_mm512_setzero_ps(), lo.raw, 0)}; #else return Vec512{_mm512_zextps256_ps512(lo.raw)}; #endif } HWY_API Vec512 ZeroExtendVector(Full512 /* tag */, Vec256 lo) { #if !HWY_COMPILER_CLANG && HWY_COMPILER_GCC && (HWY_COMPILER_GCC < 1000) return Vec512{_mm512_insertf64x4(_mm512_setzero_pd(), lo.raw, 0)}; #else return Vec512{_mm512_zextpd256_pd512(lo.raw)}; #endif } // ------------------------------ Combine template HWY_API Vec512 Combine(Full512 d, Vec256 hi, Vec256 lo) { const auto lo512 = ZeroExtendVector(d, lo); return Vec512{_mm512_inserti32x8(lo512.raw, hi.raw, 1)}; } HWY_API Vec512 Combine(Full512 d, Vec256 hi, Vec256 lo) { const auto lo512 = ZeroExtendVector(d, lo); return Vec512{_mm512_insertf32x8(lo512.raw, hi.raw, 1)}; } HWY_API Vec512 Combine(Full512 d, Vec256 hi, Vec256 lo) { const auto lo512 = ZeroExtendVector(d, lo); return Vec512{_mm512_insertf64x4(lo512.raw, hi.raw, 1)}; } // ------------------------------ ShiftLeftBytes template HWY_API Vec512 ShiftLeftBytes(Full512 /* tag */, const Vec512 v) { static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); return Vec512{_mm512_bslli_epi128(v.raw, kBytes)}; } template HWY_API Vec512 ShiftLeftBytes(const Vec512 v) { return ShiftLeftBytes(Full512(), v); } // ------------------------------ ShiftLeftLanes template HWY_API Vec512 ShiftLeftLanes(Full512 d, const Vec512 v) { const Repartition d8; return BitCast(d, ShiftLeftBytes(BitCast(d8, v))); } template HWY_API Vec512 ShiftLeftLanes(const Vec512 v) { return ShiftLeftLanes(Full512(), v); } // ------------------------------ ShiftRightBytes template HWY_API Vec512 ShiftRightBytes(Full512 /* tag */, const Vec512 v) { static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes"); return Vec512{_mm512_bsrli_epi128(v.raw, kBytes)}; } // ------------------------------ ShiftRightLanes template HWY_API Vec512 ShiftRightLanes(Full512 d, const Vec512 v) { const Repartition d8; return BitCast(d, ShiftRightBytes(d8, BitCast(d8, v))); } // ------------------------------ CombineShiftRightBytes template > HWY_API V CombineShiftRightBytes(Full512 d, V hi, V lo) { const Repartition d8; return BitCast(d, Vec512{_mm512_alignr_epi8( BitCast(d8, hi).raw, BitCast(d8, lo).raw, kBytes)}); } // ------------------------------ Broadcast/splat any lane // Unsigned template HWY_API Vec512 Broadcast(const Vec512 v) { static_assert(0 <= kLane && kLane < 8, "Invalid lane"); if (kLane < 4) { const __m512i lo = _mm512_shufflelo_epi16(v.raw, (0x55 * kLane) & 0xFF); return Vec512{_mm512_unpacklo_epi64(lo, lo)}; } else { const __m512i hi = _mm512_shufflehi_epi16(v.raw, (0x55 * (kLane - 4)) & 0xFF); return Vec512{_mm512_unpackhi_epi64(hi, hi)}; } } template HWY_API Vec512 Broadcast(const Vec512 v) { static_assert(0 <= kLane && kLane < 4, "Invalid lane"); constexpr _MM_PERM_ENUM perm = static_cast<_MM_PERM_ENUM>(0x55 * kLane); return Vec512{_mm512_shuffle_epi32(v.raw, perm)}; } template HWY_API Vec512 Broadcast(const Vec512 v) { static_assert(0 <= kLane && kLane < 2, "Invalid lane"); constexpr _MM_PERM_ENUM perm = kLane ? _MM_PERM_DCDC : _MM_PERM_BABA; return Vec512{_mm512_shuffle_epi32(v.raw, perm)}; } // Signed template HWY_API Vec512 Broadcast(const Vec512 v) { static_assert(0 <= kLane && kLane < 8, "Invalid lane"); if (kLane < 4) { const __m512i lo = _mm512_shufflelo_epi16(v.raw, (0x55 * kLane) & 0xFF); return Vec512{_mm512_unpacklo_epi64(lo, lo)}; } else { const __m512i hi = _mm512_shufflehi_epi16(v.raw, (0x55 * (kLane - 4)) & 0xFF); return Vec512{_mm512_unpackhi_epi64(hi, hi)}; } } template HWY_API Vec512 Broadcast(const Vec512 v) { static_assert(0 <= kLane && kLane < 4, "Invalid lane"); constexpr _MM_PERM_ENUM perm = static_cast<_MM_PERM_ENUM>(0x55 * kLane); return Vec512{_mm512_shuffle_epi32(v.raw, perm)}; } template HWY_API Vec512 Broadcast(const Vec512 v) { static_assert(0 <= kLane && kLane < 2, "Invalid lane"); constexpr _MM_PERM_ENUM perm = kLane ? _MM_PERM_DCDC : _MM_PERM_BABA; return Vec512{_mm512_shuffle_epi32(v.raw, perm)}; } // Float template HWY_API Vec512 Broadcast(const Vec512 v) { static_assert(0 <= kLane && kLane < 4, "Invalid lane"); constexpr _MM_PERM_ENUM perm = static_cast<_MM_PERM_ENUM>(0x55 * kLane); return Vec512{_mm512_shuffle_ps(v.raw, v.raw, perm)}; } template HWY_API Vec512 Broadcast(const Vec512 v) { static_assert(0 <= kLane && kLane < 2, "Invalid lane"); constexpr _MM_PERM_ENUM perm = static_cast<_MM_PERM_ENUM>(0xFF * kLane); return Vec512{_mm512_shuffle_pd(v.raw, v.raw, perm)}; } // ------------------------------ Hard-coded shuffles // Notation: let Vec512 have lanes 7,6,5,4,3,2,1,0 (0 is // least-significant). Shuffle0321 rotates four-lane blocks one lane to the // right (the previous least-significant lane is now most-significant => // 47650321). These could also be implemented via CombineShiftRightBytes but // the shuffle_abcd notation is more convenient. // Swap 32-bit halves in 64-bit halves. HWY_API Vec512 Shuffle2301(const Vec512 v) { return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_CDAB)}; } HWY_API Vec512 Shuffle2301(const Vec512 v) { return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_CDAB)}; } HWY_API Vec512 Shuffle2301(const Vec512 v) { return Vec512{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_CDAB)}; } // Swap 64-bit halves HWY_API Vec512 Shuffle1032(const Vec512 v) { return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_BADC)}; } HWY_API Vec512 Shuffle1032(const Vec512 v) { return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_BADC)}; } HWY_API Vec512 Shuffle1032(const Vec512 v) { // Shorter encoding than _mm512_permute_ps. return Vec512{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_BADC)}; } HWY_API Vec512 Shuffle01(const Vec512 v) { return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_BADC)}; } HWY_API Vec512 Shuffle01(const Vec512 v) { return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_BADC)}; } HWY_API Vec512 Shuffle01(const Vec512 v) { // Shorter encoding than _mm512_permute_pd. return Vec512{_mm512_shuffle_pd(v.raw, v.raw, _MM_PERM_BBBB)}; } // Rotate right 32 bits HWY_API Vec512 Shuffle0321(const Vec512 v) { return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_ADCB)}; } HWY_API Vec512 Shuffle0321(const Vec512 v) { return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_ADCB)}; } HWY_API Vec512 Shuffle0321(const Vec512 v) { return Vec512{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_ADCB)}; } // Rotate left 32 bits HWY_API Vec512 Shuffle2103(const Vec512 v) { return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_CBAD)}; } HWY_API Vec512 Shuffle2103(const Vec512 v) { return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_CBAD)}; } HWY_API Vec512 Shuffle2103(const Vec512 v) { return Vec512{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_CBAD)}; } // Reverse HWY_API Vec512 Shuffle0123(const Vec512 v) { return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_ABCD)}; } HWY_API Vec512 Shuffle0123(const Vec512 v) { return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_ABCD)}; } HWY_API Vec512 Shuffle0123(const Vec512 v) { return Vec512{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_ABCD)}; } // ------------------------------ TableLookupLanes // Returned by SetTableIndices/IndicesFromVec for use by TableLookupLanes. template struct Indices512 { __m512i raw; }; template HWY_API Indices512 IndicesFromVec(Full512 /* tag */, Vec512 vec) { static_assert(sizeof(T) == sizeof(TI), "Index size must match lane"); #if HWY_IS_DEBUG_BUILD const Full512 di; HWY_DASSERT(AllFalse(di, Lt(vec, Zero(di))) && AllTrue(di, Lt(vec, Set(di, static_cast(64 / sizeof(T)))))); #endif return Indices512{vec.raw}; } template HWY_API Indices512 SetTableIndices(const Full512 d, const TI* idx) { const Rebind di; return IndicesFromVec(d, LoadU(di, idx)); } template HWY_API Vec512 TableLookupLanes(Vec512 v, Indices512 idx) { return Vec512{_mm512_permutexvar_epi32(idx.raw, v.raw)}; } template HWY_API Vec512 TableLookupLanes(Vec512 v, Indices512 idx) { return Vec512{_mm512_permutexvar_epi64(idx.raw, v.raw)}; } HWY_API Vec512 TableLookupLanes(Vec512 v, Indices512 idx) { return Vec512{_mm512_permutexvar_ps(idx.raw, v.raw)}; } HWY_API Vec512 TableLookupLanes(Vec512 v, Indices512 idx) { return Vec512{_mm512_permutexvar_pd(idx.raw, v.raw)}; } // ------------------------------ Reverse template HWY_API Vec512 Reverse(Full512 d, const Vec512 v) { const RebindToSigned di; alignas(64) constexpr int16_t kReverse[32] = { 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}; const Vec512 idx = Load(di, kReverse); return BitCast(d, Vec512{ _mm512_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); } template HWY_API Vec512 Reverse(Full512 d, const Vec512 v) { alignas(64) constexpr int32_t kReverse[16] = {15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0}; return TableLookupLanes(v, SetTableIndices(d, kReverse)); } template HWY_API Vec512 Reverse(Full512 d, const Vec512 v) { alignas(64) constexpr int64_t kReverse[8] = {7, 6, 5, 4, 3, 2, 1, 0}; return TableLookupLanes(v, SetTableIndices(d, kReverse)); } // ------------------------------ Reverse2 template HWY_API Vec512 Reverse2(Full512 d, const Vec512 v) { const Full512 du32; return BitCast(d, RotateRight<16>(BitCast(du32, v))); } template HWY_API Vec512 Reverse2(Full512 /* tag */, const Vec512 v) { return Shuffle2301(v); } template HWY_API Vec512 Reverse2(Full512 /* tag */, const Vec512 v) { return Shuffle01(v); } // ------------------------------ Reverse4 template HWY_API Vec512 Reverse4(Full512 d, const Vec512 v) { const RebindToSigned di; alignas(64) constexpr int16_t kReverse4[32] = { 3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8, 15, 14, 13, 12, 19, 18, 17, 16, 23, 22, 21, 20, 27, 26, 25, 24, 31, 30, 29, 28}; const Vec512 idx = Load(di, kReverse4); return BitCast(d, Vec512{ _mm512_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); } template HWY_API Vec512 Reverse4(Full512 /* tag */, const Vec512 v) { return Shuffle0123(v); } template HWY_API Vec512 Reverse4(Full512 /* tag */, const Vec512 v) { return Vec512{_mm512_permutex_epi64(v.raw, _MM_SHUFFLE(0, 1, 2, 3))}; } HWY_API Vec512 Reverse4(Full512 /* tag */, Vec512 v) { return Vec512{_mm512_permutex_pd(v.raw, _MM_SHUFFLE(0, 1, 2, 3))}; } // ------------------------------ Reverse8 template HWY_API Vec512 Reverse8(Full512 d, const Vec512 v) { const RebindToSigned di; alignas(64) constexpr int16_t kReverse8[32] = { 7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8, 23, 22, 21, 20, 19, 18, 17, 16, 31, 30, 29, 28, 27, 26, 25, 24}; const Vec512 idx = Load(di, kReverse8); return BitCast(d, Vec512{ _mm512_permutexvar_epi16(idx.raw, BitCast(di, v).raw)}); } template HWY_API Vec512 Reverse8(Full512 d, const Vec512 v) { const RebindToSigned di; alignas(64) constexpr int32_t kReverse8[16] = {7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8}; const Vec512 idx = Load(di, kReverse8); return BitCast(d, Vec512{ _mm512_permutexvar_epi32(idx.raw, BitCast(di, v).raw)}); } template HWY_API Vec512 Reverse8(Full512 d, const Vec512 v) { return Reverse(d, v); } // ------------------------------ InterleaveLower // Interleaves lanes from halves of the 128-bit blocks of "a" (which provides // the least-significant lane) and "b". To concatenate two half-width integers // into one, use ZipLower/Upper instead (also works with scalar). HWY_API Vec512 InterleaveLower(const Vec512 a, const Vec512 b) { return Vec512{_mm512_unpacklo_epi8(a.raw, b.raw)}; } HWY_API Vec512 InterleaveLower(const Vec512 a, const Vec512 b) { return Vec512{_mm512_unpacklo_epi16(a.raw, b.raw)}; } HWY_API Vec512 InterleaveLower(const Vec512 a, const Vec512 b) { return Vec512{_mm512_unpacklo_epi32(a.raw, b.raw)}; } HWY_API Vec512 InterleaveLower(const Vec512 a, const Vec512 b) { return Vec512{_mm512_unpacklo_epi64(a.raw, b.raw)}; } HWY_API Vec512 InterleaveLower(const Vec512 a, const Vec512 b) { return Vec512{_mm512_unpacklo_epi8(a.raw, b.raw)}; } HWY_API Vec512 InterleaveLower(const Vec512 a, const Vec512 b) { return Vec512{_mm512_unpacklo_epi16(a.raw, b.raw)}; } HWY_API Vec512 InterleaveLower(const Vec512 a, const Vec512 b) { return Vec512{_mm512_unpacklo_epi32(a.raw, b.raw)}; } HWY_API Vec512 InterleaveLower(const Vec512 a, const Vec512 b) { return Vec512{_mm512_unpacklo_epi64(a.raw, b.raw)}; } HWY_API Vec512 InterleaveLower(const Vec512 a, const Vec512 b) { return Vec512{_mm512_unpacklo_ps(a.raw, b.raw)}; } HWY_API Vec512 InterleaveLower(const Vec512 a, const Vec512 b) { return Vec512{_mm512_unpacklo_pd(a.raw, b.raw)}; } // ------------------------------ InterleaveUpper // All functions inside detail lack the required D parameter. namespace detail { HWY_API Vec512 InterleaveUpper(const Vec512 a, const Vec512 b) { return Vec512{_mm512_unpackhi_epi8(a.raw, b.raw)}; } HWY_API Vec512 InterleaveUpper(const Vec512 a, const Vec512 b) { return Vec512{_mm512_unpackhi_epi16(a.raw, b.raw)}; } HWY_API Vec512 InterleaveUpper(const Vec512 a, const Vec512 b) { return Vec512{_mm512_unpackhi_epi32(a.raw, b.raw)}; } HWY_API Vec512 InterleaveUpper(const Vec512 a, const Vec512 b) { return Vec512{_mm512_unpackhi_epi64(a.raw, b.raw)}; } HWY_API Vec512 InterleaveUpper(const Vec512 a, const Vec512 b) { return Vec512{_mm512_unpackhi_epi8(a.raw, b.raw)}; } HWY_API Vec512 InterleaveUpper(const Vec512 a, const Vec512 b) { return Vec512{_mm512_unpackhi_epi16(a.raw, b.raw)}; } HWY_API Vec512 InterleaveUpper(const Vec512 a, const Vec512 b) { return Vec512{_mm512_unpackhi_epi32(a.raw, b.raw)}; } HWY_API Vec512 InterleaveUpper(const Vec512 a, const Vec512 b) { return Vec512{_mm512_unpackhi_epi64(a.raw, b.raw)}; } HWY_API Vec512 InterleaveUpper(const Vec512 a, const Vec512 b) { return Vec512{_mm512_unpackhi_ps(a.raw, b.raw)}; } HWY_API Vec512 InterleaveUpper(const Vec512 a, const Vec512 b) { return Vec512{_mm512_unpackhi_pd(a.raw, b.raw)}; } } // namespace detail template > HWY_API V InterleaveUpper(Full512 /* tag */, V a, V b) { return detail::InterleaveUpper(a, b); } // ------------------------------ ZipLower/ZipUpper (InterleaveLower) // Same as Interleave*, except that the return lanes are double-width integers; // this is necessary because the single-lane scalar cannot return two values. template > HWY_API Vec512 ZipLower(Vec512 a, Vec512 b) { return BitCast(Full512(), InterleaveLower(a, b)); } template > HWY_API Vec512 ZipLower(Full512 /* d */, Vec512 a, Vec512 b) { return BitCast(Full512(), InterleaveLower(a, b)); } template > HWY_API Vec512 ZipUpper(Full512 d, Vec512 a, Vec512 b) { return BitCast(Full512(), InterleaveUpper(d, a, b)); } // ------------------------------ Concat* halves // hiH,hiL loH,loL |-> hiL,loL (= lower halves) template HWY_API Vec512 ConcatLowerLower(Full512 /* tag */, const Vec512 hi, const Vec512 lo) { return Vec512{_mm512_shuffle_i32x4(lo.raw, hi.raw, _MM_PERM_BABA)}; } HWY_API Vec512 ConcatLowerLower(Full512 /* tag */, const Vec512 hi, const Vec512 lo) { return Vec512{_mm512_shuffle_f32x4(lo.raw, hi.raw, _MM_PERM_BABA)}; } HWY_API Vec512 ConcatLowerLower(Full512 /* tag */, const Vec512 hi, const Vec512 lo) { return Vec512{_mm512_shuffle_f64x2(lo.raw, hi.raw, _MM_PERM_BABA)}; } // hiH,hiL loH,loL |-> hiH,loH (= upper halves) template HWY_API Vec512 ConcatUpperUpper(Full512 /* tag */, const Vec512 hi, const Vec512 lo) { return Vec512{_mm512_shuffle_i32x4(lo.raw, hi.raw, _MM_PERM_DCDC)}; } HWY_API Vec512 ConcatUpperUpper(Full512 /* tag */, const Vec512 hi, const Vec512 lo) { return Vec512{_mm512_shuffle_f32x4(lo.raw, hi.raw, _MM_PERM_DCDC)}; } HWY_API Vec512 ConcatUpperUpper(Full512 /* tag */, const Vec512 hi, const Vec512 lo) { return Vec512{_mm512_shuffle_f64x2(lo.raw, hi.raw, _MM_PERM_DCDC)}; } // hiH,hiL loH,loL |-> hiL,loH (= inner halves / swap blocks) template HWY_API Vec512 ConcatLowerUpper(Full512 /* tag */, const Vec512 hi, const Vec512 lo) { return Vec512{_mm512_shuffle_i32x4(lo.raw, hi.raw, _MM_PERM_BADC)}; } HWY_API Vec512 ConcatLowerUpper(Full512 /* tag */, const Vec512 hi, const Vec512 lo) { return Vec512{_mm512_shuffle_f32x4(lo.raw, hi.raw, _MM_PERM_BADC)}; } HWY_API Vec512 ConcatLowerUpper(Full512 /* tag */, const Vec512 hi, const Vec512 lo) { return Vec512{_mm512_shuffle_f64x2(lo.raw, hi.raw, _MM_PERM_BADC)}; } // hiH,hiL loH,loL |-> hiH,loL (= outer halves) template HWY_API Vec512 ConcatUpperLower(Full512 /* tag */, const Vec512 hi, const Vec512 lo) { // There are no imm8 blend in AVX512. Use blend16 because 32-bit masks // are efficiently loaded from 32-bit regs. const __mmask32 mask = /*_cvtu32_mask32 */ (0x0000FFFF); return Vec512{_mm512_mask_blend_epi16(mask, hi.raw, lo.raw)}; } HWY_API Vec512 ConcatUpperLower(Full512 /* tag */, const Vec512 hi, const Vec512 lo) { const __mmask16 mask = /*_cvtu32_mask16 */ (0x00FF); return Vec512{_mm512_mask_blend_ps(mask, hi.raw, lo.raw)}; } HWY_API Vec512 ConcatUpperLower(Full512 /* tag */, const Vec512 hi, const Vec512 lo) { const __mmask8 mask = /*_cvtu32_mask8 */ (0x0F); return Vec512{_mm512_mask_blend_pd(mask, hi.raw, lo.raw)}; } // ------------------------------ ConcatOdd template HWY_API Vec512 ConcatOdd(Full512 d, Vec512 hi, Vec512 lo) { const RebindToUnsigned du; alignas(64) constexpr uint32_t kIdx[16] = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31}; return BitCast(d, Vec512{_mm512_mask2_permutex2var_epi32( BitCast(du, lo).raw, Load(du, kIdx).raw, __mmask16{0xFFFF}, BitCast(du, hi).raw)}); } HWY_API Vec512 ConcatOdd(Full512 d, Vec512 hi, Vec512 lo) { const RebindToUnsigned du; alignas(64) constexpr uint32_t kIdx[16] = {1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31}; return Vec512{_mm512_mask2_permutex2var_ps(lo.raw, Load(du, kIdx).raw, __mmask16{0xFFFF}, hi.raw)}; } template HWY_API Vec512 ConcatOdd(Full512 d, Vec512 hi, Vec512 lo) { const RebindToUnsigned du; alignas(64) constexpr uint64_t kIdx[8] = {1, 3, 5, 7, 9, 11, 13, 15}; return BitCast(d, Vec512{_mm512_mask2_permutex2var_epi64( BitCast(du, lo).raw, Load(du, kIdx).raw, __mmask8{0xFF}, BitCast(du, hi).raw)}); } HWY_API Vec512 ConcatOdd(Full512 d, Vec512 hi, Vec512 lo) { const RebindToUnsigned du; alignas(64) constexpr uint64_t kIdx[8] = {1, 3, 5, 7, 9, 11, 13, 15}; return Vec512{_mm512_mask2_permutex2var_pd(lo.raw, Load(du, kIdx).raw, __mmask8{0xFF}, hi.raw)}; } // ------------------------------ ConcatEven template HWY_API Vec512 ConcatEven(Full512 d, Vec512 hi, Vec512 lo) { const RebindToUnsigned du; alignas(64) constexpr uint32_t kIdx[16] = {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30}; return BitCast(d, Vec512{_mm512_mask2_permutex2var_epi32( BitCast(du, lo).raw, Load(du, kIdx).raw, __mmask16{0xFFFF}, BitCast(du, hi).raw)}); } HWY_API Vec512 ConcatEven(Full512 d, Vec512 hi, Vec512 lo) { const RebindToUnsigned du; alignas(64) constexpr uint32_t kIdx[16] = {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30}; return Vec512{_mm512_mask2_permutex2var_ps(lo.raw, Load(du, kIdx).raw, __mmask16{0xFFFF}, hi.raw)}; } template HWY_API Vec512 ConcatEven(Full512 d, Vec512 hi, Vec512 lo) { const RebindToUnsigned du; alignas(64) constexpr uint64_t kIdx[8] = {0, 2, 4, 6, 8, 10, 12, 14}; return BitCast(d, Vec512{_mm512_mask2_permutex2var_epi64( BitCast(du, lo).raw, Load(du, kIdx).raw, __mmask8{0xFF}, BitCast(du, hi).raw)}); } HWY_API Vec512 ConcatEven(Full512 d, Vec512 hi, Vec512 lo) { const RebindToUnsigned du; alignas(64) constexpr uint64_t kIdx[8] = {0, 2, 4, 6, 8, 10, 12, 14}; return Vec512{_mm512_mask2_permutex2var_pd(lo.raw, Load(du, kIdx).raw, __mmask8{0xFF}, hi.raw)}; } // ------------------------------ DupEven (InterleaveLower) template HWY_API Vec512 DupEven(Vec512 v) { return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_CCAA)}; } HWY_API Vec512 DupEven(Vec512 v) { return Vec512{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_CCAA)}; } template HWY_API Vec512 DupEven(const Vec512 v) { return InterleaveLower(Full512(), v, v); } // ------------------------------ DupOdd (InterleaveUpper) template HWY_API Vec512 DupOdd(Vec512 v) { return Vec512{_mm512_shuffle_epi32(v.raw, _MM_PERM_DDBB)}; } HWY_API Vec512 DupOdd(Vec512 v) { return Vec512{_mm512_shuffle_ps(v.raw, v.raw, _MM_PERM_DDBB)}; } template HWY_API Vec512 DupOdd(const Vec512 v) { return InterleaveUpper(Full512(), v, v); } // ------------------------------ OddEven template HWY_API Vec512 OddEven(const Vec512 a, const Vec512 b) { constexpr size_t s = sizeof(T); constexpr int shift = s == 1 ? 0 : s == 2 ? 32 : s == 4 ? 48 : 56; return IfThenElse(Mask512{0x5555555555555555ull >> shift}, b, a); } // ------------------------------ OddEvenBlocks template HWY_API Vec512 OddEvenBlocks(Vec512 odd, Vec512 even) { return Vec512{_mm512_mask_blend_epi64(__mmask8{0x33u}, odd.raw, even.raw)}; } HWY_API Vec512 OddEvenBlocks(Vec512 odd, Vec512 even) { return Vec512{ _mm512_mask_blend_ps(__mmask16{0x0F0Fu}, odd.raw, even.raw)}; } HWY_API Vec512 OddEvenBlocks(Vec512 odd, Vec512 even) { return Vec512{ _mm512_mask_blend_pd(__mmask8{0x33u}, odd.raw, even.raw)}; } // ------------------------------ SwapAdjacentBlocks template HWY_API Vec512 SwapAdjacentBlocks(Vec512 v) { return Vec512{_mm512_shuffle_i32x4(v.raw, v.raw, _MM_PERM_CDAB)}; } HWY_API Vec512 SwapAdjacentBlocks(Vec512 v) { return Vec512{_mm512_shuffle_f32x4(v.raw, v.raw, _MM_PERM_CDAB)}; } HWY_API Vec512 SwapAdjacentBlocks(Vec512 v) { return Vec512{_mm512_shuffle_f64x2(v.raw, v.raw, _MM_PERM_CDAB)}; } // ------------------------------ ReverseBlocks template HWY_API Vec512 ReverseBlocks(Full512 /* tag */, Vec512 v) { return Vec512{_mm512_shuffle_i32x4(v.raw, v.raw, _MM_PERM_ABCD)}; } HWY_API Vec512 ReverseBlocks(Full512 /* tag */, Vec512 v) { return Vec512{_mm512_shuffle_f32x4(v.raw, v.raw, _MM_PERM_ABCD)}; } HWY_API Vec512 ReverseBlocks(Full512 /* tag */, Vec512 v) { return Vec512{_mm512_shuffle_f64x2(v.raw, v.raw, _MM_PERM_ABCD)}; } // ------------------------------ TableLookupBytes (ZeroExtendVector) // Both full template HWY_API Vec512 TableLookupBytes(Vec512 bytes, Vec512 indices) { return Vec512{_mm512_shuffle_epi8(bytes.raw, indices.raw)}; } // Partial index vector template HWY_API Vec128 TableLookupBytes(Vec512 bytes, Vec128 from) { const Full512 d512; const Half d256; const Half d128; // First expand to full 128, then 256, then 512. const Vec128 from_full{from.raw}; const auto from_512 = ZeroExtendVector(d512, ZeroExtendVector(d256, from_full)); const auto tbl_full = TableLookupBytes(bytes, from_512); // Shrink to 256, then 128, then partial. return Vec128{LowerHalf(d128, LowerHalf(d256, tbl_full)).raw}; } template HWY_API Vec256 TableLookupBytes(Vec512 bytes, Vec256 from) { const auto from_512 = ZeroExtendVector(Full512(), from); return LowerHalf(Full256(), TableLookupBytes(bytes, from_512)); } // Partial table vector template HWY_API Vec512 TableLookupBytes(Vec128 bytes, Vec512 from) { const Full512 d512; const Half d256; const Half d128; // First expand to full 128, then 256, then 512. const Vec128 bytes_full{bytes.raw}; const auto bytes_512 = ZeroExtendVector(d512, ZeroExtendVector(d256, bytes_full)); return TableLookupBytes(bytes_512, from); } template HWY_API Vec512 TableLookupBytes(Vec256 bytes, Vec512 from) { const auto bytes_512 = ZeroExtendVector(Full512(), bytes); return TableLookupBytes(bytes_512, from); } // Partial both are handled by x86_128/256. // ================================================== CONVERT // ------------------------------ Promotions (part w/ narrow lanes -> full) // Unsigned: zero-extend. // Note: these have 3 cycle latency; if inputs are already split across the // 128 bit blocks (in their upper/lower halves), then Zip* would be faster. HWY_API Vec512 PromoteTo(Full512 /* tag */, Vec256 v) { return Vec512{_mm512_cvtepu8_epi16(v.raw)}; } HWY_API Vec512 PromoteTo(Full512 /* tag */, Vec128 v) { return Vec512{_mm512_cvtepu8_epi32(v.raw)}; } HWY_API Vec512 PromoteTo(Full512 /* tag */, Vec256 v) { return Vec512{_mm512_cvtepu8_epi16(v.raw)}; } HWY_API Vec512 PromoteTo(Full512 /* tag */, Vec128 v) { return Vec512{_mm512_cvtepu8_epi32(v.raw)}; } HWY_API Vec512 PromoteTo(Full512 /* tag */, Vec256 v) { return Vec512{_mm512_cvtepu16_epi32(v.raw)}; } HWY_API Vec512 PromoteTo(Full512 /* tag */, Vec256 v) { return Vec512{_mm512_cvtepu16_epi32(v.raw)}; } HWY_API Vec512 PromoteTo(Full512 /* tag */, Vec256 v) { return Vec512{_mm512_cvtepu32_epi64(v.raw)}; } // Signed: replicate sign bit. // Note: these have 3 cycle latency; if inputs are already split across the // 128 bit blocks (in their upper/lower halves), then ZipUpper/lo followed by // signed shift would be faster. HWY_API Vec512 PromoteTo(Full512 /* tag */, Vec256 v) { return Vec512{_mm512_cvtepi8_epi16(v.raw)}; } HWY_API Vec512 PromoteTo(Full512 /* tag */, Vec128 v) { return Vec512{_mm512_cvtepi8_epi32(v.raw)}; } HWY_API Vec512 PromoteTo(Full512 /* tag */, Vec256 v) { return Vec512{_mm512_cvtepi16_epi32(v.raw)}; } HWY_API Vec512 PromoteTo(Full512 /* tag */, Vec256 v) { return Vec512{_mm512_cvtepi32_epi64(v.raw)}; } // Float HWY_API Vec512 PromoteTo(Full512 /* tag */, const Vec256 v) { return Vec512{_mm512_cvtph_ps(v.raw)}; } HWY_API Vec512 PromoteTo(Full512 df32, const Vec256 v) { const Rebind du16; const RebindToSigned di32; return BitCast(df32, ShiftLeft<16>(PromoteTo(di32, BitCast(du16, v)))); } HWY_API Vec512 PromoteTo(Full512 /* tag */, Vec256 v) { return Vec512{_mm512_cvtps_pd(v.raw)}; } HWY_API Vec512 PromoteTo(Full512 /* tag */, Vec256 v) { return Vec512{_mm512_cvtepi32_pd(v.raw)}; } // ------------------------------ Demotions (full -> part w/ narrow lanes) HWY_API Vec256 DemoteTo(Full256 /* tag */, const Vec512 v) { const Vec512 u16{_mm512_packus_epi32(v.raw, v.raw)}; // Compress even u64 lanes into 256 bit. alignas(64) static constexpr uint64_t kLanes[8] = {0, 2, 4, 6, 0, 2, 4, 6}; const auto idx64 = Load(Full512(), kLanes); const Vec512 even{_mm512_permutexvar_epi64(idx64.raw, u16.raw)}; return LowerHalf(even); } HWY_API Vec256 DemoteTo(Full256 /* tag */, const Vec512 v) { const Vec512 i16{_mm512_packs_epi32(v.raw, v.raw)}; // Compress even u64 lanes into 256 bit. alignas(64) static constexpr uint64_t kLanes[8] = {0, 2, 4, 6, 0, 2, 4, 6}; const auto idx64 = Load(Full512(), kLanes); const Vec512 even{_mm512_permutexvar_epi64(idx64.raw, i16.raw)}; return LowerHalf(even); } HWY_API Vec128 DemoteTo(Full128 /* tag */, const Vec512 v) { const Vec512 u16{_mm512_packus_epi32(v.raw, v.raw)}; // packus treats the input as signed; we want unsigned. Clear the MSB to get // unsigned saturation to u8. const Vec512 i16{ _mm512_and_si512(u16.raw, _mm512_set1_epi16(0x7FFF))}; const Vec512 u8{_mm512_packus_epi16(i16.raw, i16.raw)}; alignas(16) static constexpr uint32_t kLanes[4] = {0, 4, 8, 12}; const auto idx32 = LoadDup128(Full512(), kLanes); const Vec512 fixed{_mm512_permutexvar_epi32(idx32.raw, u8.raw)}; return LowerHalf(LowerHalf(fixed)); } HWY_API Vec256 DemoteTo(Full256 /* tag */, const Vec512 v) { const Vec512 u8{_mm512_packus_epi16(v.raw, v.raw)}; // Compress even u64 lanes into 256 bit. alignas(64) static constexpr uint64_t kLanes[8] = {0, 2, 4, 6, 0, 2, 4, 6}; const auto idx64 = Load(Full512(), kLanes); const Vec512 even{_mm512_permutexvar_epi64(idx64.raw, u8.raw)}; return LowerHalf(even); } HWY_API Vec128 DemoteTo(Full128 /* tag */, const Vec512 v) { const Vec512 i16{_mm512_packs_epi32(v.raw, v.raw)}; const Vec512 i8{_mm512_packs_epi16(i16.raw, i16.raw)}; alignas(16) static constexpr uint32_t kLanes[16] = {0, 4, 8, 12, 0, 4, 8, 12, 0, 4, 8, 12, 0, 4, 8, 12}; const auto idx32 = LoadDup128(Full512(), kLanes); const Vec512 fixed{_mm512_permutexvar_epi32(idx32.raw, i8.raw)}; return LowerHalf(LowerHalf(fixed)); } HWY_API Vec256 DemoteTo(Full256 /* tag */, const Vec512 v) { const Vec512 u8{_mm512_packs_epi16(v.raw, v.raw)}; // Compress even u64 lanes into 256 bit. alignas(64) static constexpr uint64_t kLanes[8] = {0, 2, 4, 6, 0, 2, 4, 6}; const auto idx64 = Load(Full512(), kLanes); const Vec512 even{_mm512_permutexvar_epi64(idx64.raw, u8.raw)}; return LowerHalf(even); } HWY_API Vec256 DemoteTo(Full256 /* tag */, const Vec512 v) { // Work around warnings in the intrinsic definitions (passing -1 as a mask). HWY_DIAGNOSTICS(push) HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion") return Vec256{_mm512_cvtps_ph(v.raw, _MM_FROUND_NO_EXC)}; HWY_DIAGNOSTICS(pop) } HWY_API Vec256 DemoteTo(Full256 dbf16, const Vec512 v) { // TODO(janwas): _mm512_cvtneps_pbh once we have avx512bf16. const Rebind di32; const Rebind du32; // for logical shift right const Rebind du16; const auto bits_in_32 = BitCast(di32, ShiftRight<16>(BitCast(du32, v))); return BitCast(dbf16, DemoteTo(du16, bits_in_32)); } HWY_API Vec512 ReorderDemote2To(Full512 dbf16, Vec512 a, Vec512 b) { // TODO(janwas): _mm512_cvtne2ps_pbh once we have avx512bf16. const RebindToUnsigned du16; const Repartition du32; const Vec512 b_in_even = ShiftRight<16>(BitCast(du32, b)); return BitCast(dbf16, OddEven(BitCast(du16, a), BitCast(du16, b_in_even))); } HWY_API Vec256 DemoteTo(Full256 /* tag */, const Vec512 v) { return Vec256{_mm512_cvtpd_ps(v.raw)}; } HWY_API Vec256 DemoteTo(Full256 /* tag */, const Vec512 v) { const auto clamped = detail::ClampF64ToI32Max(Full512(), v); return Vec256{_mm512_cvttpd_epi32(clamped.raw)}; } // For already range-limited input [0, 255]. HWY_API Vec128 U8FromU32(const Vec512 v) { const Full512 d32; // In each 128 bit block, gather the lower byte of 4 uint32_t lanes into the // lowest 4 bytes. alignas(16) static constexpr uint32_t k8From32[4] = {0x0C080400u, ~0u, ~0u, ~0u}; const auto quads = TableLookupBytes(v, LoadDup128(d32, k8From32)); // Gather the lowest 4 bytes of 4 128-bit blocks. alignas(16) static constexpr uint32_t kIndex32[4] = {0, 4, 8, 12}; const Vec512 bytes{ _mm512_permutexvar_epi32(LoadDup128(d32, kIndex32).raw, quads.raw)}; return LowerHalf(LowerHalf(bytes)); } // ------------------------------ Convert integer <=> floating point HWY_API Vec512 ConvertTo(Full512 /* tag */, const Vec512 v) { return Vec512{_mm512_cvtepi32_ps(v.raw)}; } HWY_API Vec512 ConvertTo(Full512 /* tag */, const Vec512 v) { return Vec512{_mm512_cvtepi64_pd(v.raw)}; } // Truncates (rounds toward zero). HWY_API Vec512 ConvertTo(Full512 d, const Vec512 v) { return detail::FixConversionOverflow(d, v, _mm512_cvttps_epi32(v.raw)); } HWY_API Vec512 ConvertTo(Full512 di, const Vec512 v) { return detail::FixConversionOverflow(di, v, _mm512_cvttpd_epi64(v.raw)); } HWY_API Vec512 NearestInt(const Vec512 v) { const Full512 di; return detail::FixConversionOverflow(di, v, _mm512_cvtps_epi32(v.raw)); } // ================================================== CRYPTO #if !defined(HWY_DISABLE_PCLMUL_AES) // Per-target flag to prevent generic_ops-inl.h from defining AESRound. #ifdef HWY_NATIVE_AES #undef HWY_NATIVE_AES #else #define HWY_NATIVE_AES #endif HWY_API Vec512 AESRound(Vec512 state, Vec512 round_key) { #if HWY_TARGET == HWY_AVX3_DL return Vec512{_mm512_aesenc_epi128(state.raw, round_key.raw)}; #else const Full512 d; const Half d2; return Combine(d, AESRound(UpperHalf(d2, state), UpperHalf(d2, round_key)), AESRound(LowerHalf(state), LowerHalf(round_key))); #endif } HWY_API Vec512 AESLastRound(Vec512 state, Vec512 round_key) { #if HWY_TARGET == HWY_AVX3_DL return Vec512{_mm512_aesenclast_epi128(state.raw, round_key.raw)}; #else const Full512 d; const Half d2; return Combine(d, AESLastRound(UpperHalf(d2, state), UpperHalf(d2, round_key)), AESLastRound(LowerHalf(state), LowerHalf(round_key))); #endif } HWY_API Vec512 CLMulLower(Vec512 va, Vec512 vb) { #if HWY_TARGET == HWY_AVX3_DL return Vec512{_mm512_clmulepi64_epi128(va.raw, vb.raw, 0x00)}; #else alignas(64) uint64_t a[8]; alignas(64) uint64_t b[8]; const Full512 d; const Full128 d128; Store(va, d, a); Store(vb, d, b); for (size_t i = 0; i < 8; i += 2) { const auto mul = CLMulLower(Load(d128, a + i), Load(d128, b + i)); Store(mul, d128, a + i); } return Load(d, a); #endif } HWY_API Vec512 CLMulUpper(Vec512 va, Vec512 vb) { #if HWY_TARGET == HWY_AVX3_DL return Vec512{_mm512_clmulepi64_epi128(va.raw, vb.raw, 0x11)}; #else alignas(64) uint64_t a[8]; alignas(64) uint64_t b[8]; const Full512 d; const Full128 d128; Store(va, d, a); Store(vb, d, b); for (size_t i = 0; i < 8; i += 2) { const auto mul = CLMulUpper(Load(d128, a + i), Load(d128, b + i)); Store(mul, d128, a + i); } return Load(d, a); #endif } #endif // HWY_DISABLE_PCLMUL_AES // ================================================== MISC // Returns a vector with lane i=[0, N) set to "first" + i. template Vec512 Iota(const Full512 d, const T2 first) { HWY_ALIGN T lanes[64 / sizeof(T)]; for (size_t i = 0; i < 64 / sizeof(T); ++i) { lanes[i] = static_cast(first + static_cast(i)); } return Load(d, lanes); } // ------------------------------ Mask testing // Beware: the suffix indicates the number of mask bits, not lane size! namespace detail { template HWY_INLINE bool AllFalse(hwy::SizeTag<1> /*tag*/, const Mask512 mask) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return _kortestz_mask64_u8(mask.raw, mask.raw); #else return mask.raw == 0; #endif } template HWY_INLINE bool AllFalse(hwy::SizeTag<2> /*tag*/, const Mask512 mask) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return _kortestz_mask32_u8(mask.raw, mask.raw); #else return mask.raw == 0; #endif } template HWY_INLINE bool AllFalse(hwy::SizeTag<4> /*tag*/, const Mask512 mask) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return _kortestz_mask16_u8(mask.raw, mask.raw); #else return mask.raw == 0; #endif } template HWY_INLINE bool AllFalse(hwy::SizeTag<8> /*tag*/, const Mask512 mask) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return _kortestz_mask8_u8(mask.raw, mask.raw); #else return mask.raw == 0; #endif } } // namespace detail template HWY_API bool AllFalse(const Full512 /* tag */, const Mask512 mask) { return detail::AllFalse(hwy::SizeTag(), mask); } namespace detail { template HWY_INLINE bool AllTrue(hwy::SizeTag<1> /*tag*/, const Mask512 mask) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return _kortestc_mask64_u8(mask.raw, mask.raw); #else return mask.raw == 0xFFFFFFFFFFFFFFFFull; #endif } template HWY_INLINE bool AllTrue(hwy::SizeTag<2> /*tag*/, const Mask512 mask) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return _kortestc_mask32_u8(mask.raw, mask.raw); #else return mask.raw == 0xFFFFFFFFull; #endif } template HWY_INLINE bool AllTrue(hwy::SizeTag<4> /*tag*/, const Mask512 mask) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return _kortestc_mask16_u8(mask.raw, mask.raw); #else return mask.raw == 0xFFFFull; #endif } template HWY_INLINE bool AllTrue(hwy::SizeTag<8> /*tag*/, const Mask512 mask) { #if HWY_COMPILER_HAS_MASK_INTRINSICS return _kortestc_mask8_u8(mask.raw, mask.raw); #else return mask.raw == 0xFFull; #endif } } // namespace detail template HWY_API bool AllTrue(const Full512 /* tag */, const Mask512 mask) { return detail::AllTrue(hwy::SizeTag(), mask); } // `p` points to at least 8 readable bytes, not all of which need be valid. template HWY_API Mask512 LoadMaskBits(const Full512 /* tag */, const uint8_t* HWY_RESTRICT bits) { Mask512 mask; CopyBytes<8 / sizeof(T)>(bits, &mask.raw); // N >= 8 (= 512 / 64), so no need to mask invalid bits. return mask; } // `p` points to at least 8 writable bytes. template HWY_API size_t StoreMaskBits(const Full512 /* tag */, const Mask512 mask, uint8_t* bits) { const size_t kNumBytes = 8 / sizeof(T); CopyBytes(&mask.raw, bits); // N >= 8 (= 512 / 64), so no need to mask invalid bits. return kNumBytes; } template HWY_API size_t CountTrue(const Full512 /* tag */, const Mask512 mask) { return PopCount(static_cast(mask.raw)); } template HWY_API intptr_t FindFirstTrue(const Full512 /* tag */, const Mask512 mask) { return mask.raw ? intptr_t(Num0BitsBelowLS1Bit_Nonzero32(mask.raw)) : -1; } template HWY_API intptr_t FindFirstTrue(const Full512 /* tag */, const Mask512 mask) { return mask.raw ? intptr_t(Num0BitsBelowLS1Bit_Nonzero64(mask.raw)) : -1; } // ------------------------------ Compress template HWY_API Vec512 Compress(Vec512 v, Mask512 mask) { return Vec512{_mm512_maskz_compress_epi32(mask.raw, v.raw)}; } template HWY_API Vec512 Compress(Vec512 v, Mask512 mask) { return Vec512{_mm512_maskz_compress_epi64(mask.raw, v.raw)}; } HWY_API Vec512 Compress(Vec512 v, Mask512 mask) { return Vec512{_mm512_maskz_compress_ps(mask.raw, v.raw)}; } HWY_API Vec512 Compress(Vec512 v, Mask512 mask) { return Vec512{_mm512_maskz_compress_pd(mask.raw, v.raw)}; } // 16-bit may use the 32-bit Compress and must be defined after it. // // Ignore IDE redefinition error - this is not actually defined in x86_256 if // we are including x86_512-inl.h. template HWY_API Vec256 Compress(Vec256 v, Mask256 mask) { const Full256 d; const Rebind du; const auto vu = BitCast(du, v); // (required for float16_t inputs) #if HWY_TARGET == HWY_AVX3_DL // VBMI2 const Vec256 cu{_mm256_maskz_compress_epi16(mask.raw, vu.raw)}; #else // Promote to i32 (512-bit vector!) so we can use the native Compress. const auto vw = PromoteTo(Rebind(), vu); const Mask512 mask32{static_cast<__mmask16>(mask.raw)}; const auto cu = DemoteTo(du, Compress(vw, mask32)); #endif // HWY_TARGET == HWY_AVX3_DL return BitCast(d, cu); } // Expands to 32-bit, compresses, concatenate demoted halves. template HWY_API Vec512 Compress(Vec512 v, const Mask512 mask) { const Full512 d; const Rebind du; const auto vu = BitCast(du, v); // (required for float16_t inputs) #if HWY_TARGET == HWY_AVX3_DL // VBMI2 const Vec512 cu{_mm512_maskz_compress_epi16(mask.raw, v.raw)}; #else const Repartition dw; const Half duh; const auto promoted0 = PromoteTo(dw, LowerHalf(duh, vu)); const auto promoted1 = PromoteTo(dw, UpperHalf(duh, vu)); const uint32_t mask_bits{mask.raw}; const Mask512 mask0{static_cast<__mmask16>(mask_bits & 0xFFFF)}; const Mask512 mask1{static_cast<__mmask16>(mask_bits >> 16)}; const auto compressed0 = Compress(promoted0, mask0); const auto compressed1 = Compress(promoted1, mask1); const auto demoted0 = ZeroExtendVector(du, DemoteTo(duh, compressed0)); const auto demoted1 = ZeroExtendVector(du, DemoteTo(duh, compressed1)); // Concatenate into single vector by shifting upper with writemask. const size_t num0 = CountTrue(dw, mask0); const __mmask32 m_upper = ~((1u << num0) - 1); alignas(64) uint16_t iota[64] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}; const auto idx = LoadU(du, iota + 32 - num0); const Vec512 cu{_mm512_mask_permutexvar_epi16( demoted0.raw, m_upper, idx.raw, demoted1.raw)}; #endif // HWY_TARGET == HWY_AVX3_DL return BitCast(d, cu); } // ------------------------------ CompressBits template HWY_API Vec512 CompressBits(Vec512 v, const uint8_t* HWY_RESTRICT bits) { return Compress(v, LoadMaskBits(Full512(), bits)); } // ------------------------------ CompressStore template HWY_API size_t CompressStore(Vec512 v, Mask512 mask, Full512 d, T* HWY_RESTRICT unaligned) { const Rebind du; const auto vu = BitCast(du, v); // (required for float16_t inputs) const uint64_t mask_bits{mask.raw}; #if HWY_TARGET == HWY_AVX3_DL // VBMI2 _mm512_mask_compressstoreu_epi16(unaligned, mask.raw, v.raw); #else const Repartition dw; const Half duh; const auto promoted0 = PromoteTo(dw, LowerHalf(duh, vu)); const auto promoted1 = PromoteTo(dw, UpperHalf(duh, vu)); const uint64_t maskL = mask_bits & 0xFFFF; const uint64_t maskH = mask_bits >> 16; const Mask512 mask0{static_cast<__mmask16>(maskL)}; const Mask512 mask1{static_cast<__mmask16>(maskH)}; const auto compressed0 = Compress(promoted0, mask0); const auto compressed1 = Compress(promoted1, mask1); const Half dh; const auto demoted0 = BitCast(dh, DemoteTo(duh, compressed0)); const auto demoted1 = BitCast(dh, DemoteTo(duh, compressed1)); // Store 256-bit halves StoreU(demoted0, dh, unaligned); StoreU(demoted1, dh, unaligned + PopCount(maskL)); #endif return PopCount(mask_bits); } template HWY_API size_t CompressStore(Vec512 v, Mask512 mask, Full512 /* tag */, T* HWY_RESTRICT unaligned) { _mm512_mask_compressstoreu_epi32(unaligned, mask.raw, v.raw); return PopCount(uint64_t{mask.raw}); } template HWY_API size_t CompressStore(Vec512 v, Mask512 mask, Full512 /* tag */, T* HWY_RESTRICT unaligned) { _mm512_mask_compressstoreu_epi64(unaligned, mask.raw, v.raw); return PopCount(uint64_t{mask.raw}); } HWY_API size_t CompressStore(Vec512 v, Mask512 mask, Full512 /* tag */, float* HWY_RESTRICT unaligned) { _mm512_mask_compressstoreu_ps(unaligned, mask.raw, v.raw); return PopCount(uint64_t{mask.raw}); } HWY_API size_t CompressStore(Vec512 v, Mask512 mask, Full512 /* tag */, double* HWY_RESTRICT unaligned) { _mm512_mask_compressstoreu_pd(unaligned, mask.raw, v.raw); return PopCount(uint64_t{mask.raw}); } // ------------------------------ CompressBlendedStore template HWY_API size_t CompressBlendedStore(Vec512 v, Mask512 m, Full512 d, T* HWY_RESTRICT unaligned) { // AVX-512 already does the blending at no extra cost (latency 11, // rthroughput 2 - same as compress plus store). if (HWY_TARGET == HWY_AVX3_DL || sizeof(T) != 2) { return CompressStore(v, m, d, unaligned); } else { const size_t count = CountTrue(d, m); const Vec512 compressed = Compress(v, m); const Vec512 prev = LoadU(d, unaligned); StoreU(IfThenElse(FirstN(d, count), compressed, prev), d, unaligned); return count; } } // ------------------------------ CompressBitsStore template HWY_API size_t CompressBitsStore(Vec512 v, const uint8_t* HWY_RESTRICT bits, Full512 d, T* HWY_RESTRICT unaligned) { return CompressStore(v, LoadMaskBits(d, bits), d, unaligned); } // ------------------------------ StoreInterleaved3 (CombineShiftRightBytes, // TableLookupBytes) HWY_API void StoreInterleaved3(const Vec512 a, const Vec512 b, const Vec512 c, Full512 d, uint8_t* HWY_RESTRICT unaligned) { const auto k5 = Set(d, 5); const auto k6 = Set(d, 6); // Shuffle (a,b,c) vector bytes to (MSB on left): r5, bgr[4:0]. // 0x80 so lanes to be filled from other vectors are 0 for blending. alignas(16) static constexpr uint8_t tbl_r0[16] = { 0, 0x80, 0x80, 1, 0x80, 0x80, 2, 0x80, 0x80, // 3, 0x80, 0x80, 4, 0x80, 0x80, 5}; alignas(16) static constexpr uint8_t tbl_g0[16] = { 0x80, 0, 0x80, 0x80, 1, 0x80, // 0x80, 2, 0x80, 0x80, 3, 0x80, 0x80, 4, 0x80, 0x80}; const auto shuf_r0 = LoadDup128(d, tbl_r0); const auto shuf_g0 = LoadDup128(d, tbl_g0); // cannot reuse r0 due to 5 const auto shuf_b0 = CombineShiftRightBytes<15>(d, shuf_g0, shuf_g0); const auto r0 = TableLookupBytes(a, shuf_r0); // 5..4..3..2..1..0 const auto g0 = TableLookupBytes(b, shuf_g0); // ..4..3..2..1..0. const auto b0 = TableLookupBytes(c, shuf_b0); // .4..3..2..1..0.. const auto i = (r0 | g0 | b0).raw; // low byte in each 128bit: 30 20 10 00 // Second vector: g10,r10, bgr[9:6], b5,g5 const auto shuf_r1 = shuf_b0 + k6; // .A..9..8..7..6.. const auto shuf_g1 = shuf_r0 + k5; // A..9..8..7..6..5 const auto shuf_b1 = shuf_g0 + k5; // ..9..8..7..6..5. const auto r1 = TableLookupBytes(a, shuf_r1); const auto g1 = TableLookupBytes(b, shuf_g1); const auto b1 = TableLookupBytes(c, shuf_b1); const auto j = (r1 | g1 | b1).raw; // low byte in each 128bit: 35 25 15 05 // Third vector: bgr[15:11], b10 const auto shuf_r2 = shuf_b1 + k6; // ..F..E..D..C..B. const auto shuf_g2 = shuf_r1 + k5; // .F..E..D..C..B.. const auto shuf_b2 = shuf_g1 + k5; // F..E..D..C..B..A const auto r2 = TableLookupBytes(a, shuf_r2); const auto g2 = TableLookupBytes(b, shuf_g2); const auto b2 = TableLookupBytes(c, shuf_b2); const auto k = (r2 | g2 | b2).raw; // low byte in each 128bit: 3A 2A 1A 0A // To obtain 10 0A 05 00 in one vector, transpose "rows" into "columns". const auto k3_k0_i3_i0 = _mm512_shuffle_i64x2(i, k, _MM_PERM_DADA); const auto i1_i2_j0_j1 = _mm512_shuffle_i64x2(j, i, _MM_PERM_BCAB); const auto j2_j3_k1_k2 = _mm512_shuffle_i64x2(k, j, _MM_PERM_CDBC); // Alternating order, most-significant 128 bits from the second arg. const __mmask8 m = 0xCC; const auto i1_k0_j0_i0 = _mm512_mask_blend_epi64(m, k3_k0_i3_i0, i1_i2_j0_j1); const auto j2_i2_k1_j1 = _mm512_mask_blend_epi64(m, i1_i2_j0_j1, j2_j3_k1_k2); const auto k3_j3_i3_k2 = _mm512_mask_blend_epi64(m, j2_j3_k1_k2, k3_k0_i3_i0); StoreU(Vec512{i1_k0_j0_i0}, d, unaligned + 0 * 64); // 10 0A 05 00 StoreU(Vec512{j2_i2_k1_j1}, d, unaligned + 1 * 64); // 25 20 1A 15 StoreU(Vec512{k3_j3_i3_k2}, d, unaligned + 2 * 64); // 3A 35 30 2A } // ------------------------------ StoreInterleaved4 HWY_API void StoreInterleaved4(const Vec512 v0, const Vec512 v1, const Vec512 v2, const Vec512 v3, Full512 d8, uint8_t* HWY_RESTRICT unaligned) { const RepartitionToWide d16; const RepartitionToWide d32; // let a,b,c,d denote v0..3. const auto ba0 = ZipLower(d16, v0, v1); // b7 a7 .. b0 a0 const auto dc0 = ZipLower(d16, v2, v3); // d7 c7 .. d0 c0 const auto ba8 = ZipUpper(d16, v0, v1); const auto dc8 = ZipUpper(d16, v2, v3); const auto i = ZipLower(d32, ba0, dc0).raw; // 4x128bit: d..a3 d..a0 const auto j = ZipUpper(d32, ba0, dc0).raw; // 4x128bit: d..a7 d..a4 const auto k = ZipLower(d32, ba8, dc8).raw; // 4x128bit: d..aB d..a8 const auto l = ZipUpper(d32, ba8, dc8).raw; // 4x128bit: d..aF d..aC // 128-bit blocks were independent until now; transpose 4x4. const auto j1_j0_i1_i0 = _mm512_shuffle_i64x2(i, j, _MM_PERM_BABA); const auto l1_l0_k1_k0 = _mm512_shuffle_i64x2(k, l, _MM_PERM_BABA); const auto j3_j2_i3_i2 = _mm512_shuffle_i64x2(i, j, _MM_PERM_DCDC); const auto l3_l2_k3_k2 = _mm512_shuffle_i64x2(k, l, _MM_PERM_DCDC); constexpr _MM_PERM_ENUM k20 = _MM_PERM_CACA; constexpr _MM_PERM_ENUM k31 = _MM_PERM_DBDB; const auto l0_k0_j0_i0 = _mm512_shuffle_i64x2(j1_j0_i1_i0, l1_l0_k1_k0, k20); const auto l1_k1_j1_i1 = _mm512_shuffle_i64x2(j1_j0_i1_i0, l1_l0_k1_k0, k31); const auto l2_k2_j2_i2 = _mm512_shuffle_i64x2(j3_j2_i3_i2, l3_l2_k3_k2, k20); const auto l3_k3_j3_i3 = _mm512_shuffle_i64x2(j3_j2_i3_i2, l3_l2_k3_k2, k31); StoreU(Vec512{l0_k0_j0_i0}, d8, unaligned + 0 * 64); StoreU(Vec512{l1_k1_j1_i1}, d8, unaligned + 1 * 64); StoreU(Vec512{l2_k2_j2_i2}, d8, unaligned + 2 * 64); StoreU(Vec512{l3_k3_j3_i3}, d8, unaligned + 3 * 64); } // ------------------------------ MulEven/Odd (Shuffle2301, InterleaveLower) HWY_INLINE Vec512 MulEven(const Vec512 a, const Vec512 b) { const DFromV du64; const RepartitionToNarrow du32; const auto maskL = Set(du64, 0xFFFFFFFFULL); const auto a32 = BitCast(du32, a); const auto b32 = BitCast(du32, b); // Inputs for MulEven: we only need the lower 32 bits const auto aH = Shuffle2301(a32); const auto bH = Shuffle2301(b32); // Knuth double-word multiplication. We use 32x32 = 64 MulEven and only need // the even (lower 64 bits of every 128-bit block) results. See // https://github.com/hcs0/Hackers-Delight/blob/master/muldwu.c.tat const auto aLbL = MulEven(a32, b32); const auto w3 = aLbL & maskL; const auto t2 = MulEven(aH, b32) + ShiftRight<32>(aLbL); const auto w2 = t2 & maskL; const auto w1 = ShiftRight<32>(t2); const auto t = MulEven(a32, bH) + w2; const auto k = ShiftRight<32>(t); const auto mulH = MulEven(aH, bH) + w1 + k; const auto mulL = ShiftLeft<32>(t) + w3; return InterleaveLower(mulL, mulH); } HWY_INLINE Vec512 MulOdd(const Vec512 a, const Vec512 b) { const DFromV du64; const RepartitionToNarrow du32; const auto maskL = Set(du64, 0xFFFFFFFFULL); const auto a32 = BitCast(du32, a); const auto b32 = BitCast(du32, b); // Inputs for MulEven: we only need bits [95:64] (= upper half of input) const auto aH = Shuffle2301(a32); const auto bH = Shuffle2301(b32); // Same as above, but we're using the odd results (upper 64 bits per block). const auto aLbL = MulEven(a32, b32); const auto w3 = aLbL & maskL; const auto t2 = MulEven(aH, b32) + ShiftRight<32>(aLbL); const auto w2 = t2 & maskL; const auto w1 = ShiftRight<32>(t2); const auto t = MulEven(a32, bH) + w2; const auto k = ShiftRight<32>(t); const auto mulH = MulEven(aH, bH) + w1 + k; const auto mulL = ShiftLeft<32>(t) + w3; return InterleaveUpper(du64, mulL, mulH); } // ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower) HWY_API Vec512 ReorderWidenMulAccumulate(Full512 df32, Vec512 a, Vec512 b, const Vec512 sum0, Vec512& sum1) { // TODO(janwas): _mm512_dpbf16_ps when available const Repartition du16; const RebindToUnsigned du32; const Vec512 zero = Zero(du16); // Lane order within sum0/1 is undefined, hence we can avoid the // longer-latency lane-crossing PromoteTo. const Vec512 a0 = ZipLower(du32, zero, BitCast(du16, a)); const Vec512 a1 = ZipUpper(du32, zero, BitCast(du16, a)); const Vec512 b0 = ZipLower(du32, zero, BitCast(du16, b)); const Vec512 b1 = ZipUpper(du32, zero, BitCast(du16, b)); sum1 = MulAdd(BitCast(df32, a1), BitCast(df32, b1), sum1); return MulAdd(BitCast(df32, a0), BitCast(df32, b0), sum0); } // ------------------------------ Reductions // Returns the sum in each lane. HWY_API Vec512 SumOfLanes(Full512 d, Vec512 v) { return Set(d, _mm512_reduce_add_epi32(v.raw)); } HWY_API Vec512 SumOfLanes(Full512 d, Vec512 v) { return Set(d, _mm512_reduce_add_epi64(v.raw)); } HWY_API Vec512 SumOfLanes(Full512 d, Vec512 v) { return Set(d, static_cast(_mm512_reduce_add_epi32(v.raw))); } HWY_API Vec512 SumOfLanes(Full512 d, Vec512 v) { return Set(d, static_cast(_mm512_reduce_add_epi64(v.raw))); } HWY_API Vec512 SumOfLanes(Full512 d, Vec512 v) { return Set(d, _mm512_reduce_add_ps(v.raw)); } HWY_API Vec512 SumOfLanes(Full512 d, Vec512 v) { return Set(d, _mm512_reduce_add_pd(v.raw)); } // Returns the minimum in each lane. HWY_API Vec512 MinOfLanes(Full512 d, Vec512 v) { return Set(d, _mm512_reduce_min_epi32(v.raw)); } HWY_API Vec512 MinOfLanes(Full512 d, Vec512 v) { return Set(d, _mm512_reduce_min_epi64(v.raw)); } HWY_API Vec512 MinOfLanes(Full512 d, Vec512 v) { return Set(d, _mm512_reduce_min_epu32(v.raw)); } HWY_API Vec512 MinOfLanes(Full512 d, Vec512 v) { return Set(d, _mm512_reduce_min_epu64(v.raw)); } HWY_API Vec512 MinOfLanes(Full512 d, Vec512 v) { return Set(d, _mm512_reduce_min_ps(v.raw)); } HWY_API Vec512 MinOfLanes(Full512 d, Vec512 v) { return Set(d, _mm512_reduce_min_pd(v.raw)); } template HWY_API Vec512 MinOfLanes(Full512 d, Vec512 v) { const Repartition d32; const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); const auto odd = ShiftRight<16>(BitCast(d32, v)); const auto min = MinOfLanes(d32, Min(even, odd)); // Also broadcast into odd lanes. return BitCast(d, Or(min, ShiftLeft<16>(min))); } // Returns the maximum in each lane. HWY_API Vec512 MaxOfLanes(Full512 d, Vec512 v) { return Set(d, _mm512_reduce_max_epi32(v.raw)); } HWY_API Vec512 MaxOfLanes(Full512 d, Vec512 v) { return Set(d, _mm512_reduce_max_epi64(v.raw)); } HWY_API Vec512 MaxOfLanes(Full512 d, Vec512 v) { return Set(d, _mm512_reduce_max_epu32(v.raw)); } HWY_API Vec512 MaxOfLanes(Full512 d, Vec512 v) { return Set(d, _mm512_reduce_max_epu64(v.raw)); } HWY_API Vec512 MaxOfLanes(Full512 d, Vec512 v) { return Set(d, _mm512_reduce_max_ps(v.raw)); } HWY_API Vec512 MaxOfLanes(Full512 d, Vec512 v) { return Set(d, _mm512_reduce_max_pd(v.raw)); } template HWY_API Vec512 MaxOfLanes(Full512 d, Vec512 v) { const Repartition d32; const auto even = And(BitCast(d32, v), Set(d32, 0xFFFF)); const auto odd = ShiftRight<16>(BitCast(d32, v)); const auto min = MaxOfLanes(d32, Max(even, odd)); // Also broadcast into odd lanes. return BitCast(d, Or(min, ShiftLeft<16>(min))); } // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace hwy HWY_AFTER_NAMESPACE();