/*
Copyright (C) 2010, 2012 Fredrik Johansson
Copyright (C) 2020 William Hart
This file is part of FLINT.
FLINT is free software: you can redistribute it and/or modify it under
the terms of the GNU Lesser General Public License (LGPL) as published
by the Free Software Foundation; either version 2.1 of the License, or
(at your option) any later version. See .
*/
#include
#include
#include "flint.h"
#include "fmpz_mod_mat.h"
#include "fmpz_vec.h"
#include "thread_support.h"
#define FLINT_FMPZ_MUL_CLASSICAL_CACHE_SIZE 32768 /* size of L1 cache in words */
/*
with op = 0, computes D = A*B
with op = 1, computes D = C + A*B
with op = -1, computes D = C - A*B
*/
static __inline__ void
_fmpz_mod_mat_addmul_basic_op(fmpz ** D, fmpz ** const C, fmpz ** const A,
fmpz ** const B, slong m, slong k, slong n, int op, fmpz_t p)
{
slong i, j;
fmpz_t c;
fmpz_init(c);
for (i = 0; i < m; i++)
{
for (j = 0; j < n; j++)
{
_fmpz_vec_dot_ptr(c, A[i], B, j, k);
if (op == 1)
fmpz_add(c, C[i] + j, c);
else if (op == -1)
fmpz_sub(c, C[i] + j, c);
fmpz_mod(D[i] + j, c, p);
}
}
fmpz_clear(c);
}
typedef struct
{
slong block;
volatile slong * i;
volatile slong * j;
slong k;
slong m;
slong n;
fmpz ** A;
fmpz ** C;
fmpz ** D;
fmpz * tmp;
fmpz * p;
#if FLINT_USES_PTHREAD
pthread_mutex_t * mutex;
#endif
int op;
} fmpz_mod_mat_transpose_arg_t;
void
_fmpz_mod_mat_addmul_transpose_worker(void * arg_ptr)
{
fmpz_mod_mat_transpose_arg_t arg = *((fmpz_mod_mat_transpose_arg_t *) arg_ptr);
slong i, j, iend, jend, jstart;
slong block = arg.block;
slong k = arg.k;
slong m = arg.m;
slong n = arg.n;
fmpz ** const A = arg.A;
fmpz ** const C = arg.C;
fmpz ** D = arg.D;
fmpz * tmp = arg.tmp;
fmpz * p = arg.p;
int op = arg.op;
fmpz_t c;
fmpz_init(c);
while (1)
{
#if FLINT_USES_PTHREAD
pthread_mutex_lock(arg.mutex);
#endif
i = *arg.i;
j = *arg.j;
if (j >= n)
{
i += block;
*arg.i = i;
j = 0;
}
*arg.j = j + block;
#if FLINT_USES_PTHREAD
pthread_mutex_unlock(arg.mutex);
#endif
if (i >= m)
{
fmpz_clear(c);
return;
}
iend = FLINT_MIN(i + block, m);
jend = FLINT_MIN(j + block, n);
jstart = j;
for ( ; i < iend; i++)
{
for (j = jstart ; j < jend; j++)
{
_fmpz_vec_dot(c, A[i], tmp + j*k, k);
if (op == 1)
fmpz_add(c, C[i] + j, c);
else if (op == -1)
fmpz_sub(c, C[i] + j, c);
fmpz_mod(D[i] + j, c, p);
}
}
}
}
static __inline__ void
_fmpz_mod_mat_addmul_transpose_threaded_pool_op(fmpz ** D, fmpz ** const C,
fmpz ** const A, fmpz ** const B, slong m,
slong k, slong n, int op, fmpz_t p,
thread_pool_handle * threads, slong num_threads)
{
fmpz * tmp;
slong i, j, block, nlimbs;
slong shared_i = 0, shared_j = 0;
fmpz_mod_mat_transpose_arg_t * args;
#if FLINT_USES_PTHREAD
pthread_mutex_t mutex;
#endif
tmp = _fmpz_vec_init(k*n);
/* transpose B */
for (i = 0; i < k; i++)
for (j = 0; j < n; j++)
fmpz_set(tmp + j*k + i, B[i] + j);
nlimbs = fmpz_size(p);
/* compute optimal block width */
block = FLINT_MAX(FLINT_MIN(m/(num_threads + 1), n/(num_threads + 1)), 1);
while (2*block*k*nlimbs > FLINT_FMPZ_MUL_CLASSICAL_CACHE_SIZE && block > 1)
block >>= 1;
args = flint_malloc(sizeof(fmpz_mod_mat_transpose_arg_t) * (num_threads + 1));
for (i = 0; i < num_threads + 1; i++)
{
args[i].block = block;
args[i].i = &shared_i;
args[i].j = &shared_j;
args[i].k = k;
args[i].m = m;
args[i].n = n;
args[i].A = A;
args[i].C = C;
args[i].D = D;
args[i].tmp = tmp;
args[i].p = p;
#if FLINT_USES_PTHREAD
args[i].mutex = &mutex;
#endif
args[i].op = op;
}
#if FLINT_USES_PTHREAD
pthread_mutex_init(&mutex, NULL);
#endif
for (i = 0; i < num_threads; i++)
{
thread_pool_wake(global_thread_pool, threads[i], 0,
_fmpz_mod_mat_addmul_transpose_worker, &args[i]);
}
_fmpz_mod_mat_addmul_transpose_worker(&args[num_threads]);
for (i = 0; i < num_threads; i++)
{
thread_pool_wait(global_thread_pool, threads[i]);
}
#if FLINT_USES_PTHREAD
pthread_mutex_destroy(&mutex);
#endif
flint_free(args);
_fmpz_vec_clear(tmp, k*n);
}
void
_fmpz_mod_mat_mul_classical_threaded_pool_op(fmpz_mod_mat_t D, const fmpz_mod_mat_t C,
const fmpz_mod_mat_t A, const fmpz_mod_mat_t B, int op,
thread_pool_handle * threads, slong num_threads)
{
slong m, k, n;
m = A->mat->r;
k = A->mat->c;
n = B->mat->c;
_fmpz_mod_mat_addmul_transpose_threaded_pool_op(D->mat->rows,
(op == 0) ? NULL : C->mat->rows, A->mat->rows, B->mat->rows,
m, k, n, op, D->mod, threads, num_threads);
}
void
fmpz_mod_mat_mul_classical_threaded_op(fmpz_mod_mat_t D, const fmpz_mod_mat_t C,
const fmpz_mod_mat_t A, const fmpz_mod_mat_t B, int op)
{
thread_pool_handle * threads;
slong num_threads;
if (A->mat->c == 0)
{
if (op == 0)
fmpz_mod_mat_zero(D);
else
fmpz_mod_mat_set(D, C);
return;
}
if (A->mat->r < FMPZ_MOD_MAT_MUL_TRANSPOSE_CUTOFF
|| A->mat->c < FMPZ_MOD_MAT_MUL_TRANSPOSE_CUTOFF
|| B->mat->c < FMPZ_MOD_MAT_MUL_TRANSPOSE_CUTOFF)
{
_fmpz_mod_mat_addmul_basic_op(D->mat->rows,
(op == 0) ? NULL : C->mat->rows,
A->mat->rows, B->mat->rows, A->mat->r,
A->mat->c, B->mat->c, op, D->mod);
return;
}
num_threads = flint_request_threads(&threads, flint_get_num_threads());
_fmpz_mod_mat_mul_classical_threaded_pool_op(D, C, A, B, op, threads, num_threads);
flint_give_back_threads(threads, num_threads);
}
void
fmpz_mod_mat_mul_classical_threaded(fmpz_mod_mat_t C, const fmpz_mod_mat_t A,
const fmpz_mod_mat_t B)
{
fmpz_mod_mat_mul_classical_threaded_op(C, NULL, A, B, 0);
}