// Copyright 2023 Matthew Kolbe // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #if defined(HIGHWAY_HWY_CONTRIB_UNROLLER_UNROLLER_INL_H_) == \ defined(HWY_TARGET_TOGGLE) #ifdef HIGHWAY_HWY_CONTRIB_UNROLLER_UNROLLER_INL_H_ #undef HIGHWAY_HWY_CONTRIB_UNROLLER_UNROLLER_INL_H_ #else #define HIGHWAY_HWY_CONTRIB_UNROLLER_UNROLLER_INL_H_ #endif #include // std::abs #include "hwy/highway.h" HWY_BEFORE_NAMESPACE(); namespace hwy { namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; template struct UnrollerUnit { static constexpr size_t kMaxTSize = HWY_MAX(sizeof(IN_T), sizeof(OUT_T)); using LargerT = SignedFromSize; // only the size matters. DERIVED* me() { return static_cast(this); } static constexpr size_t MaxUnitLanes() { return HWY_MAX_LANES_D(hn::ScalableTag); } static size_t ActualLanes() { return Lanes(hn::ScalableTag()); } using LargerD = hn::CappedTag; using IT = hn::Rebind; using OT = hn::Rebind; IT d_in; OT d_out; using Y_VEC = hn::Vec; using X_VEC = hn::Vec; Y_VEC Func(const ptrdiff_t idx, const X_VEC x, const Y_VEC y) { return me()->Func(idx, x, y); } X_VEC X0Init() { return me()->X0InitImpl(); } X_VEC X0InitImpl() { return hn::Zero(d_in); } Y_VEC YInit() { return me()->YInitImpl(); } Y_VEC YInitImpl() { return hn::Zero(d_out); } X_VEC Load(const ptrdiff_t idx, IN_T* from) { return me()->LoadImpl(idx, from); } X_VEC LoadImpl(const ptrdiff_t idx, IN_T* from) { return hn::LoadU(d_in, from + idx); } // MaskLoad can take in either a positive or negative number for `places`. if // the number is positive, then it loads the top `places` values, and if it's // negative, it loads the bottom |places| values. example: places = 3 // | o | o | o | x | x | x | x | x | // example places = -3 // | x | x | x | x | x | o | o | o | X_VEC MaskLoad(const ptrdiff_t idx, IN_T* from, const ptrdiff_t places) { return me()->MaskLoadImpl(idx, from, places); } X_VEC MaskLoadImpl(const ptrdiff_t idx, IN_T* from, const ptrdiff_t places) { auto mask = hn::FirstN(d_in, static_cast(places)); auto maskneg = hn::Not(hn::FirstN( d_in, static_cast(places + static_cast(ActualLanes())))); if (places < 0) mask = maskneg; return hn::MaskedLoad(mask, d_in, from + idx); } bool StoreAndShortCircuit(const ptrdiff_t idx, OUT_T* to, const Y_VEC x) { return me()->StoreAndShortCircuitImpl(idx, to, x); } bool StoreAndShortCircuitImpl(const ptrdiff_t idx, OUT_T* to, const Y_VEC x) { hn::StoreU(x, d_out, to + idx); return true; } ptrdiff_t MaskStore(const ptrdiff_t idx, OUT_T* to, const Y_VEC x, ptrdiff_t const places) { return me()->MaskStoreImpl(idx, to, x, places); } ptrdiff_t MaskStoreImpl(const ptrdiff_t idx, OUT_T* to, const Y_VEC x, const ptrdiff_t places) { auto mask = hn::FirstN(d_out, static_cast(places)); auto maskneg = hn::Not(hn::FirstN( d_out, static_cast(places + static_cast(ActualLanes())))); if (places < 0) mask = maskneg; hn::BlendedStore(x, mask, d_out, to + idx); return std::abs(places); } ptrdiff_t Reduce(const Y_VEC x, OUT_T* to) { return me()->ReduceImpl(x, to); } ptrdiff_t ReduceImpl(const Y_VEC x, OUT_T* to) { // default does nothing (void)x; (void)to; return 0; } void Reduce(const Y_VEC x0, const Y_VEC x1, const Y_VEC x2, Y_VEC* y) { me()->ReduceImpl(x0, x1, x2, y); } void ReduceImpl(const Y_VEC x0, const Y_VEC x1, const Y_VEC x2, Y_VEC* y) { // default does nothing (void)x0; (void)x1; (void)x2; (void)y; } }; template struct UnrollerUnit2D { DERIVED* me() { return static_cast(this); } static constexpr size_t kMaxTSize = HWY_MAX(sizeof(IN0_T), HWY_MAX(sizeof(IN1_T), sizeof(OUT_T))); using LargerT = SignedFromSize; // only the size matters. static constexpr size_t MaxUnitLanes() { return HWY_MAX_LANES_D(hn::ScalableTag); } static size_t ActualLanes() { return Lanes(hn::ScalableTag()); } using LargerD = hn::CappedTag; using I0T = hn::Rebind; using I1T = hn::Rebind; using OT = hn::Rebind; I0T d_in0; I1T d_in1; OT d_out; using Y_VEC = hn::Vec; using X0_VEC = hn::Vec; using X1_VEC = hn::Vec; hn::Vec Func(const ptrdiff_t idx, const hn::Vec x0, const hn::Vec x1, const Y_VEC y) { return me()->Func(idx, x0, x1, y); } X0_VEC X0Init() { return me()->X0InitImpl(); } X0_VEC X0InitImpl() { return hn::Zero(d_in0); } X1_VEC X1Init() { return me()->X1InitImpl(); } X1_VEC X1InitImpl() { return hn::Zero(d_in1); } Y_VEC YInit() { return me()->YInitImpl(); } Y_VEC YInitImpl() { return hn::Zero(d_out); } X0_VEC Load0(const ptrdiff_t idx, IN0_T* from) { return me()->Load0Impl(idx, from); } X0_VEC Load0Impl(const ptrdiff_t idx, IN0_T* from) { return hn::LoadU(d_in0, from + idx); } X1_VEC Load1(const ptrdiff_t idx, IN1_T* from) { return me()->Load1Impl(idx, from); } X1_VEC Load1Impl(const ptrdiff_t idx, IN1_T* from) { return hn::LoadU(d_in1, from + idx); } // maskload can take in either a positive or negative number for `places`. if // the number is positive, then it loads the top `places` values, and if it's // negative, it loads the bottom |places| values. example: places = 3 // | o | o | o | x | x | x | x | x | // example places = -3 // | x | x | x | x | x | o | o | o | X0_VEC MaskLoad0(const ptrdiff_t idx, IN0_T* from, const ptrdiff_t places) { return me()->MaskLoad0Impl(idx, from, places); } X0_VEC MaskLoad0Impl(const ptrdiff_t idx, IN0_T* from, const ptrdiff_t places) { auto mask = hn::FirstN(d_in0, static_cast(places)); auto maskneg = hn::Not(hn::FirstN( d_in0, static_cast(places + static_cast(ActualLanes())))); if (places < 0) mask = maskneg; return hn::MaskedLoad(mask, d_in0, from + idx); } hn::Vec MaskLoad1(const ptrdiff_t idx, IN1_T* from, const ptrdiff_t places) { return me()->MaskLoad1Impl(idx, from, places); } hn::Vec MaskLoad1Impl(const ptrdiff_t idx, IN1_T* from, const ptrdiff_t places) { auto mask = hn::FirstN(d_in1, static_cast(places)); auto maskneg = hn::Not(hn::FirstN( d_in1, static_cast(places + static_cast(ActualLanes())))); if (places < 0) mask = maskneg; return hn::MaskedLoad(mask, d_in1, from + idx); } // store returns a bool that is `false` when bool StoreAndShortCircuit(const ptrdiff_t idx, OUT_T* to, const Y_VEC x) { return me()->StoreAndShortCircuitImpl(idx, to, x); } bool StoreAndShortCircuitImpl(const ptrdiff_t idx, OUT_T* to, const Y_VEC x) { hn::StoreU(x, d_out, to + idx); return true; } ptrdiff_t MaskStore(const ptrdiff_t idx, OUT_T* to, const Y_VEC x, const ptrdiff_t places) { return me()->MaskStoreImpl(idx, to, x, places); } ptrdiff_t MaskStoreImpl(const ptrdiff_t idx, OUT_T* to, const Y_VEC x, const ptrdiff_t places) { auto mask = hn::FirstN(d_out, static_cast(places)); auto maskneg = hn::Not(hn::FirstN( d_out, static_cast(places + static_cast(ActualLanes())))); if (places < 0) mask = maskneg; hn::BlendedStore(x, mask, d_out, to + idx); return std::abs(places); } ptrdiff_t Reduce(const Y_VEC x, OUT_T* to) { return me()->ReduceImpl(x, to); } ptrdiff_t ReduceImpl(const Y_VEC x, OUT_T* to) { // default does nothing (void)x; (void)to; return 0; } void Reduce(const Y_VEC x0, const Y_VEC x1, const Y_VEC x2, Y_VEC* y) { me()->ReduceImpl(x0, x1, x2, y); } void ReduceImpl(const Y_VEC x0, const Y_VEC x1, const Y_VEC x2, Y_VEC* y) { // default does nothing (void)x0; (void)x1; (void)x2; (void)y; } }; template inline void Unroller(FUNC& f, IN_T* HWY_RESTRICT x, OUT_T* HWY_RESTRICT y, const ptrdiff_t n) { auto xx = f.X0Init(); auto yy = f.YInit(); ptrdiff_t i = 0; #if HWY_MEM_OPS_MIGHT_FAULT constexpr auto lane_sz = static_cast(RemoveRef::MaxUnitLanes()); if (n < lane_sz) { const DFromV d; // this may not fit on the stack for HWY_RVV, but we do not reach this code // there HWY_ALIGN IN_T xtmp[static_cast(lane_sz)]; HWY_ALIGN OUT_T ytmp[static_cast(lane_sz)]; CopyBytes(x, xtmp, static_cast(n) * sizeof(IN_T)); xx = f.MaskLoad(0, xtmp, n); yy = f.Func(0, xx, yy); Store(Zero(d), d, ytmp); i += f.MaskStore(0, ytmp, yy, n); i += f.Reduce(yy, ytmp); CopyBytes(ytmp, y, static_cast(i) * sizeof(OUT_T)); return; } #endif const ptrdiff_t actual_lanes = static_cast(RemoveRef::ActualLanes()); if (n > 4 * actual_lanes) { auto xx1 = f.X0Init(); auto yy1 = f.YInit(); auto xx2 = f.X0Init(); auto yy2 = f.YInit(); auto xx3 = f.X0Init(); auto yy3 = f.YInit(); while (i + 4 * actual_lanes - 1 < n) { xx = f.Load(i, x); i += actual_lanes; xx1 = f.Load(i, x); i += actual_lanes; xx2 = f.Load(i, x); i += actual_lanes; xx3 = f.Load(i, x); i -= 3 * actual_lanes; yy = f.Func(i, xx, yy); yy1 = f.Func(i + actual_lanes, xx1, yy1); yy2 = f.Func(i + 2 * actual_lanes, xx2, yy2); yy3 = f.Func(i + 3 * actual_lanes, xx3, yy3); if (!f.StoreAndShortCircuit(i, y, yy)) return; i += actual_lanes; if (!f.StoreAndShortCircuit(i, y, yy1)) return; i += actual_lanes; if (!f.StoreAndShortCircuit(i, y, yy2)) return; i += actual_lanes; if (!f.StoreAndShortCircuit(i, y, yy3)) return; i += actual_lanes; } f.Reduce(yy3, yy2, yy1, &yy); } while (i + actual_lanes - 1 < n) { xx = f.Load(i, x); yy = f.Func(i, xx, yy); if (!f.StoreAndShortCircuit(i, y, yy)) return; i += actual_lanes; } if (i != n) { xx = f.MaskLoad(n - actual_lanes, x, i - n); yy = f.Func(n - actual_lanes, xx, yy); f.MaskStore(n - actual_lanes, y, yy, i - n); } f.Reduce(yy, y); } template inline void Unroller(FUNC& HWY_RESTRICT f, IN0_T* HWY_RESTRICT x0, IN1_T* HWY_RESTRICT x1, OUT_T* HWY_RESTRICT y, const ptrdiff_t n) { const ptrdiff_t lane_sz = static_cast(RemoveRef::ActualLanes()); auto xx00 = f.X0Init(); auto xx10 = f.X1Init(); auto yy = f.YInit(); ptrdiff_t i = 0; #if HWY_MEM_OPS_MIGHT_FAULT if (n < lane_sz) { const DFromV d; // this may not fit on the stack for HWY_RVV, but we do not reach this code // there constexpr auto max_lane_sz = static_cast(RemoveRef::MaxUnitLanes()); HWY_ALIGN IN0_T xtmp0[static_cast(max_lane_sz)]; HWY_ALIGN IN1_T xtmp1[static_cast(max_lane_sz)]; HWY_ALIGN OUT_T ytmp[static_cast(max_lane_sz)]; CopyBytes(x0, xtmp0, static_cast(n) * sizeof(IN0_T)); CopyBytes(x1, xtmp1, static_cast(n) * sizeof(IN1_T)); xx00 = f.MaskLoad0(0, xtmp0, n); xx10 = f.MaskLoad1(0, xtmp1, n); yy = f.Func(0, xx00, xx10, yy); Store(Zero(d), d, ytmp); i += f.MaskStore(0, ytmp, yy, n); i += f.Reduce(yy, ytmp); CopyBytes(ytmp, y, static_cast(i) * sizeof(OUT_T)); return; } #endif if (n > 4 * lane_sz) { auto xx01 = f.X0Init(); auto xx11 = f.X1Init(); auto yy1 = f.YInit(); auto xx02 = f.X0Init(); auto xx12 = f.X1Init(); auto yy2 = f.YInit(); auto xx03 = f.X0Init(); auto xx13 = f.X1Init(); auto yy3 = f.YInit(); while (i + 4 * lane_sz - 1 < n) { xx00 = f.Load0(i, x0); xx10 = f.Load1(i, x1); i += lane_sz; xx01 = f.Load0(i, x0); xx11 = f.Load1(i, x1); i += lane_sz; xx02 = f.Load0(i, x0); xx12 = f.Load1(i, x1); i += lane_sz; xx03 = f.Load0(i, x0); xx13 = f.Load1(i, x1); i -= 3 * lane_sz; yy = f.Func(i, xx00, xx10, yy); yy1 = f.Func(i + lane_sz, xx01, xx11, yy1); yy2 = f.Func(i + 2 * lane_sz, xx02, xx12, yy2); yy3 = f.Func(i + 3 * lane_sz, xx03, xx13, yy3); if (!f.StoreAndShortCircuit(i, y, yy)) return; i += lane_sz; if (!f.StoreAndShortCircuit(i, y, yy1)) return; i += lane_sz; if (!f.StoreAndShortCircuit(i, y, yy2)) return; i += lane_sz; if (!f.StoreAndShortCircuit(i, y, yy3)) return; i += lane_sz; } f.Reduce(yy3, yy2, yy1, &yy); } while (i + lane_sz - 1 < n) { xx00 = f.Load0(i, x0); xx10 = f.Load1(i, x1); yy = f.Func(i, xx00, xx10, yy); if (!f.StoreAndShortCircuit(i, y, yy)) return; i += lane_sz; } if (i != n) { xx00 = f.MaskLoad0(n - lane_sz, x0, i - n); xx10 = f.MaskLoad1(n - lane_sz, x1, i - n); yy = f.Func(n - lane_sz, xx00, xx10, yy); f.MaskStore(n - lane_sz, y, yy, i - n); } f.Reduce(yy, y); } } // namespace HWY_NAMESPACE } // namespace hwy HWY_AFTER_NAMESPACE(); #endif // HIGHWAY_HWY_CONTRIB_UNROLLER_UNROLLER_INL_H_