/* Copyright (C) 2010,2012 Fredrik Johansson This file is part of FLINT. FLINT is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License (LGPL) as published by the Free Software Foundation; either version 2.1 of the License, or (at your option) any later version. See . */ #include #include #include "flint.h" #include "nmod_mat.h" #include "nmod_vec.h" /* with op = 0, computes D = A*B with op = 1, computes D = C + A*B with op = -1, computes D = C - A*B */ static __inline__ void _nmod_mat_addmul_basic_op(mp_ptr * D, mp_ptr * const C, mp_ptr * const A, mp_ptr * const B, slong m, slong k, slong n, int op, nmod_t mod, int nlimbs) { slong i, j; mp_limb_t c; for (i = 0; i < m; i++) { for (j = 0; j < n; j++) { c = _nmod_vec_dot_ptr(A[i], B, j, k, mod, nlimbs); if (op == 1) c = nmod_add(C[i][j], c, mod); else if (op == -1) c = nmod_sub(C[i][j], c, mod); D[i][j] = c; } } } static __inline__ void _nmod_mat_addmul_transpose_op(mp_ptr * D, const mp_ptr * C, const mp_ptr * A, const mp_ptr * B, slong m, slong k, slong n, int op, nmod_t mod, int nlimbs) { mp_ptr tmp; mp_limb_t c; slong i, j; tmp = flint_malloc(sizeof(mp_limb_t) * k * n); for (i = 0; i < k; i++) for (j = 0; j < n; j++) tmp[j*k + i] = B[i][j]; for (i = 0; i < m; i++) { for (j = 0; j < n; j++) { c = _nmod_vec_dot(A[i], tmp + j*k, k, mod, nlimbs); if (op == 1) c = nmod_add(C[i][j], c, mod); else if (op == -1) c = nmod_sub(C[i][j], c, mod); D[i][j] = c; } } flint_free(tmp); } /* requires nlimbs = 1 */ void _nmod_mat_addmul_packed_op(mp_ptr * D, const mp_ptr * C, const mp_ptr * A, const mp_ptr * B, slong M, slong N, slong K, int op, nmod_t mod, int nlimbs) { slong i, j, k; slong Kpack; int pack, pack_bits; mp_limb_t c, d, mask; mp_ptr tmp; mp_ptr Aptr, Tptr; /* bound unreduced entry */ c = N * (mod.n-1) * (mod.n-1); pack_bits = FLINT_BIT_COUNT(c); pack = FLINT_BITS / pack_bits; Kpack = (K + pack - 1) / pack; if (pack_bits == FLINT_BITS) mask = UWORD(-1); else mask = (UWORD(1) << pack_bits) - 1; tmp = _nmod_vec_init(Kpack * N); /* pack and transpose B */ for (i = 0; i < Kpack; i++) { for (k = 0; k < N; k++) { c = B[k][i * pack]; for (j = 1; j < pack && i * pack + j < K; j++) c |= B[k][i * pack + j] << (pack_bits * j); tmp[i * N + k] = c; } } /* multiply */ for (i = 0; i < M; i++) { for (j = 0; j < Kpack; j++) { Aptr = A[i]; Tptr = tmp + j * N; c = 0; /* unroll by 4 */ for (k = 0; k + 4 <= N; k += 4) { c += Aptr[k + 0] * Tptr[k + 0]; c += Aptr[k + 1] * Tptr[k + 1]; c += Aptr[k + 2] * Tptr[k + 2]; c += Aptr[k + 3] * Tptr[k + 3]; } for ( ; k < N; k++) c += Aptr[k] * Tptr[k]; /* unpack and reduce */ for (k = 0; k < pack && j * pack + k < K; k++) { d = (c >> (k * pack_bits)) & mask; NMOD_RED(d, d, mod); if (op == 1) d = nmod_add(C[i][j * pack + k], d, mod); else if (op == -1) d = nmod_sub(C[i][j * pack + k], d, mod); D[i][j * pack + k] = d; } } } _nmod_vec_clear(tmp); } void _nmod_mat_mul_classical_op(nmod_mat_t D, const nmod_mat_t C, const nmod_mat_t A, const nmod_mat_t B, int op) { slong m, k, n; int nlimbs; nmod_t mod; mod = A->mod; m = A->r; k = A->c; n = B->c; if (k == 0) { if (op == 0) nmod_mat_zero(D); else nmod_mat_set(D, C); return; } nlimbs = _nmod_vec_dot_bound_limbs(k, mod); if (nlimbs == 1 && m > 10 && k > 10 && n > 10) { _nmod_mat_addmul_packed_op(D->rows, (op == 0) ? NULL : C->rows, A->rows, B->rows, m, k, n, op, D->mod, nlimbs); } else if (m < NMOD_MAT_MUL_TRANSPOSE_CUTOFF || n < NMOD_MAT_MUL_TRANSPOSE_CUTOFF || k < NMOD_MAT_MUL_TRANSPOSE_CUTOFF) { if ((mod.n & (mod.n - 1)) == 0) nlimbs = 1; _nmod_mat_addmul_basic_op(D->rows, (op == 0) ? NULL : C->rows, A->rows, B->rows, m, k, n, op, D->mod, nlimbs); } else { if ((mod.n & (mod.n - 1)) == 0) nlimbs = 1; _nmod_mat_addmul_transpose_op(D->rows, (op == 0) ? NULL : C->rows, A->rows, B->rows, m, k, n, op, D->mod, nlimbs); } } void nmod_mat_mul_classical(nmod_mat_t C, const nmod_mat_t A, const nmod_mat_t B) { _nmod_mat_mul_classical_op(C, NULL, A, B, 0); }