// Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Per-target #if defined(HIGHWAY_HWY_CONTRIB_SORT_TRAITS_TOGGLE) == \ defined(HWY_TARGET_TOGGLE) #ifdef HIGHWAY_HWY_CONTRIB_SORT_TRAITS_TOGGLE #undef HIGHWAY_HWY_CONTRIB_SORT_TRAITS_TOGGLE #else #define HIGHWAY_HWY_CONTRIB_SORT_TRAITS_TOGGLE #endif #include "hwy/contrib/sort/disabled_targets.h" #include "hwy/contrib/sort/shared-inl.h" // SortConstants #include "hwy/contrib/sort/vqsort.h" // SortDescending #include "hwy/highway.h" HWY_BEFORE_NAMESPACE(); namespace hwy { namespace HWY_NAMESPACE { namespace detail { // Highway does not provide a lane type for 128-bit keys, so we use uint64_t // along with an abstraction layer for single-lane vs. lane-pair, which is // independent of the order. struct KeyLane { constexpr size_t LanesPerKey() const { return 1; } // For HeapSort template HWY_INLINE void Swap(T* a, T* b) const { const T temp = *a; *a = *b; *b = temp; } // Broadcasts one key into a vector template HWY_INLINE Vec SetKey(D d, const TFromD* key) const { return Set(d, *key); } template HWY_INLINE Vec ReverseKeys(D d, Vec v) const { return Reverse(d, v); } template HWY_INLINE Vec ReverseKeys2(D d, Vec v) const { return Reverse2(d, v); } template HWY_INLINE Vec ReverseKeys4(D d, Vec v) const { return Reverse4(d, v); } template HWY_INLINE Vec ReverseKeys8(D d, Vec v) const { return Reverse8(d, v); } template HWY_INLINE Vec ReverseKeys16(D d, Vec v) const { static_assert(SortConstants::kMaxCols <= 16, "Assumes u32x16 = 512 bit"); return ReverseKeys(d, v); } template HWY_INLINE V OddEvenKeys(const V odd, const V even) const { return OddEven(odd, even); } template HWY_INLINE Vec SwapAdjacentPairs(D d, const Vec v) const { const Repartition du32; return BitCast(d, Shuffle2301(BitCast(du32, v))); } template HWY_INLINE Vec SwapAdjacentPairs(D /* tag */, const Vec v) const { return Shuffle1032(v); } template HWY_INLINE Vec SwapAdjacentPairs(D /* tag */, const Vec v) const { return SwapAdjacentBlocks(v); } template HWY_INLINE Vec SwapAdjacentQuads(D d, const Vec v) const { #if HWY_HAVE_FLOAT64 // in case D is float32 const RepartitionToWide dw; #else const RepartitionToWide> dw; #endif return BitCast(d, SwapAdjacentPairs(dw, BitCast(dw, v))); } template HWY_INLINE Vec SwapAdjacentQuads(D d, const Vec v) const { // Assumes max vector size = 512 return ConcatLowerUpper(d, v, v); } template HWY_INLINE Vec OddEvenPairs(D d, const Vec odd, const Vec even) const { #if HWY_HAVE_FLOAT64 // in case D is float32 const RepartitionToWide dw; #else const RepartitionToWide> dw; #endif return BitCast(d, OddEven(BitCast(dw, odd), BitCast(dw, even))); } template HWY_INLINE Vec OddEvenPairs(D /* tag */, Vec odd, Vec even) const { return OddEvenBlocks(odd, even); } template HWY_INLINE Vec OddEvenQuads(D d, Vec odd, Vec even) const { #if HWY_HAVE_FLOAT64 // in case D is float32 const RepartitionToWide dw; #else const RepartitionToWide> dw; #endif return BitCast(d, OddEvenPairs(dw, BitCast(dw, odd), BitCast(dw, even))); } template HWY_INLINE Vec OddEvenQuads(D d, Vec odd, Vec even) const { return ConcatUpperLower(d, odd, even); } }; // Anything order-related depends on the key traits *and* the order (see // FirstOfLanes). We cannot implement just one Compare function because Lt128 // only compiles if the lane type is u64. Thus we need either overloaded // functions with a tag type, class specializations, or separate classes. // We avoid overloaded functions because we want all functions to be callable // from a SortTraits without per-function wrappers. Specializing would work, but // we are anyway going to specialize at a higher level. struct OrderAscending : public KeyLane { using Order = SortAscending; template HWY_INLINE bool Compare1(const T* a, const T* b) { return *a < *b; } template HWY_INLINE Mask Compare(D /* tag */, Vec a, Vec b) const { return Lt(a, b); } // Two halves of Sort2, used in ScanMinMax. template HWY_INLINE Vec First(D /* tag */, const Vec a, const Vec b) const { return Min(a, b); } template HWY_INLINE Vec Last(D /* tag */, const Vec a, const Vec b) const { return Max(a, b); } template HWY_INLINE Vec FirstOfLanes(D d, Vec v, TFromD* HWY_RESTRICT /* buf */) const { return MinOfLanes(d, v); } template HWY_INLINE Vec LastOfLanes(D d, Vec v, TFromD* HWY_RESTRICT /* buf */) const { return MaxOfLanes(d, v); } template HWY_INLINE Vec FirstValue(D d) const { return Set(d, hwy::LowestValue>()); } template HWY_INLINE Vec LastValue(D d) const { return Set(d, hwy::HighestValue>()); } }; struct OrderDescending : public KeyLane { using Order = SortDescending; template HWY_INLINE bool Compare1(const T* a, const T* b) { return *b < *a; } template HWY_INLINE Mask Compare(D /* tag */, Vec a, Vec b) const { return Lt(b, a); } template HWY_INLINE Vec First(D /* tag */, const Vec a, const Vec b) const { return Max(a, b); } template HWY_INLINE Vec Last(D /* tag */, const Vec a, const Vec b) const { return Min(a, b); } template HWY_INLINE Vec FirstOfLanes(D d, Vec v, TFromD* HWY_RESTRICT /* buf */) const { return MaxOfLanes(d, v); } template HWY_INLINE Vec LastOfLanes(D d, Vec v, TFromD* HWY_RESTRICT /* buf */) const { return MinOfLanes(d, v); } template HWY_INLINE Vec FirstValue(D d) const { return Set(d, hwy::HighestValue>()); } template HWY_INLINE Vec LastValue(D d) const { return Set(d, hwy::LowestValue>()); } }; // Shared code that depends on Order. template struct LaneTraits : public Base { constexpr bool Is128() const { return false; } // For each lane i: replaces a[i] with the first and b[i] with the second // according to Base. // Corresponds to a conditional swap, which is one "node" of a sorting // network. Min/Max are cheaper than compare + blend at least for integers. template HWY_INLINE void Sort2(D d, Vec& a, Vec& b) const { const Base* base = static_cast(this); const Vec a_copy = a; // Prior to AVX3, there is no native 64-bit Min/Max, so they compile to 4 // instructions. We can reduce it to a compare + 2 IfThenElse. #if HWY_AVX3 < HWY_TARGET && HWY_TARGET <= HWY_SSSE3 if (sizeof(TFromD) == 8) { const Mask cmp = base->Compare(d, a, b); a = IfThenElse(cmp, a, b); b = IfThenElse(cmp, b, a_copy); return; } #endif a = base->First(d, a, b); b = base->Last(d, a_copy, b); } // Conditionally swaps even-numbered lanes with their odd-numbered neighbor. template HWY_INLINE Vec SortPairsDistance1(D d, Vec v) const { const Base* base = static_cast(this); Vec swapped = base->ReverseKeys2(d, v); // Further to the above optimization, Sort2+OddEvenKeys compile to four // instructions; we can save one by combining two blends. #if HWY_AVX3 < HWY_TARGET && HWY_TARGET <= HWY_SSSE3 const Vec cmp = VecFromMask(d, base->Compare(d, v, swapped)); return IfVecThenElse(DupOdd(cmp), swapped, v); #else Sort2(d, v, swapped); return base->OddEvenKeys(swapped, v); #endif } // (See above - we use Sort2 for non-64-bit types.) template HWY_INLINE Vec SortPairsDistance1(D d, Vec v) const { const Base* base = static_cast(this); Vec swapped = base->ReverseKeys2(d, v); Sort2(d, v, swapped); return base->OddEvenKeys(swapped, v); } // Swaps with the vector formed by reversing contiguous groups of 4 keys. template HWY_INLINE Vec SortPairsReverse4(D d, Vec v) const { const Base* base = static_cast(this); Vec swapped = base->ReverseKeys4(d, v); Sort2(d, v, swapped); return base->OddEvenPairs(d, swapped, v); } // Conditionally swaps lane 0 with 4, 1 with 5 etc. template HWY_INLINE Vec SortPairsDistance4(D d, Vec v) const { const Base* base = static_cast(this); Vec swapped = base->SwapAdjacentQuads(d, v); // Only used in Merge16, so this will not be used on AVX2 (which only has 4 // u64 lanes), so skip the above optimization for 64-bit AVX2. Sort2(d, v, swapped); return base->OddEvenQuads(d, swapped, v); } }; } // namespace detail // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace hwy HWY_AFTER_NAMESPACE(); #endif // HIGHWAY_HWY_CONTRIB_SORT_TRAITS_TOGGLE