/*
Copyright (C) 2010, 2012 Fredrik Johansson
Copyright (C) 2020 William Hart
This file is part of FLINT.
FLINT is free software: you can redistribute it and/or modify it under
the terms of the GNU Lesser General Public License (LGPL) as published
by the Free Software Foundation; either version 2.1 of the License, or
(at your option) any later version. See .
*/
#include
#include
#include "flint.h"
#include "nmod_mat.h"
#include "nmod_vec.h"
#include "thread_support.h"
#define FLINT_MUL_CLASSICAL_CACHE_SIZE 32768 /* size of L1 cache in words */
/*
with op = 0, computes D = A*B
with op = 1, computes D = C + A*B
with op = -1, computes D = C - A*B
*/
static __inline__ void
_nmod_mat_addmul_basic_op(mp_ptr * D, mp_ptr * const C, mp_ptr * const A,
mp_ptr * const B, slong m, slong k, slong n, int op, nmod_t mod, int nlimbs)
{
slong i, j;
mp_limb_t c;
for (i = 0; i < m; i++)
{
for (j = 0; j < n; j++)
{
c = _nmod_vec_dot_ptr(A[i], B, j, k, mod, nlimbs);
if (op == 1)
c = nmod_add(C[i][j], c, mod);
else if (op == -1)
c = nmod_sub(C[i][j], c, mod);
D[i][j] = c;
}
}
}
typedef struct
{
slong block;
volatile slong * i;
volatile slong * j;
slong k;
slong m;
slong n;
slong nlimbs;
const mp_ptr * A;
const mp_ptr * C;
mp_ptr * D;
mp_ptr tmp;
nmod_t mod;
#if FLINT_USES_PTHREAD
pthread_mutex_t * mutex;
#endif
int op;
} nmod_mat_transpose_arg_t;
void
_nmod_mat_addmul_transpose_worker(void * arg_ptr)
{
nmod_mat_transpose_arg_t arg = *((nmod_mat_transpose_arg_t *) arg_ptr);
slong i, j, iend, jend, jstart;
slong block = arg.block;
slong k = arg.k;
slong m = arg.m;
slong n = arg.n;
slong nlimbs = arg.nlimbs;
const mp_ptr * A = arg.A;
const mp_ptr * C = arg.C;
mp_ptr * D = arg.D;
mp_ptr tmp = arg.tmp;
nmod_t mod = arg.mod;
int op = arg.op;
mp_limb_t c;
while (1)
{
#if FLINT_USES_PTHREAD
pthread_mutex_lock(arg.mutex);
#endif
i = *arg.i;
j = *arg.j;
if (j >= n)
{
i += block;
*arg.i = i;
j = 0;
}
*arg.j = j + block;
#if FLINT_USES_PTHREAD
pthread_mutex_unlock(arg.mutex);
#endif
if (i >= m)
return;
iend = FLINT_MIN(i + block, m);
jend = FLINT_MIN(j + block, n);
jstart = j;
for ( ; i < iend; i++)
{
for (j = jstart ; j < jend; j++)
{
c = _nmod_vec_dot(A[i], tmp + j*k, k, mod, nlimbs);
if (op == 1)
c = nmod_add(C[i][j], c, mod);
else if (op == -1)
c = nmod_sub(C[i][j], c, mod);
D[i][j] = c;
}
}
}
}
static __inline__ void
_nmod_mat_addmul_transpose_threaded_pool_op(mp_ptr * D, const mp_ptr * C,
const mp_ptr * A, const mp_ptr * B, slong m,
slong k, slong n, int op, nmod_t mod, int nlimbs,
thread_pool_handle * threads, slong num_threads)
{
mp_ptr tmp;
slong i, j, block;
slong shared_i = 0, shared_j = 0;
nmod_mat_transpose_arg_t * args;
#if FLINT_USES_PTHREAD
pthread_mutex_t mutex;
#endif
tmp = flint_malloc(sizeof(mp_limb_t) * k * n);
/* transpose B */
for (i = 0; i < k; i++)
for (j = 0; j < n; j++)
tmp[j*k + i] = B[i][j];
/* compute optimal block width */
block = FLINT_MAX(FLINT_MIN(m/(num_threads + 1), n/(num_threads + 1)), 1);
while (2*block*k > FLINT_MUL_CLASSICAL_CACHE_SIZE && block > 1)
block >>= 1;
args = flint_malloc(sizeof(nmod_mat_transpose_arg_t) * (num_threads + 1));
for (i = 0; i < num_threads + 1; i++)
{
args[i].block = block;
args[i].i = &shared_i;
args[i].j = &shared_j;
args[i].k = k;
args[i].m = m;
args[i].n = n;
args[i].nlimbs = nlimbs;
args[i].A = A;
args[i].C = C;
args[i].D = D;
args[i].tmp = tmp;
args[i].mod = mod;
#if FLINT_USES_PTHREAD
args[i].mutex = &mutex;
#endif
args[i].op = op;
}
#if FLINT_USES_PTHREAD
pthread_mutex_init(&mutex, NULL);
#endif
for (i = 0; i < num_threads; i++)
{
thread_pool_wake(global_thread_pool, threads[i], 0,
_nmod_mat_addmul_transpose_worker, &args[i]);
}
_nmod_mat_addmul_transpose_worker(&args[num_threads]);
for (i = 0; i < num_threads; i++)
{
thread_pool_wait(global_thread_pool, threads[i]);
}
#if FLINT_USES_PTHREAD
pthread_mutex_destroy(&mutex);
#endif
flint_free(args);
flint_free(tmp);
}
typedef struct
{
slong block;
volatile slong * i;
volatile slong * j;
slong M;
slong K;
slong N;
slong Kpack;
const mp_ptr * A;
const mp_ptr * C;
mp_ptr * D;
mp_ptr tmp;
nmod_t mod;
mp_limb_t mask;
#if FLINT_USES_PTHREAD
pthread_mutex_t * mutex;
#endif
int pack;
int pack_bits;
int op;
} nmod_mat_packed_arg_t;
void
_nmod_mat_addmul_packed_worker(void * arg_ptr)
{
nmod_mat_packed_arg_t arg = *((nmod_mat_packed_arg_t *) arg_ptr);
slong i, j, k, iend, jend, jstart;
slong block = arg.block;
slong M = arg.M;
slong K = arg.K;
slong N = arg.N;
slong Kpack = arg.Kpack;
const mp_ptr * A = arg.A;
const mp_ptr * C = arg.C;
mp_ptr * D = arg.D;
mp_ptr tmp = arg.tmp;
nmod_t mod = arg.mod;
mp_limb_t mask = arg.mask;
int pack = arg.pack;
int pack_bits = arg.pack_bits;
int op = arg.op;
mp_limb_t c, d;
mp_ptr Aptr, Tptr;
while (1)
{
#if FLINT_USES_PTHREAD
pthread_mutex_lock(arg.mutex);
#endif
i = *arg.i;
j = *arg.j;
if (j >= Kpack)
{
i += block;
*arg.i = i;
j = 0;
}
*arg.j = j + block;
#if FLINT_USES_PTHREAD
pthread_mutex_unlock(arg.mutex);
#endif
if (i >= M)
return;
iend = FLINT_MIN(i + block, M);
jend = FLINT_MIN(j + block, Kpack);
jstart = j;
/* multiply */
for ( ; i < iend; i++)
{
for (j = jstart; j < jend; j++)
{
Aptr = A[i];
Tptr = tmp + j * N;
c = 0;
/* unroll by 4 */
for (k = 0; k + 4 <= N; k += 4)
{
c += Aptr[k + 0] * Tptr[k + 0];
c += Aptr[k + 1] * Tptr[k + 1];
c += Aptr[k + 2] * Tptr[k + 2];
c += Aptr[k + 3] * Tptr[k + 3];
}
for ( ; k < N; k++)
c += Aptr[k] * Tptr[k];
/* unpack and reduce */
for (k = 0; k < pack && j * pack + k < K; k++)
{
d = (c >> (k * pack_bits)) & mask;
NMOD_RED(d, d, mod);
if (op == 1)
d = nmod_add(C[i][j * pack + k], d, mod);
else if (op == -1)
d = nmod_sub(C[i][j * pack + k], d, mod);
D[i][j * pack + k] = d;
}
}
}
}
}
/* requires nlimbs = 1 */
void
_nmod_mat_addmul_packed_threaded_pool_op(mp_ptr * D,
const mp_ptr * C, const mp_ptr * A, const mp_ptr * B,
slong M, slong N, slong K, int op, nmod_t mod, int nlimbs,
thread_pool_handle * threads, slong num_threads)
{
slong i, j, k;
slong Kpack, block;
int pack, pack_bits;
mp_limb_t c, mask;
mp_ptr tmp;
slong shared_i = 0, shared_j = 0;
nmod_mat_packed_arg_t * args;
#if FLINT_USES_PTHREAD
pthread_mutex_t mutex;
#endif
/* bound unreduced entry */
c = N * (mod.n-1) * (mod.n-1);
pack_bits = FLINT_BIT_COUNT(c);
pack = FLINT_BITS / pack_bits;
Kpack = (K + pack - 1) / pack;
if (pack_bits == FLINT_BITS)
mask = UWORD(-1);
else
mask = (UWORD(1) << pack_bits) - 1;
tmp = _nmod_vec_init(Kpack * N);
/* pack and transpose B */
for (i = 0; i < Kpack; i++)
{
for (k = 0; k < N; k++)
{
c = B[k][i * pack];
for (j = 1; j < pack && i * pack + j < K; j++)
c |= B[k][i * pack + j] << (pack_bits * j);
tmp[i * N + k] = c;
}
}
/* compute optimal block width */
block = FLINT_MAX(FLINT_MIN(M/(num_threads + 1), Kpack/(num_threads + 1)), 1);
while (2*block*N > FLINT_MUL_CLASSICAL_CACHE_SIZE && block > 1)
block >>= 1;
args = flint_malloc(sizeof(nmod_mat_packed_arg_t) * (num_threads + 1));
for (i = 0; i < num_threads + 1; i++)
{
args[i].block = block;
args[i].i = &shared_i;
args[i].j = &shared_j;
args[i].M = M;
args[i].K = K;
args[i].N = N;
args[i].Kpack = Kpack;
args[i].A = A;
args[i].C = C;
args[i].D = D;
args[i].tmp = tmp;
args[i].mod = mod;
args[i].mask = mask;
#if FLINT_USES_PTHREAD
args[i].mutex = &mutex;
#endif
args[i].pack = pack;
args[i].pack_bits = pack_bits;
args[i].op = op;
}
#if FLINT_USES_PTHREAD
pthread_mutex_init(&mutex, NULL);
#endif
for (i = 0; i < num_threads; i++)
{
thread_pool_wake(global_thread_pool, threads[i],
0, _nmod_mat_addmul_packed_worker, &args[i]);
}
_nmod_mat_addmul_packed_worker(&args[num_threads]);
for (i = 0; i < num_threads; i++)
{
thread_pool_wait(global_thread_pool, threads[i]);
}
#if FLINT_USES_PTHREAD
pthread_mutex_destroy(&mutex);
#endif
flint_free(args);
_nmod_vec_clear(tmp);
}
void
_nmod_mat_mul_classical_threaded_pool_op(nmod_mat_t D, const nmod_mat_t C,
const nmod_mat_t A, const nmod_mat_t B, int op,
thread_pool_handle * threads, slong num_threads)
{
slong m, k, n;
int nlimbs;
nmod_t mod;
mod = A->mod;
m = A->r;
k = A->c;
n = B->c;
nlimbs = _nmod_vec_dot_bound_limbs(k, mod);
if (nlimbs == 1 && m > 10 && k > 10 && n > 10)
{
_nmod_mat_addmul_packed_threaded_pool_op(D->rows, (op == 0) ? NULL : C->rows,
A->rows, B->rows, m, k, n, op, D->mod, nlimbs, threads, num_threads);
}
else
{
if ((mod.n & (mod.n - 1)) == 0)
nlimbs = 1;
_nmod_mat_addmul_transpose_threaded_pool_op(D->rows, (op == 0) ? NULL : C->rows,
A->rows, B->rows, m, k, n, op, D->mod, nlimbs, threads, num_threads);
}
}
void
_nmod_mat_mul_classical_threaded_op(nmod_mat_t D, const nmod_mat_t C,
const nmod_mat_t A, const nmod_mat_t B, int op)
{
thread_pool_handle * threads;
slong num_threads;
if (A->c == 0)
{
if (op == 0)
nmod_mat_zero(D);
else
nmod_mat_set(D, C);
return;
}
if (A->r < NMOD_MAT_MUL_TRANSPOSE_CUTOFF
|| A->c < NMOD_MAT_MUL_TRANSPOSE_CUTOFF
|| B->c < NMOD_MAT_MUL_TRANSPOSE_CUTOFF)
{
slong nlimbs = _nmod_vec_dot_bound_limbs(A->c, D->mod);
_nmod_mat_addmul_basic_op(D->rows, (op == 0) ? NULL : C->rows,
A->rows, B->rows, A->r, A->c, B->c, op, D->mod, nlimbs);
return;
}
num_threads = flint_request_threads(&threads, flint_get_num_threads());
_nmod_mat_mul_classical_threaded_pool_op(D, C, A, B, op, threads, num_threads);
flint_give_back_threads(threads, num_threads);
}
void
nmod_mat_mul_classical_threaded(nmod_mat_t C, const nmod_mat_t A,
const nmod_mat_t B)
{
_nmod_mat_mul_classical_threaded_op(C, NULL, A, B, 0);
}