#include "crypto_core_multsntrup653.h" #include "crypto_core_multsntrup653_ntt.h" #include "crypto_decode_653xint16.h" #include "crypto_encode_653xint16.h" #include <immintrin.h> typedef int8_t int8; typedef int16_t int16; #define int16x16 __m256i #define load_x16(p) _mm256_loadu_si256((int16x16 *) (p)) #define store_x16(p,v) _mm256_storeu_si256((int16x16 *) (p),(v)) #define const_x16 _mm256_set1_epi16 #define add_x16 _mm256_add_epi16 #define sub_x16 _mm256_sub_epi16 #define mullo_x16 _mm256_mullo_epi16 #define mulhi_x16 _mm256_mulhi_epi16 #define mulhrs_x16 _mm256_mulhrs_epi16 #define signmask_x16(x) _mm256_srai_epi16((x),15) typedef union { int16 v[6][512]; int16x16 _dummy; } vec6x512; typedef union { int16 v[768]; int16x16 _dummy; } vec768; typedef union { int16 v[3 * 512]; int16x16 _dummy; } vec1536; static inline int16x16 squeeze_4621_x16(int16x16 x) { return sub_x16(x, mullo_x16(mulhrs_x16(x, const_x16(7)), const_x16(4621))); } static inline int16x16 squeeze_7681_x16(int16x16 x) { return sub_x16(x, mullo_x16(mulhrs_x16(x, const_x16(4)), const_x16(7681))); } static inline int16x16 squeeze_10753_x16(int16x16 x) { return sub_x16(x, mullo_x16(mulhrs_x16(x, const_x16(3)), const_x16(10753))); } static inline int16x16 mulmod_4621_x16(int16x16 x, int16x16 y) { int16x16 yqinv = mullo_x16(y, const_x16(-29499)); /* XXX: precompute */ int16x16 b = mulhi_x16(x, y); int16x16 d = mullo_x16(x, yqinv); int16x16 e = mulhi_x16(d, const_x16(4621)); return sub_x16(b, e); } static inline int16x16 mulmod_7681_x16(int16x16 x, int16x16 y) { int16x16 yqinv = mullo_x16(y, const_x16(-7679)); /* XXX: precompute */ int16x16 b = mulhi_x16(x, y); int16x16 d = mullo_x16(x, yqinv); int16x16 e = mulhi_x16(d, const_x16(7681)); return sub_x16(b, e); } static inline int16x16 mulmod_10753_x16(int16x16 x, int16x16 y) { int16x16 yqinv = mullo_x16(y, const_x16(-10751)); /* XXX: precompute */ int16x16 b = mulhi_x16(x, y); int16x16 d = mullo_x16(x, yqinv); int16x16 e = mulhi_x16(d, const_x16(10753)); return sub_x16(b, e); } #define mask0 _mm256_set_epi16(-1,0,0,-1,0,0,-1,0,0,-1,0,0,-1,0,0,-1) #define mask1 _mm256_set_epi16(0,0,-1,0,0,-1,0,0,-1,0,0,-1,0,0,-1,0) #define mask2 _mm256_set_epi16(0,-1,0,0,-1,0,0,-1,0,0,-1,0,0,-1,0,0) static void good(int16 fpad[3][512], const int16 f[768]) { int j; int16x16 f0, f1; j = 0; for (;;) { f0 = load_x16(f + j); f1 = load_x16(f + 512 + j); store_x16(&fpad[0][j], (f0 & mask0) | (f1 & mask1)); store_x16(&fpad[1][j], (f0 & mask1) | (f1 & mask2)); store_x16(&fpad[2][j], (f0 & mask2) | (f1 & mask0)); j += 16; if (j == 256) { break; } f0 = load_x16(f + j); f1 = load_x16(f + 512 + j); store_x16(&fpad[0][j], (f0 & mask2) | (f1 & mask0)); store_x16(&fpad[1][j], (f0 & mask0) | (f1 & mask1)); store_x16(&fpad[2][j], (f0 & mask1) | (f1 & mask2)); j += 16; f0 = load_x16(f + j); f1 = load_x16(f + 512 + j); store_x16(&fpad[0][j], (f0 & mask1) | (f1 & mask2)); store_x16(&fpad[1][j], (f0 & mask2) | (f1 & mask0)); store_x16(&fpad[2][j], (f0 & mask0) | (f1 & mask1)); j += 16; } for (;;) { f0 = load_x16(f + j); store_x16(&fpad[0][j], f0 & mask2); store_x16(&fpad[1][j], f0 & mask0); store_x16(&fpad[2][j], f0 & mask1); j += 16; if (j == 512) { break; } f0 = load_x16(f + j); store_x16(&fpad[0][j], f0 & mask1); store_x16(&fpad[1][j], f0 & mask2); store_x16(&fpad[2][j], f0 & mask0); j += 16; f0 = load_x16(f + j); store_x16(&fpad[0][j], f0 & mask0); store_x16(&fpad[1][j], f0 & mask1); store_x16(&fpad[2][j], f0 & mask2); j += 16; } } static void ungood(int16 f[1536], const int16 fpad[3][512]) { int j; int16x16 f0, f1, f2, g0, g1, g2; j = 0; for (;;) { f0 = load_x16(&fpad[0][j]); f1 = load_x16(&fpad[1][j]); f2 = load_x16(&fpad[2][j]); g0 = (f0 & mask0) | (f1 & mask1) | (f2 & mask2); g1 = (f0 & mask1) | (f1 & mask2) | (f2 & mask0); g2 = f0 ^ f1 ^ f2 ^ g0 ^ g1; /* same as (f0&mask2)|(f1&mask0)|(f2&mask1) */ store_x16(f + 0 + j, g0); store_x16(f + 512 + j, g1); store_x16(f + 1024 + j, g2); j += 16; f0 = load_x16(&fpad[0][j]); f1 = load_x16(&fpad[1][j]); f2 = load_x16(&fpad[2][j]); g0 = (f0 & mask2) | (f1 & mask0) | (f2 & mask1); g1 = (f0 & mask0) | (f1 & mask1) | (f2 & mask2); g2 = f0 ^ f1 ^ f2 ^ g0 ^ g1; /* same as (f0&mask1)|(f1&mask2)|(f2&mask0) */ store_x16(f + 0 + j, g0); store_x16(f + 512 + j, g1); store_x16(f + 1024 + j, g2); j += 16; if (j == 512) { break; } f0 = load_x16(&fpad[0][j]); f1 = load_x16(&fpad[1][j]); f2 = load_x16(&fpad[2][j]); g0 = (f0 & mask1) | (f1 & mask2) | (f2 & mask0); g1 = (f0 & mask2) | (f1 & mask0) | (f2 & mask1); g2 = f0 ^ f1 ^ f2 ^ g0 ^ g1; /* same as (f0&mask0)|(f1&mask1)|(f2&mask2) */ store_x16(f + 0 + j, g0); store_x16(f + 512 + j, g1); store_x16(f + 1024 + j, g2); j += 16; } } static void mult768(int16 h[1536], const int16 f[768], const int16 g[768]) { vec6x512 fgpad; #define fpad (fgpad.v) #define gpad (fgpad.v+3) #define hpad fpad vec1536 aligned_h_7681; vec1536 aligned_h_10753; #define h_7681 (aligned_h_7681.v) #define h_10753 (aligned_h_10753.v) int i; good(fpad, f); good(gpad, g); PQCLEAN_SNTRUP653_AVX2_ntt512_7681(fgpad.v[0], 6); for (i = 0; i < 512; i += 16) { int16x16 f0 = squeeze_7681_x16(load_x16(&fpad[0][i])); int16x16 f1 = squeeze_7681_x16(load_x16(&fpad[1][i])); int16x16 f2 = squeeze_7681_x16(load_x16(&fpad[2][i])); int16x16 g0 = squeeze_7681_x16(load_x16(&gpad[0][i])); int16x16 g1 = squeeze_7681_x16(load_x16(&gpad[1][i])); int16x16 g2 = squeeze_7681_x16(load_x16(&gpad[2][i])); int16x16 d0 = mulmod_7681_x16(f0, g0); int16x16 d1 = mulmod_7681_x16(f1, g1); int16x16 d2 = mulmod_7681_x16(f2, g2); int16x16 dsum = add_x16(add_x16(d0, d1), d2); int16x16 h0 = add_x16(dsum, mulmod_7681_x16(sub_x16(f2, f1), sub_x16(g1, g2))); int16x16 h1 = add_x16(dsum, mulmod_7681_x16(sub_x16(f1, f0), sub_x16(g0, g1))); int16x16 h2 = add_x16(dsum, mulmod_7681_x16(sub_x16(f0, f2), sub_x16(g2, g0))); store_x16(&hpad[0][i], squeeze_7681_x16(h0)); store_x16(&hpad[1][i], squeeze_7681_x16(h1)); store_x16(&hpad[2][i], squeeze_7681_x16(h2)); } PQCLEAN_SNTRUP653_AVX2_invntt512_7681(hpad[0], 3); ungood(h_7681, (const int16(*)[512]) hpad); good(fpad, f); good(gpad, g); PQCLEAN_SNTRUP653_AVX2_ntt512_10753(fgpad.v[0], 6); for (i = 0; i < 512; i += 16) { int16x16 f0 = squeeze_10753_x16(load_x16(&fpad[0][i])); int16x16 f1 = squeeze_10753_x16(load_x16(&fpad[1][i])); int16x16 f2 = squeeze_10753_x16(load_x16(&fpad[2][i])); int16x16 g0 = squeeze_10753_x16(load_x16(&gpad[0][i])); int16x16 g1 = squeeze_10753_x16(load_x16(&gpad[1][i])); int16x16 g2 = squeeze_10753_x16(load_x16(&gpad[2][i])); int16x16 d0 = mulmod_10753_x16(f0, g0); int16x16 d1 = mulmod_10753_x16(f1, g1); int16x16 d2 = mulmod_10753_x16(f2, g2); int16x16 dsum = add_x16(add_x16(d0, d1), d2); int16x16 h0 = add_x16(dsum, mulmod_10753_x16(sub_x16(f2, f1), sub_x16(g1, g2))); int16x16 h1 = add_x16(dsum, mulmod_10753_x16(sub_x16(f1, f0), sub_x16(g0, g1))); int16x16 h2 = add_x16(dsum, mulmod_10753_x16(sub_x16(f0, f2), sub_x16(g2, g0))); store_x16(&hpad[0][i], squeeze_10753_x16(h0)); store_x16(&hpad[1][i], squeeze_10753_x16(h1)); store_x16(&hpad[2][i], squeeze_10753_x16(h2)); } PQCLEAN_SNTRUP653_AVX2_invntt512_10753(hpad[0], 3); ungood(h_10753, (const int16(*)[512]) hpad); for (i = 0; i < 1536; i += 16) { int16x16 u1 = load_x16(&h_10753[i]); int16x16 u2 = load_x16(&h_7681[i]); int16x16 t; u1 = mulmod_10753_x16(u1, const_x16(1268)); u2 = mulmod_7681_x16(u2, const_x16(956)); t = mulmod_7681_x16(sub_x16(u2, u1), const_x16(-2539)); t = add_x16(u1, mulmod_4621_x16(t, const_x16(1487))); store_x16(&h[i], t); } } #define crypto_decode_pxint16 PQCLEAN_SNTRUP653_AVX2_crypto_decode_653xint16 #define crypto_encode_pxint16 PQCLEAN_SNTRUP653_AVX2_crypto_encode_653xint16 #define p 653 #define q 4621 static inline int16x16 freeze_4621_x16(int16x16 x) { int16x16 mask, xq; x = add_x16(x, const_x16(q)&signmask_x16(x)); mask = signmask_x16(sub_x16(x, const_x16((q + 1) / 2))); xq = sub_x16(x, const_x16(q)); x = _mm256_blendv_epi8(xq, x, mask); return x; } int PQCLEAN_SNTRUP653_AVX2_crypto_core_multsntrup653(unsigned char *outbytes, const unsigned char *inbytes, const unsigned char *kbytes) { vec768 x1, x2; vec1536 x3; #define f (x1.v) #define g (x2.v) #define fg (x3.v) #define h f int i; int16x16 x; x = const_x16(0); for (i = p & ~15; i < 768; i += 16) { store_x16(&f[i], x); } for (i = p & ~15; i < 768; i += 16) { store_x16(&g[i], x); } crypto_decode_pxint16(f, inbytes); for (i = 0; i < 768; i += 16) { x = load_x16(&f[i]); x = freeze_4621_x16(squeeze_4621_x16(x)); store_x16(&f[i], x); } for (i = 0; i < p; ++i) { int8 gi = (int8) kbytes[i]; int8 gi0 = gi & 1; g[i] = (int16) (gi0 - (gi & (gi0 << 1))); } mult768(fg, f, g); fg[0] = (int16) (fg[0] - fg[p - 1]); for (i = 0; i < 768; i += 16) { int16x16 fgi = load_x16(&fg[i]); int16x16 fgip = load_x16(&fg[i + p]); int16x16 fgip1 = load_x16(&fg[i + p - 1]); x = add_x16(fgi, add_x16(fgip, fgip1)); x = freeze_4621_x16(squeeze_4621_x16(x)); store_x16(&h[i], x); } crypto_encode_pxint16(outbytes, h); return 0; }