/*
Copyright (C) 2014 Fredrik Johansson
Copyright (C) 2020 William Hart
This file is part of FLINT.
FLINT is free software: you can redistribute it and/or modify it under
the terms of the GNU Lesser General Public License (LGPL) as published
by the Free Software Foundation; either version 2.1 of the License, or
(at your option) any later version. See .
*/
#include
#include
#include "flint.h"
#include "fmpz.h"
#include "fmpz_poly.h"
#include "thread_support.h"
typedef struct
{
fmpz * vec;
mp_ptr * residues;
slong n0;
slong n1;
mp_srcptr primes;
slong num_primes;
int crt; /* reduce if 0, lift if 1 */
}
mod_ui_arg_t;
void
_fmpz_vec_multi_mod_ui_worker(void * arg_ptr)
{
mod_ui_arg_t arg = *((mod_ui_arg_t *) arg_ptr);
mp_ptr tmp;
slong i, j;
fmpz_comb_t comb;
fmpz_comb_temp_t comb_temp;
tmp = flint_malloc(sizeof(mp_limb_t) * arg.num_primes);
fmpz_comb_init(comb, arg.primes, arg.num_primes);
fmpz_comb_temp_init(comb_temp, comb);
for (i = arg.n0; i < arg.n1; i++)
{
if (arg.crt)
{
for (j = 0; j < arg.num_primes; j++)
tmp[j] = arg.residues[j][i];
fmpz_multi_CRT_ui(arg.vec + i, tmp, comb, comb_temp, 1);
}
else
{
fmpz_multi_mod_ui(tmp, arg.vec + i, comb, comb_temp);
for (j = 0; j < arg.num_primes; j++)
arg.residues[j][i] = tmp[j];
}
}
flint_free(tmp);
fmpz_comb_clear(comb);
fmpz_comb_temp_clear(comb_temp);
}
void
_fmpz_vec_multi_mod_ui_threaded(mp_ptr * residues, fmpz * vec, slong len,
mp_srcptr primes, slong num_primes, int crt)
{
mod_ui_arg_t * args;
slong i, num_threads;
thread_pool_handle * threads;
num_threads = flint_request_threads(&threads, flint_get_num_threads());
args = (mod_ui_arg_t *)
flint_malloc(sizeof(mod_ui_arg_t)*(num_threads + 1));
for (i = 0; i < num_threads + 1; i++)
{
args[i].vec = vec;
args[i].residues = residues;
args[i].n0 = (len * i) / (num_threads + 1);
args[i].n1 = (len * (i + 1)) / (num_threads + 1);
args[i].primes = (mp_ptr) primes;
args[i].num_primes = num_primes;
args[i].crt = crt;
}
for (i = 0; i < num_threads; i++)
thread_pool_wake(global_thread_pool, threads[i], 0,
_fmpz_vec_multi_mod_ui_worker, &args[i]);
_fmpz_vec_multi_mod_ui_worker(&args[num_threads]);
for (i = 0; i < num_threads; i++)
thread_pool_wait(global_thread_pool, threads[i]);
flint_give_back_threads(threads, num_threads);
flint_free(args);
}
typedef struct
{
mp_ptr * residues;
slong len;
mp_srcptr primes;
slong num_primes;
slong p0;
slong p1;
fmpz * c;
}
taylor_shift_arg_t;
void
_fmpz_poly_multi_taylor_shift_worker(void * arg_ptr)
{
taylor_shift_arg_t arg = *((taylor_shift_arg_t *) arg_ptr);
slong i;
for (i = arg.p0; i < arg.p1; i++)
{
nmod_t mod;
mp_limb_t p, cm;
p = arg.primes[i];
nmod_init(&mod, p);
cm = fmpz_fdiv_ui(arg.c, p);
_nmod_poly_taylor_shift(arg.residues[i], cm, arg.len, mod);
}
}
void
_fmpz_poly_multi_taylor_shift_threaded(mp_ptr * residues, slong len,
const fmpz_t c, mp_srcptr primes, slong num_primes)
{
taylor_shift_arg_t * args;
slong i, num_threads;
thread_pool_handle * threads;
num_threads = flint_request_threads(&threads, flint_get_num_threads());
args = (taylor_shift_arg_t *)
flint_malloc(sizeof(taylor_shift_arg_t)*(num_threads + 1));
for (i = 0; i < num_threads + 1; i++)
{
args[i].residues = residues;
args[i].len = len;
args[i].p0 = (num_primes * i) / (num_threads + 1);
args[i].p1 = (num_primes * (i + 1)) / (num_threads + 1);
args[i].primes = (mp_ptr) primes;
args[i].num_primes = num_primes;
args[i].c = (fmpz *) c;
}
for (i = 0; i < num_threads; i++)
thread_pool_wake(global_thread_pool, threads[i], 0,
_fmpz_poly_multi_taylor_shift_worker, &args[i]);
_fmpz_poly_multi_taylor_shift_worker(&args[num_threads]);
for (i = 0; i < num_threads; i++)
thread_pool_wait(global_thread_pool, threads[i]);
flint_give_back_threads(threads, num_threads);
flint_free(args);
}
void
_fmpz_poly_taylor_shift_multi_mod(fmpz * poly, const fmpz_t c, slong len)
{
slong xbits, ybits, num_primes, i;
mp_ptr primes;
mp_ptr * residues;
if (len <= 1 || fmpz_is_zero(c))
return;
xbits = _fmpz_vec_max_bits(poly, len);
if (xbits == 0)
return;
/* If poly has degree D and coefficients at most |C|, the
output has coefficient at most D * |C| * 2^D * c^D */
xbits = FLINT_ABS(xbits) + 1;
ybits = xbits + len + FLINT_BIT_COUNT(len);
if (!fmpz_is_pm1(c))
{
fmpz_t t;
fmpz_init(t);
fmpz_pow_ui(t, c, len);
ybits += fmpz_bits(t);
fmpz_clear(t);
}
/* Use primes greater than 2^(FLINT_BITS-1) */
num_primes = (ybits + (FLINT_BITS - 1) - 1) / (FLINT_BITS - 1);
primes = flint_malloc(sizeof(mp_limb_t) * num_primes);
primes[0] = n_nextprime(UWORD(1) << (FLINT_BITS - 1), 1);
for (i = 1; i < num_primes; i++)
primes[i] = n_nextprime(primes[i-1], 1);
/* Space for poly reduced modulo the primes */
residues = flint_malloc(sizeof(mp_ptr) * num_primes);
for (i = 0; i < num_primes; i++)
residues[i] = flint_malloc(sizeof(mp_limb_t) * len);
_fmpz_vec_multi_mod_ui_threaded(residues, poly, len, primes,
num_primes, 0);
_fmpz_poly_multi_taylor_shift_threaded(residues, len, c,
primes, num_primes);
_fmpz_vec_multi_mod_ui_threaded(residues, poly, len, primes,
num_primes, 1);
for (i = 0; i < num_primes; i++)
flint_free(residues[i]);
flint_free(residues);
flint_free(primes);
}