/* Copyright (C) 2008, Martin Albrecht Copyright (C) 2008, 2009 William Hart. Copyright (C) 2010, Fredrik Johansson This file is part of FLINT. FLINT is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License (LGPL) as published by the Free Software Foundation; either version 2.1 of the License, or (at your option) any later version. See . */ #include #include #include "flint.h" #include "nmod_vec.h" #include "nmod_mat.h" void nmod_mat_mul_strassen(nmod_mat_t C, const nmod_mat_t A, const nmod_mat_t B) { slong a, b, c; slong anr, anc, bnr, bnc; nmod_mat_t A11, A12, A21, A22; nmod_mat_t B11, B12, B21, B22; nmod_mat_t C11, C12, C21, C22; nmod_mat_t X1, X2; a = A->r; b = A->c; c = B->c; if (a <= 4 || b <= 4 || c <= 4) { nmod_mat_mul(C, A, B); return; } anr = a / 2; anc = b / 2; bnr = anc; bnc = c / 2; nmod_mat_window_init(A11, A, 0, 0, anr, anc); nmod_mat_window_init(A12, A, 0, anc, anr, 2*anc); nmod_mat_window_init(A21, A, anr, 0, 2*anr, anc); nmod_mat_window_init(A22, A, anr, anc, 2*anr, 2*anc); nmod_mat_window_init(B11, B, 0, 0, bnr, bnc); nmod_mat_window_init(B12, B, 0, bnc, bnr, 2*bnc); nmod_mat_window_init(B21, B, bnr, 0, 2*bnr, bnc); nmod_mat_window_init(B22, B, bnr, bnc, 2*bnr, 2*bnc); nmod_mat_window_init(C11, C, 0, 0, anr, bnc); nmod_mat_window_init(C12, C, 0, bnc, anr, 2*bnc); nmod_mat_window_init(C21, C, anr, 0, 2*anr, bnc); nmod_mat_window_init(C22, C, anr, bnc, 2*anr, 2*bnc); nmod_mat_init(X1, anr, FLINT_MAX(bnc, anc), A->mod.n); nmod_mat_init(X2, anc, bnc, A->mod.n); X1->c = anc; /* See Jean-Guillaume Dumas, Clement Pernet, Wei Zhou; "Memory efficient scheduling of Strassen-Winograd's matrix multiplication algorithm"; https://arxiv.org/pdf/0707.2347v3 for reference on the used operation scheduling. */ nmod_mat_sub(X1, A11, A21); nmod_mat_sub(X2, B22, B12); nmod_mat_mul(C21, X1, X2); nmod_mat_add(X1, A21, A22); nmod_mat_sub(X2, B12, B11); nmod_mat_mul(C22, X1, X2); nmod_mat_sub(X1, X1, A11); nmod_mat_sub(X2, B22, X2); nmod_mat_mul(C12, X1, X2); nmod_mat_sub(X1, A12, X1); nmod_mat_mul(C11, X1, B22); X1->c = bnc; nmod_mat_mul(X1, A11, B11); nmod_mat_add(C12, X1, C12); nmod_mat_add(C21, C12, C21); nmod_mat_add(C12, C12, C22); nmod_mat_add(C22, C21, C22); nmod_mat_add(C12, C12, C11); nmod_mat_sub(X2, X2, B21); nmod_mat_mul(C11, A22, X2); nmod_mat_clear(X2); nmod_mat_sub(C21, C21, C11); nmod_mat_mul(C11, A12, B21); nmod_mat_add(C11, X1, C11); nmod_mat_clear(X1); nmod_mat_window_clear(A11); nmod_mat_window_clear(A12); nmod_mat_window_clear(A21); nmod_mat_window_clear(A22); nmod_mat_window_clear(B11); nmod_mat_window_clear(B12); nmod_mat_window_clear(B21); nmod_mat_window_clear(B22); nmod_mat_window_clear(C11); nmod_mat_window_clear(C12); nmod_mat_window_clear(C21); nmod_mat_window_clear(C22); if (c > 2*bnc) /* A by last col of B -> last col of C */ { nmod_mat_t Bc, Cc; nmod_mat_window_init(Bc, B, 0, 2*bnc, b, c); nmod_mat_window_init(Cc, C, 0, 2*bnc, a, c); nmod_mat_mul(Cc, A, Bc); nmod_mat_window_clear(Bc); nmod_mat_window_clear(Cc); } if (a > 2*anr) /* last row of A by B -> last row of C */ { nmod_mat_t Ar, Cr; nmod_mat_window_init(Ar, A, 2*anr, 0, a, b); nmod_mat_window_init(Cr, C, 2*anr, 0, a, c); nmod_mat_mul(Cr, Ar, B); nmod_mat_window_clear(Ar); nmod_mat_window_clear(Cr); } if (b > 2*anc) /* last col of A by last row of B -> C */ { nmod_mat_t Ac, Br, Cb; nmod_mat_window_init(Ac, A, 0, 2*anc, 2*anr, b); nmod_mat_window_init(Br, B, 2*bnr, 0, b, 2*bnc); nmod_mat_window_init(Cb, C, 0, 0, 2*anr, 2*bnc); nmod_mat_addmul(Cb, Cb, Ac, Br); nmod_mat_window_clear(Ac); nmod_mat_window_clear(Br); nmod_mat_window_clear(Cb); } }