/* Copyright (C) 2010 Fredrik Johansson Copyright (C) 2020 William Hart Copyright (C) 2020 Daniel Schultz This file is part of FLINT. FLINT is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License (LGPL) as published by the Free Software Foundation; either version 2.1 of the License, or (at your option) any later version. See . */ #include #include #include "flint.h" #include "nmod_mat.h" #include "nmod_vec.h" #include "thread_support.h" #if FLINT_USES_BLAS #include "cblas.h" #endif void nmod_mat_mul(nmod_mat_t C, const nmod_mat_t A, const nmod_mat_t B) { slong m = A->r; slong k = A->c; slong n = B->c; slong min_dim = FLINT_MIN(FLINT_MIN(m, k), n); slong cutoff; slong flint_num_threads = flint_get_num_threads(); #if FLINT_USES_BLAS /* tuning is based on several assumptions: (1) blas_num_threads >= flint_num_threads. (2) nmod_mat_mul_blas (with crt) only beats nmod_mat_mul_classical on square multiplications of large enough dimension (3) if nmod_mat_mul_blas beats nmod_mat_mul_classical on square multiplications of size d, then it beats it on rectangular muliplications as long as all dimensions are >= d */ if (FLINT_BITS == 64 && min_dim > 100) { flint_bitcnt_t bits = FLINT_BIT_COUNT(A->mod.n); if (FLINT_BIT_COUNT(k) + 2*bits < 53 + 5) { /* mul_blas definitely avoids the slow crt */ cutoff = 100; } else if (flint_num_threads > 1) { /* mul_blas with crt is competing against mul_classical_threaded */ bits = FLINT_MAX(bits, 32); cutoff = 100 + 5*flint_num_threads*bits/2; } else { /* mul_blas with crt is competing against mul_strassen */ cutoff = 450; } if (min_dim > cutoff && nmod_mat_mul_blas(C, A, B)) return; } #endif if (C == A || C == B) { nmod_mat_t T; nmod_mat_init(T, m, n, A->mod.n); nmod_mat_mul(T, A, B); nmod_mat_swap_entrywise(C, T); nmod_mat_clear(T); return; } if (FLINT_BITS == 64 && C->mod.n < 2048) cutoff = 400; else cutoff = 200; if (flint_num_threads > 1) nmod_mat_mul_classical_threaded(C, A, B); else if (min_dim < cutoff) nmod_mat_mul_classical(C, A, B); else nmod_mat_mul_strassen(C, A, B); }