/*
Copyright (C) 2012 William Hart
This file is part of FLINT.
FLINT is free software: you can redistribute it and/or modify it under
the terms of the GNU Lesser General Public License (LGPL) as published
by the Free Software Foundation; either version 2.1 of the License, or
(at your option) any later version. See .
*/
#include
#include
#include "flint.h"
#include "longlong.h"
#include "mpn_extras.h"
/*
TODO: speedup mpir's mullow and mulhigh and use instead of mul/mul_n
*/
void flint_mpn_mulmod_preinvn(mp_ptr r,
mp_srcptr a, mp_srcptr b, mp_size_t n,
mp_srcptr d, mp_srcptr dinv, ulong norm)
{
mp_limb_t cy, p1, p2, b0, b1;
mp_ptr t;
TMP_INIT;
TMP_START;
t = TMP_ALLOC(5*n*sizeof(mp_limb_t));
if (n == 2)
{
if (norm)
{
/* mpn_rshift(b, b, n, norm) */
b0 = (b[0] >> norm) | (b[1] << (FLINT_BITS - norm));
b1 = b[1] >> norm;
} else
{
b0 = b[0];
b1 = b[1];
}
/* mpn_mul_n(t, a, b, n) */
umul_ppmm(t[1], t[0], a[0], b0);
umul_ppmm(t[3], t[2], a[1], b1);
umul_ppmm(p2, p1, a[0], b1);
add_sssaaaaaa(t[3], t[2], t[1], t[3], t[2], t[1], 0, p2, p1);
umul_ppmm(p2, p1, a[1], b0);
add_sssaaaaaa(t[3], t[2], t[1], t[3], t[2], t[1], 0, p2, p1);
/* mpn_mul_n(t + 3*n, t + n, dinv, n) */
umul_ppmm(t[7], t[6], t[2], dinv[0]);
umul_ppmm(t[9], t[8], t[3], dinv[1]);
umul_ppmm(p2, p1, t[2], dinv[1]);
add_sssaaaaaa(t[9], t[8], t[7], t[9], t[8], t[7], 0, p2, p1);
umul_ppmm(p2, p1, t[3], dinv[0]);
add_sssaaaaaa(t[9], t[8], t[7], t[9], t[8], t[7], 0, p2, p1);
/* mpn_add_n(t + 4*n, t + 4*n, t + n, n) */
add_ssaaaa(t[9], t[8], t[9], t[8], t[3], t[2]);
/* mpn_mul_n(t + 2*n, t + 4*n, d, n) */
umul_ppmm(t[5], t[4], t[8], d[0]);
t[6] = t[9]*d[1];
umul_ppmm(p2, p1, t[8], d[1]);
add_ssaaaa(t[6], t[5], t[6], t[5], p2, p1);
umul_ppmm(p2, p1, t[9], d[0]);
add_ssaaaa(t[6], t[5], t[6], t[5], p2, p1);
/* cy = t[n] - t[3*n] - mpn_sub_n(r, t, t + 2*n, n) */
sub_dddmmmsss(cy, r[1], r[0], t[2], t[1], t[0], t[6], t[5], t[4]);
while (cy > 0)
{
/* cy -= mpn_sub_n(r, r, d, n) */
sub_dddmmmsss(cy, r[1], r[0], cy, r[1], r[0], 0, d[1], d[0]);
}
if ((r[1] > d[1]) || (r[1] == d[1] && r[0] >= d[0]))
{
/* mpn_sub_n(r, r, d, n) */
sub_ddmmss(r[1], r[0], r[1], r[0], d[1], d[0]);
}
} else
{
if (a == b)
mpn_sqr(t, a, n);
else
mpn_mul_n(t, a, b, n);
if (norm)
mpn_rshift(t, t, 2*n, norm);
mpn_mul_n(t + 3*n, t + n, dinv, n);
mpn_add_n(t + 4*n, t + 4*n, t + n, n);
mpn_mul_n(t + 2*n, t + 4*n, d, n);
cy = t[n] - t[3*n] - mpn_sub_n(r, t, t + 2*n, n);
while (cy > 0)
cy -= mpn_sub_n(r, r, d, n);
if (mpn_cmp(r, d, n) >= 0)
mpn_sub_n(r, r, d, n);
}
TMP_END;
}