/* Copyright 2009 William Hart Copyright 2010 Fredrik Johansson Copyright 2020 Daniel Schultz This file is part of FLINT. FLINT is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License (LGPL) as published by the Free Software Foundation; either version 2.1 of the License, or (at your option) any later version. See . */ #include #include #include "profiler.h" #include "flint.h" #include "nmod_mat.h" #include "ulong_extras.h" #include "thread_support.h" #include "cblas.h" typedef struct { slong dim_m; slong dim_n; slong dim_k; mp_limb_t modulus; int algorithm; } mat_mul_t; void sample(void * arg, ulong count) { mat_mul_t * params = (mat_mul_t *) arg; int algorithm = params->algorithm; nmod_mat_t A, B, C; ulong i; flint_rand_t state; flint_randinit(state); nmod_mat_init(A, params->dim_m, params->dim_k, params->modulus); nmod_mat_init(B, params->dim_k, params->dim_n, params->modulus); nmod_mat_init(C, params->dim_m, params->dim_n, params->modulus); nmod_mat_randfull(A, state); nmod_mat_randfull(B, state); nmod_mat_randfull(C, state); prof_start(); if (algorithm == 0) for (i = 0; i < count; i++) nmod_mat_mul(C, A, B); else if (algorithm == 1) for (i = 0; i < count; i++) nmod_mat_mul_classical(C, A, B); else if (algorithm == 2) for (i = 0; i < count; i++) nmod_mat_mul_classical_threaded(C, A, B); else if (algorithm == 3) for (i = 0; i < count; i++) nmod_mat_mul_blas(C, A, B); else for (i = 0; i < count; i++) nmod_mat_mul_strassen(C, A, B); prof_stop(); nmod_mat_clear(A); nmod_mat_clear(B); nmod_mat_clear(C); flint_randclear(state); } int main(void) { double max; mat_mul_t params; slong dim, i, flint_num, blas_num; flint_printf("nmod_mat_mul:\n"); for (dim = 2; dim <= 100; dim += dim/4 + 1) { double min_classical, min_strassen; params.dim_m = dim; params.dim_n = dim; params.dim_k = dim; params.modulus = 40000; params.algorithm = 1; prof_repeat(&min_classical, &max, sample, ¶ms); params.algorithm = 4; prof_repeat(&min_strassen, &max, sample, ¶ms); flint_printf("dim = %wd, classical %.2f us strassen %.2f us\n", dim, min_classical, min_strassen); } /* output floating point ratios time(mul_blas)/time(mul_blas) */ for (dim = 200; dim <= 1200; dim += 200) { flint_printf("dimension %wd\n", dim); for (flint_num = 2; flint_num <= 8; flint_num += 1) { flint_set_num_threads(flint_num); for (blas_num = flint_num; blas_num <= flint_num; blas_num *= 2) { double min_old, min_new, min_ratio = 100; openblas_set_num_threads(blas_num); flint_printf("[flint %wd, blas %wd]: (", flint_num, blas_num); for (i = 7; i < FLINT_BITS; i += 8) { params.dim_m = dim; params.dim_n = dim; params.dim_k = dim; params.modulus = 2*(UWORD(1) << i) - 1; params.algorithm = 2; prof_repeat(&min_old, &max, sample, ¶ms); params.algorithm = 0; prof_repeat(&min_new, &max, sample, ¶ms); min_ratio = FLINT_MIN(min_ratio, min_old/min_new); flint_printf(" %.2f ", min_old/min_new); fflush(stdout); } flint_printf(") min %0.2f\n", min_ratio); /* assume that blas gets faster with more threads */ if (min_ratio > 1) break; } } } return 0; }