#ifndef PQCLEAN_FALCON1024_AVX2_FPR_H #define PQCLEAN_FALCON1024_AVX2_FPR_H /* * Floating-point operations. * * ==========================(LICENSE BEGIN)============================ * * Copyright (c) 2017-2019 Falcon Project * * Permission is hereby granted, free of charge, to any person obtaining * a copy of this software and associated documentation files (the * "Software"), to deal in the Software without restriction, including * without limitation the rights to use, copy, modify, merge, publish, * distribute, sublicense, and/or sell copies of the Software, and to * permit persons to whom the Software is furnished to do so, subject to * the following conditions: * * The above copyright notice and this permission notice shall be * included in all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. * * ===========================(LICENSE END)============================= * * @author Thomas Pornin <thomas.pornin@nccgroup.com> */ /* ====================================================================== */ #include <immintrin.h> #include <math.h> #define FMADD(a, b, c) _mm256_add_pd(_mm256_mul_pd(a, b), c) #define FMSUB(a, b, c) _mm256_sub_pd(_mm256_mul_pd(a, b), c) /* * We wrap the native 'double' type into a structure so that the C compiler * complains if we inadvertently use raw arithmetic operators on the 'fpr' * type instead of using the inline functions below. This should have no * extra runtime cost, since all the functions below are 'inline'. */ typedef struct { double v; } fpr; static inline fpr FPR(double v) { fpr x; x.v = v; return x; } static inline fpr fpr_of(int64_t i) { return FPR((double)i); } static const fpr fpr_q = { 12289.0 }; static const fpr fpr_inverse_of_q = { 1.0 / 12289.0 }; static const fpr fpr_inv_2sqrsigma0 = { .150865048875372721532312163019 }; static const fpr fpr_inv_sigma = { .005819826392951607426919370871 }; static const fpr fpr_sigma_min_9 = { 1.291500756233514568549480827642 }; static const fpr fpr_sigma_min_10 = { 1.311734375905083682667395805765 }; static const fpr fpr_log2 = { 0.69314718055994530941723212146 }; static const fpr fpr_inv_log2 = { 1.4426950408889634073599246810 }; static const fpr fpr_bnorm_max = { 16822.4121 }; static const fpr fpr_zero = { 0.0 }; static const fpr fpr_one = { 1.0 }; static const fpr fpr_two = { 2.0 }; static const fpr fpr_onehalf = { 0.5 }; static const fpr fpr_invsqrt2 = { 0.707106781186547524400844362105 }; static const fpr fpr_invsqrt8 = { 0.353553390593273762200422181052 }; static const fpr fpr_ptwo31 = { 2147483648.0 }; static const fpr fpr_ptwo31m1 = { 2147483647.0 }; static const fpr fpr_mtwo31m1 = { -2147483647.0 }; static const fpr fpr_ptwo63m1 = { 9223372036854775807.0 }; static const fpr fpr_mtwo63m1 = { -9223372036854775807.0 }; static const fpr fpr_ptwo63 = { 9223372036854775808.0 }; static inline int64_t fpr_rint(fpr x) { /* * We do not want to use llrint() since it might be not * constant-time. * * Suppose that x >= 0. If x >= 2^52, then it is already an * integer. Otherwise, if x < 2^52, then computing x+2^52 will * yield a value that will be rounded to the nearest integer * with exactly the right rules (round-to-nearest-even). * * In order to have constant-time processing, we must do the * computation for both x >= 0 and x < 0 cases, and use a * cast to an integer to access the sign and select the proper * value. Such casts also allow us to find out if |x| < 2^52. */ int64_t sx, tx, rp, rn, m; uint32_t ub; sx = (int64_t)(x.v - 1.0); tx = (int64_t)x.v; rp = (int64_t)(x.v + 4503599627370496.0) - 4503599627370496; rn = (int64_t)(x.v - 4503599627370496.0) + 4503599627370496; /* * If tx >= 2^52 or tx < -2^52, then result is tx. * Otherwise, if sx >= 0, then result is rp. * Otherwise, result is rn. We use the fact that when x is * close to 0 (|x| <= 0.25) then both rp and rn are correct; * and if x is not close to 0, then trunc(x-1.0) yields the * appropriate sign. */ /* * Clamp rp to zero if tx < 0. * Clamp rn to zero if tx >= 0. */ m = sx >> 63; rn &= m; rp &= ~m; /* * Get the 12 upper bits of tx; if they are not all zeros or * all ones, then tx >= 2^52 or tx < -2^52, and we clamp both * rp and rn to zero. Otherwise, we clamp tx to zero. */ ub = (uint32_t)((uint64_t)tx >> 52); m = -(int64_t)((((ub + 1) & 0xFFF) - 2) >> 31); rp &= m; rn &= m; tx &= ~m; /* * Only one of tx, rn or rp (at most) can be non-zero at this * point. */ return tx | rn | rp; } static inline int64_t fpr_floor(fpr x) { int64_t r; /* * The cast performs a trunc() (rounding toward 0) and thus is * wrong by 1 for most negative values. The correction below is * constant-time as long as the compiler turns the * floating-point conversion result into a 0/1 integer without a * conditional branch or another non-constant-time construction. * This should hold on all modern architectures with an FPU (and * if it is false on a given arch, then chances are that the FPU * itself is not constant-time, making the point moot). */ r = (int64_t)x.v; return r - (x.v < (double)r); } static inline int64_t fpr_trunc(fpr x) { return (int64_t)x.v; } static inline fpr fpr_add(fpr x, fpr y) { return FPR(x.v + y.v); } static inline fpr fpr_sub(fpr x, fpr y) { return FPR(x.v - y.v); } static inline fpr fpr_neg(fpr x) { return FPR(-x.v); } static inline fpr fpr_half(fpr x) { return FPR(x.v * 0.5); } static inline fpr fpr_double(fpr x) { return FPR(x.v + x.v); } static inline fpr fpr_mul(fpr x, fpr y) { return FPR(x.v * y.v); } static inline fpr fpr_sqr(fpr x) { return FPR(x.v * x.v); } static inline fpr fpr_inv(fpr x) { return FPR(1.0 / x.v); } static inline fpr fpr_div(fpr x, fpr y) { return FPR(x.v / y.v); } static inline void fpr_sqrt_avx2(double *t) { __m128d x; x = _mm_load1_pd(t); x = _mm_sqrt_pd(x); _mm_storel_pd(t, x); } static inline fpr fpr_sqrt(fpr x) { /* * We prefer not to have a dependency on libm when it can be * avoided. On x86, calling the sqrt() libm function inlines * the relevant opcode (fsqrt or sqrtsd, depending on whether * the 387 FPU or SSE2 is used for floating-point operations) * but then makes an optional call to the library function * for proper error handling, in case the operand is negative. * * To avoid this dependency, we use intrinsics or inline assembly * on recognized platforms: * * - If AVX2 is explicitly enabled, then we use SSE2 intrinsics. * * - On GCC/Clang with SSE maths, we use SSE2 intrinsics. * * - On GCC/Clang on i386, or MSVC on i386, we use inline assembly * to call the 387 FPU fsqrt opcode. * * - On GCC/Clang/XLC on PowerPC, we use inline assembly to call * the fsqrt opcode (Clang needs a special hack). * * - On GCC/Clang on ARM with hardware floating-point, we use * inline assembly to call the vqsrt.f64 opcode. Due to a * complex ecosystem of compilers and assembly syntaxes, we * have to call it "fsqrt" or "fsqrtd", depending on case. * * If the platform is not recognized, a call to the system * library function sqrt() is performed. On some compilers, this * may actually inline the relevant opcode, and call the library * function only when the input is invalid (e.g. negative); * Falcon never actually calls sqrt() on a negative value, but * the dependency to libm will still be there. */ fpr_sqrt_avx2(&x.v); return x; } static inline int fpr_lt(fpr x, fpr y) { return x.v < y.v; } static inline uint64_t fpr_expm_p63(fpr x, fpr ccs) { /* * Polynomial approximation of exp(-x) is taken from FACCT: * https://eprint.iacr.org/2018/1234 * Specifically, values are extracted from the implementation * referenced from the FACCT article, and available at: * https://github.com/raykzhao/gaussian * Tests over more than 24 billions of random inputs in the * 0..log(2) range have never shown a deviation larger than * 2^(-50) from the true mathematical value. */ /* * AVX2 implementation uses more operations than Horner's method, * but with a lower expression tree depth. This helps because * additions and multiplications have a latency of 4 cycles on * a Skylake, but the CPU can issue two of them per cycle. */ static const union { double d[12]; __m256d v[3]; } c = { { 0.999999999999994892974086724280, 0.500000000000019206858326015208, 0.166666666666984014666397229121, 0.041666666666110491190622155955, 0.008333333327800835146903501993, 0.001388888894063186997887560103, 0.000198412739277311890541063977, 0.000024801566833585381209939524, 0.000002755586350219122514855659, 0.000000275607356160477811864927, 0.000000025299506379442070029551, 0.000000002073772366009083061987 } }; double d1, d2, d4, d8, y; __m256d d14, d58, d9c; d1 = -x.v; d2 = d1 * d1; d4 = d2 * d2; d8 = d4 * d4; d14 = _mm256_set_pd(d4, d2 * d1, d2, d1); d58 = _mm256_mul_pd(d14, _mm256_set1_pd(d4)); d9c = _mm256_mul_pd(d14, _mm256_set1_pd(d8)); d14 = _mm256_mul_pd(d14, _mm256_loadu_pd(&c.d[0])); d58 = FMADD(d58, _mm256_loadu_pd(&c.d[4]), d14); d9c = FMADD(d9c, _mm256_loadu_pd(&c.d[8]), d58); d9c = _mm256_hadd_pd(d9c, d9c); y = 1.0 + _mm_cvtsd_f64(_mm256_castpd256_pd128(d9c)) // _mm256_cvtsd_f64(d9c) + _mm_cvtsd_f64(_mm256_extractf128_pd(d9c, 1)); y *= ccs.v; /* * Final conversion goes through int64_t first, because that's what * the underlying opcode (vcvttsd2si) will do, and we know that the * result will fit, since x >= 0 and ccs < 1. If we did the * conversion directly to uint64_t, then the compiler would add some * extra code to cover the case of a source value of 2^63 or more, * and though the alternate path would never be exercised, the * extra comparison would cost us some cycles. */ return (uint64_t)(int64_t)(y * fpr_ptwo63.v); } #define fpr_gm_tab PQCLEAN_FALCON1024_AVX2_fpr_gm_tab extern const fpr fpr_gm_tab[]; #define fpr_p2_tab PQCLEAN_FALCON1024_AVX2_fpr_p2_tab extern const fpr fpr_p2_tab[]; /* ====================================================================== */ #endif