// Copyright 2021 Google LLC // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include "hwy/aligned_allocator.h" #include "hwy/base.h" // clang-format off #undef HWY_TARGET_INCLUDE #define HWY_TARGET_INCLUDE "hwy/contrib/dot/dot_test.cc" #include "hwy/foreach_target.h" // IWYU pragma: keep #include "hwy/highway.h" #include "hwy/contrib/dot/dot-inl.h" #include "hwy/tests/test_util-inl.h" // clang-format on HWY_BEFORE_NAMESPACE(); namespace hwy { namespace HWY_NAMESPACE { template HWY_NOINLINE T1 SimpleDot(const T1* pa, const T2* pb, size_t num) { float sum = 0.0f; for (size_t i = 0; i < num; ++i) { sum += ConvertScalarTo(pa[i]) * ConvertScalarTo(pb[i]); } return ConvertScalarTo(sum); } HWY_NOINLINE float SimpleDot(const float* pa, const hwy::bfloat16_t* pb, size_t num) { float sum = 0.0f; for (size_t i = 0; i < num; ++i) { sum += pa[i] * F32FromBF16(pb[i]); } return sum; } // Overload is required because the generic template hits an internal compiler // error on aarch64 clang. HWY_NOINLINE float SimpleDot(const bfloat16_t* pa, const bfloat16_t* pb, size_t num) { float sum = 0.0f; for (size_t i = 0; i < num; ++i) { sum += F32FromBF16(pa[i]) * F32FromBF16(pb[i]); } return sum; } class TestDot { // Computes/verifies one dot product. template void Test(D d, size_t num, size_t misalign_a, size_t misalign_b, RandomState& rng) { using T = TFromD; const size_t N = Lanes(d); const auto random_t = [&rng]() { const int32_t bits = static_cast(Random32(&rng)) & 1023; return static_cast(bits - 512) * (1.0f / 64); }; const size_t padded = (kAssumptions & Dot::kPaddedToVector) ? RoundUpTo(num, N) : num; AlignedFreeUniquePtr pa = AllocateAligned(misalign_a + padded); AlignedFreeUniquePtr pb = AllocateAligned(misalign_b + padded); HWY_ASSERT(pa && pb); T* a = pa.get() + misalign_a; T* b = pb.get() + misalign_b; size_t i = 0; for (; i < num; ++i) { a[i] = ConvertScalarTo(random_t()); b[i] = ConvertScalarTo(random_t()); } // Fill padding with NaN - the values are not used, but avoids MSAN errors. for (; i < padded; ++i) { ScalableTag df1; a[i] = ConvertScalarTo(GetLane(NaN(df1))); b[i] = ConvertScalarTo(GetLane(NaN(df1))); } const double expected = SimpleDot(a, b, num); const double magnitude = expected > 0.0 ? expected : -expected; const double actual = ConvertScalarTo(Dot::Compute(d, a, b, num)); const double max = static_cast(8 * 8 * num); HWY_ASSERT(-max <= actual && actual <= max); const double tolerance = 64.0 * ConvertScalarTo(Epsilon()) * HWY_MAX(magnitude, 1.0); HWY_ASSERT(expected - tolerance <= actual && actual <= expected + tolerance); } // Runs tests with various alignments. template void ForeachMisalign(D d, size_t num, RandomState& rng) { const size_t N = Lanes(d); const size_t misalignments[3] = {0, N / 4, 3 * N / 5}; for (size_t ma : misalignments) { for (size_t mb : misalignments) { Test(d, num, ma, mb, rng); } } } // Runs tests with various lengths compatible with the given assumptions. template void ForeachCount(D d, RandomState& rng) { const size_t N = Lanes(d); const size_t counts[] = {1, 3, 7, 16, HWY_MAX(N / 2, 1), HWY_MAX(2 * N / 3, 1), N, N + 1, 4 * N / 3, 3 * N, 8 * N, 8 * N + 2}; for (size_t num : counts) { if ((kAssumptions & Dot::kAtLeastOneVector) && num < N) continue; if ((kAssumptions & Dot::kMultipleOfVector) && (num % N) != 0) continue; ForeachMisalign(d, num, rng); } } public: // Must be inlined on aarch64 for bf16, else clang crashes. template HWY_INLINE void operator()(T /*unused*/, D d) { RandomState rng; // All 8 combinations of the three length-related flags: ForeachCount<0>(d, rng); ForeachCount(d, rng); ForeachCount(d, rng); ForeachCount(d, rng); ForeachCount(d, rng); ForeachCount(d, rng); ForeachCount(d, rng); ForeachCount(d, rng); } }; class TestDotF32BF16 { // Computes/verifies one dot product. template void Test(D d, size_t num, size_t misalign_a, size_t misalign_b, RandomState& rng) { using T = TFromD; using T2 = hwy::bfloat16_t; const size_t N = Lanes(d); const auto random_t = [&rng]() { const int32_t bits = static_cast(Random32(&rng)) & 1023; return static_cast(bits - 512) * (1.0f / 64); }; const size_t padded = (kAssumptions & Dot::kPaddedToVector) ? RoundUpTo(num, N) : num; AlignedFreeUniquePtr pa = AllocateAligned(misalign_a + padded); AlignedFreeUniquePtr pb = AllocateAligned(misalign_b + padded); HWY_ASSERT(pa && pb); T* a = pa.get() + misalign_a; T2* b = pb.get() + misalign_b; size_t i = 0; for (; i < num; ++i) { a[i] = ConvertScalarTo(random_t()); b[i] = ConvertScalarTo(random_t()); } // Fill padding with NaN - the values are not used, but avoids MSAN errors. for (; i < padded; ++i) { ScalableTag df1; a[i] = ConvertScalarTo(GetLane(NaN(df1))); b[i] = ConvertScalarTo(GetLane(NaN(df1))); } const double expected = SimpleDot(a, b, num); const double magnitude = expected > 0.0 ? expected : -expected; const double actual = ConvertScalarTo(Dot::Compute(d, a, b, num)); const double max = static_cast(8 * 8 * num); HWY_ASSERT(-max <= actual && actual <= max); const double tolerance = 64.0 * ConvertScalarTo(Epsilon()) * HWY_MAX(magnitude, 1.0); HWY_ASSERT(expected - tolerance <= actual && actual <= expected + tolerance); } // Runs tests with various alignments. template void ForeachMisalign(D d, size_t num, RandomState& rng) { const size_t N = Lanes(d); const size_t misalignments[3] = {0, N / 4, 3 * N / 5}; for (size_t ma : misalignments) { for (size_t mb : misalignments) { Test(d, num, ma, mb, rng); } } } // Runs tests with various lengths compatible with the given assumptions. template void ForeachCount(D d, RandomState& rng) { const size_t N = Lanes(d); const size_t counts[] = {1, 3, 7, 16, HWY_MAX(N / 2, 1), HWY_MAX(2 * N / 3, 1), N, N + 1, 4 * N / 3, 3 * N, 8 * N, 8 * N + 2}; for (size_t num : counts) { if ((kAssumptions & Dot::kAtLeastOneVector) && num < N) continue; if ((kAssumptions & Dot::kMultipleOfVector) && (num % N) != 0) continue; ForeachMisalign(d, num, rng); } } public: // Must be inlined on aarch64 for bf16, else clang crashes. template HWY_INLINE void operator()(T /*unused*/, D d) { RandomState rng; // All 8 combinations of the three length-related flags: ForeachCount<0>(d, rng); ForeachCount(d, rng); ForeachCount(d, rng); ForeachCount(d, rng); ForeachCount(d, rng); ForeachCount(d, rng); ForeachCount(d, rng); ForeachCount(d, rng); } }; // All floating-point types, both arguments same. void TestAllDot() { ForFloatTypes(ForPartialVectors()); } // Mixed f32 and bf16. void TestAllDotF32BF16() { ForPartialVectors test; test(float()); } // Both bf16. void TestAllDotBF16() { ForShrinkableVectors()(bfloat16_t()); } // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace hwy HWY_AFTER_NAMESPACE(); #if HWY_ONCE namespace hwy { HWY_BEFORE_TEST(DotTest); HWY_EXPORT_AND_TEST_P(DotTest, TestAllDot); HWY_EXPORT_AND_TEST_P(DotTest, TestAllDotF32BF16); HWY_EXPORT_AND_TEST_P(DotTest, TestAllDotBF16); } // namespace hwy #endif