// Copyright Google LLC 2021 // Matthew Kolbe 2023 // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "hwy/base.h" // clang-format off #undef HWY_TARGET_INCLUDE #define HWY_TARGET_INCLUDE "hwy/contrib/unroller/unroller_test.cc" //NOLINT #include "hwy/foreach_target.h" // IWYU pragma: keep #include "hwy/highway.h" #include "hwy/contrib/unroller/unroller-inl.h" #include "hwy/tests/test_util-inl.h" // clang-format on HWY_BEFORE_NAMESPACE(); namespace hwy { namespace HWY_NAMESPACE { template T SimpleDot(const T* pa, const T* pb, size_t num) { T sum = 0; for (size_t i = 0; i < num; ++i) { // For reasons unknown, fp16 += does not compile on clang (Arm). sum = ConvertScalarTo(sum + pa[i] * pb[i]); } return sum; } template T SimpleAcc(const T* pa, size_t num) { T sum = 0; for (size_t i = 0; i < num; ++i) { sum += pa[i]; } return sum; } template T SimpleMin(const T* pa, size_t num) { T min = HighestValue(); for (size_t i = 0; i < num; ++i) { if (min > pa[i]) min = pa[i]; } return min; } template struct MultiplyUnit : UnrollerUnit2D, T, T, T> { using TT = hn::ScalableTag; HWY_INLINE hn::Vec Func(ptrdiff_t idx, const hn::Vec x0, const hn::Vec x1, const hn::Vec y) { (void)idx; (void)y; return hn::Mul(x0, x1); } }; template struct ConvertUnit : UnrollerUnit, FROM_T, TO_T> { using Base = UnrollerUnit, FROM_T, TO_T>; using Base::MaxUnitLanes; using typename Base::LargerD; using TT_FROM = hn::Rebind; using TT_TO = hn::Rebind; template < class ToD, class FromV, hwy::EnableIf<(sizeof(TFromV) > sizeof(TFromD))>* = nullptr> static HWY_INLINE hn::Vec DoConvertVector(ToD d, FromV v) { return hn::DemoteTo(d, v); } template < class ToD, class FromV, hwy::EnableIf<(sizeof(TFromV) == sizeof(TFromD))>* = nullptr> static HWY_INLINE hn::Vec DoConvertVector(ToD d, FromV v) { return hn::ConvertTo(d, v); } template < class ToD, class FromV, hwy::EnableIf<(sizeof(TFromV) < sizeof(TFromD))>* = nullptr> static HWY_INLINE hn::Vec DoConvertVector(ToD d, FromV v) { return hn::PromoteTo(d, v); } hn::Vec Func(ptrdiff_t idx, const hn::Vec x, const hn::Vec y) { (void)idx; (void)y; TT_TO d; return DoConvertVector(d, x); } }; // Returns a value that does not compare equal to `value`. template HWY_INLINE Vec OtherValue(D d, TFromD /*value*/) { return NaN(d); } template HWY_INLINE Vec OtherValue(D d, TFromD value) { return hn::Set(d, hwy::AddWithWraparound(value, 1)); } // Caveat: stores lane indices as MakeSigned, which may overflow for 8-bit T // on HWY_RVV. template struct FindUnit : UnrollerUnit, T, MakeSigned> { using TI = MakeSigned; using Base = UnrollerUnit, T, TI>; using Base::ActualLanes; using Base::MaxUnitLanes; using D = hn::CappedTag; T to_find; D d; using DI = RebindToSigned; DI di; FindUnit(T find) : to_find(find) {} hn::Vec Func(ptrdiff_t idx, const hn::Vec x, const hn::Vec y) { const Mask msk = hn::Eq(x, hn::Set(d, to_find)); const TI first_idx = static_cast(hn::FindFirstTrue(d, msk)); if (first_idx > -1) return hn::Set(di, static_cast(static_cast(idx) + first_idx)); else return y; } hn::Vec X0InitImpl() { return OtherValue(D(), to_find); } hn::Vec YInitImpl() { return hn::Set(di, TI{-1}); } hn::Vec MaskLoadImpl(const ptrdiff_t idx, T* from, const ptrdiff_t places) { auto mask = hn::FirstN(d, static_cast(places)); auto maskneg = hn::Not(hn::FirstN( d, static_cast(places + static_cast(ActualLanes())))); if (places < 0) mask = maskneg; return hn::IfThenElse(mask, hn::MaskedLoad(mask, d, from + idx), X0InitImpl()); } bool StoreAndShortCircuitImpl(const ptrdiff_t idx, TI* to, const hn::Vec x) { (void)idx; TI a = hn::GetLane(x); to[0] = a; if (a == -1) return true; return false; } ptrdiff_t MaskStoreImpl(const ptrdiff_t idx, TI* to, const hn::Vec x, const ptrdiff_t places) { (void)idx; (void)places; TI a = hn::GetLane(x); to[0] = a; return 1; } }; template struct AccumulateUnit : UnrollerUnit, T, T> { using TT = hn::ScalableTag; hn::Vec Func(ptrdiff_t idx, const hn::Vec x, const hn::Vec y) { (void)idx; return hn::Add(x, y); } bool StoreAndShortCircuitImpl(const ptrdiff_t idx, T* to, const hn::Vec x) { // no stores in a reducer (void)idx; (void)to; (void)x; return true; } ptrdiff_t MaskStoreImpl(const ptrdiff_t idx, T* to, const hn::Vec x, const ptrdiff_t places) { // no stores in a reducer (void)idx; (void)to; (void)x; (void)places; return 0; } ptrdiff_t ReduceImpl(const hn::Vec x, T* to) { const hn::ScalableTag d; (*to) = hn::ReduceSum(d, x); return 1; } void ReduceImpl(const hn::Vec x0, const hn::Vec x1, const hn::Vec x2, hn::Vec* y) { (*y) = hn::Add(hn::Add(*y, x0), hn::Add(x1, x2)); } }; template struct MinUnit : UnrollerUnit, T, T> { using Base = UnrollerUnit, T, T>; using Base::ActualLanes; using TT = hn::ScalableTag; TT d; hn::Vec Func(const ptrdiff_t idx, const hn::Vec x, const hn::Vec y) { (void)idx; return hn::Min(y, x); } hn::Vec YInitImpl() { return hn::Set(d, HighestValue()); } hn::Vec MaskLoadImpl(const ptrdiff_t idx, T* from, const ptrdiff_t places) { auto mask = hn::FirstN(d, static_cast(places)); auto maskneg = hn::Not(hn::FirstN( d, static_cast(places + static_cast(ActualLanes())))); if (places < 0) mask = maskneg; auto def = YInitImpl(); return hn::MaskedLoadOr(def, mask, d, from + idx); } bool StoreAndShortCircuitImpl(const ptrdiff_t idx, T* to, const hn::Vec x) { // no stores in a reducer (void)idx; (void)to; (void)x; return true; } ptrdiff_t MaskStoreImpl(const ptrdiff_t idx, T* to, const hn::Vec x, const ptrdiff_t places) { // no stores in a reducer (void)idx; (void)to; (void)x; (void)places; return 0; } ptrdiff_t ReduceImpl(const hn::Vec x, T* to) { const hn::ScalableTag d; auto minvect = hn::MinOfLanes(d, x); (*to) = hn::ExtractLane(minvect, 0); return 1; } void ReduceImpl(const hn::Vec x0, const hn::Vec x1, const hn::Vec x2, hn::Vec* y) { auto a = hn::Min(x1, x0); auto b = hn::Min(*y, x2); (*y) = hn::Min(a, b); } }; template struct DotUnit : UnrollerUnit2D, T, T, T> { using TT = hn::ScalableTag; hn::Vec Func(const ptrdiff_t idx, const hn::Vec x0, const hn::Vec x1, const hn::Vec y) { (void)idx; return hn::MulAdd(x0, x1, y); } bool StoreAndShortCircuitImpl(const ptrdiff_t idx, T* to, const hn::Vec x) { // no stores in a reducer (void)idx; (void)to; (void)x; return true; } ptrdiff_t MaskStoreImpl(const ptrdiff_t idx, T* to, const hn::Vec x, const ptrdiff_t places) { // no stores in a reducer (void)idx; (void)to; (void)x; (void)places; return 0; } ptrdiff_t ReduceImpl(const hn::Vec x, T* to) { const hn::ScalableTag d; (*to) = hn::ReduceSum(d, x); return 1; } void ReduceImpl(const hn::Vec x0, const hn::Vec x1, const hn::Vec x2, hn::Vec* y) { (*y) = hn::Add(hn::Add(*y, x0), hn::Add(x1, x2)); } }; template std::vector Counts(D d) { const size_t N = Lanes(d); return std::vector{1, 3, 7, 16, HWY_MAX(N / 2, 1), HWY_MAX(2 * N / 3, 1), N, N + 1, 4 * N / 3, 3 * N, 8 * N, 8 * N + 2, 256 * N - 1, 256 * N}; } struct TestDot { template HWY_NOINLINE void operator()(T /*unused*/, D d) { // TODO(janwas): avoid internal compiler error #if HWY_TARGET == HWY_SVE || HWY_TARGET == HWY_SVE2 || HWY_COMPILER_MSVC (void)d; #else RandomState rng; const auto random_t = [&rng]() { const int32_t bits = static_cast(Random32(&rng)) & 1023; return static_cast(bits - 512) * (1.0f / 64); }; for (size_t num : Counts(d)) { AlignedFreeUniquePtr pa = AllocateAligned(num); AlignedFreeUniquePtr pb = AllocateAligned(num); AlignedFreeUniquePtr py = AllocateAligned(num); HWY_ASSERT(pa && pb && py); T* a = pa.get(); T* b = pb.get(); T* y = py.get(); size_t i = 0; for (; i < num; ++i) { a[i] = ConvertScalarTo(random_t()); b[i] = ConvertScalarTo(random_t()); } const T expected_dot = SimpleDot(a, b, num); MultiplyUnit multfn; Unroller(multfn, a, b, y, static_cast(num)); AccumulateUnit accfn; T dot_via_mul_acc; Unroller(accfn, y, &dot_via_mul_acc, static_cast(num)); const double tolerance = 32.0 * ConvertScalarTo(hwy::Epsilon()) * ScalarAbs(expected_dot); HWY_ASSERT(ScalarAbs(expected_dot - dot_via_mul_acc) < tolerance); DotUnit dotfn; T dotr; Unroller(dotfn, a, b, &dotr, static_cast(num)); HWY_ASSERT(ConvertScalarTo(ScalarAbs((expected_dot - dotr))) < tolerance); auto expected_min = SimpleMin(a, num); MinUnit minfn; T minr; Unroller(minfn, a, &minr, static_cast(num)); HWY_ASSERT(ConvertScalarTo(ScalarAbs(expected_min - minr)) < 1e-7); } #endif } }; void TestAllDot() { ForFloatTypes(ForPartialVectors()); } struct TestConvert { template HWY_NOINLINE void operator()(T /*unused*/, D d) { // TODO(janwas): avoid internal compiler error #if HWY_TARGET == HWY_SVE || HWY_TARGET == HWY_SVE2 || HWY_COMPILER_MSVC (void)d; #else for (size_t num : Counts(d)) { AlignedFreeUniquePtr pa = AllocateAligned(num); AlignedFreeUniquePtr pto = AllocateAligned(num); HWY_ASSERT(pa && pto); T* HWY_RESTRICT a = pa.get(); int* HWY_RESTRICT to = pto.get(); for (size_t i = 0; i < num; ++i) { a[i] = ConvertScalarTo(static_cast(i) * 0.25); } ConvertUnit cvtfn; Unroller(cvtfn, a, to, static_cast(num)); for (size_t i = 0; i < num; ++i) { // TODO(janwas): RVV QEMU fcvt_rtz appears to 'truncate' 4.75 to 5. HWY_ASSERT( static_cast(a[i]) == to[i] || (HWY_TARGET == HWY_RVV && static_cast(a[i]) == to[i] - 1)); } ConvertUnit cvtbackfn; Unroller(cvtbackfn, to, a, static_cast(num)); for (size_t i = 0; i < num; ++i) { HWY_ASSERT_EQ(ConvertScalarTo(to[i]), a[i]); } } #endif } }; void TestAllConvert() { ForFloat3264Types(ForPartialVectors()); } struct TestFind { template HWY_NOINLINE void operator()(T /*unused*/, D d) { for (size_t num : Counts(d)) { AlignedFreeUniquePtr pa = AllocateAligned(num); HWY_ASSERT(pa); T* a = pa.get(); for (size_t i = 0; i < num; ++i) a[i] = ConvertScalarTo(i); FindUnit cvtfn(ConvertScalarTo(num - 1)); MakeSigned idx = 0; Unroller(cvtfn, a, &idx, static_cast(num)); HWY_ASSERT(static_cast>(idx) < num); HWY_ASSERT(a[idx] == ConvertScalarTo(num - 1)); FindUnit cvtfnzero((T)(0)); Unroller(cvtfnzero, a, &idx, static_cast(num)); HWY_ASSERT(static_cast>(idx) < num); HWY_ASSERT(a[idx] == (T)(0)); // For f16, we cannot search for `num` because it may round to a value // that is actually in the (large) array. FindUnit cvtfnnotin(HighestValue()); Unroller(cvtfnnotin, a, &idx, static_cast(num)); HWY_ASSERT(idx == -1); } } }; void TestAllFind() { ForFloatTypes(ForPartialVectors()); } } // namespace HWY_NAMESPACE } // namespace hwy HWY_AFTER_NAMESPACE(); #if HWY_ONCE namespace hwy { HWY_BEFORE_TEST(UnrollerTest); HWY_EXPORT_AND_TEST_P(UnrollerTest, TestAllDot); HWY_EXPORT_AND_TEST_P(UnrollerTest, TestAllConvert); HWY_EXPORT_AND_TEST_P(UnrollerTest, TestAllFind); } // namespace hwy #endif