/* BLIS An object-based framework for developing high-performance BLAS-like libraries. Copyright (C) 2018-2019, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. - Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. - Neither the name of The University of Texas at Austin nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ #include "blis.h" #ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM #include "immintrin.h" #define GEMM_BLK_V1 8 //Block size to perform gemm and apply trsm #define GEMM_ACCUM_A 1 //Peform B1=B1-(B0*A0) operation instead of B1'=(B0*A0) and then B1=B1-B1' #define OPT_CACHE_BLOCKING_L1 1 //Perform trsm block-wise in blocks of GEMM_BLK_V1 instead of all columns of B together. #define REARRANGE_SHFL 0 //Rearrange operations using blend or shuffle #define BLI_AlXB_M_SP 16 #define BLI_XAltB_N_SP 128 #define BLI_AutXB_M_SP 64 #define BLI_AutXB_N_SP 128 // XA = B; A is lower-traingular; No transpose; double precision; non-unit diagonal static err_t bli_dtrsm_small_XAlB( side_t side, obj_t* AlphaObj, obj_t* a, obj_t* b, cntx_t* cntx, cntl_t* cntl ); //XA = B; A is lower triabgular; No transpose; double precision; unit-diagonal static err_t bli_dtrsm_small_XAlB_unitDiag( side_t side, obj_t* AlphaObj, obj_t* a, obj_t* b, cntx_t* cntx, cntl_t* cntl ); //XA = B; A is lower-triangular; A is transposed; double precision; non-unit-diagonal static err_t bli_dtrsm_small_XAltB( side_t side, obj_t* AlphaObj, obj_t* a, obj_t* b, cntx_t* cntx, cntl_t* cntl ); //XA = B; A is lower-triangular; A is transposed; double precision; unit-diagonal static err_t bli_dtrsm_small_XAltB_unitDiag( side_t side, obj_t* AlphaObj, obj_t* a, obj_t* b, cntx_t* cntx, cntl_t* cntl ); // XA = B; A is upper triangular; No transpose; double presicion; non-unit diagonal static err_t bli_dtrsm_small_XAuB ( side_t side, obj_t* alpha, obj_t* a, obj_t* b, cntx_t* cntx, cntl_t* cntl ); //XA = B; A is upper triangular; No transpose; double precision; unit-diagonal static err_t bli_dtrsm_small_XAuB_unitDiag( side_t side, obj_t* AlphaObj, obj_t* a, obj_t* b, cntx_t* cntx, cntl_t* cntl ); //XA = B; A is upper-triangular; A is transposed; double precision; non-unit diagonal static err_t bli_dtrsm_small_XAutB( side_t side, obj_t* AlphaObj, obj_t* a, obj_t* b, cntx_t* cntx, cntl_t* cntl ); //XA = B; A is upper-triangular; A is transposed; double precision; unit diagonal static err_t bli_dtrsm_small_XAutB_unitDiag( side_t side, obj_t* AlphaObj, obj_t* a, obj_t* b, cntx_t* cntx, cntl_t* cntl ); //AX = B; A is lower triangular; No transpose; double precision; non-unit diagonal static err_t bli_dtrsm_small_AlXB( side_t side, obj_t* AlphaObj, obj_t* a, obj_t* b, cntx_t* cntx, cntl_t* cntl ); //AX = B; A is lower triangular; No transpose; double precision; unit diagonal static err_t bli_dtrsm_small_AlXB_unitDiag( side_t side, obj_t* AlphaObj, obj_t* a, obj_t* b, cntx_t* cntx, cntl_t* cntl ); static void (*fp_blis_strsm_microkernel)( float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b ); static void blis_strsm_microkernel( float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b ); static void blis_strsm_microkernel_alpha( float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alphaVal ); static void blis_strsm_microkernel_unitDiag( float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b ); static void blis_strsm_microkernel_alpha_unitDiag( float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alphaVal ); static void trsm_XAtB_block_allSmallSizedMatrices(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b); static void trsm_XAtB_block_allSmallSizedMatrices_alpha(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alphaVal); static void trsm_XAtB_block_allSmallSizedMatrices_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b); static void trsm_XAtB_block_allSmallSizedMatrices_alpha_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alphaVal); static void blis_dtrsm_microkernel( double *ptr_l, double *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b ); static void blis_dtrsm_microkernel_alpha( double *ptr_l, double *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, double alphaVal ); static void blis_dtrsm_microkernel_unitDiag( double *ptr_l, double *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b ); static void blis_dtrsm_microkernel_alpha_unitDiag( double *ptr_l, double *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, double alphaVal ); static void dtrsm_XAtB_block_allSmallSizedMatrices(double *ptr_l, double *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b); static void dtrsm_XAtB_block_allSmallSizedMatrices_alpha(double *ptr_l, double *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, double alphaVal); static void dtrsm_XAtB_block_allSmallSizedMatrices_unitDiag(double *ptr_l, double *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b); static void dtrsm_XAtB_block_allSmallSizedMatrices_alpha_unitDiag(double *ptr_l, double *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, double alphaVal); static void trsm_AutXB_block_allSmallSizedMatrices(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b); static void trsm_AutXB_block_allSmallSizedMatrices_alpha(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha); static void trsm_AutXB_block_allSmallSizedMatrices_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b); static void trsm_AutXB_block_allSmallSizedMatrices_alpha_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha); //AX = B; A is lower triangular; No transpose; single precision static err_t bli_strsm_small_AlXB ( side_t side, obj_t* alpha, obj_t* a, obj_t* b, cntx_t* cntx, cntl_t* cntl ); //A.'X = B; A is upper triangular; A has to be transposed; single precision static err_t bli_strsm_small_AutXB ( side_t side, obj_t* alpha, obj_t* a, obj_t* b, cntx_t* cntx, cntl_t* cntl ); //XA.' = B; A is lower triangular; A has to be transposed; single precision static err_t bli_strsm_small_XAltB ( side_t side, obj_t* alpha, obj_t* a, obj_t* b, cntx_t* cntx, cntl_t* cntl ); //A.'X = B; A is upper triangular; A has to be transposed; double precision static err_t bli_dtrsm_small_AutXB ( side_t side, obj_t* alpha, obj_t* a, obj_t* b, cntx_t* cntx, cntl_t* cntl ); /* * The bli_trsm_small implements unpacked version of TRSM * Currently only column-major is supported, A & B are column-major * Input: A: MxM (triangular matrix) * B: MxN matrix * Output: X: MxN matrix such that AX = alpha*B or XA = alpha*B or A'X = alpha*B or XA' = alpha*B * Here the output X is stored in B * The custom-kernel will be called only when M*(M+N)* sizeof(Matrix Elements) < L3 cache */ err_t bli_trsm_small ( side_t side, obj_t* alpha, obj_t* a, obj_t* b, cntx_t* cntx, cntl_t* cntl ) { #ifdef BLIS_ENABLE_MULTITHREADING return BLIS_NOT_YET_IMPLEMENTED; #endif dim_t m = bli_obj_length(b); dim_t n = bli_obj_width(b); if(!(m && n)) return BLIS_SUCCESS; // If alpha is zero, B matrix will become zero after scaling & hence solution is also zero matrix if (bli_obj_equals(alpha, &BLIS_ZERO)) { return BLIS_NOT_YET_IMPLEMENTED; // scale B by alpha } // We have to call matrix scaling if alpha != 1.0 // if row major format return. Check this again. if ((bli_obj_row_stride(a) != 1) || (bli_obj_row_stride(b) != 1)) { return BLIS_INVALID_ROW_STRIDE; } num_t dt = ((*b).info & (0x7 << 0)); // only float and double datatypes are supported as of now. if (dt != BLIS_DOUBLE && dt != BLIS_FLOAT) { return BLIS_EXPECTED_REAL_DATATYPE; } // A is expected to be triangular in trsm if (!bli_obj_is_upper_or_lower (a)) { return BLIS_EXPECTED_TRIANGULAR_OBJECT; } // can use other control structs - even can use array of function pointers, // indexed by a number with bits formed by f('side', 'uplo', 'transa', dt). // In the below implementation, based on the number of finally implemented // cases, can move the checks with more cases higher up. if(side == BLIS_LEFT) { if(bli_obj_has_trans(a)) { if(dt == BLIS_DOUBLE) { if(bli_obj_is_upper(a)) { //return bli_dtrsm_small_AutXB(side, alpha, a, b, cntx, cntl); return BLIS_NOT_YET_IMPLEMENTED; } else { //return bli_dtrsm_small_AltXB(side, alpha, a, b, cntx, cntl); return BLIS_NOT_YET_IMPLEMENTED; } } else { if(bli_obj_is_upper(a)) { return bli_strsm_small_AutXB(side, alpha, a, b, cntx, cntl); } else { //return bli_strsm_small_AltXB(side, alpha, a, b, cntx, cntl); return BLIS_NOT_YET_IMPLEMENTED; } } } else { if(dt == BLIS_DOUBLE) { if(bli_obj_is_upper(a)) { //return bli_dtrsm_small_AuXB(side, alpha, a, b, cntx, cntl); return BLIS_NOT_YET_IMPLEMENTED; } else { if(bli_obj_has_unit_diag(a)) return bli_dtrsm_small_AlXB_unitDiag(side, alpha, a, b, cntx, cntl); else return bli_dtrsm_small_AlXB(side, alpha, a, b, cntx, cntl); } } else { if(bli_obj_is_upper(a)) { //return bli_strsm_small_AuXB(side, alpha, a, b, cntx, cntl); return BLIS_NOT_YET_IMPLEMENTED; } else { return bli_strsm_small_AlXB(side, alpha, a, b, cntx, cntl); } } } } else { if(bli_obj_has_trans(a)) { if(dt == BLIS_DOUBLE) { if(bli_obj_is_upper(a)) { if(bli_obj_has_unit_diag(a)) return bli_dtrsm_small_XAutB_unitDiag(side, alpha, a, b, cntx, cntl); else return bli_dtrsm_small_XAutB(side, alpha, a, b, cntx, cntl); } else { if(bli_obj_has_unit_diag(a)) return bli_dtrsm_small_XAltB_unitDiag(side, alpha, a, b, cntx, cntl); else return bli_dtrsm_small_XAltB(side, alpha, a, b, cntx, cntl); } } else { if(bli_obj_is_upper(a)) { //return bli_strsm_small_XAutB(side, alpha, a, b, cntx, cntl); return BLIS_NOT_YET_IMPLEMENTED; } else { return bli_strsm_small_XAltB(side, alpha, a, b, cntx, cntl); } } } else { if(dt == BLIS_DOUBLE) { if(bli_obj_is_upper(a)) { if(bli_obj_has_unit_diag(a)) return bli_dtrsm_small_XAuB_unitDiag(side, alpha, a, b, cntx, cntl); else return bli_dtrsm_small_XAuB(side, alpha, a, b, cntx, cntl); } else { if(bli_obj_has_unit_diag(a)) return bli_dtrsm_small_XAlB_unitDiag(side, alpha, a, b, cntx, cntl); else return bli_dtrsm_small_XAlB(side, alpha, a, b, cntx, cntl); } } else { if(bli_obj_is_upper(a)) { //return bli_strsm_small_XAuB(side, alpha, a, b, cntx, cntl); return BLIS_NOT_YET_IMPLEMENTED; } else { //return bli_strsm_small_XAlB(side, alpha, a, b, cntx, cntl); return BLIS_NOT_YET_IMPLEMENTED; } } } } return BLIS_NOT_YET_IMPLEMENTED; }; /* TRSM scalar code for the case AX = alpha * B * A is lower-triangular, non-unit-diagonal, no transpose * Dimensions: A: mxm X: mxn B:mxn */ static err_t dtrsm_small_AlXB ( double *A, double *B, dim_t M, dim_t N, dim_t lda, dim_t ldb ) { dim_t i, j, k; for (k = 0; k < M; k++) { double lkk_inv = 1.0/A[k+k*lda]; for (j = 0; j < N; j++) { B[k + j*ldb] *= lkk_inv; for (i = k+1; i < M; i++) { B[i + j*ldb] -= A[i + k*lda] * B[k + j*ldb]; } } }// k -loop return BLIS_SUCCESS; }// end of function /* TRSM scalar code for the case AX = alpha * B * A is lower-triangular, unit-diagonal, no transpose * Dimensions: A: mxm X: mxn B:mxn */ static err_t dtrsm_small_AlXB_unitDiag ( double *A, double *B, dim_t M, dim_t N, dim_t lda, dim_t ldb ) { dim_t i, j, k; for (k = 0; k < M; k++) { for (j = 0; j < N; j++) { for (i = k+1; i < M; i++) { B[i + j*ldb] -= A[i + k*lda] * B[k + j*ldb]; } } } return BLIS_SUCCESS; }// end of function /* TRSM scalar code for the case XA = alpha * B * A is upper-triangular, non-unit-diagonal no transpose * Dimensions: X:mxn A:nxn B:mxn */ static err_t dtrsm_small_XAuB ( double *A, double *B, dim_t M, dim_t N, dim_t lda, dim_t ldb ) { dim_t i, j, k; for(k = 0; k < N; k++) { double lkk_inv = 1.0/A[k+k*lda]; for(i = 0; i < M; i++) { B[i+k*ldb] *= lkk_inv; for(j = k+1; j < N; j++) { B[i+j*ldb] -= B[i+k*ldb] * A[k+j*lda]; } } } return BLIS_SUCCESS; } /* TRSM scalar code for the case XA = alpha * B * A is lower-triangular, non-unit triangular, no transpose * Dimensions: X:mxn A:nxn B:mxn */ static err_t dtrsm_small_XAlB ( double *A, double *B, double alpha, dim_t M, dim_t N, dim_t lda, dim_t ldb ) { dim_t i, j, k; for(j = 0; j < N; j++) for(i = 0; i < M; i++) B[i+j*ldb] *= alpha; for(k = N;k--;) { double lkk_inv = 1.0/A[(k)+(k)*lda]; for(i = M;i--;) { B[(i)+(k)*ldb] *= lkk_inv; for(j = k;j--;) { B[(i)+(j)*ldb] -= B[(i)+(k)*ldb] * A[(k)+(j)*lda]; } } } return BLIS_SUCCESS; } /* TRSM scalar code for the case XA = alpha * B * A is lower-triangular, unit-diagonal, no transpose *Dimensions: X:mxn A:nxn B:mxn */ static err_t dtrsm_small_XAlB_unitDiag( double *A, double *B, double alpha, dim_t M, dim_t N, dim_t lda, dim_t ldb ) { dim_t i, j, k; for(j = 0 ; j < N; j++) for(i = 0; i < M; i++) B[i+j*ldb] *= alpha; double A_k_j; for(k = N; k--;) { for(j = k; j--;) { A_k_j = A[(k)+(j)*lda]; for(i = M; i--;) { B[(i)+(j)*ldb] -= B[(i)+(k)*ldb] * A_k_j; } } } return BLIS_SUCCESS; } /* TRSM scalar code for the case XA = alpha * B *A is upper-triangular, non-unit-diagonal, A is transposed * Dimensions: X:mxn A:nxn B:mxn */ static err_t dtrsm_small_XAutB ( double *A, double *B, double alpha, dim_t M, dim_t N, dim_t lda, dim_t ldb ) { dim_t i, j, k; for(j = 0; j < N; j++) for(i = 0; i < M; i++) B[i+j*ldb] *=alpha; for(k = N; k--;) { double lkk_inv = 1.0/A[(k)+(k)*lda]; for(i = M; i--;) { B[(i)+(k)*ldb] *= lkk_inv; for(j = k; j--;) { B[(i)+(j)*ldb] -= B[(i)+(k)*ldb] * A[(j)+(k)*lda]; } } } return BLIS_SUCCESS; } /* TRSM scalar code for the case XA = alpha * B * A is upper-triangular, unit-diagonal, A has to be transposed * Dimensions: X:mxn A:nxn B:mxn */ static err_t dtrsm_small_XAutB_unitDiag( double *A, double *B, double alpha, dim_t M, dim_t N, dim_t lda, dim_t ldb ) { dim_t i, j, k; double A_k_j; for(j = 0; j< N; j++) for(i = 0; i< M; i++) B[i+j*ldb] *= alpha; for(k = N; k--;) { for(j = k; j--;) { A_k_j = A[(j)+(k)*lda]; for(i = M; i--;) { B[(i)+(j)*ldb] -= B[(i)+(k)*ldb] * A_k_j; } } } return BLIS_SUCCESS; } /* TRSM scalar code for the case XA = alpha * B * A is lower-triangular, non-unit-diagonal, A has to be transposed * Dimensions: X:mxn A:nxn B:mxn */ static err_t dtrsm_small_XAltB ( double *A, double *B, dim_t M, dim_t N, dim_t lda, dim_t ldb ) { dim_t i, j, k; for(k = 0; k < N; k++) { double lkk_inv = 1.0/A[k+k*lda]; for(i = 0; i < M; i++) { B[i+k*ldb] *= lkk_inv; for(j = k+1; j < N; j++) { B[i+j*ldb] -= B[i+k*ldb] * A[j+k*lda]; } } } return BLIS_SUCCESS; } /* TRSM scalar code for XA = alpha * B * A is lower-triangular, unit-diagonal, A has to be transposed * Dimensions: X:mxn A:nxn B:mxn */ static err_t dtrsm_small_XAltB_unitDiag( double *A, double *B, dim_t M, dim_t N, dim_t lda, dim_t ldb ) { dim_t i, j, k; for(k = 0; k < N; k++) { for(i = 0; i < M; i++) { for(j = k+1; j < N; j++) { B[i+j*ldb] -= B[i+k*ldb] * A[j+k*lda]; } } } return BLIS_SUCCESS; } /* TRSM scalar code for the case XA = alpha * B * A is upper-triangular, unit-diagonal, no transpose * Dimensions: X:mxn A:nxn B:mxn */ static err_t dtrsm_small_XAuB_unitDiag ( double *A, double *B, dim_t M, dim_t N, dim_t lda, dim_t ldb ) { dim_t i, j, k; for(k = 0; k < N; k++) { for(i = 0; i < M; i++) { for(j = k+1; j < N; j++) { B[i+j*ldb] -= B[i+k*ldb] * A[k+j*lda]; } } } return BLIS_SUCCESS; } /* TRSM for the case AX = alpha * B, Double precision * A is lower-triangular, no-transpose, non-unit diagonal * dimensions A: mxm X: mxn B: mxn b01---> * ***************** ** * * * * * * * * * * * * * * *b01* * * * * * * * * * * a10 ****** b11 ***************** | * * * | * * * * * | * * * | * * * * * | *a10*a11* | *b11* * * * v * * * v * * * * * *********** ***************** * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * **************** ***************** a11---> */ static err_t bli_dtrsm_small_AlXB( side_t side, obj_t* AlphaObj, obj_t* a, obj_t* b, cntx_t* cntx, cntl_t* cntl ) { dim_t D_MR = 4; //size of block along 'M' dimpension dim_t D_NR = 8; //size of block along 'N' dimension dim_t m = bli_obj_length(b); // number of rows of matrix B dim_t n = bli_obj_width(b); // number of columns of matrix B #ifdef BLIS_ENABLE_SMALL_MATRIX_ROME if((m>D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME_ROW_PANEL_M && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME) || (m> D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME_COLUMN_PANEL_N) || (m>D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME_COLUMN_PANEL_M && n D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_NAPLES) { return BLIS_NOT_YET_IMPLEMENTED; } #endif dim_t m_remainder = m & 3; //number of remainder rows dim_t n_remainder = n & 7; //number of remainder columns dim_t cs_a = bli_obj_col_stride(a); // column stride of A dim_t cs_b = bli_obj_col_stride(b); // column stride of B dim_t i, j, k; //loop variables dim_t k_iter; //number of times GEMM to be performed double AlphaVal = *(double *)AlphaObj->buffer; //value of alpha double *L = a->buffer; //pointer to matrix A double *B = b->buffer; //pointer to matrix B double *a10, *a11, *b01, *b11; //pointers that point to blocks for GEMM and TRSM double *ptr_b01_dup; double f_t[4] __attribute__((aligned(64)));//buffer to store corner column when m_remainder !=0 double* f_temp; double ones = 1.0; //scratch registers __m256d ymm0, ymm1, ymm2, ymm3; __m256d ymm4, ymm5, ymm6, ymm7; __m256d ymm8, ymm9, ymm10, ymm11; __m256d ymm12, ymm13, ymm14, ymm15; __m256d ymm16; for(j = 0; j+D_NR-1 < n; j += D_NR) //loop along 'N' dimension { for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' dimension { a10 = L +i; //pointer to block of A to be used for GEMM a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM b01 = B + j*cs_b; //pointer to block of B to be used for GEMM b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM k_iter = i / D_MR; //number of times GEMM to be performed(in blocks of 4x4) ymm8 = _mm256_setzero_pd(); ymm9 = _mm256_setzero_pd(); ymm10 = _mm256_setzero_pd(); ymm11 = _mm256_setzero_pd(); ymm12 = _mm256_setzero_pd(); ymm13 = _mm256_setzero_pd(); ymm14 = _mm256_setzero_pd(); ymm15 = _mm256_setzero_pd(); ///GEMM code begins/// for(k = 0; k< k_iter; k++) //loop for number of GEMM operations { ptr_b01_dup = b01; ymm16 = _mm256_loadu_pd((double const *)(a10));//A10[0][0] A10[1][0] A10[2][0] A10[3][0] ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[0][4] ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[0][5] ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[0][6] ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[0][7] b01 += 1; //mobe to next row of B ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[0][4]*A10[0][0] B01[0][4]*A10[1][0] B01[0][4]*A10[2][0] B01[0][4]*A10[3][0]) ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[0][5]*A10[0][0] B01[0][5]*A10[1][0] B01[0][5]*A10[2][0] B01[0][5]*A10[3][0]) ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[0][6]*A10[0][0] B01[0][6]*A10[1][0] B01[0][6]*A10[2][0] B01[0][6]*A10[3][0]) ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[0][7]*A10[0][0] B01[0][7]*A10[1][0] B01[0][7]*A10[2][0] B01[0][7]*A10[3][0]) ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a));//A10[0][1] A10[1][1] A10[2][1] A10[3][1] ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[1][4] ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[1][5] ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[1][6] ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[1][7] b01 += 1; //mobe to next row of B ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[1][4]*A10[0][1] B01[1][4]*A10[1][1] B01[1][4]*A10[2][1] B01[1][4]*A10[3][1]) ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[1][5]*A10[0][1] B01[1][5]*A10[1][1] B01[1][5]*A10[2][1] B01[1][5]*A10[3][1]) ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[1][6]*A10[0][1] B01[1][6]*A10[1][1] B01[1][6]*A10[2][1] B01[1][6]*A10[3][1]) ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[1][7]*A10[0][1] B01[1][7]*A10[1][1] B01[1][7]*A10[2][1] B01[1][7]*A10[3][1]) ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2));//A10[0][2] A10[1][2] A10[2][2] A10[3][2] ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[2][4] ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[2][5] ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[2][6] ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[2][7] b01 += 1; //mobe to next row of B ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[2][4]*A10[0][2] B01[2][4]*A10[1][2] B01[2][4]*A10[2][2] B01[2][4]*A10[3][2]) ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[2][5]*A10[0][2] B01[2][5]*A10[1][2] B01[2][5]*A10[2][2] B01[2][5]*A10[3][2]) ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[2][6]*A10[0][2] B01[2][6]*A10[1][2] B01[2][6]*A10[2][2] B01[2][6]*A10[3][2]) ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[2][7]*A10[0][2] B01[2][7]*A10[1][2] B01[2][7]*A10[2][2] B01[2][7]*A10[3][2]) ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3));//A10[0][3] A10[1][3] A10[2][3] A10[3][3] ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[3][4] ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[3][5] ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[3][6] ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[3][7] b01 += 1; //mobe to next row of B ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[3][0]*A10[0][3] B01[3][0]*A10[3][0] B01[3][0]*A10[2][3] B01[3][0]*A10[3][0]) ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[3][1]*A10[0][3] B01[3][1]*A10[3][0] B01[3][1]*A10[2][3] B01[3][1]*A10[3][0]) ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[3][2]*A10[0][3] B01[3][2]*A10[3][0] B01[3][2]*A10[2][3] B01[3][2]*A10[3][0]) ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[3][3]*A10[0][3] B01[3][3]*A10[3][0] B01[3][3]*A10[2][3] B01[3][3]*A10[3][0]) ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[3][4]*A10[0][3] B01[3][4]*A10[3][0] B01[3][4]*A10[2][3] B01[3][4]*A10[3][3]) ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[3][5]*A10[0][3] B01[3][5]*A10[3][0] B01[3][5]*A10[2][3] B01[3][5]*A10[3][3]) ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[3][6]*A10[0][3] B01[3][6]*A10[3][0] B01[3][6]*A10[2][3] B01[3][6]*A10[3][3]) ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[3][7]*A10[0][3] B01[3][7]*A10[3][0] B01[3][7]*A10[2][3] B01[3][7]*A10[3][3]) a10 += D_MR * cs_a; //pointer math to calculate next block of A for GEMM b01 = ptr_b01_dup + D_MR; //pointer math to calculate next block of B for GEMM } ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *6)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b *7)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3] ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] * alpha -= B01[0-3][6] ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); //B11[0-3][7] * alpha -= B01[0-3][7] ///implement TRSM/// ///transpose of B11// ///unpacklow/// ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[2][4] B11[2][5] ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[2][6] B11[2][7] //rearrange low elements ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] ////unpackhigh//// ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[1][4] B11[1][5] B11[3][4] B11[3][5] ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[1][6] B11[1][7] B11[3][6] B11[3][7] //rearrange high elements ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] ymm0 = _mm256_broadcast_sd((double const *)&ones); //broadcast diagonal elements of A11 ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a +1)); //A11[1][1] ymm3 = _mm256_broadcast_sd((double const *)(a11+cs_a*2 + 2)); //A11[2][2] ymm4 = _mm256_broadcast_sd((double const *)(a11+cs_a*3 + 3)); //A11[3][3] ymm5 = _mm256_unpacklo_pd(ymm1, ymm2); //A11[0][0] A11[0][0] A11[1][1] A11[1][1] ymm6 = _mm256_unpacklo_pd(ymm3, ymm4); //A11[2][2] A11[2][2] A11[3][3] A11[3][3] ymm5 = _mm256_blend_pd(ymm5, ymm6, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] ymm0 = _mm256_div_pd(ymm0, ymm5); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2] //extract a00 ymm1 = _mm256_permute_pd(ymm0, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0] //(Row 0): perform mul operation of reciprocal of L(0,0) element with 1st row elements of B ymm8 = _mm256_mul_pd(ymm8, ymm1); //B11[0-3][0] /= A11[0][0] ymm12 = _mm256_mul_pd(ymm12, ymm1); //B11[0-3][4] /= A11[0][0] //extract a11 ymm1 = _mm256_permute_pd(ymm0, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[3][3] 1/A11[3][3] ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x00); //1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1] ymm2 = _mm256_broadcast_sd((double const *)(a11 +1)); //A11[1][0] ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][0] ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][0] a11 += cs_a; //(Row1): FMA operations ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); //B11[1][0-3] -= A11[1][0] * B11[0-3][0] ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[2][0-3] -= A11[2][0] * B11[0-3][0] ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); //B11[3][0-3] -= A11[3][0] * B11[0-3][0] ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); //B11[5][0-3] -= A11[1][0] * B11[0-3][4] ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[6][0-3] -= A11[2][0] * B11[0-3][4] ymm15 = _mm256_fnmadd_pd(ymm4, ymm12, ymm15); //B11[7][0-3] -= A11[3][0] * B11[0-3][4] ymm9 = _mm256_mul_pd(ymm9, ymm1); //B11[0-3][1] /= A11[1][1] ymm13 = _mm256_mul_pd(ymm13, ymm1); //B11[0-3][5] /= A11[1][1] ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][1] ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][1] a11 += cs_a; //extract a22 ymm1 = _mm256_permute_pd(ymm0, 0x00); //1/A11[0][0] 1/A110[][0] 1/A11[2][2] 1/A11[2][2] ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x11); //1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2] //(ROw2): FMA operations ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); //B11[2][0-3] -= A11[2][1] * B11[0-3][1] ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); //B11[3][0-3] -= A11[3][1] * B11[0-3][1] ymm14 = _mm256_fnmadd_pd(ymm3, ymm13, ymm14); //B11[6][0-3] -= A11[2][1] * B11[0-3][5] ymm15 = _mm256_fnmadd_pd(ymm4, ymm13, ymm15); //B11[7][0-3] -= A11[3][1] * B11[0-3][5] //perform mul operation ymm10 = _mm256_mul_pd(ymm10, ymm1); //B11[0-3][2] /= A11[2][2] ymm14 = _mm256_mul_pd(ymm14, ymm1); //B11[0-3][6] /= A11[2][2] ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][2] a11 += cs_a; //extract a33 ymm1 = _mm256_permute_pd(ymm0, 0x0C); //1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x11);//1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3] //(ROw2): FMA operations ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); //B11[3][0-3] -= A11[3][2] * B11[0-3][2] ymm15 = _mm256_fnmadd_pd(ymm4, ymm14, ymm15); //B11[7][0-3] -= A11[3][2] * B11[0-3][6] //perform mul operation ymm11 = _mm256_mul_pd(ymm11, ymm1); //B11[0-3][3] /= A11[3][3] ymm15 = _mm256_mul_pd(ymm15, ymm1); //B11[0-3][7] /= A11[3][3] //unpacklow// ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] //rearrange low elements ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] ///unpack high/// ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] //rearrange high elements ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm4); //store B11[4][0-3] _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm5); //store B11[5][0-3] _mm256_storeu_pd((double *)(b11 + cs_b * 6), ymm6); //store B11[6][0-3] _mm256_storeu_pd((double *)(b11 + cs_b * 7), ymm7); //store B11[7][0-3] } if(m_remainder) //implementation for reamainder rows(when 'M' is not a multiple of D_MR) { a10 = L +i; //pointer to block of A to be used for GEMM a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM b01 = B + j*cs_b; //pointer to block of B to be used for GEMM b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM k_iter = i / D_MR; //number of times GEMM operation to be done(in blocks of 4x4) int iter; if((j+D_NR) == n) { for(iter = 0; iter < m_remainder; iter++) f_t[iter] = (b11 + cs_b * 7)[iter]; f_temp = f_t; } else f_temp = (b11 + cs_b * 7); ymm8 = _mm256_setzero_pd(); ymm9 = _mm256_setzero_pd(); ymm10 = _mm256_setzero_pd(); ymm11 = _mm256_setzero_pd(); ymm12 = _mm256_setzero_pd(); ymm13 = _mm256_setzero_pd(); ymm14 = _mm256_setzero_pd(); ymm15 = _mm256_setzero_pd(); ///GEMM code Begins/// for(k = 0; k< k_iter; k++) //loop for number of GEMM operations { ptr_b01_dup = b01; ymm16 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[0][4] ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[0][5] ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[0][6] ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[0][7] b01 += 1; //move to next row of B ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0] ) ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[0][4]*A10[0][0] B01[0][4]*A10[1][0] B01[0][4]*A10[2][0] B01[0][4]*A10[3][0]) ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[0][5]*A10[0][0] B01[0][5]*A10[1][0] B01[0][5]*A10[2][0] B01[0][5]*A10[3][0]) ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[0][6]*A10[0][0] B01[0][6]*A10[1][0] B01[0][6]*A10[2][0] B01[0][6]*A10[3][0]) ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm16 += (B01[0][7]*A10[0][0] B01[0][7]*A10[1][0] B01[0][7]*A10[2][0] B01[0][7]*A10[3][0]) ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 1)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[1][4] ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[1][5] ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[1][6] ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[1][7] b01 += 1; //move to next row of B01 ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[1][4]*A10[0][1] B01[1][4]*A10[1][1] B01[1][4]*A10[2][1] B01[1][4]*A10[3][1]) ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[1][5]*A10[0][1] B01[1][5]*A10[1][1] B01[1][5]*A10[2][1] B01[1][5]*A10[3][1]) ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[1][6]*A10[0][1] B01[1][6]*A10[1][1] B01[1][6]*A10[2][1] B01[1][6]*A10[3][1]) ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[1][7]*A10[0][1] B01[1][7]*A10[1][1] B01[1][7]*A10[2][1] B01[1][7]*A10[3][1]) ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] //A10[1][2] A10[2][2] A10[3][2] ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[2][4] ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[2][5] ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[2][6] ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[2][7] b01 += 1; //move to next row of B ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[2][4]*A10[0][2] B01[2][4]*A10[1][2] B01[2][4]*A10[2][2] B01[2][0]*A10[3][2]) ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[2][5]*A10[0][2] B01[2][5]*A10[1][2] B01[2][5]*A10[2][2] B01[2][1]*A10[3][2]) ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[2][6]*A10[0][2] B01[2][6]*A10[1][2] B01[2][6]*A10[2][2] B01[2][2]*A10[3][2]) ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[2][7]*A10[0][2] B01[2][7]*A10[1][2] B01[2][7]*A10[2][2] B01[2][3]*A10[3][2]) ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[3][4] ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[3][5] ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[3][6] ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[3][7] b01 += 1; //move to next row of B ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm8 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm8 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm8 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm8 += (B01[3][0]*A10[0][3] B01[3][4]*A10[1][3] B01[3][4]*A10[2][3] B01[3][4]*A10[3][3]) ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm8 += (B01[3][1]*A10[0][3] B01[3][5]*A10[1][3] B01[3][5]*A10[2][3] B01[3][5]*A10[3][3]) ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm8 += (B01[3][2]*A10[0][3] B01[3][6]*A10[1][3] B01[3][6]*A10[2][3] B01[3][6]*A10[3][3]) ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm8 += (B01[3][3]*A10[0][3] B01[3][7]*A10[1][3] B01[3][7]*A10[2][3] B01[3][7]*A10[3][3]) a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM } ///GEMM code ends/// ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *6)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] *alpha -= B01[0-3][0] ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] *alpha -= B01[0-3][1] ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] *alpha -= B01[0-3][2] ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] *alpha -= B01[0-3][3] ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] *alpha -= B01[0-3][4] ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] *alpha -= B01[0-3][5] ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] *alpha -= B01[0-3][6] ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); //B11[0-3][7] *alpha -= B01[0-3][7] if(3 == m_remainder) { ///implement TRSM/// ///unpacklow/// ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[1][4] B11[1][5] ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[1][6] B11[1][7] //rearrange low elements ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] ////unpackhigh//// ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[5][0] B11[5][1] B11[7][0] B11[7][1] ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[5][2] B11[5][3] B11[7][2] B11[7][3] //rearrange high elements ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] ymm0 = _mm256_broadcast_sd((double const *)&ones); //broadcast diagonal elements of A11 ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a +1)); //A11[1][1] ymm3 = _mm256_broadcast_sd((double const *)(a11+cs_a*2 + 2)); //A11[2][2] ymm5 = _mm256_unpacklo_pd(ymm1, ymm2); //A11[0][0] A11[0][0] A11[1][1] A11[1][1] ymm6 = _mm256_unpacklo_pd(ymm3, ymm0); //A11[2][2] A11[2][2] A11[3][3] A11[3][3] ymm5 = _mm256_blend_pd(ymm5, ymm6, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] ymm0 = _mm256_div_pd(ymm0, ymm5); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] //extract a00 ymm1 = _mm256_permute_pd(ymm0, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0] //(Row 0): perform mul operation of reciprocal of L(0,0) element with 1st row elements of B ymm8 = _mm256_mul_pd(ymm8, ymm1); //B11[0-3][0] /= A11[0][0] ymm12 = _mm256_mul_pd(ymm12, ymm1); //B11[0-3][4] /= A11[0][0] //extract a11 ymm1 = _mm256_permute_pd(ymm0, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[3][3] 1/A11[3][3] ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x00); //1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1] ymm2 = _mm256_broadcast_sd((double const *)(a11 +1)); //A11[1][0] ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][0] a11 += cs_a; //(Row1): FMA operations ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); //B11[1][0-3] -= B11[0-3][0]*A11[1][0] ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[2][0-3] -= B11[0-3][0]*A11[2][0] ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); //B11[5][0-3] -= B11[0-3][4]*A11[1][4] ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[6][0-3] -= B11[0-3][4]*A11[2][4] ymm9 = _mm256_mul_pd(ymm9, ymm1); //B11[0-3][1] /= A11[1][1] ymm13 = _mm256_mul_pd(ymm13, ymm1); //B11[0-3][5] /= A11[1][1] ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][1] a11 += cs_a; //extract a22 ymm1 = _mm256_permute_pd(ymm0, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x11); //1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2] //(ROw2): FMA operations ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); //B11[2][0-3] -= A11[2][1] * B11[0-3][1] ymm14 = _mm256_fnmadd_pd(ymm3, ymm13, ymm14); //B11[6][0-3] -= A11[2][1] * B11[0-3][5] //perform mul operation ymm10 = _mm256_mul_pd(ymm10, ymm1); //B11[0-3][2] /=A11[2][2] ymm14 = _mm256_mul_pd(ymm14, ymm1); //B11[0-3][6] /= A11[2][2] ymm11 = _mm256_broadcast_sd((double const *)(&ones)); ymm15 = _mm256_broadcast_sd((double const *)(&ones)); //unpacklow// ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] //rearrange low elements ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] ///unpack high/// ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[0][5] B11[1][5] B11[0][7] B11[1][7] ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[2][5] B11[3][5] B11[2][7] B11[3][7] //rearrange high elements ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] ymm8 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); //load B11[0-3][0] ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); //load B11[0-3][1] ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //load B11[0-3][2] ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //load B11[0-3][3] ymm12 = _mm256_loadu_pd((double const *)(b11 + cs_b * 4)); //load B11[0-3][4] ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b * 5)); //load B11[0-3][5] ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b * 6)); //load B11[0-3][6] ymm15 = _mm256_loadu_pd((double const *)(f_temp)); //load B11[0-3][7] //determine correct values to store ymm0 = _mm256_blend_pd(ymm0, ymm8, 0x08); ymm1 = _mm256_blend_pd(ymm1, ymm9, 0x08); ymm2 = _mm256_blend_pd(ymm2, ymm10, 0x08); ymm3 = _mm256_blend_pd(ymm3, ymm11, 0x08); ymm4 = _mm256_blend_pd(ymm4, ymm12, 0x08); ymm5 = _mm256_blend_pd(ymm5, ymm13, 0x08); ymm6 = _mm256_blend_pd(ymm6, ymm14, 0x08); ymm7 = _mm256_blend_pd(ymm7, ymm15, 0x08); } else if(2 == m_remainder) { ///implement TRSM/// ///unpacklow/// ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[1][4] B11[1][5] ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[1][6] B11[1][7] //rearrange low elements ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] ////unpackhigh//// ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[5][0] B11[5][1] B11[7][0] B11[7][1] ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[5][2] B11[5][3] B11[7][2] B11[7][3] //rearrange high elements ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] ymm0 = _mm256_broadcast_sd((double const *)&ones); //broadcast diagonal elements of A11 ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a +1)); //A11[1][1] ymm5 = _mm256_unpacklo_pd(ymm1, ymm2); //A11[0][0] A11[0][0] A11[1][1] A11[1][1] ymm5 = _mm256_blend_pd(ymm5, ymm0, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] ymm0 = _mm256_div_pd(ymm0, ymm5); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] //extract a00 ymm1 = _mm256_permute_pd(ymm0, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0] //(Row 0): perform mul operation of reciprocal of L(0,0) element with 1st row elements of B ymm8 = _mm256_mul_pd(ymm8, ymm1); //B11[0-3][0] /= A11[0][0] ymm12 = _mm256_mul_pd(ymm12, ymm1); //B11[0-3][4] /= A11[0][0] //extract a11 ymm1 = _mm256_permute_pd(ymm0, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[3][3] 1/A11[3][3] ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x00); //1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1] ymm2 = _mm256_broadcast_sd((double const *)(a11 +1)); //A11[1][0] a11 += cs_a; //(Row1): FMA operations ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); //B11[1][0-3] -= B11[0-3][0]*A11[1][0] ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); //B11[5][0-3] -= B11[0-3][4]*A11[1][4] ymm9 = _mm256_mul_pd(ymm9, ymm1); //B11[0-3][1] /= A11[1][1] ymm13 = _mm256_mul_pd(ymm13, ymm1); //B11[0-3][5] /= A11[1][1] ymm10 = _mm256_broadcast_sd((double const *)&ones); //unpacklow// ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] //rearrange low elements ymm0 = _mm256_permute2f128_pd(ymm1, ymm10, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm2 = _mm256_permute2f128_pd(ymm1, ymm10, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ymm4 = _mm256_permute2f128_pd(ymm5, ymm10, 0x20); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] ymm6 = _mm256_permute2f128_pd(ymm5, ymm10, 0x31); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] ///unpack high/// ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[0][5] B11[1][5] B11[0][7] B11[1][7] //rearrange high elements ymm1 = _mm256_permute2f128_pd(ymm8, ymm10, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm3 = _mm256_permute2f128_pd(ymm8, ymm10, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] ymm5 = _mm256_permute2f128_pd(ymm12, ymm10, 0x20); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] ymm7 = _mm256_permute2f128_pd(ymm12, ymm10, 0x31); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] ymm8 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); //load B11[0-3][0] ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); //load B11[0-3][1] ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //load B11[0-3][2] ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //load B11[0-3][3] ymm12 = _mm256_loadu_pd((double const *)(b11 + cs_b * 4)); //load B11[0-3][4] ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b * 5)); //load B11[0-3][5] ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b * 6)); //load B11[0-3][6] ymm15 = _mm256_loadu_pd((double const *)(f_temp)); //load B11[0-3][7] //determine correct values to store ymm0 = _mm256_permute2f128_pd(ymm0, ymm8, 0x30); ymm1 = _mm256_permute2f128_pd(ymm1, ymm9, 0x30); ymm2 = _mm256_permute2f128_pd(ymm2, ymm10, 0x30); ymm3 = _mm256_permute2f128_pd(ymm3, ymm11, 0x30); ymm4 = _mm256_permute2f128_pd(ymm4, ymm12, 0x30); ymm5 = _mm256_permute2f128_pd(ymm5, ymm13, 0x30); ymm6 = _mm256_permute2f128_pd(ymm6, ymm14, 0x30); ymm7 = _mm256_permute2f128_pd(ymm7, ymm15, 0x30); } else if(1 == m_remainder) { ///implement TRSM/// ///unpacklow/// ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[1][4] B11[1][5] ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[1][6] B11[1][7] //rearrange low elements ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] ////unpackhigh//// ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[5][0] B11[5][1] B11[7][0] B11[7][1] ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[5][2] B11[5][3] B11[7][2] B11[7][3] //rearrange high elements ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] ymm0 = _mm256_broadcast_sd((double const *)&ones); //broadcast diagonal elements of A11 ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] ymm0 = _mm256_div_pd(ymm0, ymm1); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] //extract a00 ymm1 = _mm256_permute_pd(ymm0, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0] //(Row 0): perform mul operation of reciprocal of L(0,0) element with 1st row elements of B ymm8 = _mm256_mul_pd(ymm8, ymm1); //B11[0-3][0] /= A11[0][0] ymm12 = _mm256_mul_pd(ymm12, ymm1); //B11[0-3][4] /= A11[0][0] ymm9 = _mm256_broadcast_sd((double const *)(&ones)); //unpacklow// ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] ymm5 = _mm256_unpacklo_pd(ymm12, ymm9); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] //rearrange low elements ymm0 = _mm256_permute2f128_pd(ymm1, ymm9, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm2 = _mm256_permute2f128_pd(ymm1, ymm9, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ymm4 = _mm256_permute2f128_pd(ymm5, ymm9, 0x20); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] ymm6 = _mm256_permute2f128_pd(ymm5, ymm9, 0x31); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] ///unpack high/// ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] ymm12 = _mm256_unpackhi_pd(ymm12, ymm9); //B11[0][5] B11[1][5] B11[0][7] B11[1][7] //rearrange high elements ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] ymm5 = _mm256_permute2f128_pd(ymm12, ymm9, 0x20); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] ymm7 = _mm256_permute2f128_pd(ymm12, ymm9, 0x31); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] ymm8 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); //load B11[0-3][0] ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); //load B11[0-3][1] ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //load B11[0-3][2] ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //load B11[0-3][3] ymm12 = _mm256_loadu_pd((double const *)(b11 + cs_b * 4)); //load B11[0-3][4] ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b * 5)); //load B11[0-3][5] ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b * 6)); //load B11[0-3][6] ymm15 = _mm256_loadu_pd((double const *)(f_temp)); //load B11[0-3][7] //determine correct values to store ymm0 = _mm256_blend_pd(ymm0, ymm8, 0x0E); ymm1 = _mm256_blend_pd(ymm1, ymm9, 0x0E); ymm2 = _mm256_blend_pd(ymm2, ymm10, 0x0E); ymm3 = _mm256_blend_pd(ymm3, ymm11, 0x0E); ymm4 = _mm256_blend_pd(ymm4, ymm12, 0x0E); ymm5 = _mm256_blend_pd(ymm5, ymm13, 0x0E); ymm6 = _mm256_blend_pd(ymm6, ymm14, 0x0E); ymm7 = _mm256_blend_pd(ymm7, ymm15, 0x0E); } _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm4); //store(B11[0-3][4]) _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm5); //store(B11[0-3][5]) _mm256_storeu_pd((double *)(b11 + cs_b * 6), ymm6); //store(B11[0-3][6]) _mm256_storeu_pd((double *)(f_temp), ymm7); //store(B11[0-3][7]) if((j+D_NR) == n) { for(iter = 0; iter < m_remainder; iter++) (b11 + cs_b * 7)[iter] = f_t[iter]; } } } if((n & 4)) //implementation for remainder columns(when 'n_remainder' is greater than 4) { for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' direction { a10 = L +i; //pointer to block of A to be used for GEMM a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM b01 = B + j*cs_b; //pointer to block of B to be used for GEMM b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM k_iter = i / D_MR; //number of times GEMM to be performed(in block of 4) ///GEMM for previously calculated values /// //load 4x4 block from b11 ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] ymm4 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); ymm6 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); for(k = 0; k < k_iter; k++) //loop for number of GEMM operations { ptr_b01_dup = b01; ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a*2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] b01 += 1; //move to next row of B ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] b01 += 1; ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] b01 += 1; ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[1][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[2][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[3][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] b01 += 1; ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM } ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B01[0-3][1] *alpha -= ymm5 ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B01[0-3][2] *alpha -= ymm6 ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B01[0-3][3] *alpha -= ymm7 ///implement TRSM/// //1st col ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0] ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0] ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][0] //2nd col a11 += cs_a; ymm8 = _mm256_broadcast_sd((double const *)(a11 + 1)); //A11[1][1] ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1] ymm10 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][1] //3rd col a11 += cs_a; ymm11 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][2] ymm12 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][2] //4th col a11 += cs_a; ymm13 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][3] //compute reciprocals of L(i,i) and broadcast in registers ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[0][0] A11[2][2] A11[2][2] ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[1][1] A11[1][1] A11[3][3] A11[3][3] ymm14 = _mm256_broadcast_sd((double const *)&ones); ymm4 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] ymm14 = _mm256_div_pd(ymm14, ymm4); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] ////unpacklow//// ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] //rearrange low elements ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] ////unpackhigh//// ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] //rearrange high elements ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] //extract a00 ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0] //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B ymm4 = _mm256_mul_pd(ymm4, ymm15); //B11[0][0-3] /= A11[0][0] //extract diag a11 from a ymm15 = _mm256_permute_pd(ymm14, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[3][3] 1/A11[3][3] ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1] //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0) ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]*B11[0][0-3] ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0]*B11[0][0-3] ymm13 = _mm256_fnmadd_pd(ymm7, ymm4, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][0]*B11[0][0-3] //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B ymm8 = _mm256_mul_pd(ymm8, ymm15); //B11[1][0-3] /= A11[1][1] //extract diag a22 from a ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2] //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1]*B11[1][0-3] ymm13 = _mm256_fnmadd_pd(ymm10, ymm8, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][1]*B11[1][0-3] //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B ymm11 = _mm256_mul_pd(ymm11, ymm15); //B11[2][0-3] /= A11[2][2] //extract diag a33 from a ymm15 = _mm256_permute_pd(ymm14, 0x0C); //1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3] //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) ymm13 = _mm256_fnmadd_pd(ymm12, ymm11, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][2]*B11[2][0-3] //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B ymm13 = _mm256_mul_pd(ymm13, ymm15); //B11[3][0-3] /= A11[3][3] //--> Transpose and store results of columns of B block <--// ////unpacklow//// ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] //rearrange low elements ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ////unpackhigh//// ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] //rearrange high elements ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm2); //store(B11[0-3][2]) _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[0-3][3]) } if(m_remainder) //implementation for remainder rows(when 'M' is not a multiple of D_MR) { a10 = L +i; //pointer to block of A to be used for GEMM a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM b01 = B + j*cs_b; //pointer to block of B to be used for GEMM b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM k_iter = i / D_MR; //number of GEMM operations to be performed(in blocks of 4x4) dim_t iter; if((j+4) == n) { f_temp = f_t; for(iter = 0; iter < m_remainder; iter++) f_temp[iter] = (b11 + cs_b * 3)[iter]; } else f_temp = (b11 + cs_b * 3); ///GEMM for previously calculated values /// //load 4x4 block from b11 ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ymm3 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] ymm4 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); ymm6 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); for(k = 0; k < k_iter; k++) //looop for number of GEMM operations { ptr_b01_dup = b01; ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] b01 += 1; ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] b01 += 1; ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] b01 += 1; ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] b01 += 1; ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM } ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] *alpha -= ymm5 ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[0-3][2] *alpha -= ymm6 ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[0-3][3] *alpha -= ymm7 if(3 == m_remainder) { ///implement TRSM/// //1st col ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0] ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0] //2nd col a11 += cs_a; ymm8 = _mm256_broadcast_sd((double const *)(a11 + 1)); //A11[1][1] ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1] //3rd col a11 += cs_a; ymm11 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][2] //4th col a11 += cs_a; ymm13 = _mm256_broadcast_sd((double const *)(&ones)); //A11[3][3] //compute reciprocals of L(i,i) and broadcast in registers ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[0][0] A11[1][1] A11[1][1] ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[2][2] A11[3][3] A11[3][3] ymm14 = _mm256_broadcast_sd((double const *)&ones); ymm4 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] ymm14 = _mm256_div_pd(ymm14, ymm4); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] ////unpacklow//// ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] //rearrange low elements ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] ////unpackhigh//// ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] //rearrange high elements ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] //extract a00 ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00);//1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0] //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B ymm4 = _mm256_mul_pd(ymm4, ymm15); //B11[0][0-3] /= A11[0][0] //extract diag a11 from a ymm15 = _mm256_permute_pd(ymm14, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[3][3] 1/A11[3][3] ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[][] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1] //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0) ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]* B11[0][0-3] ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0]* B11[0][0-3] //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B ymm8 = _mm256_mul_pd(ymm8, ymm15); //B11[1][0-3] /= A11[1][1] //extract diag a22 from a ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2] //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1]* B11[1][0-3] //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B ymm11 = _mm256_mul_pd(ymm11, ymm15); //B11[2][0-3] /= A11[2][2] ymm13 = _mm256_broadcast_sd((double const *)(&ones)); //--> Transpose and store results of columns of B block <--// ////unpacklow//// ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] //rearrange low elements ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ////unpackhigh//// ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] //rearrange high elements ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] //load 4x4 block from b11 ymm4 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][2] B11[3][3] //determine correct values to store ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x08); ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x08); ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x08); ymm3 = _mm256_blend_pd(ymm3, ymm7, 0x08); } else if( 2 == m_remainder ) { ///implement TRSM/// //1st col ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0] //2nd col a11 += cs_a; ymm8 = _mm256_broadcast_sd((double const *)(a11 + 1)); //A11[1][1] //compute reciprocals of L(i,i) and broadcast in registers ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[0][0] A11[1][1] A11[1][1] ymm14 = _mm256_broadcast_sd((double const *)&ones); ymm4 = _mm256_blend_pd(ymm4, ymm14, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] ymm14 = _mm256_div_pd(ymm14, ymm4); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] ////unpacklow//// ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] //rearrange low elements ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] ////unpackhigh//// ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] //rearrange high elements ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] //extract a00 ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00);//1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0] //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B ymm4 = _mm256_mul_pd(ymm4, ymm15); //B11[0][0-3] /= A11[0][0] //extract diag a11 from a ymm15 = _mm256_permute_pd(ymm14, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[3][3] 1/A11[3][3] ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[][] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1] //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0) ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]* B11[0][0-3] //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B ymm8 = _mm256_mul_pd(ymm8, ymm15); //B11[1][0-3] /= A11[1][1] ymm11 = _mm256_broadcast_sd((double const *)(&ones)); ymm13 = _mm256_broadcast_sd((double const *)(&ones)); //--> Transpose and store results of columns of B block <--// ////unpacklow//// ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] //rearrange low elements ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ////unpackhigh//// ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] //rearrange high elements ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] //load 4x4 block from b11 ymm4 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][2] B11[3][3] //determine correct values to store ymm0 = _mm256_permute2f128_pd(ymm0, ymm4,0x30); ymm1 = _mm256_permute2f128_pd(ymm1, ymm5,0x30); ymm2 = _mm256_permute2f128_pd(ymm2, ymm6,0x30); ymm3 = _mm256_permute2f128_pd(ymm3, ymm7,0x30); } else if(1 == m_remainder) { ///implement TRSM/// //1st col ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] ymm14 = _mm256_broadcast_sd((double const *)&ones); ymm14 = _mm256_div_pd(ymm14, ymm4); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] ////unpacklow//// ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] //rearrange low elements ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] ////unpackhigh//// ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] //rearrange high elements ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] //extract a00 ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00);//1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0] //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B ymm4 = _mm256_mul_pd(ymm4, ymm15); //B11[0][0-3] /= A11[0][0] ymm8 = _mm256_broadcast_sd((double const *)(&ones)); ymm11 = _mm256_broadcast_sd((double const *)(&ones)); ymm13 = _mm256_broadcast_sd((double const *)(&ones)); //--> Transpose and store results of columns of B block <--// ////unpacklow//// ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] //rearrange low elements ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ////unpackhigh//// ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] //rearrange high elements ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] //load 4x4 block from b11 ymm4 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][2] B11[3][3] //determine correct values to store ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x0E); ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x0E); ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x0E); ymm3 = _mm256_blend_pd(ymm3, ymm7, 0x0E); } _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) _mm256_storeu_pd((double *)(f_temp), ymm3); //store(B11[0-3][3]) if((j+4) == n) { for(iter = 0; iter < m_remainder; iter++) (b11 + cs_b * 3)[iter] = f_temp[iter]; } } n_remainder -= 4; j += 4; } if(n_remainder) //implementation fo remaining columns(when 'N' is not a multiple of D_NR) { for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' direction { a10 = L +i; //pointer to block of A to be used for GEMM a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM b01 = B + j*cs_b; //pointer to block of B to be used for GEMM b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM k_iter = i / D_MR; //number of GEMM operations to be performed(in blocks of 4x4) ymm4 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); ymm6 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); ///GEMM for previously calculated values /// //load 4x4 block from b11 if(3 == n_remainder) { ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] for(k = 0; k < k_iter; k++) { ptr_b01_dup = b01; ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] b01 += 1; ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] b01 += 1; ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] b01 += 1; ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] b01 += 1; ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM } ///GEMM code ends/// ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha Value ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] *alpha -= ymm5 ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[0-3][2] *alpha -= ymm6 ymm3 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][3] *alpha -= ymm7 } else if(2 == n_remainder) { ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] for(k = 0; k < k_iter; k++) { ptr_b01_dup = b01; ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] b01 += 1; ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] b01 += 1; ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] b01 += 1; ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] b01 += 1; ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM } ///GEMM code ends/// ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha Value ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] *alpha -= ymm5 ymm2 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][2] *alpha -= ymm6 ymm3 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][3] *alpha -= ymm7 } else if(1 == n_remainder) { ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] for(k = 0; k < k_iter; k++) { ptr_b01_dup = b01; ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] b01 += 1; ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] b01 += 1; ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] b01 += 1; ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] b01 += 1; ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM } ///GEMM code ends/// ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha Value ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 ymm1 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][1] *alpha -= ymm5 ymm2 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][2] *alpha -= ymm6 ymm3 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][3] *alpha -= ymm7 } ///implement TRSM/// //1st col ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0] ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0] ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][0] //2nd col a11 += cs_a; ymm8 = _mm256_broadcast_sd((double const *)(a11 + 1)); //A11[1][1] ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1] ymm10 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][1] //3rd col a11 += cs_a; ymm11 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][2] ymm12 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][2] //4th col a11 += cs_a; ymm13 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][3] //compute reciprocals of L(i,i) and broadcast in registers ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[0][0] A11[1][1] A11[1][1] ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[2][2] A11[3][3] A11[3][3] ymm14 = _mm256_broadcast_sd((double const *)&ones); ymm4 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] ymm14 = _mm256_div_pd(ymm14, ymm4); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] ////unpacklow//// ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] //rearrange low elements ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] ////unpackhigh//// ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] //rearrange high elements ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] //extract a00 ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0] //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B ymm4 = _mm256_mul_pd(ymm4, ymm15); //B11[0][0-3] /= A11[0][0] //extract diag a11 from a ymm15 = _mm256_permute_pd(ymm14, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2] ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1] //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0) ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0] * B11[0][0-3] ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0] * B11[0][0-3] ymm13 = _mm256_fnmadd_pd(ymm7, ymm4, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][0] * B11[0][0-3] //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B ymm8 = _mm256_mul_pd(ymm8, ymm15); //B11[1][0-3] /= A11[1][1] //extract diag a22 from a ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2] //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1] * B11[1][0-3] ymm13 = _mm256_fnmadd_pd(ymm10, ymm8, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][1] * B11[1][0-3] //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B ymm11 = _mm256_mul_pd(ymm11, ymm15); //B11[2][0-3] /= A11[2][2] //extract diag a33 from a ymm15 = _mm256_permute_pd(ymm14, 0x0C); //1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3] //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) ymm13 = _mm256_fnmadd_pd(ymm12, ymm11, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][2] * B11[2][0-3] //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B ymm13 = _mm256_mul_pd(ymm13, ymm15); //B11[3][0-3] /= A11[3][3] //--> Transpose and store results of columns of B block <--// ////unpacklow//// ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] //rearrange low elements ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ////unpackhigh//// ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] //rearrange high elements ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] if(3 == n_remainder) { _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) } else if(2 == n_remainder) { _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) } else if(1 == n_remainder) { _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) } } if(m_remainder) //implementation for remainder rows(when 'M' is not a multiple of D_MR) { a10 = L +i; //pointer to block of A to be used for GEMM a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM b01 = B + j*cs_b; //pointer to block of B to be used for GEMM b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM k_iter = i / D_MR; //number of times GEMM operations to be performed dim_t iter; if((j+n_remainder) == n) { f_temp = f_t; for(iter = 0; iter < m_remainder; iter++) f_temp[iter] = (b11 + cs_b * (n_remainder -1))[iter]; } else f_temp = (b11 + cs_b * (n_remainder -1)); ymm4 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); ymm6 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); ///GEMM for previously calculated values /// //load 4x4 block from b11 if(3 == n_remainder) { ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm2 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] for(k = 0; k < k_iter; k++) //loop for number of GEMM operations { ptr_b01_dup = b01; ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[0][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[0][1] ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[0][2] b01 += 1; //move to next row of B ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[1][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[1][1] ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[1][2] b01 += 1; //move to next row of B ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[2][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[2][1] ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[2][2] b01 += 1; //move to next row of B ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[3][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[3][1] ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[3][2] b01 += 1; //move to next row of B ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM } ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha value ymm8 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] * alpha -= ymm4 ymm9 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] * alpha -= ymm5 ymm10 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[0-3][2] * alpha -= ymm6 ///implement TRSM/// //determine correct values to store if(3 == m_remainder) { ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); } else if(2 == m_remainder) { ymm0 = _mm256_permute2f128_pd(ymm8, ymm0, 0x30); ymm1 = _mm256_permute2f128_pd(ymm9, ymm1, 0x30); ymm2 = _mm256_permute2f128_pd(ymm10, ymm2, 0x30); } else if(1 == m_remainder) { ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); } _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) _mm256_storeu_pd((double *)(f_temp), ymm2); //store(B11[0-3][2]) } if(2 == n_remainder) { ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm1 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] for(k = 0; k < k_iter; k++) //loop for number of GEMM operations { ptr_b01_dup = b01; ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[0][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[0][1] b01 += 1; //move to next row of B ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[1][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[1][1] b01 += 1; //move to next row of B ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[2][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[2][1] b01 += 1; //move to next row of B ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[3][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[3][1] b01 += 1; //move to next row of B ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM } ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha value ymm8 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] * alpha -= ymm4 ymm9 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] * alpha -= ymm5 ///implement TRSM/// //determine correct values to store if(3 == m_remainder) { ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); } else if(2 == m_remainder) { ymm0 = _mm256_permute2f128_pd(ymm8, ymm0, 0x30); ymm1 = _mm256_permute2f128_pd(ymm9, ymm1, 0x30); } else if(1 == m_remainder) { ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); } _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(f_temp), ymm1); //store(B11[0-3][1]) } if(n_remainder == 1) { ymm0 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] for(k = 0; k < k_iter; k++) //loop for number of GEMM operations { ptr_b01_dup = b01; ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[0][0] b01 += 1; //move to next row of B ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[1][0] b01 += 1; //move to next row of B ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[2][0] b01 += 1; //move to next row of B ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[3][0] b01 += 1; //move to next row of B ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM } ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha value ymm8 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] * alpha -= ymm4 ///implement TRSM/// //determine correct values to store if(3 == m_remainder) { ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); } else if(2 == m_remainder) { ymm0 = _mm256_permute2f128_pd(ymm8, ymm0, 0x30); } else if(1 == m_remainder) { ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); } _mm256_storeu_pd((double *)(f_temp), ymm0); //store(B11[0-3][0]) } if((j+n_remainder) == n) { for(iter = 0; iter < m_remainder; iter++) (b11 + cs_b * (n_remainder-1))[iter] = f_temp[iter]; } ///scalar code for trsm without alpha/// dtrsm_small_AlXB(a11, b11, m_remainder, n_remainder, cs_a, cs_b); } } return BLIS_SUCCESS; } /* TRSM for the case AX = alpha * B, Double precision * A is lower-triangular, no-transpose, unit diagonal * dimensions A: mxm X: mxn B: mxn b01---> * ***************** ** * * * * * * * * * * * * * * *b01* * * * * * * * * * * a10 ****** b11 ***************** | * * * | * * * * * | * * * | * * * * * | *a10*a11* | *b11* * * * v * * * v * * * * * *********** ***************** * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * **************** ***************** a11---> */ static err_t bli_dtrsm_small_AlXB_unitDiag( side_t side, obj_t* AlphaObj, obj_t* a, obj_t* b, cntx_t* cntx, cntl_t* cntl ) { dim_t D_MR = 4; //size of block along 'M' dimpension dim_t D_NR = 8; //size of block along 'N' dimension dim_t m = bli_obj_length(b); // number of rows of matrix B dim_t n = bli_obj_width(b); // number of columns of matrix B #ifdef BLIS_ENABLE_SMALL_MATRIX_ROME if((m>D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME_ROW_PANEL_M && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME) || (m> D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME_COLUMN_PANEL_N) || (m>D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME_COLUMN_PANEL_M && n D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_NAPLES) { return BLIS_NOT_YET_IMPLEMENTED; } #endif dim_t m_remainder = m & (3); //number of remainder rows dim_t n_remainder = n & (7); //number of remainder columns dim_t cs_a = bli_obj_col_stride(a); // column stride of A dim_t cs_b = bli_obj_col_stride(b); // column stride of B dim_t i, j, k; //loop variables dim_t k_iter; //number of times GEMM to be performed double AlphaVal = *(double *)AlphaObj->buffer; //value of alpha double *L = a->buffer; //pointer to matrix A double *B = b->buffer; //pointer to matrix B double *a10, *a11, *b01, *b11; //pointers that point to blocks for GEMM and TRSM double *ptr_b01_dup; double f_t[4] __attribute__((aligned(64)));//buffer to store corner column when m_remainder !=0 double* f_temp; double ones = 1.0; //scratch registers __m256d ymm0, ymm1, ymm2, ymm3; __m256d ymm4, ymm5, ymm6, ymm7; __m256d ymm8, ymm9, ymm10, ymm11; __m256d ymm12, ymm13, ymm14, ymm15; __m256d ymm16; for(j = 0; j+D_NR-1 < n; j += D_NR) //loop along 'N' dimension { for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' dimension { a10 = L +i; //pointer to block of A to be used for GEMM a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM b01 = B + j*cs_b; //pointer to block of B to be used for GEMM b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM k_iter = i / D_MR; //number of times GEMM to be performed(in blocks of 4x4) ymm8 = _mm256_setzero_pd(); ymm9 = _mm256_setzero_pd(); ymm10 = _mm256_setzero_pd(); ymm11 = _mm256_setzero_pd(); ymm12 = _mm256_setzero_pd(); ymm13 = _mm256_setzero_pd(); ymm14 = _mm256_setzero_pd(); ymm15 = _mm256_setzero_pd(); ///GEMM code begins/// for(k = 0; k< k_iter; k++) //loop for number of GEMM operations { ptr_b01_dup = b01; ymm16 = _mm256_loadu_pd((double const *)(a10));//A10[0][0] A10[1][0] A10[2][0] A10[3][0] ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[0][4] ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[0][5] ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[0][6] ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[0][7] b01 += 1; //mobe to next row of B ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[0][4]*A10[0][0] B01[0][4]*A10[1][0] B01[0][4]*A10[2][0] B01[0][4]*A10[3][0]) ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[0][5]*A10[0][0] B01[0][5]*A10[1][0] B01[0][5]*A10[2][0] B01[0][5]*A10[3][0]) ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[0][6]*A10[0][0] B01[0][6]*A10[1][0] B01[0][6]*A10[2][0] B01[0][6]*A10[3][0]) ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[0][7]*A10[0][0] B01[0][7]*A10[1][0] B01[0][7]*A10[2][0] B01[0][7]*A10[3][0]) ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a));//A10[0][1] A10[1][1] A10[2][1] A10[3][1] ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[1][4] ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[1][5] ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[1][6] ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[1][7] b01 += 1; //mobe to next row of B ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[1][4]*A10[0][1] B01[1][4]*A10[1][1] B01[1][4]*A10[2][1] B01[1][4]*A10[3][1]) ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[1][5]*A10[0][1] B01[1][5]*A10[1][1] B01[1][5]*A10[2][1] B01[1][5]*A10[3][1]) ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[1][6]*A10[0][1] B01[1][6]*A10[1][1] B01[1][6]*A10[2][1] B01[1][6]*A10[3][1]) ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[1][7]*A10[0][1] B01[1][7]*A10[1][1] B01[1][7]*A10[2][1] B01[1][7]*A10[3][1]) ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2));//A10[0][2] A10[1][2] A10[2][2] A10[3][2] ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[2][4] ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[2][5] ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[2][6] ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[2][7] b01 += 1; //mobe to next row of B ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[2][4]*A10[0][2] B01[2][4]*A10[1][2] B01[2][4]*A10[2][2] B01[2][4]*A10[3][2]) ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[2][5]*A10[0][2] B01[2][5]*A10[1][2] B01[2][5]*A10[2][2] B01[2][5]*A10[3][2]) ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[2][6]*A10[0][2] B01[2][6]*A10[1][2] B01[2][6]*A10[2][2] B01[2][6]*A10[3][2]) ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[2][7]*A10[0][2] B01[2][7]*A10[1][2] B01[2][7]*A10[2][2] B01[2][7]*A10[3][2]) ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3));//A10[0][3] A10[1][3] A10[2][3] A10[3][3] ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[3][4] ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[3][5] ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[3][6] ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[3][7] b01 += 1; //mobe to next row of B ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[3][0]*A10[0][3] B01[3][0]*A10[3][0] B01[3][0]*A10[2][3] B01[3][0]*A10[3][0]) ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[3][1]*A10[0][3] B01[3][1]*A10[3][0] B01[3][1]*A10[2][3] B01[3][1]*A10[3][0]) ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[3][2]*A10[0][3] B01[3][2]*A10[3][0] B01[3][2]*A10[2][3] B01[3][2]*A10[3][0]) ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[3][3]*A10[0][3] B01[3][3]*A10[3][0] B01[3][3]*A10[2][3] B01[3][3]*A10[3][0]) ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[3][4]*A10[0][3] B01[3][4]*A10[3][0] B01[3][4]*A10[2][3] B01[3][4]*A10[3][3]) ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[3][5]*A10[0][3] B01[3][5]*A10[3][0] B01[3][5]*A10[2][3] B01[3][5]*A10[3][3]) ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[3][6]*A10[0][3] B01[3][6]*A10[3][0] B01[3][6]*A10[2][3] B01[3][6]*A10[3][3]) ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[3][7]*A10[0][3] B01[3][7]*A10[3][0] B01[3][7]*A10[2][3] B01[3][7]*A10[3][3]) a10 += D_MR * cs_a; //pointer math to calculate next block of A for GEMM b01 = ptr_b01_dup + D_MR; //pointer math to calculate next block of B for GEMM } ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *6)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b *7)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3] ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] * alpha -= B01[0-3][6] ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); //B11[0-3][7] * alpha -= B01[0-3][7] ///implement TRSM/// ///transpose of B11// ///unpacklow/// ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[2][4] B11[2][5] ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[2][6] B11[2][7] //rearrange low elements ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] ////unpackhigh//// ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[1][4] B11[1][5] B11[3][4] B11[3][5] ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[1][6] B11[1][7] B11[3][6] B11[3][7] //rearrange high elements ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] ymm2 = _mm256_broadcast_sd((double const *)(a11 +1)); //A11[1][0] ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][0] ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][0] a11 += cs_a; //(Row1): FMA operations ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); //B11[1][0-3] -= A11[1][0] * B11[0-3][0] ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[2][0-3] -= A11[2][0] * B11[0-3][0] ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); //B11[3][0-3] -= A11[3][0] * B11[0-3][0] ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); //B11[5][0-3] -= A11[1][0] * B11[0-3][4] ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[6][0-3] -= A11[2][0] * B11[0-3][4] ymm15 = _mm256_fnmadd_pd(ymm4, ymm12, ymm15); //B11[7][0-3] -= A11[3][0] * B11[0-3][4] ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][1] ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][1] a11 += cs_a; //(ROw2): FMA operations ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); //B11[2][0-3] -= A11[2][1] * B11[0-3][1] ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); //B11[3][0-3] -= A11[3][1] * B11[0-3][1] ymm14 = _mm256_fnmadd_pd(ymm3, ymm13, ymm14); //B11[6][0-3] -= A11[2][1] * B11[0-3][5] ymm15 = _mm256_fnmadd_pd(ymm4, ymm13, ymm15); //B11[7][0-3] -= A11[3][1] * B11[0-3][5] ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][2] a11 += cs_a; //(ROw1): FMA operations ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); //B11[3][0-3] -= A11[3][2] * B11[0-3][2] ymm15 = _mm256_fnmadd_pd(ymm4, ymm14, ymm15); //B11[7][0-3] -= A11[3][2] * B11[0-3][6] //unpacklow// ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] //rearrange low elements ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] ///unpack high/// ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] //rearrange high elements ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm4); //store B11[4][0-3] _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm5); //store B11[5][0-3] _mm256_storeu_pd((double *)(b11 + cs_b * 6), ymm6); //store B11[6][0-3] _mm256_storeu_pd((double *)(b11 + cs_b * 7), ymm7); //store B11[7][0-3] } if(m_remainder) //implementation for reamainder rows(when 'M' is not a multiple of D_MR) { a10 = L +i; //pointer to block of A to be used for GEMM a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM b01 = B + j*cs_b; //pointer to block of B to be used for GEMM b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM k_iter = i / D_MR; //number of times GEMM operation to be done(in blocks of 4x4) dim_t iter; if((j+D_NR) == n) { f_temp = f_t; for(iter = 0; iter < m_remainder; iter++) f_temp[iter] = (b11 + cs_b * 7)[iter]; } else f_temp = (b11 + cs_b * 7); ymm8 = _mm256_setzero_pd(); ymm9 = _mm256_setzero_pd(); ymm10 = _mm256_setzero_pd(); ymm11 = _mm256_setzero_pd(); ymm12 = _mm256_setzero_pd(); ymm13 = _mm256_setzero_pd(); ymm14 = _mm256_setzero_pd(); ymm15 = _mm256_setzero_pd(); ///GEMM code Begins/// for(k = 0; k< k_iter; k++) //loop for number of GEMM operations { ptr_b01_dup = b01; ymm16 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[0][4] ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[0][5] ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[0][6] ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[0][7] b01 += 1; //move to next row of B ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0] ) ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[0][4]*A10[0][0] B01[0][4]*A10[1][0] B01[0][4]*A10[2][0] B01[0][4]*A10[3][0]) ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[0][5]*A10[0][0] B01[0][5]*A10[1][0] B01[0][5]*A10[2][0] B01[0][5]*A10[3][0]) ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[0][6]*A10[0][0] B01[0][6]*A10[1][0] B01[0][6]*A10[2][0] B01[0][6]*A10[3][0]) ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm16 += (B01[0][7]*A10[0][0] B01[0][7]*A10[1][0] B01[0][7]*A10[2][0] B01[0][7]*A10[3][0]) ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 1)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[1][4] ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[1][5] ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[1][6] ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[1][7] b01 += 1; //move to next row of B01 ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[1][4]*A10[0][1] B01[1][4]*A10[1][1] B01[1][4]*A10[2][1] B01[1][4]*A10[3][1]) ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[1][5]*A10[0][1] B01[1][5]*A10[1][1] B01[1][5]*A10[2][1] B01[1][5]*A10[3][1]) ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[1][6]*A10[0][1] B01[1][6]*A10[1][1] B01[1][6]*A10[2][1] B01[1][6]*A10[3][1]) ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[1][7]*A10[0][1] B01[1][7]*A10[1][1] B01[1][7]*A10[2][1] B01[1][7]*A10[3][1]) ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] //A10[1][2] A10[2][2] A10[3][2] ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[2][4] ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[2][5] ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[2][6] ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[2][7] b01 += 1; //move to next row of B ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[2][4]*A10[0][2] B01[2][4]*A10[1][2] B01[2][4]*A10[2][2] B01[2][0]*A10[3][2]) ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[2][5]*A10[0][2] B01[2][5]*A10[1][2] B01[2][5]*A10[2][2] B01[2][1]*A10[3][2]) ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[2][6]*A10[0][2] B01[2][6]*A10[1][2] B01[2][6]*A10[2][2] B01[2][2]*A10[3][2]) ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[2][7]*A10[0][2] B01[2][7]*A10[1][2] B01[2][7]*A10[2][2] B01[2][3]*A10[3][2]) ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[3][4] ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[3][5] ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[3][6] ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[3][7] b01 += 1; //move to next row of B ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm8 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm8 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm8 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm8 += (B01[3][0]*A10[0][3] B01[3][4]*A10[1][3] B01[3][4]*A10[2][3] B01[3][4]*A10[3][3]) ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm8 += (B01[3][1]*A10[0][3] B01[3][5]*A10[1][3] B01[3][5]*A10[2][3] B01[3][5]*A10[3][3]) ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm8 += (B01[3][2]*A10[0][3] B01[3][6]*A10[1][3] B01[3][6]*A10[2][3] B01[3][6]*A10[3][3]) ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm8 += (B01[3][3]*A10[0][3] B01[3][7]*A10[1][3] B01[3][7]*A10[2][3] B01[3][7]*A10[3][3]) a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM } ///GEMM code ends/// ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *6)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] *alpha -= B01[0-3][0] ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] *alpha -= B01[0-3][1] ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] *alpha -= B01[0-3][2] ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] *alpha -= B01[0-3][3] ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] *alpha -= B01[0-3][4] ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] *alpha -= B01[0-3][5] ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] *alpha -= B01[0-3][6] ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); //B11[0-3][7] *alpha -= B01[0-3][7] if(3 == m_remainder) { ///implement TRSM/// ///unpacklow/// ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[1][4] B11[1][5] ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[1][6] B11[1][7] //rearrange low elements ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] ////unpackhigh//// ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[5][0] B11[5][1] B11[7][0] B11[7][1] ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[5][2] B11[5][3] B11[7][2] B11[7][3] //rearrange high elements ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] ymm2 = _mm256_broadcast_sd((double const *)(a11 +1)); //A11[1][0] ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][0] a11 += cs_a; //(Row1): FMA operations ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); //B11[1][0-3] -= B11[0-3][0]*A11[1][0] ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[2][0-3] -= B11[0-3][0]*A11[2][0] ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); //B11[5][0-3] -= B11[0-3][4]*A11[1][4] ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[6][0-3] -= B11[0-3][4]*A11[2][4] ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][1] a11 += cs_a; //(ROw2): FMA operations ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); //B11[2][0-3] -= A11[2][1] * B11[0-3][1] ymm14 = _mm256_fnmadd_pd(ymm3, ymm13, ymm14); //B11[6][0-3] -= A11[2][1] * B11[0-3][5] ymm11 = _mm256_broadcast_sd((double const *)(&ones)); ymm15 = _mm256_broadcast_sd((double const *)(&ones)); //unpacklow// ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] //rearrange low elements ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] ///unpack high/// ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[0][5] B11[1][5] B11[0][7] B11[1][7] ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[2][5] B11[3][5] B11[2][7] B11[3][7] //rearrange high elements ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] ymm8 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); //load B11[0-3][0] ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); //load B11[0-3][1] ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //load B11[0-3][2] ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //load B11[0-3][3] ymm12 = _mm256_loadu_pd((double const *)(b11 + cs_b * 4)); //load B11[0-3][4] ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b * 5)); //load B11[0-3][5] ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b * 6)); //load B11[0-3][6] ymm15 = _mm256_loadu_pd((double const *)(f_temp)); //load B11[0-3][7] //determine correct values to store ymm0 = _mm256_blend_pd(ymm0, ymm8, 0x08); ymm1 = _mm256_blend_pd(ymm1, ymm9, 0x08); ymm2 = _mm256_blend_pd(ymm2, ymm10, 0x08); ymm3 = _mm256_blend_pd(ymm3, ymm11, 0x08); ymm4 = _mm256_blend_pd(ymm4, ymm12, 0x08); ymm5 = _mm256_blend_pd(ymm5, ymm13, 0x08); ymm6 = _mm256_blend_pd(ymm6, ymm14, 0x08); ymm7 = _mm256_blend_pd(ymm7, ymm15, 0x08); } else if(2 == m_remainder) { ///implement TRSM/// ///unpacklow/// ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[1][4] B11[1][5] ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[1][6] B11[1][7] //rearrange low elements ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] ////unpackhigh//// ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[5][0] B11[5][1] B11[7][0] B11[7][1] ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[5][2] B11[5][3] B11[7][2] B11[7][3] //rearrange high elements ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] ymm2 = _mm256_broadcast_sd((double const *)(a11 +1)); //A11[1][0] a11 += cs_a; //(Row1): FMA operations ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); //B11[1][0-3] -= B11[0-3][0]*A11[1][0] ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); //B11[5][0-3] -= B11[0-3][4]*A11[1][4] ymm10 = _mm256_broadcast_sd((double const *)&ones); //unpacklow// ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] //rearrange low elements ymm0 = _mm256_permute2f128_pd(ymm1, ymm10, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm2 = _mm256_permute2f128_pd(ymm1, ymm10, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ymm4 = _mm256_permute2f128_pd(ymm5, ymm10, 0x20); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] ymm6 = _mm256_permute2f128_pd(ymm5, ymm10, 0x31); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] ///unpack high/// ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[0][5] B11[1][5] B11[0][7] B11[1][7] //rearrange high elements ymm1 = _mm256_permute2f128_pd(ymm8, ymm10, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm3 = _mm256_permute2f128_pd(ymm8, ymm10, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] ymm5 = _mm256_permute2f128_pd(ymm12, ymm10, 0x20); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] ymm7 = _mm256_permute2f128_pd(ymm12, ymm10, 0x31); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] ymm8 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); //load B11[0-3][0] ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); //load B11[0-3][1] ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //load B11[0-3][2] ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //load B11[0-3][3] ymm12 = _mm256_loadu_pd((double const *)(b11 + cs_b * 4)); //load B11[0-3][4] ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b * 5)); //load B11[0-3][5] ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b * 6)); //load B11[0-3][6] ymm15 = _mm256_loadu_pd((double const *)(f_temp)); //load B11[0-3][7] //determine correct values to store ymm0 = _mm256_permute2f128_pd(ymm0, ymm8, 0x30); ymm1 = _mm256_permute2f128_pd(ymm1, ymm9, 0x30); ymm2 = _mm256_permute2f128_pd(ymm2, ymm10, 0x30); ymm3 = _mm256_permute2f128_pd(ymm3, ymm11, 0x30); ymm4 = _mm256_permute2f128_pd(ymm4, ymm12, 0x30); ymm5 = _mm256_permute2f128_pd(ymm5, ymm13, 0x30); ymm6 = _mm256_permute2f128_pd(ymm6, ymm14, 0x30); ymm7 = _mm256_permute2f128_pd(ymm7, ymm15, 0x30); } else if(1 == m_remainder) { ymm8 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); //load B11[0-3][0] ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); //load B11[0-3][1] ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //load B11[0-3][2] ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //load B11[0-3][3] ymm12 = _mm256_loadu_pd((double const *)(b11 + cs_b * 4)); //load B11[0-3][4] ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b * 5)); //load B11[0-3][5] ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b * 6)); //load B11[0-3][6] ymm15 = _mm256_loadu_pd((double const *)(f_temp)); //load B11[0-3][7] //determine correct values to store ymm0 = _mm256_blend_pd(ymm0, ymm8, 0x0E); ymm1 = _mm256_blend_pd(ymm1, ymm9, 0x0E); ymm2 = _mm256_blend_pd(ymm2, ymm10, 0x0E); ymm3 = _mm256_blend_pd(ymm3, ymm11, 0x0E); ymm4 = _mm256_blend_pd(ymm4, ymm12, 0x0E); ymm5 = _mm256_blend_pd(ymm5, ymm13, 0x0E); ymm6 = _mm256_blend_pd(ymm6, ymm14, 0x0E); ymm7 = _mm256_blend_pd(ymm7, ymm15, 0x0E); } _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm4); //store(B11[0-3][4]) _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm5); //store(B11[0-3][5]) _mm256_storeu_pd((double *)(b11 + cs_b * 6), ymm6); //store(B11[0-3][6]) _mm256_storeu_pd((double *)(f_temp), ymm7); //store(B11[0-3][7]) if((j+D_NR) == n) { for(iter = 0; iter < m_remainder; iter++) (b11 + cs_b * 7)[iter] = f_temp[iter]; } } } if((n & 4)) //implementation for remainder columns(when 'n_remainder' is greater than 4) { for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' direction { a10 = L +i; //pointer to block of A to be used for GEMM a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM b01 = B + j*cs_b; //pointer to block of B to be used for GEMM b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM k_iter = i / D_MR; //number of times GEMM to be performed(in block of 4) ///GEMM for previously calculated values /// //load 4x4 block from b11 ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] ymm4 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); ymm6 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); for(k = 0; k < k_iter; k++) //loop for number of GEMM operations { ptr_b01_dup = b01; ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a*2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] b01 += 1; //move to next row of B ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] b01 += 1; ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] b01 += 1; ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[1][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[2][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[3][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] b01 += 1; ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM } ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B01[0-3][1] *alpha -= ymm5 ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B01[0-3][2] *alpha -= ymm6 ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B01[0-3][3] *alpha -= ymm7 ///implement TRSM/// //1st col ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0] ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0] ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][0] //2nd col a11 += cs_a; ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1] ymm10 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][1] //3rd col a11 += cs_a; ymm12 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][2] ////unpacklow//// ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] //rearrange low elements ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] ////unpackhigh//// ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] //rearrange high elements ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0) ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]*B11[0][0-3] ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0]*B11[0][0-3] ymm13 = _mm256_fnmadd_pd(ymm7, ymm4, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][0]*B11[0][0-3] //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1]*B11[1][0-3] ymm13 = _mm256_fnmadd_pd(ymm10, ymm8, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][1]*B11[1][0-3] //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) ymm13 = _mm256_fnmadd_pd(ymm12, ymm11, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][2]*B11[2][0-3] //--> Transpose and store results of columns of B block <--// ////unpacklow//// ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] //rearrange low elements ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ////unpackhigh//// ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] //rearrange high elements ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm2); //store(B11[0-3][2]) _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[0-3][3]) } if(m_remainder) //implementation for remainder rows(when 'M' is not a multiple of D_MR) { a10 = L +i; //pointer to block of A to be used for GEMM a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM b01 = B + j*cs_b; //pointer to block of B to be used for GEMM b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM k_iter = i / D_MR; //number of GEMM operations to be performed(in blocks of 4x4) dim_t iter; if((j+4) == n) { f_temp = f_t; for(iter = 0; iter < m_remainder; iter++) f_temp[iter] = (b11 + cs_b * 3)[iter]; } else f_temp = (b11 + cs_b * 3); ///GEMM for previously calculated values /// //load 4x4 block from b11 ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ymm3 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] ymm4 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); ymm6 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); for(k = 0; k < k_iter; k++) //looop for number of GEMM operations { ptr_b01_dup = b01; ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] b01 += 1; ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] b01 += 1; ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] b01 += 1; ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] b01 += 1; ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM } ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] *alpha -= ymm5 ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[0-3][2] *alpha -= ymm6 ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[0-3][3] *alpha -= ymm7 if(3 == m_remainder) { ///implement TRSM/// //1st col ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0] ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0] //2nd col a11 += cs_a; ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1] ////unpacklow//// ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] //rearrange low elements ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] ////unpackhigh//// ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] //rearrange high elements ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0) ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]* B11[0][0-3] ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0]* B11[0][0-3] //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1]* B11[1][0-3] ymm13 = _mm256_broadcast_sd((double const *)(&ones)); //--> Transpose and store results of columns of B block <--// ////unpacklow//// ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] //rearrange low elements ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ////unpackhigh//// ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] //rearrange high elements ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] //load 4x4 block from b11 ymm4 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][2] B11[3][3] //determine correct values to store ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x08); ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x08); ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x08); ymm3 = _mm256_blend_pd(ymm3, ymm7, 0x08); } else if(2 == m_remainder) { ///implement TRSM/// //1st col ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0] ////unpacklow//// ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] //rearrange low elements ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] ////unpackhigh//// ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] //rearrange high elements ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0) ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]* B11[0][0-3] ymm11 = _mm256_broadcast_sd((double const *)(&ones)); ymm13 = _mm256_broadcast_sd((double const *)(&ones)); //--> Transpose and store results of columns of B block <--// ////unpacklow//// ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] //rearrange low elements ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ////unpackhigh//// ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] //rearrange high elements ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] //load 4x4 block from b11 ymm4 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][2] B11[3][3] //determine correct values to store ymm0 = _mm256_permute2f128_pd(ymm0, ymm4,0x30); ymm1 = _mm256_permute2f128_pd(ymm1, ymm5,0x30); ymm2 = _mm256_permute2f128_pd(ymm2, ymm6,0x30); ymm3 = _mm256_permute2f128_pd(ymm3, ymm7,0x30); } else if(1 == m_remainder) { //load 4x4 block from b11 ymm4 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][2] B11[3][3] //determine correct values to store ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x0E); ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x0E); ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x0E); ymm3 = _mm256_blend_pd(ymm3, ymm7, 0x0E); } _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) _mm256_storeu_pd((double *)(f_temp), ymm3); //store(B11[0-3][3]) if((j+4) == n) { for(iter = 0; iter < m_remainder; iter++) (b11 + cs_b * 3)[iter] = f_temp[iter]; } } n_remainder -= 4; j += 4; } if(n_remainder) //implementation fo remaining columns(when 'N' is not a multiple of D_NR) { for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' direction { a10 = L +i; //pointer to block of A to be used for GEMM a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM b01 = B + j*cs_b; //pointer to block of B to be used for GEMM b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM k_iter = i / D_MR; //number of GEMM operations to be performed(in blocks of 4x4) ymm4 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); ymm6 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); ///GEMM for previously calculated values /// //load 4x4 block from b11 if(3 == n_remainder) { ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] for(k = 0; k < k_iter; k++) { ptr_b01_dup = b01; ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] b01 += 1; ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] b01 += 1; ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] b01 += 1; ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] b01 += 1; ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM } ///GEMM code ends/// ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha Value ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] *alpha -= ymm5 ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[0-3][2] *alpha -= ymm6 ymm3 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][3] *alpha -= ymm7 } else if(2 == n_remainder) { ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] for(k = 0; k < k_iter; k++) { ptr_b01_dup = b01; ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] b01 += 1; ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] b01 += 1; ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] b01 += 1; ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] b01 += 1; ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM } ///GEMM code ends/// ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha Value ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] *alpha -= ymm5 ymm2 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][2] *alpha -= ymm6 ymm3 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][3] *alpha -= ymm7 } else if(1 == n_remainder) { ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] for(k = 0; k < k_iter; k++) { ptr_b01_dup = b01; ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] b01 += 1; ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] b01 += 1; ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] b01 += 1; ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] b01 += 1; ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM } ///GEMM code ends/// ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha Value ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4 ymm1 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][1] *alpha -= ymm5 ymm2 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][2] *alpha -= ymm6 ymm3 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][3] *alpha -= ymm7 } ///implement TRSM/// //1st col ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0] ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0] ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][0] //2nd col a11 += cs_a; ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1] ymm10 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][1] //3rd col a11 += cs_a; ymm12 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][2] ////unpacklow//// ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] //rearrange low elements ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] ////unpackhigh//// ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] //rearrange high elements ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0) ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0] * B11[0][0-3] ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0] * B11[0][0-3] ymm13 = _mm256_fnmadd_pd(ymm7, ymm4, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][0] * B11[0][0-3] //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1] * B11[1][0-3] ymm13 = _mm256_fnmadd_pd(ymm10, ymm8, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][1] * B11[1][0-3] //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) ymm13 = _mm256_fnmadd_pd(ymm12, ymm11, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][2] * B11[2][0-3] //--> Transpose and store results of columns of B block <--// ////unpacklow//// ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] //rearrange low elements ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ////unpackhigh//// ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] //rearrange high elements ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] if(3 == n_remainder) { _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) } else if(2 == n_remainder) { _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) } else if(1 == n_remainder) { _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) } } if(m_remainder) //implementation for remainder rows(when 'M' is not a multiple of D_MR) { a10 = L +i; //pointer to block of A to be used for GEMM a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM b01 = B + j*cs_b; //pointer to block of B to be used for GEMM b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM k_iter = i / D_MR; //number of times GEMM operations to be performed dim_t iter; if((j+n_remainder) == n) { f_temp = f_t; for(iter = 0; iter < m_remainder; iter++) f_temp[iter] = (b11 + cs_b * (n_remainder -1))[iter]; } else f_temp = (b11 + cs_b * (n_remainder -1)); ymm4 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); ymm6 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); ///GEMM for previously calculated values /// //load 4x4 block from b11 if(3 == n_remainder) { ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm2 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] for(k = 0; k < k_iter; k++) //loop for number of GEMM operations { ptr_b01_dup = b01; ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[0][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[0][1] ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[0][2] b01 += 1; //move to next row of B ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[1][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[1][1] ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[1][2] b01 += 1; //move to next row of B ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[2][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[2][1] ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[2][2] b01 += 1; //move to next row of B ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[3][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[3][1] ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[3][2] b01 += 1; //move to next row of B ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM } ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha value ymm8 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] * alpha -= ymm4 ymm9 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] * alpha -= ymm5 ymm10 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[0-3][2] * alpha -= ymm6 ///implement TRSM/// //determine correct values to store if(3 == m_remainder) { ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); } else if(2 == m_remainder) { ymm0 = _mm256_permute2f128_pd(ymm8, ymm0, 0x30); ymm1 = _mm256_permute2f128_pd(ymm9, ymm1, 0x30); ymm2 = _mm256_permute2f128_pd(ymm10, ymm2, 0x30); } else if(1 == m_remainder) { ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); } _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) _mm256_storeu_pd((double *)(f_temp), ymm2); //store(B11[0-3][2]) } else if(2 == n_remainder) { ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm1 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] for(k = 0; k < k_iter; k++) //loop for number of GEMM operations { ptr_b01_dup = b01; ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[0][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[0][1] b01 += 1; //move to next row of B ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[1][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[1][1] b01 += 1; //move to next row of B ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[2][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[2][1] b01 += 1; //move to next row of B ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[3][0] ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[3][1] b01 += 1; //move to next row of B ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM } ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha value ymm8 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] * alpha -= ymm4 ymm9 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] * alpha -= ymm5 ///implement TRSM/// //determine correct values to store if(3 == m_remainder) { ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); } else if(2 == m_remainder) { ymm0 = _mm256_permute2f128_pd(ymm8, ymm0, 0x30); ymm1 = _mm256_permute2f128_pd(ymm9, ymm1, 0x30); } else if(1 == m_remainder) { ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); } _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(f_temp), ymm1); //store(B11[0-3][1]) } else if(1 == n_remainder) { ymm0 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] for(k = 0; k < k_iter; k++) //loop for number of GEMM operations { ptr_b01_dup = b01; ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[0][0] b01 += 1; //move to next row of B ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[1][0] b01 += 1; //move to next row of B ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[2][0] b01 += 1; //move to next row of B ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[3][0] b01 += 1; //move to next row of B ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM } ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha value ymm8 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] * alpha -= ymm4 ///implement TRSM/// //determine correct values to store if(3 == m_remainder) { ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); } else if(2 == m_remainder) { ymm0 = _mm256_permute2f128_pd(ymm8, ymm0, 0x30); } else if(1 == m_remainder) { ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); } _mm256_storeu_pd((double *)(f_temp), ymm0); //store(B11[0-3][0]) } if((j+n_remainder) == n) { for(iter = 0; iter < m_remainder; iter++) (b11 + cs_b * (n_remainder-1))[iter] = f_temp[iter]; } ///scalar code for trsm without alpha/// dtrsm_small_AlXB_unitDiag(a11, b11, m_remainder, n_remainder, cs_a, cs_b); } } return BLIS_SUCCESS; } /*implements TRSM for the case XA = alpha * B *A is upper triangular, non-unit diagonal, no transpose *dimensions: X:mxn A:nxn B: mxn */ /* b11---> a01 ----> ***************** *********** *b01*b11* * * * * * * b11 * * * * * **a01 * * a11 | ***************** ********* | | * * * * * *a11* * | | * * * * * * * * | v ***************** ****** v * * * * * * * * * * * * * * ***************** * * * */ static err_t bli_dtrsm_small_XAuB( side_t side, obj_t* AlphaObj, obj_t* a, obj_t* b, cntx_t* cntx, cntl_t* cntl ) { dim_t D_MR = 8; //block dimension along the rows dim_t D_NR = 4; //block dimension along the columns dim_t m = bli_obj_length(b); //number of rows dim_t n = bli_obj_width(b); //number of columns dim_t m_remainder = m & 7; //number of corner rows dim_t n_remainder = n & 3; //number of corner columns dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B #ifdef BLIS_ENABLE_SMALL_MATRIX_ROME if((m>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME_ROW_PANEL_M && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME) || (m>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME_COLUMN_PANEL_N) ) return BLIS_NOT_YET_IMPLEMENTED; #else if(bli_max(m,n)>D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES && (m/n) < D_BLIS_SMALL_MATRIX_THRES_TRSM_DIM_RATIO) { return BLIS_NOT_YET_IMPLEMENTED; } #endif dim_t i, j, k; //loop variablse dim_t k_iter; //determines the number of GEMM operations to be done dim_t cs_b_offset[2]; //pre-calculated strides double ones = 1.0; double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha double *L = a->buffer; //pointer to matrix A double *B = b->buffer; //pointer to matrix B double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks double *ptr_a01_dup; double f_t[4] __attribute__((aligned(64)));//buffer to store corner column when m_remainder !=0 double* f_temp; cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2; cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3; //ymm scratch reginsters __m256d ymm0, ymm1, ymm2, ymm3; __m256d ymm4, ymm5, ymm6, ymm7; __m256d ymm8, ymm9, ymm10, ymm11; __m256d ymm12, ymm13, ymm14, ymm15; __m256d ymm16; for(i = 0; (i+D_MR-1) < m; i += D_MR) //loop along 'M' direction { for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' direction { a01 = L + j*cs_a; //pointer to block of A to be used in GEMM a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM b10 = B + i; //pointer to block of B to be used in GEMM b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM k_iter = j / D_NR; //number of GEMM operations to be done(in blocks of 4x4) ymm0 = _mm256_setzero_pd(); ymm1 = _mm256_setzero_pd(); ymm2 = _mm256_setzero_pd(); ymm3 = _mm256_setzero_pd(); ymm4 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); ymm6 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); ///GEMM implementation starts/// for(k = 0; k < k_iter; k++) //loop for number of GEMM operations { ptr_a01_dup = a01; //broadcast 1st row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] a01 += 1; //move to next row //load 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) //broadcast 2nd row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] a01 += 1; //move to next row of A ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) //broadcast 3rd row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] a01 += 1; //move to next row of A01 //load next 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) //broadcast 4th row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] a01 += 1; //move to next row of A01 ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM } ///GEMM code ends/// ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //load 8x4 block of B11 ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0 ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1 ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2 ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3 ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4 ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5 ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6 ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7 ///implement TRSM/// ///read 4x4 block of A11/// ymm7 = _mm256_broadcast_sd((double const *)(&ones)); //1st col ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] //2nd col a11 += cs_a; ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][1] //3rd col a11 += cs_a; ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2] ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2] ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][2] //4th col a11 += cs_a; ymm6 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][3] //compute reciprocals of L(i,i) and broadcast in registers ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3] ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) ymm2 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][3] ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][3] ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][3] //extract a00 ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0] ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0] //extract a11 ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) //(Row1): FMA operations ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1] ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2] ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); //B11[0-3][3] -= B11[0-3][0] * A11[0][3] ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1] ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2] ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); //B11[4-7][3] -= B11[4-7][0] * A11[0][3] ymm9 = _mm256_mul_pd(ymm9, ymm0); //B11[0-3][1] /= A11[1][1] ymm13 = _mm256_mul_pd(ymm13, ymm0); //B11[4-7][1] /= A11[1][1] //extract a22 ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) //(Row2)FMA operations ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2] ymm11 = _mm256_fnmadd_pd(ymm5, ymm9, ymm11); //B11[0-3][3] -= B11[0-3][1] * A11[1][3] ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2] ymm15 = _mm256_fnmadd_pd(ymm5, ymm13, ymm15); //B11[4-7][3] -= B11[4-7][1] * A11[1][3] ymm10 = _mm256_mul_pd(ymm10, ymm0); //B11[0-3][2] /= A11[2][2] ymm14 = _mm256_mul_pd(ymm14, ymm0); //B11[4-7][2] /= A11[2][2] //extract a33 ymm0 = _mm256_permute_pd(ymm7, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) //(Row3)FMA operations ymm11 = _mm256_fnmadd_pd(ymm6, ymm10, ymm11); //B11[0-3][3] -= B11[0-3][2] * A11[2][3] ymm15 = _mm256_fnmadd_pd(ymm6, ymm14, ymm15); //B11[4-7][3] -= B11[4-7][2] * A11[2][3] ymm11 = _mm256_mul_pd(ymm11, ymm0); //B11[0-3][3] /= A11[3][3] ymm15 = _mm256_mul_pd(ymm15, ymm0); //B11[4-7][3] /= A11[3][3] _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][2]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b), ymm11); //store(B11[0-3][3]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b + D_NR), ymm15);//store(B11[4-7][3]) } if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR) { a01 = L + j*cs_a; //pointer to block of A to be used for GEMM a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM b10 = B + i; //pointer to block of B to be used for GEMM b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM k_iter = j / D_NR; //number of GEMM operations to be performed(in blocks of 4x4) ///load 4x4 block of b11 ymm0 = _mm256_setzero_pd(); ymm1 = _mm256_setzero_pd(); ymm2 = _mm256_setzero_pd(); ymm3 = _mm256_setzero_pd(); ymm4 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); ymm6 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); //subtract the calculated GEMM block from current TRSM block //load 8x4 block of B11 if(3 == n_remainder) { ///GEMM implementation begins/// for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations { ptr_a01_dup = a01; //broadcast 1st row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] a01 += 1; //move to next row of A //load 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) //broadcast 2nd row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] a01 += 1; //move to next row of A ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) //broadcast 3rd row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] a01 += 1; //move to next row of A //load next 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) //broadcast 4th row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] a01 += 1; //move to next row of A ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM } ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); ///GEMM code ends/// ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1] ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1] ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][2] ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][2] ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0] ymm9 = _mm256_fmsub_pd(ymm9, ymm15, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0] ymm10 = _mm256_fmsub_pd(ymm10, ymm15, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1] ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2] ymm13 = _mm256_fmsub_pd(ymm13, ymm15, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2] ymm14 = _mm256_fmsub_pd(ymm14, ymm15, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3] ///implement TRSM/// ///read 4x4 block of A11/// ymm7 = _mm256_broadcast_sd((double const *)(&ones)); //1st col ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] //2nd col a11 += cs_a; ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][1] //3rd col a11 += cs_a; ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2] ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2] ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][2] //4th col a11 += cs_a; ymm6 = _mm256_broadcast_sd((double const *)(&ones)); //A11[3][3] //compute reciprocals of L(i,i) and broadcast in registers ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3] ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) //extract a00 ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0] ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0] //extract a11 ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) //(Row1): FMA operations ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1] ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2] ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1] ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2] ymm9 = _mm256_mul_pd(ymm9, ymm0); //B11[0-3][1] /= A11[1][1] ymm13 = _mm256_mul_pd(ymm13, ymm0); //B11[4-7][1] /= A11[1][1] //extract a22 ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) //(Row2)FMA operations ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2] ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2] ymm10 = _mm256_mul_pd(ymm10, ymm0); //B11[0-3][2] /= A11[2][2] ymm14 = _mm256_mul_pd(ymm14, ymm0); //B11[4-7][2] /= A11[2][2] _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2]) } else if(2 == n_remainder) { ///GEMM implementation begins/// for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations { ptr_a01_dup = a01; //broadcast 1st row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] a01 += 1; //move to next row of A //load 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) //broadcast 2nd row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] a01 += 1; //move to next row of A ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) //broadcast 3rd row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] a01 += 1; //move to next row of A //load next 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) //broadcast 4th row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] a01 += 1; //move to next row of A ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM } ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); ///GEMM code ends/// ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1] ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1] ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0] ymm9 = _mm256_fmsub_pd(ymm9, ymm15, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0] ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2] ymm13 = _mm256_fmsub_pd(ymm13, ymm15, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2] ///implement TRSM/// ///read 4x4 block of A11/// ymm7 = _mm256_broadcast_sd((double const *)(&ones)); //1st col ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] //2nd col a11 += cs_a; ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][1] //compute reciprocals of L(i,i) and broadcast in registers ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] ymm0 = _mm256_blend_pd(ymm0, ymm7, 0x0C); //A11[0][0] A11[1][1] 1 1 ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/1 1/1) //extract a00 ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0] ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0] //extract a11 ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) //(Row1): FMA operations ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1] ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1] ymm9 = _mm256_mul_pd(ymm9, ymm0); //B11[0-3][1] /= A11[1][1] ymm13 = _mm256_mul_pd(ymm13, ymm0); //B11[4-7][1] /= A11[1][1] _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) } else if(1 == n_remainder) { ///GEMM implementation begins/// for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations { ptr_a01_dup = a01; //broadcast 1st row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] a01 += 1; //move to next row of A //load 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) //broadcast 2nd row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] a01 += 1; //move to next row of A ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) //broadcast 3rd row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] a01 += 1; //move to next row of A //load next 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) //broadcast 4th row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] a01 += 1; //move to next row of A ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM } ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); ///GEMM code ends/// ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0] ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2] ///implement TRSM/// ///read 4x4 block of A11/// ymm7 = _mm256_broadcast_sd((double const *)(&ones)); //1st col ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) //extract a00 ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0] ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0] _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) } } } if((m & 4)) ///implementation for remainder rows(when m_remainder is greater than 4) { for(j = 0; (j+D_NR-1) a01 ----> ***************** *********** *b01*b11* * * * * * * b11 * * * * * **a01 * * a11 | ***************** ********* | | * * * * * *a11* * | | * * * * * * * * | v ***************** ****** v * * * * * * * * * * * * * * ***************** * * * */ static err_t bli_dtrsm_small_XAuB_unitDiag( side_t side, obj_t* AlphaObj, obj_t* a, obj_t* b, cntx_t* cntx, cntl_t* cntl ) { dim_t D_MR = 8; //block dimension along the rows dim_t D_NR = 4; //block dimension along the columns dim_t m = bli_obj_length(b); //number of rows dim_t n = bli_obj_width(b); //number of columns dim_t m_remainder = m & 7; //number of corner rows dim_t n_remainder = n & 3; //number of corner columns dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B #ifdef BLIS_ENABLE_SMALL_MATRIX_ROME if((m>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME_ROW_PANEL_M && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME) || (m>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME_COLUMN_PANEL_N) ) return BLIS_NOT_YET_IMPLEMENTED; #else if(bli_max(m,n)>D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES && (m/n) < D_BLIS_SMALL_MATRIX_THRES_TRSM_DIM_RATIO) { return BLIS_NOT_YET_IMPLEMENTED; } #endif dim_t i, j, k; //loop variablse dim_t k_iter; //determines the number of GEMM operations to be done dim_t cs_b_offset[2]; //pre-calculated strides double ones = 1.0; double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha double *L = a->buffer; //pointer to matrix A double *B = b->buffer; //pointer to matrix B double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks double *ptr_a01_dup; double f_t[4] __attribute__((aligned(64)));//buffer to store corner column when m_remainder !=0 double* f_temp; cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2; cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3; //ymm scratch reginsters __m256d ymm0, ymm1, ymm2, ymm3; __m256d ymm4, ymm5, ymm6, ymm7; __m256d ymm8, ymm9, ymm10, ymm11; __m256d ymm12, ymm13, ymm14, ymm15; __m256d ymm16; for(i = 0; (i+D_MR-1) < m; i += D_MR) //loop along 'M' direction { for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' direction { a01 = L + j*cs_a; //pointer to block of A to be used in GEMM a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM b10 = B + i; //pointer to block of B to be used in GEMM b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM k_iter = j / D_NR; //number of GEMM operations to be done(in blocks of 4x4) ymm0 = _mm256_setzero_pd(); ymm1 = _mm256_setzero_pd(); ymm2 = _mm256_setzero_pd(); ymm3 = _mm256_setzero_pd(); ymm4 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); ymm6 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); ///GEMM implementation starts/// for(k = 0; k < k_iter; k++) //loop for number of GEMM operations { ptr_a01_dup = a01; //broadcast 1st row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] a01 += 1; //move to next row //load 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) //broadcast 2nd row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] a01 += 1; //move to next row of A ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) //broadcast 3rd row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] a01 += 1; //move to next row of A01 //load next 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) //broadcast 4th row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] a01 += 1; //move to next row of A01 ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM } ///GEMM code ends/// ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //load 8x4 block of B11 ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0 ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1 ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2 ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3 ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4 ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5 ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6 ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7 ///implement TRSM/// ///read 4x4 block of A11/// ymm7 = _mm256_broadcast_sd((double const *)(&ones)); //2nd col a11 += cs_a; ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] //3rd col a11 += cs_a; ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2] ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2] //4th col a11 += cs_a; ymm2 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][3] ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][3] ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][3] //(Row1): FMA operations ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1] ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2] ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); //B11[0-3][3] -= B11[0-3][0] * A11[0][3] ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1] ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2] ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); //B11[4-7][3] -= B11[4-7][0] * A11[0][3] //(Row2)FMA operations ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2] ymm11 = _mm256_fnmadd_pd(ymm5, ymm9, ymm11); //B11[0-3][3] -= B11[0-3][1] * A11[1][3] ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2] ymm15 = _mm256_fnmadd_pd(ymm5, ymm13, ymm15); //B11[4-7][3] -= B11[4-7][1] * A11[1][3] //(Row3)FMA operations ymm11 = _mm256_fnmadd_pd(ymm6, ymm10, ymm11); //B11[0-3][3] -= B11[0-3][2] * A11[2][3] ymm15 = _mm256_fnmadd_pd(ymm6, ymm14, ymm15); //B11[4-7][3] -= B11[4-7][2] * A11[2][3] _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][2]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b), ymm11); //store(B11[0-3][3]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b + D_NR), ymm15);//store(B11[4-7][3]) } if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR) { a01 = L + j*cs_a; //pointer to block of A to be used for GEMM a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM b10 = B + i; //pointer to block of B to be used for GEMM b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM k_iter = j / D_NR; //number of GEMM operations to be performed(in blocks of 4x4) ///load 4x4 block of b11 ymm0 = _mm256_setzero_pd(); ymm1 = _mm256_setzero_pd(); ymm2 = _mm256_setzero_pd(); ymm3 = _mm256_setzero_pd(); ymm4 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); ymm6 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); //subtract the calculated GEMM block from current TRSM block //load 8x4 block of B11 if(3 == n_remainder) { ///GEMM implementation begins/// for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations { ptr_a01_dup = a01; //broadcast 1st row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] a01 += 1; //move to next row of A //load 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) //broadcast 2nd row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] a01 += 1; //move to next row of A ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) //broadcast 3rd row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] a01 += 1; //move to next row of A //load next 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) //broadcast 4th row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] a01 += 1; //move to next row of A ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM } ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); ///GEMM code ends/// ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1] ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1] ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][2] ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][2] ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0] ymm9 = _mm256_fmsub_pd(ymm9, ymm15, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0] ymm10 = _mm256_fmsub_pd(ymm10, ymm15, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1] ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2] ymm13 = _mm256_fmsub_pd(ymm13, ymm15, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2] ymm14 = _mm256_fmsub_pd(ymm14, ymm15, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3] ///implement TRSM/// ///read 4x4 block of A11/// ymm7 = _mm256_broadcast_sd((double const *)(&ones)); //2nd col a11 += cs_a; ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] //3rd col a11 += cs_a; ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2] ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2] //(Row1): FMA operations ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1] ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2] ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1] ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2] //(Row2)FMA operations ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2] ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2] _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2]) } else if(2 == n_remainder) { ///GEMM implementation begins/// for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations { ptr_a01_dup = a01; //broadcast 1st row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] a01 += 1; //move to next row of A //load 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) //broadcast 2nd row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] a01 += 1; //move to next row of A ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) //broadcast 3rd row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] a01 += 1; //move to next row of A //load next 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) //broadcast 4th row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] a01 += 1; //move to next row of A ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM } ///GEMM code ends/// ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1] ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1] ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0] ymm9 = _mm256_fmsub_pd(ymm9, ymm15, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0] ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2] ymm13 = _mm256_fmsub_pd(ymm13, ymm15, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2] ///implement TRSM/// ///read 4x4 block of A11/// ymm7 = _mm256_broadcast_sd((double const *)(&ones)); //2nd col a11 += cs_a; ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] //(Row1): FMA operations ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1] ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1] _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) } else if(1 == n_remainder) { ///GEMM implementation begins/// for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations { ptr_a01_dup = a01; //broadcast 1st row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] a01 += 1; //move to next row of A //load 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) //broadcast 2nd row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] a01 += 1; //move to next row of A ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) //broadcast 3rd row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] a01 += 1; //move to next row of A //load next 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) //broadcast 4th row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] a01 += 1; //move to next row of A ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM } ///GEMM code ends/// ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0] ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2] _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) } } } if((m & 4)) ///implementation for remainder rows(when m_remainder is greater than 4) { for(j = 0; (j+D_NR-1) a01 ----> ***************** *********** *b01*b11* * * * * * * b11 * * * * * **a01 * * a11 | ***************** ********* | | * * * * * *a11* * | | * * * * * * * * | v ***************** ****** v * * * * * * * * * * * * * * ***************** * * * */ static err_t bli_dtrsm_small_XAltB( side_t side, obj_t* AlphaObj, obj_t* a, obj_t* b, cntx_t* cntx, cntl_t* cntl ) { dim_t D_MR = 8; //block dimension along the rows dim_t D_NR = 4; //block dimension along the columns dim_t m = bli_obj_length(b); //number of rows dim_t n = bli_obj_width(b); //number of columns dim_t m_remainder = m & 7; //number of corner rows dim_t n_remainder = n & 3; //number of corner columns dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B #ifdef BLIS_ENABLE_SMALL_MATRIX_ROME if((m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_ROW_PANEL_M && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_ROW_PANEL_N) || (m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_SQUARE_M && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_SQUARE_N) || (m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_COLUMN_PANEL_N && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME) || (m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME) || (m > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_COLUMN_PANEL_N) ) return BLIS_NOT_YET_IMPLEMENTED; #else if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES) { return BLIS_NOT_YET_IMPLEMENTED; } #endif dim_t i, j, k; //loop variablse dim_t k_iter; //determines the number of GEMM operations to be done dim_t cs_b_offset[2]; //pre-calculated strides double ones = 1.0; double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha double *L = a->buffer; //pointer to matrix A double *B = b->buffer; //pointer to matrix B double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks double *ptr_a01_dup; double f_t[4] __attribute__((aligned(64)));//buffer to store corner column when m_remainder !=0 double* f_temp; cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2; cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3; //ymm scratch reginsters __m256d ymm0, ymm1, ymm2, ymm3; __m256d ymm4, ymm5, ymm6, ymm7; __m256d ymm8, ymm9, ymm10, ymm11; __m256d ymm12, ymm13, ymm14, ymm15; __m256d ymm16; for(i = 0; (i+D_MR-1) < m; i += D_MR) //loop along 'M' direction { for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' direction { a01 = L + j; //pointer to block of A to be used in GEMM a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM b10 = B + i; //pointer to block of B to be used in GEMM b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM k_iter = j / D_NR; //number of GEMM operations to be done(in blocks of 4x4) ymm0 = _mm256_setzero_pd(); ymm1 = _mm256_setzero_pd(); ymm2 = _mm256_setzero_pd(); ymm3 = _mm256_setzero_pd(); ymm4 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); ymm6 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); ///GEMM implementation starts/// for(k = 0; k < k_iter; k++) //loop for number of GEMM operations { ptr_a01_dup = a01; //broadcast 1st row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] a01 += cs_a; //move to next row //load 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) //broadcast 2nd row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] a01 += cs_a; //move to next row of A ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) //broadcast 3rd row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] a01 += cs_a; //move to next row of A01 //load next 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) //broadcast 4th row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] a01 += cs_a; //move to next row of A01 ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM } ///GEMM code ends/// ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //load 8x4 block of B11 ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0 ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1 ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2 ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3 ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4 ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5 ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6 ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7 ///implement TRSM/// ///read 4x4 block of A11/// ymm7 = _mm256_broadcast_sd((double const *)(&ones)); //1st col ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] //2nd col a11 += 1; ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1] ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][1] //3rd col a11 += 1; ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2] ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2] ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2] //4th col a11 += 1; ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3] //compute reciprocals of L(i,i) and broadcast in registers ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3] ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3] ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3] ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3] //extract a00 ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0] ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0] //extract a11 ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) //(Row1): FMA operations ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1] ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2] ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); //B11[0-3][3] -= B11[0-3][0] * A11[0][3] ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1] ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2] ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); //B11[4-7][3] -= B11[4-7][0] * A11[0][3] ymm9 = _mm256_mul_pd(ymm9, ymm0); //B11[0-3][1] /= A11[1][1] ymm13 = _mm256_mul_pd(ymm13, ymm0); //B11[4-7][1] /= A11[1][1] //extract a22 ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) //(Row2)FMA operations ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2] ymm11 = _mm256_fnmadd_pd(ymm5, ymm9, ymm11); //B11[0-3][3] -= B11[0-3][1] * A11[1][3] ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2] ymm15 = _mm256_fnmadd_pd(ymm5, ymm13, ymm15); //B11[4-7][3] -= B11[4-7][1] * A11[1][3] ymm10 = _mm256_mul_pd(ymm10, ymm0); //B11[0-3][2] /= A11[2][2] ymm14 = _mm256_mul_pd(ymm14, ymm0); //B11[4-7][2] /= A11[2][2] //extract a33 ymm0 = _mm256_permute_pd(ymm7, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) //(Row3)FMA operations ymm11 = _mm256_fnmadd_pd(ymm6, ymm10, ymm11); //B11[0-3][3] -= B11[0-3][2] * A11[2][3] ymm15 = _mm256_fnmadd_pd(ymm6, ymm14, ymm15); //B11[4-7][3] -= B11[4-7][2] * A11[2][3] ymm11 = _mm256_mul_pd(ymm11, ymm0); //B11[0-3][3] /= A11[3][3] ymm15 = _mm256_mul_pd(ymm15, ymm0); //B11[4-7][3] /= A11[3][3] _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][2]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b), ymm11); //store(B11[0-3][3]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b + D_NR), ymm15);//store(B11[4-7][3]) } if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR) { a01 = L + j; //pointer to block of A to be used for GEMM a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM b10 = B + i; //pointer to block of B to be used for GEMM b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM k_iter = j / D_NR; //number of GEMM operations to be performed(in blocks of 4x4) ///load 4x4 block of b11 ymm0 = _mm256_setzero_pd(); ymm1 = _mm256_setzero_pd(); ymm2 = _mm256_setzero_pd(); ymm3 = _mm256_setzero_pd(); ymm4 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); ymm6 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); //subtract the calculated GEMM block from current TRSM block //load 8x4 block of B11 if(3 == n_remainder) { ///GEMM implementation begins/// for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations { ptr_a01_dup = a01; //broadcast 1st row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] a01 += cs_a; //move to next row of A //load 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) //broadcast 2nd row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] a01 += cs_a; //move to next row of A ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) //broadcast 3rd row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] a01 += cs_a; //move to next row of A //load next 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) //broadcast 4th row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] a01 += cs_a; //move to next row of A ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM } ///GEMM code ends/// ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1] ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1] ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][2] ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][2] ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0] ymm9 = _mm256_fmsub_pd(ymm9, ymm15, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0] ymm10 = _mm256_fmsub_pd(ymm10, ymm15, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1] ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2] ymm13 = _mm256_fmsub_pd(ymm13, ymm15, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2] ymm14 = _mm256_fmsub_pd(ymm14, ymm15, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3] ///implement TRSM/// ///read 4x4 block of A11/// ymm7 = _mm256_broadcast_sd((double const *)(&ones)); //1st col ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] //2nd col a11 += 1; ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1] ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][1] //3rd col a11 += 1; ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2] ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2] ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2] //4th col a11 += 1; ymm6 = _mm256_broadcast_sd((double const *)(&ones)); //A11[3][3] //compute reciprocals of L(i,i) and broadcast in registers ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3] ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3] ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3] ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3] //extract a00 ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0] ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0] //extract a11 ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) //(Row1): FMA operations ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1] ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2] ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1] ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2] ymm9 = _mm256_mul_pd(ymm9, ymm0); //B11[0-3][1] /= A11[1][1] ymm13 = _mm256_mul_pd(ymm13, ymm0); //B11[4-7][1] /= A11[1][1] //extract a22 ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) //(Row2)FMA operations ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2] ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2] ymm10 = _mm256_mul_pd(ymm10, ymm0); //B11[0-3][2] /= A11[2][2] ymm14 = _mm256_mul_pd(ymm14, ymm0); //B11[4-7][2] /= A11[2][2] _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2]) } else if(2 == n_remainder) { ///GEMM implementation begins/// for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations { ptr_a01_dup = a01; //broadcast 1st row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] a01 += cs_a; //move to next row of A //load 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) //broadcast 2nd row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] a01 += cs_a; //move to next row of A ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) //broadcast 3rd row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] a01 += cs_a; //move to next row of A //load next 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) //broadcast 4th row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] a01 += cs_a; //move to next row of A ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM } ///GEMM code ends/// ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); ymm8 = _mm256_loadu_pd((double const *)b11); ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0] ymm9 = _mm256_fmsub_pd(ymm9, ymm15, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0] ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2] ymm13 = _mm256_fmsub_pd(ymm13, ymm15, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2] ///implement TRSM/// ///read 4x4 block of A11/// ymm7 = _mm256_broadcast_sd((double const *)(&ones)); //1st col ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] //2nd col a11 += 1; ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1] ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][1] //compute reciprocals of L(i,i) and broadcast in registers ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] ymm0 = _mm256_blend_pd(ymm0, ymm7, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) //extract a00 ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0] ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0] //extract a11 ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) //(Row1): FMA operations ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1] ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1] ymm9 = _mm256_mul_pd(ymm9, ymm0); //B11[0-3][1] /= A11[1][1] ymm13 = _mm256_mul_pd(ymm13, ymm0); //B11[4-7][1] /= A11[1][1] _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) } else if(1 == n_remainder) { ///GEMM implementation begins/// for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations { ptr_a01_dup = a01; //broadcast 1st row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] a01 += cs_a; //move to next row of A //load 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) //broadcast 2nd row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] a01 += cs_a; //move to next row of A ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) //broadcast 3rd row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] a01 += cs_a; //move to next row of A //load next 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) //broadcast 4th row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] a01 += cs_a; //move to next row of A ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM } ///GEMM code ends/// ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0] ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2] ///implement TRSM/// ymm7 = _mm256_broadcast_sd((double const *)(&ones)); //1st col ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) ymm8 = _mm256_mul_pd(ymm8, ymm7); //B11[0-3][0] /= A11[0][0] ymm12 = _mm256_mul_pd(ymm12, ymm7); //B11[4-7][0] /= A11[0][0] _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) } } } if((m & 4)) ///implementation for remainder rows(when m_remainder is greater than 4) { for(j = 0; (j+D_NR-1) a01 ----> ***************** *********** *b01*b11* * * * * * * b11 * * * * * **a01 * * a11 | ***************** ********* | | * * * * * *a11* * | | * * * * * * * * | v ***************** ****** v * * * * * * * * * * * * * * ***************** * * * */ static err_t bli_dtrsm_small_XAltB_unitDiag( side_t side, obj_t* AlphaObj, obj_t* a, obj_t* b, cntx_t* cntx, cntl_t* cntl ) { dim_t D_MR = 8; //block dimension along the rows dim_t D_NR = 4; //block dimension along the columns dim_t m = bli_obj_length(b); //number of rows dim_t n = bli_obj_width(b); //number of columns dim_t m_remainder = m & 7; //number of corner rows dim_t n_remainder = n & 3; //number of corner columns dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B #ifdef BLIS_ENABLE_SMALL_MATRIX_ROME if((m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_ROW_PANEL_M && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_ROW_PANEL_N) || (m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_SQUARE_M && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_SQUARE_N) || (m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_COLUMN_PANEL_N && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME) || (m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME) || (m > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_COLUMN_PANEL_N) ) return BLIS_NOT_YET_IMPLEMENTED; #else if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES) { return BLIS_NOT_YET_IMPLEMENTED; } #endif dim_t i, j, k; //loop variablse dim_t k_iter; //determines the number of GEMM operations to be done dim_t cs_b_offset[2]; //pre-calculated strides double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha double *L = a->buffer; //pointer to matrix A double *B = b->buffer; //pointer to matrix B double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks double *ptr_a01_dup; double f_t[4] __attribute__((aligned(64)));//buffer to store corner column when m_remainder !=0 double* f_temp; cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2; cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3; //ymm scratch reginsters __m256d ymm0, ymm1, ymm2, ymm3; __m256d ymm4, ymm5, ymm6, ymm7; __m256d ymm8, ymm9, ymm10, ymm11; __m256d ymm12, ymm13, ymm14, ymm15; __m256d ymm16; for(i = 0; (i+D_MR-1) < m; i += D_MR) //loop along 'M' direction { for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' direction { a01 = L + j; //pointer to block of A to be used in GEMM a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM b10 = B + i; //pointer to block of B to be used in GEMM b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM k_iter = j / D_NR; //number of GEMM operations to be done(in blocks of 4x4) ymm0 = _mm256_setzero_pd(); ymm1 = _mm256_setzero_pd(); ymm2 = _mm256_setzero_pd(); ymm3 = _mm256_setzero_pd(); ymm4 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); ymm6 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); ///GEMM implementation starts/// for(k = 0; k < k_iter; k++) //loop for number of GEMM operations { ptr_a01_dup = a01; //broadcast 1st row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] a01 += cs_a; //move to next row //load 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) //broadcast 2nd row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] a01 += cs_a; //move to next row of A ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) //broadcast 3rd row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] a01 += cs_a; //move to next row of A01 //load next 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) //broadcast 4th row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] a01 += cs_a; //move to next row of A01 ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM } ///GEMM code ends/// ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //load 8x4 block of B11 ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0 ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1 ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2 ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3 ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4 ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5 ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6 ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7 ///implement TRSM/// ///read 4x4 block of A11/// //2nd col a11 += 1; ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1] //3rd col a11 += 1; ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2] ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2] //4th col a11 += 1; ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3] ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3] ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3] //(Row1): FMA operations ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1] ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2] ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); //B11[0-3][3] -= B11[0-3][0] * A11[0][3] ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1] ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2] ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); //B11[4-7][3] -= B11[4-7][0] * A11[0][3] //(Row2)FMA operations ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2] ymm11 = _mm256_fnmadd_pd(ymm5, ymm9, ymm11); //B11[0-3][3] -= B11[0-3][1] * A11[1][3] ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2] ymm15 = _mm256_fnmadd_pd(ymm5, ymm13, ymm15); //B11[4-7][3] -= B11[4-7][1] * A11[1][3] //(Row3)FMA operations ymm11 = _mm256_fnmadd_pd(ymm6, ymm10, ymm11); //B11[0-3][3] -= B11[0-3][2] * A11[2][3] ymm15 = _mm256_fnmadd_pd(ymm6, ymm14, ymm15); //B11[4-7][3] -= B11[4-7][2] * A11[2][3] _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][2]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b), ymm11); //store(B11[0-3][3]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b + D_NR), ymm15);//store(B11[4-7][3]) } if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR) { a01 = L + j; //pointer to block of A to be used for GEMM a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM b10 = B + i; //pointer to block of B to be used for GEMM b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM k_iter = j / D_NR; //number of GEMM operations to be performed(in blocks of 4x4) ///load 4x4 block of b11 ymm0 = _mm256_setzero_pd(); ymm1 = _mm256_setzero_pd(); ymm2 = _mm256_setzero_pd(); ymm3 = _mm256_setzero_pd(); ymm4 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); ymm6 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); //subtract the calculated GEMM block from current TRSM block //load 8x4 block of B11 if(3 == n_remainder) { ///GEMM implementation begins/// for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations { ptr_a01_dup = a01; //broadcast 1st row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] a01 += cs_a; //move to next row of A //load 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) //broadcast 2nd row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] a01 += cs_a; //move to next row of A ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) //broadcast 3rd row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] a01 += cs_a; //move to next row of A //load next 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) //broadcast 4th row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] a01 += cs_a; //move to next row of A ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM } ///GEMM code ends/// ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1] ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1] ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][2] ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][2] ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0] ymm9 = _mm256_fmsub_pd(ymm9, ymm15, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0] ymm10 = _mm256_fmsub_pd(ymm10, ymm15, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1] ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2] ymm13 = _mm256_fmsub_pd(ymm13, ymm15, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2] ymm14 = _mm256_fmsub_pd(ymm14, ymm15, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3] ///implement TRSM/// ///read 4x4 block of A11/// //2nd col a11 += 1; ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1] //3rd col a11 += 1; ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2] ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2] //4th col a11 += 1; ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3] ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3] ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3] //(Row1): FMA operations ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1] ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2] ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1] ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2] //(Row2)FMA operations ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2] ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2] _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2]) } else if(2 == n_remainder) { ///GEMM implementation begins/// for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations { ptr_a01_dup = a01; //broadcast 1st row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] a01 += cs_a; //move to next row of A //load 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) //broadcast 2nd row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] a01 += cs_a; //move to next row of A ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) //broadcast 3rd row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] a01 += cs_a; //move to next row of A //load next 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) //broadcast 4th row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] a01 += cs_a; //move to next row of A ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM } ///GEMM code ends/// ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); ymm8 = _mm256_loadu_pd((double const *)b11); ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0] ymm9 = _mm256_fmsub_pd(ymm9, ymm15, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0] ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2] ymm13 = _mm256_fmsub_pd(ymm13, ymm15, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2] ///implement TRSM/// ///read 4x4 block of A11/// //2nd col a11 += 1; ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1] //(Row1): FMA operations ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1] ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1] _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) } else if(1 == n_remainder) { ///GEMM implementation begins/// for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations { ptr_a01_dup = a01; //broadcast 1st row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] a01 += cs_a; //move to next row of A //load 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) //broadcast 2nd row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] a01 += cs_a; //move to next row of A ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) //broadcast 3rd row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] a01 += cs_a; //move to next row of A //load next 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) //broadcast 4th row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] a01 += cs_a; //move to next row of A ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM } ///GEMM code ends/// ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0] ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2] _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) } } } if((m & 4)) ///implementation for remainder rows(when m_remainder is greater than 4) { for(j = 0; (j+D_NR-1) D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME) ||(m > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME_COL_PANEL_N) ) return BLIS_NOT_YET_IMPLEMENTED; #else if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES) { return BLIS_NOT_YET_IMPLEMENTED; } #endif dim_t i, j, k; //loop variablse dim_t k_iter; //determines the number of GEMM operations to be done dim_t cs_b_offset[2]; //pre-calculated strides double ones = 1.0; double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha double* restrict L = a->buffer; //pointer to matrix A double* restrict B = b->buffer; //pointer to matrix B double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks double *ptr_a01_dup; cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2; cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3; //ymm scratch reginsters __m256d ymm0, ymm1, ymm2, ymm3; __m256d ymm4, ymm5, ymm6, ymm7; __m256d ymm8, ymm9, ymm10, ymm11; __m256d ymm12, ymm13, ymm14, ymm15; __m256d ymm16; for(i = (m-D_MR); (i+1) > 0; i -= D_MR) //loop along 'M' direction { for(j = (n-D_NR); (j+1) > 0; j -= D_NR) //loop along 'N' direction { a01 = L + j*cs_a +(j+D_NR); //pointer to block of A to be used in GEMM a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be done(in blocks of 4x4) ymm0 = _mm256_setzero_pd(); ymm1 = _mm256_setzero_pd(); ymm2 = _mm256_setzero_pd(); ymm3 = _mm256_setzero_pd(); ymm4 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); ymm6 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); ///GEMM implementation starts/// for(k = 0; k < k_iter; k++) //loop for number of GEMM operations { ptr_a01_dup = a01; //broadcast 1st row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] a01 += 1; //move to next row //load 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) //broadcast 2nd row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] a01 += 1; //move to next row of A ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) //broadcast 3rd row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] a01 += 1; //move to next row of A01 //load next 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) //broadcast 4th row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] a01 += 1; //move to next row of A01 ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM } ///GEMM code ends/// ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //load 8x4 block of B11 ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0 ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1 ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2 ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3 ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4 ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5 ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6 ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7 ///implement TRSM/// ///read 4x4 block of A11/// ymm7 = _mm256_broadcast_sd((double const *)(&ones)); //1st col ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] //2nd col a11 += 1; ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1] ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][1] //3rd col a11 += 1; ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2] ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2] ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2] //4th col a11 += 1; ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3] //compute reciprocals of L(i,i) and broadcast in registers ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3] ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3] ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3] ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3] //extract a33 ymm0 = _mm256_permute_pd(ymm7, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) ymm11 = _mm256_mul_pd(ymm11, ymm0); ymm15 = _mm256_mul_pd(ymm15, ymm0); //extract a22 ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) //(row 3):FMA operations ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10); ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9); ymm8 = _mm256_fnmadd_pd(ymm11, ymm2, ymm8); ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14); ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13); ymm12 = _mm256_fnmadd_pd(ymm15, ymm2, ymm12); ymm10 = _mm256_mul_pd(ymm10, ymm0); ymm14 = _mm256_mul_pd(ymm14, ymm0); //extract a11 ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) //(Row 2): FMA operations ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9); ymm8 = _mm256_fnmadd_pd(ymm10, ymm3, ymm8); ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13); ymm12 = _mm256_fnmadd_pd(ymm14, ymm3, ymm12); ymm9 = _mm256_mul_pd(ymm9, ymm0); ymm13 = _mm256_mul_pd(ymm13, ymm0); //extract a00 ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) //(Row 1): FMA operations ymm8 = _mm256_fnmadd_pd(ymm9, ymm1, ymm8); ymm12 = _mm256_fnmadd_pd(ymm13, ymm1, ymm12); ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0] ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0] _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][2]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b), ymm11); //store(B11[0-3][3]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b + D_NR), ymm15);//store(B11[4-7][3]) } if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR) { a01 = L + j*cs_a + (j+D_NR); //pointer to block of A to be used for GEMM a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM b10 = B + i + (j + D_NR)*cs_b; //pointer to block of B to be used for GEMM b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be performed(in blocks of 4x4) ///load 4x4 block of b11 ymm0 = _mm256_setzero_pd(); ymm1 = _mm256_setzero_pd(); ymm2 = _mm256_setzero_pd(); ymm3 = _mm256_setzero_pd(); ymm4 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); ymm6 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); //subtract the calculated GEMM block from current TRSM block //load 8x4 block of B11 if(3 == n_remainder) { ///GEMM implementation begins/// for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations { ptr_a01_dup = a01; //broadcast 1st row of A01 ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] a01 += 1; //move to next row of A //load 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) //broadcast 2nd row of A01 ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] a01 += 1; //move to next row of A ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) //broadcast 3rd row of A01 ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] a01 += 1; //move to next row of A //load next 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1] )); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1] + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) //broadcast 4th row of A01 ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] a01 += 1; //move to next row of A ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM } ///GEMM code ends/// ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal); ymm9 = _mm256_loadu_pd((double const *)(b11+cs_b)); //B11[0-3][0] ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][0] ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0-3][1] ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + D_NR)); //B11[4-7][1] ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][2] ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][2] ymm9 = _mm256_fmsub_pd(ymm9, ymm8, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0] ymm10 = _mm256_fmsub_pd(ymm10, ymm8, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1] ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1] ymm13 = _mm256_fmsub_pd(ymm13, ymm8, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2] ymm14 = _mm256_fmsub_pd(ymm14, ymm8, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3] ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3] ///implement TRSM/// ///read 4x4 block of A11/// ymm7 = _mm256_broadcast_sd((double const *)(&ones)); //1st col ymm0 = _mm256_broadcast_sd((double const *)(&ones)); //A11[0][0] //2nd col a11 += 1; ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][1] //3rd col a11 += 1; ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2] ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2] //4th col a11 += 1; ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3] //compute reciprocals of L(i,i) and broadcast in registers ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3] ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3] ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3] //extract a33 ymm0 = _mm256_permute_pd(ymm7, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) ymm11 = _mm256_mul_pd(ymm11, ymm0); ymm15 = _mm256_mul_pd(ymm15, ymm0); //extract a22 ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) //(row 3):FMA operations ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10); ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9); ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14); ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13); ymm10 = _mm256_mul_pd(ymm10, ymm0); ymm14 = _mm256_mul_pd(ymm14, ymm0); //extract a11 ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) //(Row 2): FMA operations ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9); ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13); ymm9 = _mm256_mul_pd(ymm9, ymm0); ymm13 = _mm256_mul_pd(ymm13, ymm0); _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) } else if(2 == n_remainder) { ///GEMM implementation begins/// for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations { ptr_a01_dup = a01; //broadcast 1st row of A01 ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] a01 += 1; //move to next row of A //load 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) //broadcast 2nd row of A01 ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] a01 += 1; //move to next row of A ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) //broadcast 3rd row of A01 ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] a01 += 1; //move to next row of A //load next 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) //broadcast 4th row of A01 ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] a01 += 1; //move to next row of A ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM } ///GEMM code ends/// ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal); ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][0] ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][0] ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][1] ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][1] ymm10 = _mm256_fmsub_pd(ymm10, ymm8, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1] ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1] ymm14 = _mm256_fmsub_pd(ymm14, ymm8, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3] ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3] ///implement TRSM/// ///read 4x4 block of A11/// ymm7 = _mm256_broadcast_sd((double const *)(&ones)); //3rd col a11 += 2; ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2] //4th col a11 += 1; ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3] //compute reciprocals of L(i,i) and broadcast in registers ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3] ymm0 = _mm256_blend_pd(ymm7, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3] //extract a33 ymm0 = _mm256_permute_pd(ymm7, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) ymm11 = _mm256_mul_pd(ymm11, ymm0); ymm15 = _mm256_mul_pd(ymm15, ymm0); //extract a22 ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) //(row 3):FMA operations ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10); ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14); ymm10 = _mm256_mul_pd(ymm10, ymm0); ymm14 = _mm256_mul_pd(ymm14, ymm0); _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][1]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][1]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) } else if(1 == n_remainder) { ///GEMM implementation begins/// for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations { ptr_a01_dup = a01; //broadcast 1st row of A01 ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] a01 += 1; //move to next row of A //load 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) //broadcast 2nd row of A01 ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] a01 += 1; //move to next row of A ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) //broadcast 3rd row of A01 ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] a01 += 1; //move to next row of A //load next 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) //broadcast 4th row of A01 ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] a01 += 1; //move to next row of A ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM } ///GEMM code ends/// ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal); ymm11 = _mm256_loadu_pd((double const *)(b11+cs_b_offset[1])); //B11[0-3][0] ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] +D_NR)); //B11[4-7][0] ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1] ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3] ///implement TRSM/// ///read 4x4 block of A11/// ymm7 = _mm256_broadcast_sd((double const *)(&ones)); //4th col a11 += 3; ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3] //compute reciprocals of L(i,i) and broadcast in registers ymm7 = _mm256_div_pd(ymm7, ymm6); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) ymm11 = _mm256_mul_pd(ymm11, ymm7); ymm15 = _mm256_mul_pd(ymm15, ymm7); _mm256_storeu_pd((double *)(b11+ cs_b_offset[1]), ymm11); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) } } } if(i<0) i += D_NR; if((m & 4)) ///implementation for remainder rows(when m_remainder is greater than 4) { for(j = (n-D_NR); (j+1) > 0; j -=D_NR) //loop along n direction { a01 = L + j*cs_a + (j+D_NR); //pointer to block of A to be used for GEMM a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4) ///GEMM for previous blocks /// ///load 4x4 block of b11 ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] ymm4 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); ymm6 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); ///GEMM implementation starts/// for(k = 0; k < k_iter; k++) //loop for number of GEMM operations { ptr_a01_dup = a01; //load 4x4 bblock of b10 ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] //broadcast 1st row of A01 ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] a01 += 1; //move to next row of A ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) //broadcast 2nd row of A01 ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] a01 += 1; //move to next row of A ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0]) ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) //braodcast 3rd row of A01 ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] a01 += 1; //move to next row of A ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0]) ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) //broadcast 4th row of A01 ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] a01 += 1; //move to next row of A ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0]) ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM } ///GEMM code end/// ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[x][0] -=ymm4 ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5 ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6 ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7 ///implement TRSM/// ///read 4x4 block of A11/// //1st col ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][0] ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][0] ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][0] //2nd col a11 += cs_a; ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] ymm9 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[1][1] ymm10 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][1] //3rd col a11 += cs_a; ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][2] ymm12 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][2] //4th col a11 += cs_a; ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][3] ymm14 = _mm256_broadcast_sd((double const *)&ones); //compute reciprocals of A(i,i) and broadcast in registers ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3] ymm15 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] //extract a33 ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]) ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) ymm3 = _mm256_mul_pd(ymm3, ymm15); //extract a22 ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) //(Row 3): FMA operations ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1); ymm0 = _mm256_fnmadd_pd(ymm3, ymm7, ymm0); ymm2 = _mm256_mul_pd(ymm2, ymm15); //extract a11 ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) //(ROW 2): FMA operations ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1); ymm0 = _mm256_fnmadd_pd(ymm2, ymm6, ymm0); ymm1 = _mm256_mul_pd(ymm1, ymm15); //extract A00 ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) //(Row 1):FMA operations ymm0 = _mm256_fnmadd_pd(ymm1, ymm5, ymm0); ymm0 = _mm256_mul_pd(ymm0, ymm15); _mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0]) _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2])) _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3]) } if(n_remainder) //implementation for remainder columns(when n is not a multiple of D_NR) { a01 = L + j*cs_a + (j+D_NR); //pointer to block of A to be used for GEMM a11 = L + j*cs_a + j; //pointwr to block of A to be used for TRSM b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4) ymm4 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); ymm6 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); ///GEMM for previous blocks /// if(3 == n_remainder) { ///load 4x4 block of b11 ymm0 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] ymm1 = _mm256_loadu_pd((double const *)b11+ cs_b); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ///GEMM processing stars/// for(k = 0; k < k_iter; k++) { ptr_a01_dup = a01; //load 4x4 bblock of b10 ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] //broadcast 1st row of A01 ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] a01 += 1; //move to next row of A ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) //broadcast 2nd row of A01 ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] a01 += 1; //move to next row of A ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) //braodcast 3rd row of A01 ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] a01 += 1; //move to next row of A ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) //broadcast 4th row of A01 ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] a01 += 1; //move to next row of A ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM } ///GEMM code ends/// ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5 ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6 ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7 ///implement TRSM/// ///read 4x4 block of A11/// //2nd col a11 += cs_a; ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] ymm9 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[1][1] ymm10 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][1] //3rd col a11 += cs_a; ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][2] ymm12 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][2] //4th col a11 += cs_a; ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][3] ymm14 = _mm256_broadcast_sd((double const *)&ones); //compute reciprocals of A(i,i) and broadcast in registers ymm4 = _mm256_unpacklo_pd(ymm14, ymm8); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3] ymm15 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] //extract a33 ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]) ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) ymm3 = _mm256_mul_pd(ymm3, ymm15); //extract a22 ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) //(Row 3): FMA operations ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1); ymm2 = _mm256_mul_pd(ymm2, ymm15); //extract a11 ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) //(ROW 2): FMA operations ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1); ymm1 = _mm256_mul_pd(ymm1, ymm15); _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2])) _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[x][0]) } else if(2 == n_remainder) { ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ///GEMM processing stars/// for(k = 0; k < k_iter; k++) { ptr_a01_dup = a01; //load 4x4 bblock of b10 ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] //broadcast 1st row of A01 ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] a01 += 1; //move to next row of A ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) //broadcast 2nd row of A01 ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] a01 += 1; //move to next row of A ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) //braodcast 3rd row of A01 ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] a01 += 1; //move to next row of A ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) //broadcast 4th row of A01 ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] a01 += 1; //move to next row of A ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM } ///GEMM code ends/// ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6 ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7 ///implement TRSM/// ///read 4x4 block of A11/// //3rd col a11 += 2 * cs_a; ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][2] ymm12 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][2] //4th col a11 += cs_a; ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][3] ymm14 = _mm256_broadcast_sd((double const *)&ones); //compute reciprocals of A(i,i) and broadcast in registers ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3] ymm15 = _mm256_blend_pd(ymm14, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] //extract a33 ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]) ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) ymm3 = _mm256_mul_pd(ymm3, ymm15); //extract a22 ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) //(Row 3): FMA operations ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); ymm2 = _mm256_mul_pd(ymm2, ymm15); _mm256_storeu_pd((double *)(b11+ cs_b * 2), ymm2); //store(B11[x][0]) _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][1]) } else if(1 == n_remainder) { ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ///GEMM processing stars/// for(k = 0; k < k_iter; k++) { ptr_a01_dup = a01; //load 4x4 bblock of b10 ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] //broadcast 1st row of A01 ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] a01 += 1; //move to next row of A ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) //broadcast 2nd row of A01 ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] a01 += 1; //move to next row of A ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) //braodcast 3rd row of A01 ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] a01 += 1; //move to next row of A ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) //broadcast 4th row of A01 ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] a01 += 1; //move to next row of A ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM } ///GEMM code ends/// ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7 ///implement TRSM/// ///read 4x4 block of A11/// //4th col a11 += 3 * cs_a; ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][3] ymm14 = _mm256_broadcast_sd((double const *)&ones); //compute reciprocals of A(i,i) and broadcast in registers ymm14 = _mm256_div_pd(ymm14, ymm13); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] //extract a33 ymm3 = _mm256_mul_pd(ymm3, ymm14); _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][0]) } } m_remainder -= 4; i -= 4; } // if(i < 0) i = 0; if(m_remainder) ///implementation for remainder rows { dtrsm_small_XAlB(L, B, AlphaVal, m_remainder, n, cs_a, cs_b); } return BLIS_SUCCESS; } /*implements TRSM for the case XA = alpha * B *A is lower triangular, unit-diagonal, no transpose *dimensions: X:mxn A:nxn B: mxn */ /* <---b11 <---a11 ***************** * *b01*b11* * * * * ^ * * * * * ^ * * | ***************** | ******* | * * * * * | * * * | * * * * * a01* * * b10 ***************** ************* * * * * * * * * * * * * * * * * * * ***************** ******************* */ static err_t bli_dtrsm_small_XAlB_unitDiag( side_t side, obj_t* AlphaObj, obj_t* a, obj_t* b, cntx_t* cntx, cntl_t* cntl ) { dim_t D_MR = 8; //block dimension along the rows dim_t D_NR = 4; //block dimension along the columns dim_t m = bli_obj_length(b); //number of rows dim_t n = bli_obj_width(b); //number of columns dim_t m_remainder = m & 7; //number of corner rows dim_t n_remainder = n & 3; //number of corner columns dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B #ifdef BLIS_ENABLE_SMALL_MATRIX_ROME if((m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME) ||(m > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME_COL_PANEL_N) ) return BLIS_NOT_YET_IMPLEMENTED; #else if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES) { return BLIS_NOT_YET_IMPLEMENTED; } #endif dim_t i, j, k; //loop variablse dim_t k_iter; //determines the number of GEMM operations to be done dim_t cs_b_offset[2]; //pre-calculated strides double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha double* restrict L = a->buffer; //pointer to matrix A double* restrict B = b->buffer; //pointer to matrix B double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks double *ptr_a01_dup; cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2; cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3; //ymm scratch reginsters __m256d ymm0, ymm1, ymm2, ymm3; __m256d ymm4, ymm5, ymm6, ymm7; __m256d ymm8, ymm9, ymm10, ymm11; __m256d ymm12, ymm13, ymm14, ymm15; __m256d ymm16; for(i = (m-D_MR); (i+1) > 0; i -= D_MR) //loop along 'M' direction { for(j = (n-D_NR); (j+1) > 0; j -= D_NR) //loop along 'N' direction { a01 = L + j*cs_a +(j+D_NR); //pointer to block of A to be used in GEMM a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be done(in blocks of 4x4) ymm0 = _mm256_setzero_pd(); ymm1 = _mm256_setzero_pd(); ymm2 = _mm256_setzero_pd(); ymm3 = _mm256_setzero_pd(); ymm4 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); ymm6 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); ///GEMM implementation starts/// for(k = 0; k < k_iter; k++) //loop for number of GEMM operations { ptr_a01_dup = a01; //broadcast 1st row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] a01 += 1; //move to next row //load 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) //broadcast 2nd row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] a01 += 1; //move to next row of A ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) //broadcast 3rd row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] a01 += 1; //move to next row of A01 //load next 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) //broadcast 4th row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] a01 += 1; //move to next row of A01 ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM } ///GEMM code ends/// ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //load 8x4 block of B11 ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0 ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1 ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2 ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3 ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4 ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5 ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6 ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7 ///implement TRSM/// ///read 4x4 block of A11/// //2nd col a11 += 1; ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1] //3rd col a11 += 1; ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2] ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2] //4th col a11 += 1; ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3] ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3] ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3] //(row 3):FMA operations ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10); ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9); ymm8 = _mm256_fnmadd_pd(ymm11, ymm2, ymm8); ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14); ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13); ymm12 = _mm256_fnmadd_pd(ymm15, ymm2, ymm12); //(Row 2): FMA operations ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9); ymm8 = _mm256_fnmadd_pd(ymm10, ymm3, ymm8); ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13); ymm12 = _mm256_fnmadd_pd(ymm14, ymm3, ymm12); //(Row 1): FMA operations ymm8 = _mm256_fnmadd_pd(ymm9, ymm1, ymm8); ymm12 = _mm256_fnmadd_pd(ymm13, ymm1, ymm12); _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][2]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b), ymm11); //store(B11[0-3][3]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b + D_NR), ymm15);//store(B11[4-7][3]) } if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR) { a01 = L + j*cs_a + (j+D_NR); //pointer to block of A to be used for GEMM a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM b10 = B + i + (j + D_NR)*cs_b; //pointer to block of B to be used for GEMM b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be performed(in blocks of 4x4) ///load 4x4 block of b11 ymm0 = _mm256_setzero_pd(); ymm1 = _mm256_setzero_pd(); ymm2 = _mm256_setzero_pd(); ymm3 = _mm256_setzero_pd(); ymm4 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); ymm6 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); //subtract the calculated GEMM block from current TRSM block //load 8x4 block of B11 if(3 == n_remainder) { ///GEMM implementation begins/// for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations { ptr_a01_dup = a01; //broadcast 1st row of A01 ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] a01 += 1; //move to next row of A //load 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) //broadcast 2nd row of A01 ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] a01 += 1; //move to next row of A ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) //broadcast 3rd row of A01 ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] a01 += 1; //move to next row of A //load next 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1] )); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1] + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) //broadcast 4th row of A01 ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] a01 += 1; //move to next row of A ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM } ///GEMM code ends/// ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal); ymm9 = _mm256_loadu_pd((double const *)(b11+cs_b)); //B11[0-3][0] ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][0] ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0-3][1] ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + D_NR)); //B11[4-7][1] ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][2] ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][2] ymm9 = _mm256_fmsub_pd(ymm9, ymm8, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0] ymm10 = _mm256_fmsub_pd(ymm10, ymm8, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1] ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1] ymm13 = _mm256_fmsub_pd(ymm13, ymm8, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2] ymm14 = _mm256_fmsub_pd(ymm14, ymm8, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3] ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3] ///implement TRSM/// ///read 4x4 block of A11/// //3rd col a11 += 2; ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2] //4th col a11 += 1; ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3] ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3] //(row 3):FMA operations ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10); ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9); ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14); ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13); //(Row 2): FMA operations ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9); ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13); _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) } else if(2 == n_remainder) { ///GEMM implementation begins/// for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations { ptr_a01_dup = a01; //broadcast 1st row of A01 ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] a01 += 1; //move to next row of A //load 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) //broadcast 2nd row of A01 ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] a01 += 1; //move to next row of A ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) //broadcast 3rd row of A01 ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] a01 += 1; //move to next row of A //load next 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) //broadcast 4th row of A01 ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] a01 += 1; //move to next row of A ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM } ///GEMM code ends/// ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal); ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][0] ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][0] ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][1] ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][1] ymm10 = _mm256_fmsub_pd(ymm10, ymm8, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1] ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1] ymm14 = _mm256_fmsub_pd(ymm14, ymm8, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3] ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3] ///implement TRSM/// ///read 4x4 block of A11/// //4th col a11 += 3; ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3] //(row 3):FMA operations ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10); ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14); _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][1]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][1]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) } else if(1 == n_remainder) { ///GEMM implementation begins/// for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations { ptr_a01_dup = a01; //broadcast 1st row of A01 ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] a01 += 1; //move to next row of A //load 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) //broadcast 2nd row of A01 ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] a01 += 1; //move to next row of A ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) //broadcast 3rd row of A01 ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] a01 += 1; //move to next row of A //load next 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) //broadcast 4th row of A01 ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] a01 += 1; //move to next row of A ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM } ///GEMM code ends/// ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal); ymm11 = _mm256_loadu_pd((double const *)(b11+cs_b_offset[1])); //B11[0-3][0] ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] +D_NR)); //B11[4-7][0] ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1] ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3] _mm256_storeu_pd((double *)(b11+ cs_b_offset[1]), ymm11); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) } } } if(i<0) i += D_NR; if((m & 4)) ///implementation for remainder rows(when m_remainder is greater than 4) { for(j = (n-D_NR); (j+1) > 0; j -=D_NR) //loop along n direction { a01 = L + j*cs_a + (j+D_NR); //pointer to block of A to be used for GEMM a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4) ///GEMM for previous blocks /// ///load 4x4 block of b11 ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] ymm4 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); ymm6 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); ///GEMM implementation starts/// for(k = 0; k < k_iter; k++) //loop for number of GEMM operations { ptr_a01_dup = a01; //load 4x4 bblock of b10 ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] //broadcast 1st row of A01 ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] a01 += 1; //move to next row of A ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) //broadcast 2nd row of A01 ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] a01 += 1; //move to next row of A ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0]) ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) //braodcast 3rd row of A01 ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] a01 += 1; //move to next row of A ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0]) ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) //broadcast 4th row of A01 ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] a01 += 1; //move to next row of A ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0]) ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM } ///GEMM code end/// ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[x][0] -=ymm4 ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5 ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6 ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7 ///implement TRSM/// ///read 4x4 block of A11/// //1st col ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][0] ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][0] ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][0] //2nd col a11 += cs_a; ymm9 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[1][1] ymm10 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][1] //3rd col a11 += cs_a; ymm12 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][2] //(Row 3): FMA operations ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1); ymm0 = _mm256_fnmadd_pd(ymm3, ymm7, ymm0); //(ROW 2): FMA operations ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1); ymm0 = _mm256_fnmadd_pd(ymm2, ymm6, ymm0); //(Row 1):FMA operations ymm0 = _mm256_fnmadd_pd(ymm1, ymm5, ymm0); _mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0]) _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2])) _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3]) } if(n_remainder) //implementation for remainder columns(when n is not a multiple of D_NR) { a01 = L + j*cs_a + (j+D_NR); //pointer to block of A to be used for GEMM a11 = L + j*cs_a + j; //pointwr to block of A to be used for TRSM b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4) ymm4 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); ymm6 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); ///GEMM for previous blocks /// if(3 == n_remainder) { ///load 4x4 block of b11 ymm1 = _mm256_loadu_pd((double const *)b11+ cs_b); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ///GEMM processing stars/// for(k = 0; k < k_iter; k++) { ptr_a01_dup = a01; //load 4x4 bblock of b10 ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] //broadcast 1st row of A01 ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] a01 += 1; //move to next row of A ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) //broadcast 2nd row of A01 ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] a01 += 1; //move to next row of A ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) //braodcast 3rd row of A01 ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] a01 += 1; //move to next row of A ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) //broadcast 4th row of A01 ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] a01 += 1; //move to next row of A ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM } ///GEMM code ends/// ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5 ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6 ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7 ///implement TRSM/// ///read 4x4 block of A11/// //2nd col a11 += cs_a; ymm9 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[1][1] ymm10 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][1] //3rd col a11 += cs_a; ymm12 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][2] //(Row 3): FMA operations ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1); //(ROW 2): FMA operations ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1); _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2])) _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[x][0]) } else if(2 == n_remainder) { ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ///GEMM processing stars/// for(k = 0; k < k_iter; k++) { ptr_a01_dup = a01; //load 4x4 bblock of b10 ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] //broadcast 1st row of A01 ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] a01 += 1; //move to next row of A ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) //broadcast 2nd row of A01 ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] a01 += 1; //move to next row of A ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) //braodcast 3rd row of A01 ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] a01 += 1; //move to next row of A ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) //broadcast 4th row of A01 ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] a01 += 1; //move to next row of A ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM } ///GEMM code ends/// ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6 ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7 ///implement TRSM/// ///read 4x4 block of A11/// //3rd col a11 += 2 * cs_a; ymm12 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][2] //(Row 3): FMA operations ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); _mm256_storeu_pd((double *)(b11+ cs_b * 2), ymm2); //store(B11[x][0]) _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][1]) } else if(1 == n_remainder) { ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ///GEMM processing stars/// for(k = 0; k < k_iter; k++) { ptr_a01_dup = a01; //load 4x4 bblock of b10 ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] //broadcast 1st row of A01 ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] a01 += 1; //move to next row of A ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) //broadcast 2nd row of A01 ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] a01 += 1; //move to next row of A ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) //braodcast 3rd row of A01 ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] a01 += 1; //move to next row of A ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) //broadcast 4th row of A01 ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] a01 += 1; //move to next row of A ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM } ///GEMM code ends/// ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7 _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][0]) } } m_remainder -= 4; i -= 4; } // if(i < 0) i = 0; if(m_remainder) ///implementation for remainder rows { dtrsm_small_XAlB_unitDiag(L, B, AlphaVal, m_remainder, n, cs_a, cs_b); } return BLIS_SUCCESS; } /*implements TRSM for the case XA = alpha * B *A is lower triangular, non-unit diagonal, no transpose *dimensions: X:mxn A:nxn B: mxn */ /* <---b11 <---a11 ***************** * *b01*b11* * * * * ^ * * * * * ^ * * | ***************** | ******* | * * * * * | * * * | * * * * * a01* * * b10 ***************** ************* * * * * * * * * * * * * * * * * * * ***************** ******************* */ static err_t bli_dtrsm_small_XAutB( side_t side, obj_t* AlphaObj, obj_t* a, obj_t* b, cntx_t* cntx, cntl_t* cntl ) { dim_t D_MR = 8; //block dimension along the rows dim_t D_NR = 4; //block dimension along the columns dim_t m = bli_obj_length(b); //number of rows dim_t n = bli_obj_width(b); //number of columns dim_t m_remainder = m & 7; //number of corner rows dim_t n_remainder = n & 3; //number of corner columns dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B #ifdef BLIS_ENABLE_SMALL_MATRIX_ROME if((m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUTB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUTB_ROME) ||(m > D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUTB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME_COL_PANEL_N) ) return BLIS_NOT_YET_IMPLEMENTED; #else if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES) { return BLIS_NOT_YET_IMPLEMENTED; } #endif dim_t i, j, k; //loop variablse dim_t k_iter; //determines the number of GEMM operations to be done dim_t cs_b_offset[2]; //pre-calculated strides double ones = 1.0; double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha double* restrict L = a->buffer; //pointer to matrix A double* restrict B = b->buffer; //pointer to matrix B double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks double *ptr_a01_dup; cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2; cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3; //ymm scratch reginsters __m256d ymm0, ymm1, ymm2, ymm3; __m256d ymm4, ymm5, ymm6, ymm7; __m256d ymm8, ymm9, ymm10, ymm11; __m256d ymm12, ymm13, ymm14, ymm15; __m256d ymm16; for(i = (m-D_MR); (i+1) > 0; i -= D_MR) //loop along 'M' direction { for(j = (n-D_NR); (j+1) > 0; j -= D_NR) //loop along 'N' direction { a01 = L + (j+D_NR)*cs_a +(j); //pointer to block of A to be used in GEMM a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be done(in blocks of 4x4) ymm0 = _mm256_setzero_pd(); ymm1 = _mm256_setzero_pd(); ymm2 = _mm256_setzero_pd(); ymm3 = _mm256_setzero_pd(); ymm4 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); ymm6 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); ///GEMM implementation starts/// for(k = 0; k < k_iter; k++) //loop for number of GEMM operations { ptr_a01_dup = a01; //broadcast 1st row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] a01 += cs_a; //move to next row //load 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) //broadcast 2nd row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] a01 += cs_a; //move to next row of A ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) //broadcast 3rd row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] a01 += cs_a; //move to next row of A01 //load next 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) //broadcast 4th row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] a01 += cs_a; //move to next row of A01 ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM } ///GEMM code ends/// ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //load 8x4 block of B11 ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0 ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1 ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2 ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3 ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4 ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5 ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6 ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7 ///implement TRSM/// ///read 4x4 block of A11/// //1st col ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] a11 += cs_a; //2nd col ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] a11 += cs_a; //3rd col ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] a11 += cs_a; //4th col ymm6 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1] ymm7 = _mm256_broadcast_sd((double const *)&ones); //compute reciprocals of A(i,i) and broadcast in registers ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[2][2] A11[3][3] ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] ymm0 = _mm256_div_pd(ymm7, ymm0); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] ymm2 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] //extract a33 ymm7 = _mm256_permute_pd(ymm0, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]) ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) ymm11 = _mm256_mul_pd(ymm11, ymm7); ymm15 = _mm256_mul_pd(ymm15, ymm7); //extract a22 ymm7 = _mm256_permute_pd(ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) //(Row 3): FMA operations ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10); ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9); ymm8 = _mm256_fnmadd_pd(ymm11, ymm2, ymm8); //(Row 3): FMA operations ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14); ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13); ymm12 = _mm256_fnmadd_pd(ymm15, ymm2, ymm12); ymm10 = _mm256_mul_pd(ymm10, ymm7); ymm14 = _mm256_mul_pd(ymm14, ymm7); //extract a11 ymm7 = _mm256_permute_pd(ymm0, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) //(ROW 2): FMA operations ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9); ymm8 = _mm256_fnmadd_pd(ymm10, ymm3, ymm8); ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13); ymm12 = _mm256_fnmadd_pd(ymm14, ymm3, ymm12); ymm9 = _mm256_mul_pd(ymm9, ymm7); ymm13 = _mm256_mul_pd(ymm13, ymm7); //extract A00 ymm7 = _mm256_permute_pd(ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) //(Row 1):FMA operations ymm8 = _mm256_fnmadd_pd(ymm9, ymm1, ymm8); ymm12 = _mm256_fnmadd_pd(ymm13, ymm1, ymm12); ymm8 = _mm256_mul_pd(ymm8, ymm7); ymm12 = _mm256_mul_pd(ymm12, ymm7); _mm256_storeu_pd((double *)b11, ymm8); //store(B11[x][0]) _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[x][0]) _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[x][1]) _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[x][1]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //(store(B11[x][2])) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //(store(B11[x][2])) _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[x][3]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[x][3]) } if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR) { a01 = L + (j+D_NR)*cs_a +(j); //pointer to block of A to be used in GEMM a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be done(in blocks of 4x4) ymm0 = _mm256_setzero_pd(); ymm1 = _mm256_setzero_pd(); ymm2 = _mm256_setzero_pd(); ymm3 = _mm256_setzero_pd(); ymm4 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); ymm6 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); //load 8x4 block of B11 if(3 == n_remainder) { ///GEMM implementation starts/// for(k = 0; k < k_iter; k++) //loop for number of GEMM operations { ptr_a01_dup = a01; //broadcast 1st row of A01 ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] a01 += cs_a; //move to next row //load 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) //broadcast 2nd row of A01 ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] a01 += cs_a; //move to next row of A ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) //broadcast 3rd row of A01 ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] a01 += cs_a; //move to next row of A01 //load next 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) //broadcast 4th row of A01 ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] a01 += cs_a; //move to next row of A01 ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM } ///GEMM code ends/// ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal); ymm9 = _mm256_loadu_pd((double const *)(b11+cs_b)); //B11[0-3][0] ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][0] ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0-3][1] ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + D_NR)); //B11[4-7][1] ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][2] ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][2] ymm9 = _mm256_fmsub_pd(ymm9, ymm8, ymm1); //B11[4-7][0] * alpha-= ymm1 ymm10 = _mm256_fmsub_pd(ymm10, ymm8, ymm2); //B11[0-3][1] * alpha-= ymm2 ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= ymm3 ymm13 = _mm256_fmsub_pd(ymm13, ymm8, ymm5); //B11[4-7][2] * alpha -= ymm5 ymm14 = _mm256_fmsub_pd(ymm14, ymm8, ymm6); //B11[0-3][3] * alpha -= ymm6 ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= ymm7 ///implement TRSM/// ///read 4x4 block of A11/// //1st col ymm0 = _mm256_broadcast_sd((double const *)(&ones)); //A11[0][0] a11 += cs_a; //2nd col ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] a11 += cs_a; //3rd col ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] a11 += cs_a; //4th col ymm6 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1] ymm7 = _mm256_broadcast_sd((double const *)&ones); //compute reciprocals of A(i,i) and broadcast in registers ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[2][2] A11[3][3] ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] ymm0 = _mm256_div_pd(ymm7, ymm0); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] //extract a33 ymm7 = _mm256_permute_pd(ymm0, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]) ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) ymm11 = _mm256_mul_pd(ymm11, ymm7); ymm15 = _mm256_mul_pd(ymm15, ymm7); //extract a22 ymm7 = _mm256_permute_pd(ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) //(Row 3): FMA operations ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10); ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9); //(Row 3): FMA operations ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14); ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13); ymm10 = _mm256_mul_pd(ymm10, ymm7); ymm14 = _mm256_mul_pd(ymm14, ymm7); //extract a11 ymm7 = _mm256_permute_pd(ymm0, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) //(ROW 2): FMA operations ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9); ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13); ymm9 = _mm256_mul_pd(ymm9, ymm7); ymm13 = _mm256_mul_pd(ymm13, ymm7); _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) } else if(2 == n_remainder) { ///GEMM implementation starts/// for(k = 0; k < k_iter; k++) //loop for number of GEMM operations { ptr_a01_dup = a01; //broadcast 1st row of A01 ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] a01 += cs_a; //move to next row //load 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) //broadcast 2nd row of A01 ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] a01 += cs_a; //move to next row of A ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) //broadcast 3rd row of A01 ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] a01 += cs_a; //move to next row of A01 //load next 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) //broadcast 4th row of A01 ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] a01 += cs_a; //move to next row of A01 ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM } ///GEMM code ends/// ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal); ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][0] ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][1] ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][1] ymm10 = _mm256_fmsub_pd(ymm10, ymm8, ymm2); //B11[0-3][1] * alpha-= ymm2 ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= ymm3 ymm14 = _mm256_fmsub_pd(ymm14, ymm8, ymm6); //B11[0-3][3] * alpha -= ymm6 ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= ymm7 ///implement TRSM/// ///read 4x4 block of A11/// //1st col a11 += 2 * cs_a; //3rd col ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] a11 += cs_a; //4th col ymm6 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1] ymm7 = _mm256_broadcast_sd((double const *)&ones); //compute reciprocals of A(i,i) and broadcast in registers ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[2][2] A11[3][3] ymm0 = _mm256_blend_pd(ymm7, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] ymm0 = _mm256_div_pd(ymm7, ymm0); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] //extract a33 ymm7 = _mm256_permute_pd(ymm0, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]) ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) ymm11 = _mm256_mul_pd(ymm11, ymm7); ymm15 = _mm256_mul_pd(ymm15, ymm7); //extract a22 ymm7 = _mm256_permute_pd(ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) //(Row 3): FMA operations ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10); //(Row 3): FMA operations ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14); ymm10 = _mm256_mul_pd(ymm10, ymm7); ymm14 = _mm256_mul_pd(ymm14, ymm7); _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][1]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][1]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) } else if(1 == n_remainder) { ///GEMM implementation starts/// for(k = 0; k < k_iter; k++) //loop for number of GEMM operations { ptr_a01_dup = a01; //broadcast 1st row of A01 ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] a01 += cs_a; //move to next row //load 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) //broadcast 2nd row of A01 ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] a01 += cs_a; //move to next row of A ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) //broadcast 3rd row of A01 ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] a01 += cs_a; //move to next row of A01 //load next 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) //broadcast 4th row of A01 ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] a01 += cs_a; //move to next row of A01 ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM } ///GEMM code ends/// ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal); ymm11 = _mm256_loadu_pd((double const *)(b11+cs_b_offset[1])); //B11[0-3][0] ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] +D_NR)); //B11[4-7][0] ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= ymm3 ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= ymm7 ///implement TRSM/// ///read 4x4 block of A11/// a11 += 3 * cs_a; //4th col ymm6 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1] ymm7 = _mm256_broadcast_sd((double const *)&ones); ymm0 = _mm256_div_pd(ymm7, ymm6); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] ymm11 = _mm256_mul_pd(ymm11, ymm0); ymm15 = _mm256_mul_pd(ymm15, ymm0); _mm256_storeu_pd((double *)(b11+ cs_b_offset[1]), ymm11); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) } } } if(i<0) i += D_NR; if((m & 4)) ///implementation for remainder rows(when m_remainder is greater than 4) { for(j = (n-D_NR); (j+1) > 0; j -=D_NR) //loop along n direction { a01 = L + (j+D_NR)*cs_a + (j); //pointer to block of A to be used for GEMM a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4) ///GEMM for previous blocks /// ///load 4x4 block of b11 ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] ymm4 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); ymm6 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); ///GEMM implementation starts/// for(k = 0; k < k_iter; k++) //loop for number of GEMM operations { ptr_a01_dup = a01; //load 4x4 bblock of b10 ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] //broadcast 1st row of A01 ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] a01 += cs_a; //move to next row of A ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) //broadcast 2nd row of A01 ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] a01 += cs_a; //move to next row of A ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0]) ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) //braodcast 3rd row of A01 ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] a01 += cs_a; //move to next row of A ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0]) ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) //broadcast 4th row of A01 ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] a01 += cs_a; //move to next row of A ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0]) ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM a01 = ptr_a01_dup + D_NR*cs_a; //pointer math to find next block of A for GEMM } ///GEMM code end/// ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[x][0] -=ymm4 ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5 ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6 ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7 ///implement TRSM/// ///read 4x4 block of A11/// //1st col ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] a11 += cs_a; //2nd col ymm5 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] a11 += cs_a; //3rd col ymm6 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] ymm9 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] a11 += cs_a; //4th col ymm7 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] ymm10 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1] ymm14 = _mm256_broadcast_sd((double const *)&ones); //compute reciprocals of A(i,i) and broadcast in registers ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3] ymm15 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] //extract a33 ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]) ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) ymm3 = _mm256_mul_pd(ymm3, ymm15); //extract a22 ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) //(Row 3): FMA operations ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1); ymm0 = _mm256_fnmadd_pd(ymm3, ymm7, ymm0); ymm2 = _mm256_mul_pd(ymm2, ymm15); //extract a11 ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) //(ROW 2): FMA operations ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1); ymm0 = _mm256_fnmadd_pd(ymm2, ymm6, ymm0); ymm1 = _mm256_mul_pd(ymm1, ymm15); //extract A00 ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) //(Row 1):FMA operations ymm0 = _mm256_fnmadd_pd(ymm1, ymm5, ymm0); ymm0 = _mm256_mul_pd(ymm0, ymm15); _mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0]) _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2])) _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3]) } if(n_remainder) //implementation for remainder columns(when n is not a multiple of D_NR) { a01 = L + (j+D_NR)*cs_a + (j); //pointer to block of A to be used for GEMM a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4) ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha ///GEMM for previous blocks /// ymm4 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); ymm6 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); ///load 4x4 block of b11 if(3 == n_remainder) { ymm1 = _mm256_loadu_pd((double const *)b11+ cs_b); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ///GEMM implementation starts/// for(k = 0; k < k_iter; k++) //loop for number of GEMM operations { ptr_a01_dup = a01; //load 4x4 bblock of b10 ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] //broadcast 1st row of A01 ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] a01 += cs_a; //move to next row of A ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) //broadcast 2nd row of A01 ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] a01 += cs_a; //move to next row of A ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) //braodcast 3rd row of A01 ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] a01 += cs_a; //move to next row of A ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) //broadcast 4th row of A01 ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] a01 += cs_a; //move to next row of A ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM } ///GEMM code end/// ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5 ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6 ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7 ///implement TRSM/// ///read 4x4 block of A11/// //1st col ymm4 = _mm256_broadcast_sd((double const *)(&ones)); //A11[0][0] a11 += cs_a; //2nd col ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] a11 += cs_a; //3rd col ymm9 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] a11 += cs_a; //4th col ymm10 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1] ymm14 = _mm256_broadcast_sd((double const *)&ones); //compute reciprocals of A(i,i) and broadcast in registers ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3] ymm15 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] //extract a33 ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]) ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) ymm3 = _mm256_mul_pd(ymm3, ymm15); //extract a22 ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) //(Row 3): FMA operations ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1); ymm2 = _mm256_mul_pd(ymm2, ymm15); //extract a11 ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) //(ROW 2): FMA operations ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1); ymm1 = _mm256_mul_pd(ymm1, ymm15); _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2])) _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[x][0]) } else if(2 == n_remainder) { ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ///GEMM implementation starts/// for(k = 0; k < k_iter; k++) //loop for number of GEMM operations { ptr_a01_dup = a01; //load 4x4 bblock of b10 ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] //broadcast 1st row of A01 ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] a01 += cs_a; //move to next row of A ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) //broadcast 2nd row of A01 ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] a01 += cs_a; //move to next row of A ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) //braodcast 3rd row of A01 ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] a01 += cs_a; //move to next row of A ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) //broadcast 4th row of A01 ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] a01 += cs_a; //move to next row of A ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM } ///GEMM code end/// ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6 ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7 ///implement TRSM/// ///read 4x4 block of A11/// //1st col a11 += 2 * cs_a; //3rd col ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] a11 += cs_a; //4th col ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1] ymm14 = _mm256_broadcast_sd((double const *)&ones); //compute reciprocals of A(i,i) and broadcast in registers ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3] ymm15 = _mm256_blend_pd(ymm14, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] //extract a33 ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]) ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) ymm3 = _mm256_mul_pd(ymm3, ymm15); //extract a22 ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) //(Row 3): FMA operations ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); ymm2 = _mm256_mul_pd(ymm2, ymm15); _mm256_storeu_pd((double *)(b11+ cs_b * 2), ymm2); //store(B11[x][0]) _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][1]) } else if(1 == n_remainder) { ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ///GEMM implementation starts/// for(k = 0; k < k_iter; k++) //loop for number of GEMM operations { ptr_a01_dup = a01; //load 4x4 bblock of b10 ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] //broadcast 1st row of A01 ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] a01 += cs_a; //move to next row of A ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) //broadcast 2nd row of A01 ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] a01 += cs_a; //move to next row of A ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) //braodcast 3rd row of A01 ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] a01 += cs_a; //move to next row of A ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) //broadcast 4th row of A01 ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] a01 += cs_a; //move to next row of A ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM } ///GEMM code end/// ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7 ///implement TRSM/// ///read 4x4 block of A11/// a11 += 3 * cs_a; //4th col ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1] ymm14 = _mm256_broadcast_sd((double const *)&ones); //compute reciprocals of A(i,i) and broadcast in registers ymm14 = _mm256_div_pd(ymm14, ymm13); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] ymm3 = _mm256_mul_pd(ymm3, ymm14); _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][0]) } } m_remainder -= 4; i -= 4; } if(m_remainder) ///implementation for remainder rows { dtrsm_small_XAutB(L, B, AlphaVal, m_remainder, n, cs_a, cs_b); } return BLIS_SUCCESS; } /*implements TRSM for the case XA = alpha * B *A is lower triangular, unit-diagonal, no transpose *dimensions: X:mxn A:nxn B: mxn */ /* <---b11 <---a11 ***************** * *b01*b11* * * * * ^ * * * * * ^ * * | ***************** | ******* | * * * * * | * * * | * * * * * a01* * * b10 ***************** ************* * * * * * * * * * * * * * * * * * * ***************** ******************* */ static err_t bli_dtrsm_small_XAutB_unitDiag( side_t side, obj_t* AlphaObj, obj_t* a, obj_t* b, cntx_t* cntx, cntl_t* cntl ) { dim_t D_MR = 8; //block dimension along the rows dim_t D_NR = 4; //block dimension along the columns dim_t m = bli_obj_length(b); //number of rows dim_t n = bli_obj_width(b); //number of columns dim_t m_remainder = m & 7; //number of corner rows dim_t n_remainder = n & 3; //number of corner columns dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B #ifdef BLIS_ENABLE_SMALL_MATRIX_ROME if((m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUTB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUTB_ROME) ||(m > D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUTB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME_COL_PANEL_N) ) return BLIS_NOT_YET_IMPLEMENTED; #else if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES) { return BLIS_NOT_YET_IMPLEMENTED; } #endif dim_t i, j, k; //loop variablse dim_t k_iter; //determines the number of GEMM operations to be done dim_t cs_b_offset[2]; //pre-calculated strides double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha double* restrict L = a->buffer; //pointer to matrix A double* restrict B = b->buffer; //pointer to matrix B double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks double *ptr_a01_dup; cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2; cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3; //ymm scratch reginsters __m256d ymm0, ymm1, ymm2, ymm3; __m256d ymm4, ymm5, ymm6, ymm7; __m256d ymm8, ymm9, ymm10, ymm11; __m256d ymm12, ymm13, ymm14, ymm15; __m256d ymm16; for(i = (m-D_MR); (i+1) > 0; i -= D_MR) //loop along 'M' direction { for(j = (n-D_NR); (j+1) > 0; j -= D_NR) //loop along 'N' direction { a01 = L + (j+D_NR)*cs_a +(j); //pointer to block of A to be used in GEMM a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be done(in blocks of 4x4) ymm0 = _mm256_setzero_pd(); ymm1 = _mm256_setzero_pd(); ymm2 = _mm256_setzero_pd(); ymm3 = _mm256_setzero_pd(); ymm4 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); ymm6 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); ///GEMM implementation starts/// for(k = 0; k < k_iter; k++) //loop for number of GEMM operations { ptr_a01_dup = a01; //broadcast 1st row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] a01 += cs_a; //move to next row //load 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) //broadcast 2nd row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] a01 += cs_a; //move to next row of A ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) //broadcast 3rd row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] a01 += cs_a; //move to next row of A01 //load next 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) //broadcast 4th row of A01 ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] a01 += cs_a; //move to next row of A01 ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM } ///GEMM code ends/// ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //load 8x4 block of B11 ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0 ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1 ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2 ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3 ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4 ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5 ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6 ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7 ///implement TRSM/// ///read 4x4 block of A11/// //1st col a11 += cs_a; //2nd col ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] a11 += cs_a; //3rd col ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] a11 += cs_a; //4th col ymm2 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] //(Row 3): FMA operations ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10); ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9); ymm8 = _mm256_fnmadd_pd(ymm11, ymm2, ymm8); //(Row 3): FMA operations ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14); ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13); ymm12 = _mm256_fnmadd_pd(ymm15, ymm2, ymm12); //(ROW 2): FMA operations ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9); ymm8 = _mm256_fnmadd_pd(ymm10, ymm3, ymm8); ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13); ymm12 = _mm256_fnmadd_pd(ymm14, ymm3, ymm12); //(Row 1):FMA operations ymm8 = _mm256_fnmadd_pd(ymm9, ymm1, ymm8); ymm12 = _mm256_fnmadd_pd(ymm13, ymm1, ymm12); _mm256_storeu_pd((double *)b11, ymm8); //store(B11[x][0]) _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[x][0]) _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[x][1]) _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[x][1]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //(store(B11[x][2])) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //(store(B11[x][2])) _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[x][3]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[x][3]) } if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR) { a01 = L + (j+D_NR)*cs_a +(j); //pointer to block of A to be used in GEMM a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be done(in blocks of 4x4) ymm0 = _mm256_setzero_pd(); ymm1 = _mm256_setzero_pd(); ymm2 = _mm256_setzero_pd(); ymm3 = _mm256_setzero_pd(); ymm4 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); ymm6 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); //load 8x4 block of B11 if(3 == n_remainder) { ///GEMM implementation starts/// for(k = 0; k < k_iter; k++) //loop for number of GEMM operations { ptr_a01_dup = a01; //broadcast 1st row of A01 ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] a01 += cs_a; //move to next row //load 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) //broadcast 2nd row of A01 ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] a01 += cs_a; //move to next row of A ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) //broadcast 3rd row of A01 ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] a01 += cs_a; //move to next row of A01 //load next 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) //broadcast 4th row of A01 ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] a01 += cs_a; //move to next row of A01 ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM } ///GEMM code ends/// ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal); ymm9 = _mm256_loadu_pd((double const *)(b11+cs_b)); //B11[0-3][0] ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][0] ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0-3][1] ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + D_NR)); //B11[4-7][1] ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][2] ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][2] ymm9 = _mm256_fmsub_pd(ymm9, ymm8, ymm1); //B11[4-7][0] * alpha-= ymm1 ymm10 = _mm256_fmsub_pd(ymm10, ymm8, ymm2); //B11[0-3][1] * alpha-= ymm2 ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= ymm3 ymm13 = _mm256_fmsub_pd(ymm13, ymm8, ymm5); //B11[4-7][2] * alpha -= ymm5 ymm14 = _mm256_fmsub_pd(ymm14, ymm8, ymm6); //B11[0-3][3] * alpha -= ymm6 ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= ymm7 ///implement TRSM/// ///read 4x4 block of A11/// a11 += 2 * cs_a; //3rd col ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] a11 += cs_a; //4th col ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] //(Row 3): FMA operations ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10); ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9); //(Row 3): FMA operations ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14); ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13); //(ROW 2): FMA operations ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9); ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13); _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) } else if(2 == n_remainder) { ///GEMM implementation starts/// for(k = 0; k < k_iter; k++) //loop for number of GEMM operations { ptr_a01_dup = a01; //broadcast 1st row of A01 ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] a01 += cs_a; //move to next row //load 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) //broadcast 2nd row of A01 ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] a01 += cs_a; //move to next row of A ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) //broadcast 3rd row of A01 ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] a01 += cs_a; //move to next row of A01 //load next 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) //broadcast 4th row of A01 ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] a01 += cs_a; //move to next row of A01 ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM } ///GEMM code ends/// ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal); ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][0] ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][1] ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][1] ymm10 = _mm256_fmsub_pd(ymm10, ymm8, ymm2); //B11[0-3][1] * alpha-= ymm2 ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= ymm3 ymm14 = _mm256_fmsub_pd(ymm14, ymm8, ymm6); //B11[0-3][3] * alpha -= ymm6 ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= ymm7 ///implement TRSM/// ///read 4x4 block of A11/// a11 += 3 * cs_a; //4th col ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] //(Row 3): FMA operations ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10); //(Row 3): FMA operations ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14); _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][1]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][1]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) } else if(1 == n_remainder) { ///GEMM implementation starts/// for(k = 0; k < k_iter; k++) //loop for number of GEMM operations { ptr_a01_dup = a01; //broadcast 1st row of A01 ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] a01 += cs_a; //move to next row //load 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) //broadcast 2nd row of A01 ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] a01 += cs_a; //move to next row of A ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) //broadcast 3rd row of A01 ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] a01 += cs_a; //move to next row of A01 //load next 8x2 block of B10 ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) //broadcast 4th row of A01 ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] a01 += cs_a; //move to next row of A01 ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM } ///GEMM code ends/// ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal); ymm11 = _mm256_loadu_pd((double const *)(b11+cs_b_offset[1])); //B11[0-3][0] ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] +D_NR)); //B11[4-7][0] ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= ymm3 ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= ymm7 _mm256_storeu_pd((double *)(b11+ cs_b_offset[1]), ymm11); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) } } } if(i<0) i += D_NR; if((m & 4)) ///implementation for remainder rows(when m_remainder is greater than 4) { for(j = (n-D_NR); (j+1) > 0; j -=D_NR) //loop along n direction { a01 = L + (j+D_NR)*cs_a + (j); //pointer to block of A to be used for GEMM a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4) ///GEMM for previous blocks /// ///load 4x4 block of b11 ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] ymm4 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); ymm6 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); ///GEMM implementation starts/// for(k = 0; k < k_iter; k++) //loop for number of GEMM operations { ptr_a01_dup = a01; //load 4x4 bblock of b10 ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] //broadcast 1st row of A01 ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] a01 += cs_a; //move to next row of A ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) //broadcast 2nd row of A01 ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] a01 += cs_a; //move to next row of A ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0]) ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) //braodcast 3rd row of A01 ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] a01 += cs_a; //move to next row of A ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0]) ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) //broadcast 4th row of A01 ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] a01 += cs_a; //move to next row of A ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0]) ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM a01 = ptr_a01_dup + D_NR*cs_a; //pointer math to find next block of A for GEMM } ///GEMM code end/// ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[x][0] -=ymm4 ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5 ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6 ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7 ///implement TRSM/// ///read 4x4 block of A11/// a11 += cs_a; //2nd col ymm5 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] a11 += cs_a; //3rd col ymm6 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] ymm9 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] a11 += cs_a; //4th col ymm7 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] ymm10 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] //(Row 3): FMA operations ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1); ymm0 = _mm256_fnmadd_pd(ymm3, ymm7, ymm0); //(ROW 2): FMA operations ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1); ymm0 = _mm256_fnmadd_pd(ymm2, ymm6, ymm0); //(Row 1):FMA operations ymm0 = _mm256_fnmadd_pd(ymm1, ymm5, ymm0); _mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0]) _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2])) _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3]) } if(n_remainder) //implementation for remainder columns(when n is not a multiple of D_NR) { a01 = L + (j+D_NR)*cs_a + (j); //pointer to block of A to be used for GEMM a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4) ///GEMM for previous blocks /// ymm4 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); ymm6 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); ///load 4x4 block of b11 if(3 == n_remainder) { ymm1 = _mm256_loadu_pd((double const *)b11+ cs_b); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] ///GEMM implementation starts/// for(k = 0; k < k_iter; k++) //loop for number of GEMM operations { ptr_a01_dup = a01; //load 4x4 bblock of b10 ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] //broadcast 1st row of A01 ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] a01 += cs_a; //move to next row of A ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) //broadcast 2nd row of A01 ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] a01 += cs_a; //move to next row of A ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) //braodcast 3rd row of A01 ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] a01 += cs_a; //move to next row of A ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) //broadcast 4th row of A01 ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] a01 += cs_a; //move to next row of A ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM } ///GEMM code end/// ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5 ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6 ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7 ///implement TRSM/// ///read 4x4 block of A11/// a11 += 2 * cs_a; //3rd col ymm9 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] a11 += cs_a; //4th col ymm10 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] //(Row 3): FMA operations ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1); //(ROW 2): FMA operations ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1); _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2])) _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[x][0]) } else if(2 == n_remainder) { ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ///GEMM implementation starts/// for(k = 0; k < k_iter; k++) //loop for number of GEMM operations { ptr_a01_dup = a01; //load 4x4 bblock of b10 ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] //broadcast 1st row of A01 ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] a01 += cs_a; //move to next row of A ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) //broadcast 2nd row of A01 ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] a01 += cs_a; //move to next row of A ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) //braodcast 3rd row of A01 ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] a01 += cs_a; //move to next row of A ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) //broadcast 4th row of A01 ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] a01 += cs_a; //move to next row of A ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM } ///GEMM code end/// ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6 ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7 ///implement TRSM/// ///read 4x4 block of A11/// a11 += 3 * cs_a; //4th col ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] //(Row 3): FMA operations ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); _mm256_storeu_pd((double *)(b11+ cs_b * 2), ymm2); //store(B11[x][0]) _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][1]) } else if(1 == n_remainder) { ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ///GEMM implementation starts/// for(k = 0; k < k_iter; k++) //loop for number of GEMM operations { ptr_a01_dup = a01; //load 4x4 bblock of b10 ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] //broadcast 1st row of A01 ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] a01 += cs_a; //move to next row of A ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) //broadcast 2nd row of A01 ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] a01 += cs_a; //move to next row of A ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) //braodcast 3rd row of A01 ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] a01 += cs_a; //move to next row of A ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) //broadcast 4th row of A01 ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] a01 += cs_a; //move to next row of A ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM } ///GEMM code end/// ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7 _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][0]) } } m_remainder -= 4; i -= 4; } if(m_remainder) ///implementation for remainder rows { dtrsm_small_XAutB_unitDiag(L, B, AlphaVal, m_remainder, n, cs_a, cs_b); } return BLIS_SUCCESS; } /* * AX = Alpha*B, Single precision, A: lower triangular * This kernel implementation supports matrices A and B such that m is equal to BLI_AlXB_M_SP and n is mutiple of 8 */ static err_t bli_strsm_small_AlXB ( side_t side, obj_t* AlphaObj, obj_t* a, obj_t* b, cntx_t* cntx, cntl_t* cntl ) { obj_t alpha, beta; // gemm parameters obj_t Ga, Gb, Gc; // for GEMM int m = bli_obj_length(b); // number of rows of matrix B int n = bli_obj_width(b); // number of columns of matrix B int lda = bli_obj_col_stride(a); // column stride of A int ldb = bli_obj_col_stride(b); // column stride of B int rsa = bli_obj_row_stride(a); // row stride of A int rsb = bli_obj_row_stride(b); // row stride of B int i = 0; int j; int blk_size = 8; int isUnitDiag = bli_obj_has_unit_diag(a); float alphaVal; float* restrict L = a->buffer; float* restrict B = b->buffer; if (m != BLI_AlXB_M_SP || (n&7) != 0) { return BLIS_NOT_YET_IMPLEMENTED; } if ( (m*(m + n)) > BLIS_SMALL_MATRIX_THRES_TRSM ) { return BLIS_NOT_YET_IMPLEMENTED; } alphaVal = *((float *)bli_obj_buffer_for_const(BLIS_FLOAT, AlphaObj)); /* Small _GEMM preparation code */ bli_obj_create( BLIS_FLOAT, 1, 1, 0, 0, &alpha ); bli_obj_create( BLIS_FLOAT, 1, 1, 0, 0, &beta ); /* B = B - A*B */ bli_setsc( -(1.0), 0.0, &alpha ); bli_setsc( (1.0), 0.0, &beta ); bli_obj_create_with_attached_buffer( BLIS_FLOAT, blk_size, blk_size, a->buffer, rsa, lda, &Ga); bli_obj_create_with_attached_buffer( BLIS_FLOAT, blk_size, n, b->buffer, rsb, ldb, &Gb); bli_obj_create_with_attached_buffer( BLIS_FLOAT, blk_size, n, b->buffer, rsb, ldb, &Gc); bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, &Ga ); bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, &Gb ); bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, &Gc ); //first block of trsm Gb.buffer = (void*)(B + i); //trsm of first 8xn block if (alphaVal != 1) { if (isUnitDiag == 0) { blis_strsm_microkernel_alpha((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb, alphaVal); fp_blis_strsm_microkernel = blis_strsm_microkernel; } else { blis_strsm_microkernel_alpha_unitDiag((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb, alphaVal); fp_blis_strsm_microkernel = blis_strsm_microkernel_unitDiag; } bli_setsc( alphaVal, 0.0, &beta ); } else { if (isUnitDiag == 0) { blis_strsm_microkernel((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb); fp_blis_strsm_microkernel = blis_strsm_microkernel; } else { blis_strsm_microkernel_unitDiag((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb); fp_blis_strsm_microkernel = blis_strsm_microkernel_unitDiag; } } //gemm update for (j = i + blk_size; j < m; j += blk_size) // for rows upto multiple of BLOCK_HEIGHT { Ga.buffer = (void*)(L + j + i*lda); Gc.buffer = (void*)(B + j); bli_gemm_small(&alpha, &Ga, &Gb, &beta, &Gc, cntx, cntl ); // Gc = beta*Gc + alpha*Ga *Gb } //trsm of remaining blocks for (i = blk_size; i < m; i += blk_size) { Gb.buffer = (void*)(B + i); fp_blis_strsm_microkernel((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb); for (j = i + blk_size; j < m; j += blk_size) // for rows upto multiple of BLOCK_HEIGHT { Ga.buffer = (void*)(L + j + i*lda); Gc.buffer = (void*)(B + j); bli_gemm_small(&alpha, &Ga, &Gb, &beta, &Gc, cntx, cntl ); // Gc = beta*Gc + alpha*Ga *Gb } } // End of for loop - i return BLIS_SUCCESS; } /* * XA' = Alpha*B, Single precision, A: lower triangular * This kernel implementation supports matrices A and B such that * m and n are multiples of 8 and n is less than or equal to BLI_XAltB_N_SP */ static err_t bli_strsm_small_XAltB( side_t side, obj_t* AlphaObj, obj_t* a, obj_t* b, cntx_t* cntx, cntl_t* cntl ) { int m = bli_obj_length(a); // number of rows of matrix B int n = bli_obj_length(b); // number of columns of matrix B int lda = bli_obj_col_stride(a); // column stride of A int ldb = bli_obj_col_stride(b); // column stride of B int rsa = bli_obj_row_stride(a); // row stride of A int rsb = bli_obj_row_stride(b); // row stride of B int i = 0; int isUnitDiag = bli_obj_has_unit_diag(a); float alphaVal; float *L = a->buffer; float *B = b->buffer; if ((m&7) != 0 || (n&7) != 0) { return BLIS_NOT_YET_IMPLEMENTED; } if ( n > BLI_XAltB_N_SP || (m*(m + n)) > BLIS_SMALL_MATRIX_THRES_TRSM ) { return BLIS_NOT_YET_IMPLEMENTED; } alphaVal = *((float *)bli_obj_buffer_for_const(BLIS_FLOAT, AlphaObj)); if (alphaVal != 1) { if (isUnitDiag == 0) { trsm_XAtB_block_allSmallSizedMatrices_alpha((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb, alphaVal); } else { trsm_XAtB_block_allSmallSizedMatrices_alpha_unitDiag((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb, alphaVal); } } else { if (isUnitDiag == 0) { trsm_XAtB_block_allSmallSizedMatrices((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb); } else { trsm_XAtB_block_allSmallSizedMatrices_unitDiag((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb); } } return BLIS_SUCCESS; } /* * A'X = Alpha*B, Single precision, A: upper triangular * This kernel implementation supports matrices A and B such that * m and n are multiples of 8, m is less than or equal to BLI_AutXB_M_SP and n is less than or equal to BLI_AutXB_N_SP */ static err_t bli_strsm_small_AutXB( side_t side, obj_t* AlphaObj, obj_t* a, obj_t* b, cntx_t* cntx, cntl_t* cntl ) { int m = bli_obj_width(a); // number of rows of matrix A (since At, so width is taken) int n = bli_obj_width(b); // number of columns of matrix B int lda = bli_obj_col_stride(a); // column stride of A int ldb = bli_obj_col_stride(b); // column stride of B int rsa = bli_obj_row_stride(a); // row stride of A int rsb = bli_obj_row_stride(b); // row stride of B int i = 0; int isUnitDiag = bli_obj_has_unit_diag(a); float alphaVal; float *L = a->buffer; float *B = b->buffer; if ((m&7) != 0 || (n&7) != 0) { return BLIS_NOT_YET_IMPLEMENTED; } if ( m > BLI_AutXB_M_SP || n > BLI_AutXB_N_SP || (m*(m + n)) > BLIS_SMALL_MATRIX_THRES_TRSM ) { return BLIS_NOT_YET_IMPLEMENTED; } alphaVal = *((float *)bli_obj_buffer_for_const(BLIS_FLOAT, AlphaObj)); if (alphaVal != 1) { if (isUnitDiag == 0) { trsm_AutXB_block_allSmallSizedMatrices_alpha((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb, alphaVal); } else { trsm_AutXB_block_allSmallSizedMatrices_alpha_unitDiag((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb, alphaVal); } } else { if (isUnitDiag == 0) { trsm_AutXB_block_allSmallSizedMatrices((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb); } else { trsm_AutXB_block_allSmallSizedMatrices_unitDiag((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb); } } return BLIS_SUCCESS; } ///////////////////////////// AX=B /////////////////////////////// static void blis_strsm_microkernel_alpha(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alphaVal) { float ones = 1.0; int j; int cs_b_offset[6]; //int row2, row4, row6; float *ptr_b_dup; //70 number of ymm(256 bits) registers used __m256 mat_b_col[8]; __m256 mat_b_rearr[8]; __m256 mat_a_cols[8]; __m256 mat_a_cols_rearr[36]; __m256 mat_a_diag_inv[8]; __m256 reciprocal_diags; __m256 alphaReg; cs_b_offset[0] = (cs_b << 1); cs_b_offset[1] = cs_b + cs_b_offset[0]; cs_b_offset[2] = (cs_b << 2); cs_b_offset[3] = cs_b + cs_b_offset[2]; cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; cs_b_offset[5] = cs_b + cs_b_offset[4]; //reciprocal_diags = _mm256_loadu_ps((float const *)ones); reciprocal_diags = _mm256_broadcast_ss((float const *)&ones); alphaReg = _mm256_broadcast_ss((float const *)&alphaVal); // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // //read first set of 16x8 block of B into registers, where 16 is the blk_height and 8 is the blk_width for B mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b); //_mm_prefetch((char*)(ptr_l + 0), _MM_HINT_T0); //row2 = (cs_l << 1); //row4 = (cs_l << 2); mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b))); //_mm_prefetch((char*)(ptr_l + cs_l), _MM_HINT_T0); mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0])); //_mm_prefetch((char*)(ptr_l + row2), _MM_HINT_T0); mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1])); //_mm_prefetch((char*)(ptr_l + row2 + cs_l), _MM_HINT_T0); //row6 = row2 + row4; mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2])); //_mm_prefetch((char*)(ptr_l + row4), _MM_HINT_T0); mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3])); //_mm_prefetch((char*)(ptr_l + row4 + cs_l), _MM_HINT_T0); mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4])); //_mm_prefetch((char*)(ptr_l + row6), _MM_HINT_T0); mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5])); //_mm_prefetch((char*)(ptr_l + row6 + cs_l), _MM_HINT_T0); //reciprocal_diags = _mm256_loadu_ps((float const *)ones); //read first set of 16x16 block of L, where 16 is the blk_height and 16 is the blk_width for L /*mat_a_cols[0] = _mm256_loadu_ps((float const *)ptr_l); ptr_l += cs_l; mat_a_cols[1] = _mm256_loadu_ps((float const *)ptr_l); ptr_l += cs_l; mat_a_cols[2] = _mm256_loadu_ps((float const *)ptr_l); ptr_l += cs_l; mat_a_cols[3] = _mm256_loadu_ps((float const *)ptr_l); ptr_l += cs_l; mat_a_cols[4] = _mm256_loadu_ps((float const *)ptr_l); ptr_l += cs_l; mat_a_cols[5] = _mm256_loadu_ps((float const *)ptr_l); ptr_l += cs_l; mat_a_cols[6] = _mm256_loadu_ps((float const *)ptr_l); ptr_l += cs_l; mat_a_cols[7] = _mm256_loadu_ps((float const *)ptr_l);*/ //Shuffle to rearrange/transpose 16x16 block of L into contiguous row-wise registers //tmpRegs[0] = _mm256_castps256_ps128(mat_a_cols[0]); //zero latency, no instruction added actually. //mat_a_cols_rearr[0] = _mm256_broadcastss_ps(tmpRegs[0]); //1st col mat_a_cols_rearr[0] = _mm256_broadcast_ss((float const *)(ptr_l+0)); mat_a_cols_rearr[1] = _mm256_broadcast_ss((float const *)(ptr_l+1)); mat_a_cols_rearr[3] = _mm256_broadcast_ss((float const *)(ptr_l+2)); mat_a_cols_rearr[6] = _mm256_broadcast_ss((float const *)(ptr_l+3)); mat_a_cols_rearr[10] = _mm256_broadcast_ss((float const *)(ptr_l+4)); mat_a_cols_rearr[15] = _mm256_broadcast_ss((float const *)(ptr_l+5)); mat_a_cols_rearr[21] = _mm256_broadcast_ss((float const *)(ptr_l+6)); mat_a_cols_rearr[28] = _mm256_broadcast_ss((float const *)(ptr_l+7)); //2nd col ptr_l += cs_l; mat_a_cols_rearr[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); mat_a_cols_rearr[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); mat_a_cols_rearr[7] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); mat_a_cols_rearr[11] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); mat_a_cols_rearr[16] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); mat_a_cols_rearr[22] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); mat_a_cols_rearr[29] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); //3rd col ptr_l += cs_l; mat_a_cols_rearr[5] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); mat_a_cols_rearr[8] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); mat_a_cols_rearr[12] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); mat_a_cols_rearr[17] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); mat_a_cols_rearr[23] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); mat_a_cols_rearr[30] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); //4rth col ptr_l += cs_l; mat_a_cols_rearr[9] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); mat_a_cols_rearr[13] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); mat_a_cols_rearr[18] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); mat_a_cols_rearr[24] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); mat_a_cols_rearr[31] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); //5th col ptr_l += cs_l; mat_a_cols_rearr[14] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); mat_a_cols_rearr[19] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); mat_a_cols_rearr[25] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); mat_a_cols_rearr[32] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); //6th col ptr_l += cs_l; mat_a_cols_rearr[20] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); mat_a_cols_rearr[26] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); mat_a_cols_rearr[33] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); //7th col ptr_l += cs_l; mat_a_cols_rearr[27] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); mat_a_cols_rearr[34] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); //7th col ptr_l += cs_l; mat_a_cols_rearr[35] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); numCols_b -= 8; // blk_width = 8 //compute reciprocals of L(i,i) and broadcast in registers mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[2]); mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_cols_rearr[5], mat_a_cols_rearr[9]); mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_cols_rearr[14], mat_a_cols_rearr[20]); mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_cols_rearr[27], mat_a_cols_rearr[35]); //mat_a_diag_inv[1] = _mm256_permute_ps(mat_a_diag_inv[1], 0x55); //mat_a_diag_inv[3] = _mm256_permute_ps(mat_a_diag_inv[3], 0x55); mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC); mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC); mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0x20); //reciprocal of diagnol elements reciprocal_diags = _mm256_div_ps(reciprocal_diags, mat_a_diag_inv[0]); //Start loop for cols of B to be processed in size of blk_width for (j = 0; j < numCols_b; j += 8) { ptr_b_dup = ptr_b; /*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/ ////unpacklow//// mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); //Rearrange low elements #if REARRANGE_SHFL == 1 mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); #else mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); #endif //Merge rearranged low elements into complete rows mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg); mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg); mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg); mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg); ////unpackhigh//// mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); //Rearrange high elements #if REARRANGE_SHFL == 1 mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); #else mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); #endif //extract diag a00 from a mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00); mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); //Merge rearranged high elements into complete rows mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg); mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg); mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg); mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg); //extract diag a11 from a mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55); mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); //extract diag a22 from a mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA); mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); //extract diag a33 from a mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF); mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); //extract diag a44 from a mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00); mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); //extract diag a55 from a mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55); mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); //extract diag a66 from a mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA); mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); //extract diag a77 from a mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF); mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); //(Row7): FMA operations of b7 with elements of index (7, 0) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); //--> Transpose and store results of columns of B block <--// ////unpacklow//// mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); //Rearrange low elements #if REARRANGE_SHFL == 1 mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44); mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE); mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44); mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE); #else mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E); mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E); mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC); mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33); mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC); mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33); #endif //Merge rearranged low elements into complete rows mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20); mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31); mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20); mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31); ////unpackhigh//// mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); //Rearrange high elements #if REARRANGE_SHFL == 1 mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); #else mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); #endif //Merge rearranged high elements into complete rows mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); //Read next set of B columns ptr_b += (cs_b + cs_b_offset[5]); mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b); mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b))); mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0])); mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1])); mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2])); mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3])); mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4])); mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5])); //Store the computed B columns _mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]); _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]); //end loop of cols } //Last block trsm processing ptr_b_dup = ptr_b; /*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/ ////unpacklow//// mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); //Rearrange low elements #if REARRANGE_SHFL == 1 mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); #else mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); #endif //Merge rearranged low elements into complete rows mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg); mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg); mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg); mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg); ////unpackhigh//// mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); //Rearrange high elements #if REARRANGE_SHFL == 1 mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); #else mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); #endif //extract diag a00 from a mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00); mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); //Merge rearranged high elements into complete rows mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg); mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg); mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg); mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg); //extract diag a11 from a mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55); mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); //extract diag a22 from a mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA); mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); //extract diag a33 from a mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF); mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); //extract diag a44 from a mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00); mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); //extract diag a55 from a mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55); mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); //extract diag a66 from a mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA); mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); //extract diag a77 from a mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF); mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); //(Row7): FMA operations of b7 with elements of index (7, 0) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); //--> Transpose and store results of columns of B block <--// ////unpacklow//// mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); //Rearrange low elements #if REARRANGE_SHFL == 1 mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44); mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE); mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44); mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE); #else mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E); mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E); mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC); mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33); mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC); mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33); #endif //Merge rearranged low elements into complete rows mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20); mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31); mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20); mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31); ////unpackhigh//// mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); //Rearrange high elements #if REARRANGE_SHFL == 1 mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); #else mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); #endif //Merge rearranged high elements into complete rows mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); //Store the computed B columns _mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]); _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]); //end loop of cols } static void blis_strsm_microkernel_alpha_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alphaVal) { //float ones = 1.0; int j; int cs_b_offset[6]; //int row2, row4, row6; float *ptr_b_dup; //70 number of ymm(256 bits) registers used __m256 mat_b_col[8]; __m256 mat_b_rearr[8]; __m256 mat_a_cols[8]; __m256 mat_a_cols_rearr[36]; //__m256 mat_a_diag_inv[8]; //__m256 reciprocal_diags; __m256 alphaReg; cs_b_offset[0] = (cs_b << 1); cs_b_offset[1] = cs_b + cs_b_offset[0]; cs_b_offset[2] = (cs_b << 2); cs_b_offset[3] = cs_b + cs_b_offset[2]; cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; cs_b_offset[5] = cs_b + cs_b_offset[4]; //reciprocal_diags = _mm256_loadu_ps((float const *)ones); //reciprocal_diags = _mm256_broadcast_ss((float const *)&ones); alphaReg = _mm256_broadcast_ss((float const *)&alphaVal); // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // //read first set of 16x8 block of B into registers, where 16 is the blk_height and 8 is the blk_width for B mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b); //_mm_prefetch((char*)(ptr_l + 0), _MM_HINT_T0); //row2 = (cs_l << 1); //row4 = (cs_l << 2); mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b))); //_mm_prefetch((char*)(ptr_l + cs_l), _MM_HINT_T0); mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0])); //_mm_prefetch((char*)(ptr_l + row2), _MM_HINT_T0); mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1])); //_mm_prefetch((char*)(ptr_l + row2 + cs_l), _MM_HINT_T0); //row6 = row2 + row4; mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2])); //_mm_prefetch((char*)(ptr_l + row4), _MM_HINT_T0); mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3])); //_mm_prefetch((char*)(ptr_l + row4 + cs_l), _MM_HINT_T0); mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4])); //_mm_prefetch((char*)(ptr_l + row6), _MM_HINT_T0); mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5])); //_mm_prefetch((char*)(ptr_l + row6 + cs_l), _MM_HINT_T0); //reciprocal_diags = _mm256_loadu_ps((float const *)ones); //read first set of 16x16 block of L, where 16 is the blk_height and 16 is the blk_width for L /*mat_a_cols[0] = _mm256_loadu_ps((float const *)ptr_l); ptr_l += cs_l; mat_a_cols[1] = _mm256_loadu_ps((float const *)ptr_l); ptr_l += cs_l; mat_a_cols[2] = _mm256_loadu_ps((float const *)ptr_l); ptr_l += cs_l; mat_a_cols[3] = _mm256_loadu_ps((float const *)ptr_l); ptr_l += cs_l; mat_a_cols[4] = _mm256_loadu_ps((float const *)ptr_l); ptr_l += cs_l; mat_a_cols[5] = _mm256_loadu_ps((float const *)ptr_l); ptr_l += cs_l; mat_a_cols[6] = _mm256_loadu_ps((float const *)ptr_l); ptr_l += cs_l; mat_a_cols[7] = _mm256_loadu_ps((float const *)ptr_l);*/ //Shuffle to rearrange/transpose 16x16 block of L into contiguous row-wise registers //tmpRegs[0] = _mm256_castps256_ps128(mat_a_cols[0]); //zero latency, no instruction added actually. //mat_a_cols_rearr[0] = _mm256_broadcastss_ps(tmpRegs[0]); //1st col mat_a_cols_rearr[0] = _mm256_broadcast_ss((float const *)(ptr_l+0)); mat_a_cols_rearr[1] = _mm256_broadcast_ss((float const *)(ptr_l+1)); mat_a_cols_rearr[3] = _mm256_broadcast_ss((float const *)(ptr_l+2)); mat_a_cols_rearr[6] = _mm256_broadcast_ss((float const *)(ptr_l+3)); mat_a_cols_rearr[10] = _mm256_broadcast_ss((float const *)(ptr_l+4)); mat_a_cols_rearr[15] = _mm256_broadcast_ss((float const *)(ptr_l+5)); mat_a_cols_rearr[21] = _mm256_broadcast_ss((float const *)(ptr_l+6)); mat_a_cols_rearr[28] = _mm256_broadcast_ss((float const *)(ptr_l+7)); //2nd col ptr_l += cs_l; mat_a_cols_rearr[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); mat_a_cols_rearr[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); mat_a_cols_rearr[7] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); mat_a_cols_rearr[11] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); mat_a_cols_rearr[16] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); mat_a_cols_rearr[22] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); mat_a_cols_rearr[29] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); //3rd col ptr_l += cs_l; mat_a_cols_rearr[5] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); mat_a_cols_rearr[8] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); mat_a_cols_rearr[12] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); mat_a_cols_rearr[17] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); mat_a_cols_rearr[23] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); mat_a_cols_rearr[30] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); //4rth col ptr_l += cs_l; mat_a_cols_rearr[9] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); mat_a_cols_rearr[13] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); mat_a_cols_rearr[18] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); mat_a_cols_rearr[24] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); mat_a_cols_rearr[31] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); //5th col ptr_l += cs_l; mat_a_cols_rearr[14] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); mat_a_cols_rearr[19] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); mat_a_cols_rearr[25] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); mat_a_cols_rearr[32] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); //6th col ptr_l += cs_l; mat_a_cols_rearr[20] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); mat_a_cols_rearr[26] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); mat_a_cols_rearr[33] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); //7th col ptr_l += cs_l; mat_a_cols_rearr[27] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); mat_a_cols_rearr[34] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); //8th col //ptr_l += cs_l; //mat_a_cols_rearr[35] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); numCols_b -= 8; // blk_width = 8 //compute reciprocals of L(i,i) and broadcast in registers //mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[2]); //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_cols_rearr[5], mat_a_cols_rearr[9]); //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_cols_rearr[14], mat_a_cols_rearr[20]); //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_cols_rearr[27], mat_a_cols_rearr[35]); //mat_a_diag_inv[1] = _mm256_permute_ps(mat_a_diag_inv[1], 0x55); //mat_a_diag_inv[3] = _mm256_permute_ps(mat_a_diag_inv[3], 0x55); //mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC); //mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC); //mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0x20); //reciprocal of diagnol elements //reciprocal_diags = _mm256_div_ps(reciprocal_diags, mat_a_diag_inv[0]); //Start loop for cols of B to be processed in size of blk_width for (j = 0; j < numCols_b; j += 8) { ptr_b_dup = ptr_b; /*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/ ////unpacklow//// mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); //Rearrange low elements #if REARRANGE_SHFL == 1 mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); #else mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); #endif //Merge rearranged low elements into complete rows mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg); mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg); mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg); mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg); ////unpackhigh//// mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); //Rearrange high elements #if REARRANGE_SHFL == 1 mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); #else mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); #endif //extract diag a00 from a //mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00); //mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B //mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); //Merge rearranged high elements into complete rows mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg); mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg); mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg); mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg); //extract diag a11 from a //mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55); //mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B //mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); //extract diag a22 from a //mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA); //mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B //mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); //extract diag a33 from a //mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF); //mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B //mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); //extract diag a44 from a //mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00); //mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B //mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); //extract diag a55 from a //mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55); //mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B //mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); //extract diag a66 from a //mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA); //mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B //mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); //extract diag a77 from a //mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF); //mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); //(Row7): FMA operations of b7 with elements of index (7, 0) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B //mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); //--> Transpose and store results of columns of B block <--// ////unpacklow//// mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); //Rearrange low elements #if REARRANGE_SHFL == 1 mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44); mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE); mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44); mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE); #else mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E); mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E); mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC); mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33); mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC); mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33); #endif //Merge rearranged low elements into complete rows mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20); mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31); mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20); mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31); ////unpackhigh//// mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); //Rearrange high elements #if REARRANGE_SHFL == 1 mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); #else mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); #endif //Merge rearranged high elements into complete rows mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); //Read next set of B columns ptr_b += (cs_b + cs_b_offset[5]); mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b); mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b))); mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0])); mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1])); mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2])); mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3])); mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4])); mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5])); //Store the computed B columns _mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]); _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]); //end loop of cols } //Last block trsm processing ptr_b_dup = ptr_b; /*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/ ////unpacklow//// mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); //Rearrange low elements #if REARRANGE_SHFL == 1 mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); #else mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); #endif //Merge rearranged low elements into complete rows mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg); mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg); mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg); mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg); ////unpackhigh//// mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); //Rearrange high elements #if REARRANGE_SHFL == 1 mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); #else mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); #endif //extract diag a00 from a //mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00); //mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B //mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); //Merge rearranged high elements into complete rows mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg); mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg); mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg); mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg); //extract diag a11 from a //mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55); //mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B //mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); //extract diag a22 from a //mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA); //mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B //mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); //extract diag a33 from a //mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF); //mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B //mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); //extract diag a44 from a //mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00); //mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B //mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); //extract diag a55 from a //mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55); //mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B //mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); //extract diag a66 from a //mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA); //mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B //mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); //extract diag a77 from a //mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF); //mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); //(Row7): FMA operations of b7 with elements of index (7, 0) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B //mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); //--> Transpose and store results of columns of B block <--// ////unpacklow//// mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); //Rearrange low elements #if REARRANGE_SHFL == 1 mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44); mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE); mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44); mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE); #else mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E); mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E); mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC); mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33); mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC); mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33); #endif //Merge rearranged low elements into complete rows mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20); mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31); mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20); mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31); ////unpackhigh//// mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); //Rearrange high elements #if REARRANGE_SHFL == 1 mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); #else mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); #endif //Merge rearranged high elements into complete rows mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); //Store the computed B columns _mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]); _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]); //end loop of cols } static void blis_strsm_microkernel_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b) { //float ones = 1.0; int j; int cs_b_offset[6]; //int row2, row4, row6; float *ptr_b_dup; //70 number of ymm(256 bits) registers used __m256 mat_b_col[8]; __m256 mat_b_rearr[8]; __m256 mat_a_cols[8]; __m256 mat_a_cols_rearr[36]; //__m256 mat_a_diag_inv[8]; //__m256 reciprocal_diags; cs_b_offset[0] = (cs_b << 1); cs_b_offset[1] = cs_b + cs_b_offset[0]; cs_b_offset[2] = (cs_b << 2); cs_b_offset[3] = cs_b + cs_b_offset[2]; cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; cs_b_offset[5] = cs_b + cs_b_offset[4]; //reciprocal_diags = _mm256_loadu_ps((float const *)ones); //reciprocal_diags = _mm256_broadcast_ss((float const *)&ones); // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // //read first set of 16x8 block of B into registers, where 16 is the blk_height and 8 is the blk_width for B mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b); //_mm_prefetch((char*)(ptr_l + 0), _MM_HINT_T0); //row2 = (cs_l << 1); //row4 = (cs_l << 2); mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b))); //_mm_prefetch((char*)(ptr_l + cs_l), _MM_HINT_T0); mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0])); //_mm_prefetch((char*)(ptr_l + row2), _MM_HINT_T0); mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1])); //_mm_prefetch((char*)(ptr_l + row2 + cs_l), _MM_HINT_T0); //row6 = row2 + row4; mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2])); //_mm_prefetch((char*)(ptr_l + row4), _MM_HINT_T0); mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3])); //_mm_prefetch((char*)(ptr_l + row4 + cs_l), _MM_HINT_T0); mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4])); //_mm_prefetch((char*)(ptr_l + row6), _MM_HINT_T0); mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5])); //_mm_prefetch((char*)(ptr_l + row6 + cs_l), _MM_HINT_T0); //reciprocal_diags = _mm256_loadu_ps((float const *)ones); //read first set of 16x16 block of L, where 16 is the blk_height and 16 is the blk_width for L /*mat_a_cols[0] = _mm256_loadu_ps((float const *)ptr_l); ptr_l += cs_l; mat_a_cols[1] = _mm256_loadu_ps((float const *)ptr_l); ptr_l += cs_l; mat_a_cols[2] = _mm256_loadu_ps((float const *)ptr_l); ptr_l += cs_l; mat_a_cols[3] = _mm256_loadu_ps((float const *)ptr_l); ptr_l += cs_l; mat_a_cols[4] = _mm256_loadu_ps((float const *)ptr_l); ptr_l += cs_l; mat_a_cols[5] = _mm256_loadu_ps((float const *)ptr_l); ptr_l += cs_l; mat_a_cols[6] = _mm256_loadu_ps((float const *)ptr_l); ptr_l += cs_l; mat_a_cols[7] = _mm256_loadu_ps((float const *)ptr_l);*/ //Shuffle to rearrange/transpose 16x16 block of L into contiguous row-wise registers //tmpRegs[0] = _mm256_castps256_ps128(mat_a_cols[0]); //zero latency, no instruction added actually. //mat_a_cols_rearr[0] = _mm256_broadcastss_ps(tmpRegs[0]); //1st col mat_a_cols_rearr[0] = _mm256_broadcast_ss((float const *)(ptr_l+0)); mat_a_cols_rearr[1] = _mm256_broadcast_ss((float const *)(ptr_l+1)); mat_a_cols_rearr[3] = _mm256_broadcast_ss((float const *)(ptr_l+2)); mat_a_cols_rearr[6] = _mm256_broadcast_ss((float const *)(ptr_l+3)); mat_a_cols_rearr[10] = _mm256_broadcast_ss((float const *)(ptr_l+4)); mat_a_cols_rearr[15] = _mm256_broadcast_ss((float const *)(ptr_l+5)); mat_a_cols_rearr[21] = _mm256_broadcast_ss((float const *)(ptr_l+6)); mat_a_cols_rearr[28] = _mm256_broadcast_ss((float const *)(ptr_l+7)); //2nd col ptr_l += cs_l; mat_a_cols_rearr[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); mat_a_cols_rearr[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); mat_a_cols_rearr[7] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); mat_a_cols_rearr[11] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); mat_a_cols_rearr[16] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); mat_a_cols_rearr[22] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); mat_a_cols_rearr[29] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); //3rd col ptr_l += cs_l; mat_a_cols_rearr[5] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); mat_a_cols_rearr[8] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); mat_a_cols_rearr[12] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); mat_a_cols_rearr[17] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); mat_a_cols_rearr[23] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); mat_a_cols_rearr[30] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); //4rth col ptr_l += cs_l; mat_a_cols_rearr[9] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); mat_a_cols_rearr[13] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); mat_a_cols_rearr[18] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); mat_a_cols_rearr[24] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); mat_a_cols_rearr[31] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); //5th col ptr_l += cs_l; mat_a_cols_rearr[14] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); mat_a_cols_rearr[19] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); mat_a_cols_rearr[25] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); mat_a_cols_rearr[32] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); //6th col ptr_l += cs_l; mat_a_cols_rearr[20] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); mat_a_cols_rearr[26] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); mat_a_cols_rearr[33] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); //7th col ptr_l += cs_l; mat_a_cols_rearr[27] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); mat_a_cols_rearr[34] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); //8th col //ptr_l += cs_l; //mat_a_cols_rearr[35] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); numCols_b -= 8; // blk_width = 8 //compute reciprocals of L(i,i) and broadcast in registers //mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[2]); //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_cols_rearr[5], mat_a_cols_rearr[9]); //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_cols_rearr[14], mat_a_cols_rearr[20]); //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_cols_rearr[27], mat_a_cols_rearr[35]); //mat_a_diag_inv[1] = _mm256_permute_ps(mat_a_diag_inv[1], 0x55); //mat_a_diag_inv[3] = _mm256_permute_ps(mat_a_diag_inv[3], 0x55); //mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC); //mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC); //mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0x20); //reciprocal of diagnol elements //reciprocal_diags = _mm256_div_ps(reciprocal_diags, mat_a_diag_inv[0]); //Start loop for cols of B to be processed in size of blk_width for (j = 0; j < numCols_b; j += 8) { ptr_b_dup = ptr_b; /*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/ ////unpacklow//// mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); //Rearrange low elements #if REARRANGE_SHFL == 1 mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); #else mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); #endif //Merge rearranged low elements into complete rows mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); ////unpackhigh//// mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); //Rearrange high elements #if REARRANGE_SHFL == 1 mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); #else mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); #endif //extract diag a00 from a //mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00); //mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B //mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); //Merge rearranged high elements into complete rows mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); //extract diag a11 from a //mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55); //mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B //mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); //extract diag a22 from a //mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA); //mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B //mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); //extract diag a33 from a //mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF); //mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B //mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); //extract diag a44 from a //mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00); //mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B //mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); //extract diag a55 from a //mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55); //mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B //mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); //extract diag a66 from a //mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA); //mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B //mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); //extract diag a77 from a //mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF); //mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); //(Row7): FMA operations of b7 with elements of index (7, 0) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B //mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); //--> Transpose and store results of columns of B block <--// ////unpacklow//// mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); //Rearrange low elements #if REARRANGE_SHFL == 1 mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44); mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE); mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44); mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE); #else mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E); mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E); mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC); mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33); mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC); mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33); #endif //Merge rearranged low elements into complete rows mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20); mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31); mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20); mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31); ////unpackhigh//// mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); //Rearrange high elements #if REARRANGE_SHFL == 1 mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); #else mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); #endif //Merge rearranged high elements into complete rows mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); //Read next set of B columns ptr_b += (cs_b + cs_b_offset[5]); mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b); mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b))); mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0])); mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1])); mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2])); mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3])); mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4])); mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5])); //Store the computed B columns _mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]); _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]); //end loop of cols } //Last block trsm processing ptr_b_dup = ptr_b; /*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/ ////unpacklow//// mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); //Rearrange low elements #if REARRANGE_SHFL == 1 mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); #else mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); #endif //Merge rearranged low elements into complete rows mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); ////unpackhigh//// mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); //Rearrange high elements #if REARRANGE_SHFL == 1 mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); #else mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); #endif //extract diag a00 from a //mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00); //mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B //mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); //Merge rearranged high elements into complete rows mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); //extract diag a11 from a //mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55); //mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B //mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); //extract diag a22 from a //mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA); //mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B //mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); //extract diag a33 from a //mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF); //mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B //mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); //extract diag a44 from a //mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00); //mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B //mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); //extract diag a55 from a //mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55); //mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B //mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); //extract diag a66 from a //mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA); //mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B //mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); //extract diag a77 from a //mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF); //mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); //(Row7): FMA operations of b7 with elements of index (7, 0) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B //mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); //--> Transpose and store results of columns of B block <--// ////unpacklow//// mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); //Rearrange low elements #if REARRANGE_SHFL == 1 mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44); mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE); mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44); mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE); #else mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E); mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E); mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC); mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33); mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC); mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33); #endif //Merge rearranged low elements into complete rows mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20); mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31); mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20); mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31); ////unpackhigh//// mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); //Rearrange high elements #if REARRANGE_SHFL == 1 mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); #else mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); #endif //Merge rearranged high elements into complete rows mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); //Store the computed B columns _mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]); _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]); //end loop of cols } static void blis_strsm_microkernel(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b) { float ones = 1.0; int j; int cs_b_offset[6]; //int row2, row4, row6; float *ptr_b_dup; //70 number of ymm(256 bits) registers used __m256 mat_b_col[8]; __m256 mat_b_rearr[8]; __m256 mat_a_cols[8]; __m256 mat_a_cols_rearr[36]; __m256 mat_a_diag_inv[8]; __m256 reciprocal_diags; cs_b_offset[0] = (cs_b << 1); cs_b_offset[1] = cs_b + cs_b_offset[0]; cs_b_offset[2] = (cs_b << 2); cs_b_offset[3] = cs_b + cs_b_offset[2]; cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; cs_b_offset[5] = cs_b + cs_b_offset[4]; //reciprocal_diags = _mm256_loadu_ps((float const *)ones); reciprocal_diags = _mm256_broadcast_ss((float const *)&ones); // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // //read first set of 16x8 block of B into registers, where 16 is the blk_height and 8 is the blk_width for B mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b); //_mm_prefetch((char*)(ptr_l + 0), _MM_HINT_T0); //row2 = (cs_l << 1); //row4 = (cs_l << 2); mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b))); //_mm_prefetch((char*)(ptr_l + cs_l), _MM_HINT_T0); mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0])); //_mm_prefetch((char*)(ptr_l + row2), _MM_HINT_T0); mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1])); //_mm_prefetch((char*)(ptr_l + row2 + cs_l), _MM_HINT_T0); //row6 = row2 + row4; mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2])); //_mm_prefetch((char*)(ptr_l + row4), _MM_HINT_T0); mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3])); //_mm_prefetch((char*)(ptr_l + row4 + cs_l), _MM_HINT_T0); mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4])); //_mm_prefetch((char*)(ptr_l + row6), _MM_HINT_T0); mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5])); //_mm_prefetch((char*)(ptr_l + row6 + cs_l), _MM_HINT_T0); //reciprocal_diags = _mm256_loadu_ps((float const *)ones); //read first set of 16x16 block of L, where 16 is the blk_height and 16 is the blk_width for L /*mat_a_cols[0] = _mm256_loadu_ps((float const *)ptr_l); ptr_l += cs_l; mat_a_cols[1] = _mm256_loadu_ps((float const *)ptr_l); ptr_l += cs_l; mat_a_cols[2] = _mm256_loadu_ps((float const *)ptr_l); ptr_l += cs_l; mat_a_cols[3] = _mm256_loadu_ps((float const *)ptr_l); ptr_l += cs_l; mat_a_cols[4] = _mm256_loadu_ps((float const *)ptr_l); ptr_l += cs_l; mat_a_cols[5] = _mm256_loadu_ps((float const *)ptr_l); ptr_l += cs_l; mat_a_cols[6] = _mm256_loadu_ps((float const *)ptr_l); ptr_l += cs_l; mat_a_cols[7] = _mm256_loadu_ps((float const *)ptr_l);*/ //Shuffle to rearrange/transpose 16x16 block of L into contiguous row-wise registers //tmpRegs[0] = _mm256_castps256_ps128(mat_a_cols[0]); //zero latency, no instruction added actually. //mat_a_cols_rearr[0] = _mm256_broadcastss_ps(tmpRegs[0]); //1st col mat_a_cols_rearr[0] = _mm256_broadcast_ss((float const *)(ptr_l+0)); mat_a_cols_rearr[1] = _mm256_broadcast_ss((float const *)(ptr_l+1)); mat_a_cols_rearr[3] = _mm256_broadcast_ss((float const *)(ptr_l+2)); mat_a_cols_rearr[6] = _mm256_broadcast_ss((float const *)(ptr_l+3)); mat_a_cols_rearr[10] = _mm256_broadcast_ss((float const *)(ptr_l+4)); mat_a_cols_rearr[15] = _mm256_broadcast_ss((float const *)(ptr_l+5)); mat_a_cols_rearr[21] = _mm256_broadcast_ss((float const *)(ptr_l+6)); mat_a_cols_rearr[28] = _mm256_broadcast_ss((float const *)(ptr_l+7)); //2nd col ptr_l += cs_l; mat_a_cols_rearr[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); mat_a_cols_rearr[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); mat_a_cols_rearr[7] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); mat_a_cols_rearr[11] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); mat_a_cols_rearr[16] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); mat_a_cols_rearr[22] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); mat_a_cols_rearr[29] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); //3rd col ptr_l += cs_l; mat_a_cols_rearr[5] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); mat_a_cols_rearr[8] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); mat_a_cols_rearr[12] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); mat_a_cols_rearr[17] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); mat_a_cols_rearr[23] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); mat_a_cols_rearr[30] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); //4rth col ptr_l += cs_l; mat_a_cols_rearr[9] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); mat_a_cols_rearr[13] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); mat_a_cols_rearr[18] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); mat_a_cols_rearr[24] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); mat_a_cols_rearr[31] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); //5th col ptr_l += cs_l; mat_a_cols_rearr[14] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); mat_a_cols_rearr[19] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); mat_a_cols_rearr[25] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); mat_a_cols_rearr[32] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); //6th col ptr_l += cs_l; mat_a_cols_rearr[20] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); mat_a_cols_rearr[26] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); mat_a_cols_rearr[33] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); //7th col ptr_l += cs_l; mat_a_cols_rearr[27] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); mat_a_cols_rearr[34] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); //7th col ptr_l += cs_l; mat_a_cols_rearr[35] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); numCols_b -= 8; // blk_width = 8 //compute reciprocals of L(i,i) and broadcast in registers mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[2]); mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_cols_rearr[5], mat_a_cols_rearr[9]); mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_cols_rearr[14], mat_a_cols_rearr[20]); mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_cols_rearr[27], mat_a_cols_rearr[35]); //mat_a_diag_inv[1] = _mm256_permute_ps(mat_a_diag_inv[1], 0x55); //mat_a_diag_inv[3] = _mm256_permute_ps(mat_a_diag_inv[3], 0x55); mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC); mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC); mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0x20); //reciprocal of diagnol elements reciprocal_diags = _mm256_div_ps(reciprocal_diags, mat_a_diag_inv[0]); //Start loop for cols of B to be processed in size of blk_width for (j = 0; j < numCols_b; j += 8) { ptr_b_dup = ptr_b; /*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/ ////unpacklow//// mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); //Rearrange low elements #if REARRANGE_SHFL == 1 mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); #else mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); #endif //Merge rearranged low elements into complete rows mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); ////unpackhigh//// mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); //Rearrange high elements #if REARRANGE_SHFL == 1 mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); #else mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); #endif //extract diag a00 from a mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00); mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); //Merge rearranged high elements into complete rows mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); //extract diag a11 from a mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55); mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); //extract diag a22 from a mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA); mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); //extract diag a33 from a mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF); mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); //extract diag a44 from a mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00); mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); //extract diag a55 from a mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55); mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); //extract diag a66 from a mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA); mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); //extract diag a77 from a mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF); mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); //(Row7): FMA operations of b7 with elements of index (7, 0) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); //--> Transpose and store results of columns of B block <--// ////unpacklow//// mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); //Rearrange low elements #if REARRANGE_SHFL == 1 mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44); mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE); mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44); mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE); #else mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E); mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E); mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC); mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33); mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC); mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33); #endif //Merge rearranged low elements into complete rows mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20); mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31); mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20); mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31); ////unpackhigh//// mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); //Rearrange high elements #if REARRANGE_SHFL == 1 mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); #else mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); #endif //Merge rearranged high elements into complete rows mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); //Read next set of B columns ptr_b += (cs_b + cs_b_offset[5]); mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b); mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b))); mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0])); mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1])); mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2])); mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3])); mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4])); mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5])); //Store the computed B columns _mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]); _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]); //end loop of cols } //Last block trsm processing ptr_b_dup = ptr_b; /*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/ ////unpacklow//// mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); //Rearrange low elements #if REARRANGE_SHFL == 1 mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); #else mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); #endif //Merge rearranged low elements into complete rows mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); ////unpackhigh//// mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); //Rearrange high elements #if REARRANGE_SHFL == 1 mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); #else mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); #endif //extract diag a00 from a mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00); mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); //Merge rearranged high elements into complete rows mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); //extract diag a11 from a mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55); mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); //extract diag a22 from a mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA); mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); //extract diag a33 from a mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF); mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); //extract diag a44 from a mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00); mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); //extract diag a55 from a mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55); mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); //extract diag a66 from a mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA); mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); //extract diag a77 from a mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF); mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); //(Row7): FMA operations of b7 with elements of index (7, 0) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); //--> Transpose and store results of columns of B block <--// ////unpacklow//// mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); //Rearrange low elements #if REARRANGE_SHFL == 1 mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44); mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE); mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44); mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE); #else mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E); mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E); mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC); mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33); mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC); mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33); #endif //Merge rearranged low elements into complete rows mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20); mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31); mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20); mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31); ////unpackhigh//// mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); //Rearrange high elements #if REARRANGE_SHFL == 1 mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); #else mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); #endif //Merge rearranged high elements into complete rows mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); //Store the computed B columns _mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]); _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]); //end loop of cols } #if OPT_CACHE_BLOCKING_L1 //new intrinsic kernels static void trsm_XAtB_block_allSmallSizedMatrices(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b) { float ones = 1.0; int i, i1, i2, i3, i4, j, k, l, r; int cs_b_offset[7]; int cs_l_offset[7]; float *ptr_b_dup, *ptr_l_dup; //57 number of ymm(256 bits) registers used __m256 mat_b_col[8]; __m256 mat_b_rearr[8]; __m256 mat_a_blk_elems[8]; __m256 mat_a_diag_inv[8]; __m256 reciprocal_diags[2]; reciprocal_diags[0] = _mm256_broadcast_ss((float const *)(&ones)); // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // //L matrix offsets cs_l_offset[0] = (cs_l << 1); cs_l_offset[1] = cs_l + cs_l_offset[0]; cs_l_offset[2] = (cs_l << 2); cs_l_offset[3] = cs_l + cs_l_offset[2]; cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; cs_l_offset[5] = cs_l + cs_l_offset[4]; cs_l_offset[6] = (cs_l_offset[5] + cs_l); //read diag elems of L 16x16 block mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l); mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + cs_l); mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[0]); mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[1]); mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[2]); mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[3]); mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[4]); mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[5]); cs_b_offset[0] = (cs_b << 1); cs_b_offset[1] = cs_b + cs_b_offset[0]; cs_b_offset[2] = (cs_b << 2); cs_b_offset[3] = cs_b + cs_b_offset[2]; cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; cs_b_offset[5] = cs_b + cs_b_offset[4]; cs_b_offset[6] = (cs_b_offset[5] + cs_b); reciprocal_diags[1] = reciprocal_diags[0]; //pack first 8 diags together mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1 mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3 mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5 mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7 mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 //reciprocal of diagnal elements 0,1,2,3,4,5,6,7 reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); //extract diag a00 from a mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); //mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_diag_inv[0], mat_a_diag_inv[0]); //extract diag a11 from a mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); //extract diag a22 from a mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); //extract diag a33 from a mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); //extract diag a44 from a mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); //extract diag a55 from a mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); //extract diag a66 from a mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); //extract diag a77 from a mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); /***************** first set of 8 rows of B processing starts *****************/ ptr_b_dup = ptr_b; i = 0; for (j = 0; j < numCols_b; j += 8) { /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A //read 8x8 block of B into registers mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i); mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], mat_a_diag_inv[0]); mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b) mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b) mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b) mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b) mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], mat_a_diag_inv[1]); mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b) mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b) mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b) mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], mat_a_diag_inv[2]); mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b) mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b) mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], mat_a_diag_inv[3]); mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b) mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], mat_a_diag_inv[4]); mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], mat_a_diag_inv[5]); mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], mat_a_diag_inv[6]); mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); //(Row7): FMA operations of b7 with elements of index (7, 0) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], mat_a_diag_inv[7]); //////////////////////////////////////////////////////////////////////////////// //Store the computed B columns _mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]); _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]); //i += cs_b_offset[6]; //ptr_b_dup += cs_b_offset[6]; i += 8; ptr_b_dup += 8; } //c = 0; /***************** first set of 8 cols of B processing done *****************/ ptr_b_dup = ptr_b; i3 = 0; i1 = 0; //Start loop for cols of B to be processed in size of blk_width for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row { ptr_l += 8; //ptr_b += j; //ptr_b_dup += 8; ptr_b_dup += cs_b_offset[6]; i1 += cs_b_offset[6]; //Read next 8x8 block of A to get diag elements i3 += cs_l_offset[6]; mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l + i3); mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l); mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[0]); mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[1]); mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[2]); mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[3]); mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[4]); mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[5]); //pack 8 diags of A together reciprocal_diags[0] = reciprocal_diags[1]; mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1 mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3 mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5 mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7 mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 //reciprocal of diagnal elements of A :- 0,1,2,3,4,5,6,7 reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); //extract diag a00 from a mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); //mat_a_diag_inv2[0] = _mm256_unpacklo_ps(mat_a_diag_inv2[0], mat_a_diag_inv2[0]); //extract diag a11 from a mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); //extract diag a22 from a mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); //extract diag a33 from a mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); //extract diag a44 from a mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); //extract diag a55 from a mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); //extract diag a66 from a mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); //extract diag a77 from a mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); for (r = 0; r < numCols_b; r += GEMM_BLK_V1) { #if GEMM_ACCUM_A i = i1 + r; //Read 8 cols of B columns of Block-to-be-solved mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); #endif i = 0; i2 = 0; for (l = 0; l < j; l += 8) // move across m { //for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) { /////////////////// Partial Lower 8x8 block trsm of B ptr_l_dup = ptr_l; i4 = i2 + r; //Read current 8 cols of B columns from specified 8x8 current-block of B mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4); mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); //Broadcast A8,0 to A15,0 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); i4 = k >> 3; ptr_l_dup += cs_l; #if GEMM_ACCUM_A //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]); mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]); mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]); mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]); mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]); mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]); mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]); mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]); #endif //Broadcast A21 to A71 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); ptr_l_dup += cs_l; #if GEMM_ACCUM_A //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A8,2 to A15,2 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); ptr_l_dup += cs_l; #if GEMM_ACCUM_A //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A8,3 to A15,3 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); ptr_l_dup += cs_l; #if GEMM_ACCUM_A //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A8,4 to A15,4 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); ptr_l_dup += cs_l; #if GEMM_ACCUM_A //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A8,5 to A15,5 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); ptr_l_dup += cs_l; #if GEMM_ACCUM_A //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A8,6 to A15,6 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); ptr_l_dup += cs_l; #if GEMM_ACCUM_A //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A8,7 to A15,7 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); ptr_l_dup += cs_l; #if GEMM_ACCUM_A //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) #endif //end loop of cols } i2 += cs_b_offset[6]; i += cs_l_offset[6]; } //trsm solve k = 0; //for (i2 = 0; i2 < numCols_b; i2 += 8) { i2 = i1 + r; /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A #if !GEMM_ACCUM_A //Read 8 cols of B columns of Block-to-be-solved mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i2); mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i2)); mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i2)); mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i2)); mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i2)); mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i2)); mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i2)); mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i2)); #endif //Broadcast A10 to A70 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); i += cs_l; #if GEMM_ACCUM_A //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); #else mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]); mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); #endif #if GEMM_ACCUM_A mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]); mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]); mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]); mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]); mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]); mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]); mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]); //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A21 to A71 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); i += cs_l; //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) //Broadcast A32 to A72 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); i += cs_l; //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) //Broadcast A43 to A73 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); i += cs_l; //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) //Broadcast A54 to A74 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); i += cs_l; //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) //Broadcast A65 to A75 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); i += cs_l; //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) //Broadcast A76 to register mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); //(Row7): FMA operations of b7 with elements of index (7, 0) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); //////////////////////////////////////////////////////////////////////////////// //Store the computed B columns _mm256_storeu_ps((float *)ptr_b_dup + r, mat_b_rearr[0]); _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+r), mat_b_rearr[1]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + r), mat_b_rearr[2]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + r), mat_b_rearr[3]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + r), mat_b_rearr[4]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + r), mat_b_rearr[5]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + r), mat_b_rearr[6]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + r), mat_b_rearr[7]); //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); k++; } } } //numRows of A ///////////////////loop ends ///////////////////// } static void trsm_XAtB_block_allSmallSizedMatrices_alpha(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha) { float ones = 1.0; int i, i1, i2, i3, i4, j, k, l, r; int cs_b_offset[7]; int cs_l_offset[7]; float *ptr_b_dup, *ptr_l_dup; //57 number of ymm(256 bits) registers used __m256 mat_b_col[8]; __m256 mat_b_rearr[8]; __m256 mat_a_blk_elems[8]; __m256 mat_a_diag_inv[8]; __m256 reciprocal_diags[2]; __m256 alphaReg; reciprocal_diags[0] = _mm256_broadcast_ss((float const *)(&ones)); alphaReg = _mm256_broadcast_ss((float const *)&alpha); // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // //L matrix offsets cs_l_offset[0] = (cs_l << 1); cs_l_offset[1] = cs_l + cs_l_offset[0]; cs_l_offset[2] = (cs_l << 2); cs_l_offset[3] = cs_l + cs_l_offset[2]; cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; cs_l_offset[5] = cs_l + cs_l_offset[4]; cs_l_offset[6] = (cs_l_offset[5] + cs_l); //read diag elems of L 16x16 block mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l); mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + cs_l); mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[0]); mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[1]); mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[2]); mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[3]); mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[4]); mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[5]); cs_b_offset[0] = (cs_b << 1); cs_b_offset[1] = cs_b + cs_b_offset[0]; cs_b_offset[2] = (cs_b << 2); cs_b_offset[3] = cs_b + cs_b_offset[2]; cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; cs_b_offset[5] = cs_b + cs_b_offset[4]; cs_b_offset[6] = (cs_b_offset[5] + cs_b); reciprocal_diags[1] = reciprocal_diags[0]; //pack first 8 diags together mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1 mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3 mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5 mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7 mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 //reciprocal of diagnal elements 0,1,2,3,4,5,6,7 reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); #if 0 //Broadcast A10 to A70 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); //Broadcast A21 to A71 to registers mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); //Broadcast A32 to A72 to registers mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); //Broadcast A43 to A73 to registers mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); //Broadcast A54 to A74 to registers mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); //Broadcast A65 to A75 to registers mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); //Broadcast A76 to register mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); #endif //extract diag a00 from a mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); //mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_diag_inv[0], mat_a_diag_inv[0]); //extract diag a11 from a mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); //extract diag a22 from a mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); //extract diag a33 from a mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); //extract diag a44 from a mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); //extract diag a55 from a mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); //extract diag a66 from a mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); //extract diag a77 from a mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); /***************** first set of 8 rows of B processing starts *****************/ ptr_b_dup = ptr_b; i = 0; for (j = 0; j < numCols_b; j += 8) { /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A //read 8x8 block of B into registers mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i); mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg); mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg); mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg); mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg); mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg); mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg); mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg); mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg); //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], mat_a_diag_inv[0]); mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b) mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b) mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b) mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b) mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], mat_a_diag_inv[1]); mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b) mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b) mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b) mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], mat_a_diag_inv[2]); mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b) mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b) mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], mat_a_diag_inv[3]); mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b) mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], mat_a_diag_inv[4]); mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], mat_a_diag_inv[5]); mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], mat_a_diag_inv[6]); mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); //(Row7): FMA operations of b7 with elements of index (7, 0) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], mat_a_diag_inv[7]); //////////////////////////////////////////////////////////////////////////////// //Store the computed B columns _mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]); _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]); //i += cs_b_offset[6]; //ptr_b_dup += cs_b_offset[6]; i += 8; ptr_b_dup += 8; } //c = 0; /***************** first set of 8 cols of B processing done *****************/ ptr_b_dup = ptr_b; i3 = 0; i1 = 0; //Start loop for cols of B to be processed in size of blk_width for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row { ptr_l += 8; //ptr_b += j; //ptr_b_dup += 8; ptr_b_dup += cs_b_offset[6]; i1 += cs_b_offset[6]; //Read next 8x8 block of A to get diag elements i3 += cs_l_offset[6]; mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l + i3); mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l); mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[0]); mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[1]); mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[2]); mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[3]); mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[4]); mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[5]); //pack 8 diags of A together reciprocal_diags[0] = reciprocal_diags[1]; mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1 mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3 mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5 mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7 mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 //reciprocal of diagnal elements of A :- 0,1,2,3,4,5,6,7 reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); //extract diag a00 from a mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); //mat_a_diag_inv2[0] = _mm256_unpacklo_ps(mat_a_diag_inv2[0], mat_a_diag_inv2[0]); //extract diag a11 from a mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); //extract diag a22 from a mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); //extract diag a33 from a mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); //extract diag a44 from a mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); //extract diag a55 from a mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); //extract diag a66 from a mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); //extract diag a77 from a mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); for (r = 0; r < numCols_b; r += GEMM_BLK_V1) { #if GEMM_ACCUM_A i = i1 + r; //Read 8 cols of B columns of Block-to-be-solved mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg); mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg); mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg); mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg); mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg); mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg); mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg); mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg); #endif i = 0; i2 = 0; for (l = 0; l < j; l += 8) // move across m { //for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) { /////////////////// Partial Lower 8x8 block trsm of B ptr_l_dup = ptr_l; i4 = i2 + r; //Read current 8 cols of B columns from specified 8x8 current-block of B mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4); mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); //Broadcast A8,0 to A15,0 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); i4 = k >> 3; ptr_l_dup += cs_l; #if GEMM_ACCUM_A //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]); mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]); mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]); mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]); mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]); mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]); mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]); mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]); #endif //Broadcast A21 to A71 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); ptr_l_dup += cs_l; #if GEMM_ACCUM_A //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A8,2 to A15,2 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); ptr_l_dup += cs_l; #if GEMM_ACCUM_A //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A8,3 to A15,3 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); ptr_l_dup += cs_l; #if GEMM_ACCUM_A //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A8,4 to A15,4 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); ptr_l_dup += cs_l; #if GEMM_ACCUM_A //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A8,5 to A15,5 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); ptr_l_dup += cs_l; #if GEMM_ACCUM_A //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A8,6 to A15,6 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); ptr_l_dup += cs_l; #if GEMM_ACCUM_A //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A8,7 to A15,7 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); ptr_l_dup += cs_l; #if GEMM_ACCUM_A //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) #endif //end loop of cols } i2 += cs_b_offset[6]; i += cs_l_offset[6]; } //trsm solve k = 0; //for (i2 = 0; i2 < numCols_b; i2 += 8) { i2 = i1 + r; /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A #if !GEMM_ACCUM_A //Read 8 cols of B columns of Block-to-be-solved mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i2); mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i2)); mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i2)); mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i2)); mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i2)); mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i2)); mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i2)); mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i2)); mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg); mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg); mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg); mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg); mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg); mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg); mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg); mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg); #endif //Broadcast A10 to A70 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); i += cs_l; #if GEMM_ACCUM_A //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); #else mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]); mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); #endif #if GEMM_ACCUM_A mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]); mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]); mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]); mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]); mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]); mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]); mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]); //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A21 to A71 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); i += cs_l; //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) //Broadcast A32 to A72 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); i += cs_l; //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) //Broadcast A43 to A73 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); i += cs_l; //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) //Broadcast A54 to A74 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); i += cs_l; //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) //Broadcast A65 to A75 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); i += cs_l; //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) //Broadcast A76 to register mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); //(Row7): FMA operations of b7 with elements of index (7, 0) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); //////////////////////////////////////////////////////////////////////////////// //Store the computed B columns _mm256_storeu_ps((float *)ptr_b_dup + r, mat_b_rearr[0]); _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+r), mat_b_rearr[1]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + r), mat_b_rearr[2]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + r), mat_b_rearr[3]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + r), mat_b_rearr[4]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + r), mat_b_rearr[5]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + r), mat_b_rearr[6]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + r), mat_b_rearr[7]); //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); k++; } } } //numRows of A ///////////////////loop ends ///////////////////// } static void trsm_XAtB_block_allSmallSizedMatrices_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b) { //float ones = 1.0; int i, i1, i2, i3, i4, j, k, l, r; int cs_b_offset[7]; int cs_l_offset[7]; float *ptr_b_dup, *ptr_l_dup; //57 number of ymm(256 bits) registers used __m256 mat_b_col[8]; __m256 mat_b_rearr[8]; __m256 mat_a_blk_elems[8]; //__m256 mat_a_diag_inv[8]; //__m256 reciprocal_diags[2]; // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // //L matrix offsets cs_l_offset[0] = (cs_l << 1); cs_l_offset[1] = cs_l + cs_l_offset[0]; cs_l_offset[2] = (cs_l << 2); cs_l_offset[3] = cs_l + cs_l_offset[2]; cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; cs_l_offset[5] = cs_l + cs_l_offset[4]; cs_l_offset[6] = (cs_l_offset[5] + cs_l); cs_b_offset[0] = (cs_b << 1); cs_b_offset[1] = cs_b + cs_b_offset[0]; cs_b_offset[2] = (cs_b << 2); cs_b_offset[3] = cs_b + cs_b_offset[2]; cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; cs_b_offset[5] = cs_b + cs_b_offset[4]; cs_b_offset[6] = (cs_b_offset[5] + cs_b); /***************** first set of 8 rows of B processing starts *****************/ ptr_b_dup = ptr_b; i = 0; for (j = 0; j < numCols_b; j += 8) { /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A //read 8x8 block of B into registers mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i); mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); //(Row0) mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b) mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b) mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b) mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b) mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b) mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b) mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b) mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b) mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b) mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b) mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b) mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b) mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b) mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b) mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b) mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b) mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); //(Row7): FMA operations of b7 with elements of index (7, 0) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b) //////////////////////////////////////////////////////////////////////////////// //Store the computed B columns _mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]); _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]); //i += cs_b_offset[6]; //ptr_b_dup += cs_b_offset[6]; i += 8; ptr_b_dup += 8; } //c = 0; /***************** first set of 8 cols of B processing done *****************/ ptr_b_dup = ptr_b; i3 = 0; i1 = 0; //Start loop for cols of B to be processed in size of blk_width for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row { ptr_l += 8; //ptr_b += j; //ptr_b_dup += 8; ptr_b_dup += cs_b_offset[6]; i1 += cs_b_offset[6]; i3 += cs_l_offset[6]; i = 0; i2 = 0; for (r = 0; r < numCols_b; r += GEMM_BLK_V1) { #if GEMM_ACCUM_A i = i1 + r; //Read 8 cols of B columns of Block-to-be-solved mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); #endif i = 0; i2 = 0; for (l = 0; l < j; l += 8) // move across m { //for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) { /////////////////// Partial Lower 8x8 block trsm of B ptr_l_dup = ptr_l; i4 = i2 + r; //Read current 8 cols of B columns from specified 8x8 current-block of B mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4); mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); //Broadcast A8,0 to A15,0 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); i4 = k >> 3; ptr_l_dup += cs_l; #if GEMM_ACCUM_A //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]); mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]); mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]); mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]); mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]); mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]); mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]); mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]); #endif //Broadcast A21 to A71 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); ptr_l_dup += cs_l; #if GEMM_ACCUM_A //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A8,2 to A15,2 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); ptr_l_dup += cs_l; #if GEMM_ACCUM_A //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A8,3 to A15,3 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); ptr_l_dup += cs_l; #if GEMM_ACCUM_A //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A8,4 to A15,4 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); ptr_l_dup += cs_l; #if GEMM_ACCUM_A //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A8,5 to A15,5 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); ptr_l_dup += cs_l; #if GEMM_ACCUM_A //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A8,6 to A15,6 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); ptr_l_dup += cs_l; #if GEMM_ACCUM_A //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A8,7 to A15,7 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); ptr_l_dup += cs_l; #if GEMM_ACCUM_A //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) #endif //end loop of cols } i2 += cs_b_offset[6]; i += cs_l_offset[6]; } //trsm solve k = 0; //for (i2 = 0; i2 < numCols_b; i2 += 8) { i2 = i1 + r; /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A #if !GEMM_ACCUM_A //Read 8 cols of B columns of Block-to-be-solved mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i2); mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i2)); mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i2)); mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i2)); mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i2)); mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i2)); mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i2)); mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i2)); #endif //Broadcast A10 to A70 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); i += cs_l; #if GEMM_ACCUM_A //(Row0): already done #else mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]); #endif #if GEMM_ACCUM_A mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]); mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]); mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]); mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]); mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]); mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]); mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]); //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A21 to A71 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); i += cs_l; //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) //Broadcast A32 to A72 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); i += cs_l; //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) //Broadcast A43 to A73 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); i += cs_l; //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) //Broadcast A54 to A74 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); i += cs_l; //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) //Broadcast A65 to A75 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); i += cs_l; //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) //Broadcast A76 to register mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); //(Row7): FMA operations of b7 with elements of index (7, 0) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) //////////////////////////////////////////////////////////////////////////////// //Store the computed B columns _mm256_storeu_ps((float *)ptr_b_dup + r, mat_b_rearr[0]); _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+r), mat_b_rearr[1]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + r), mat_b_rearr[2]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + r), mat_b_rearr[3]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + r), mat_b_rearr[4]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + r), mat_b_rearr[5]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + r), mat_b_rearr[6]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + r), mat_b_rearr[7]); //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); k++; } } } //numRows of A ///////////////////loop ends ///////////////////// } static void trsm_XAtB_block_allSmallSizedMatrices_alpha_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha) { //float ones = 1.0; int i, i1, i2, i3, i4, j, k, l, r; int cs_b_offset[7]; int cs_l_offset[7]; float *ptr_b_dup, *ptr_l_dup; //57 number of ymm(256 bits) registers used __m256 mat_b_col[8]; __m256 mat_b_rearr[8]; __m256 mat_a_blk_elems[8]; //__m256 mat_a_diag_inv[8]; //__m256 reciprocal_diags[2]; __m256 alphaReg; alphaReg = _mm256_broadcast_ss((float const *)&alpha); // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // //L matrix offsets cs_l_offset[0] = (cs_l << 1); cs_l_offset[1] = cs_l + cs_l_offset[0]; cs_l_offset[2] = (cs_l << 2); cs_l_offset[3] = cs_l + cs_l_offset[2]; cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; cs_l_offset[5] = cs_l + cs_l_offset[4]; cs_l_offset[6] = (cs_l_offset[5] + cs_l); cs_b_offset[0] = (cs_b << 1); cs_b_offset[1] = cs_b + cs_b_offset[0]; cs_b_offset[2] = (cs_b << 2); cs_b_offset[3] = cs_b + cs_b_offset[2]; cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; cs_b_offset[5] = cs_b + cs_b_offset[4]; cs_b_offset[6] = (cs_b_offset[5] + cs_b); #if 0 //Broadcast A10 to A70 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); //Broadcast A21 to A71 to registers mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); //Broadcast A32 to A72 to registers mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); //Broadcast A43 to A73 to registers mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); //Broadcast A54 to A74 to registers mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); //Broadcast A65 to A75 to registers mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); //Broadcast A76 to register mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); #endif /***************** first set of 8 rows of B processing starts *****************/ ptr_b_dup = ptr_b; i = 0; for (j = 0; j < numCols_b; j += 8) { /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A //read 8x8 block of B into registers mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i); mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg); mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg); mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg); mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg); mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg); mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg); mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg); mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg); //(Row0) mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b) mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b) mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b) mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b) mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b) mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b) mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b) mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b) mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b) mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b) mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b) mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b) mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b) mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b) mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b) mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b) mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); //(Row7): FMA operations of b7 with elements of index (7, 0) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b) //////////////////////////////////////////////////////////////////////////////// //Store the computed B columns _mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]); _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]); //i += cs_b_offset[6]; //ptr_b_dup += cs_b_offset[6]; i += 8; ptr_b_dup += 8; } //c = 0; /***************** first set of 8 cols of B processing done *****************/ ptr_b_dup = ptr_b; i3 = 0; i1 = 0; //Start loop for cols of B to be processed in size of blk_width for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row { ptr_l += 8; //ptr_b += j; //ptr_b_dup += 8; ptr_b_dup += cs_b_offset[6]; i1 += cs_b_offset[6]; i3 += cs_l_offset[6]; i = 0; i2 = 0; for (r = 0; r < numCols_b; r += GEMM_BLK_V1) { #if GEMM_ACCUM_A i = i1 + r; //Read 8 cols of B columns of Block-to-be-solved mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg); mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg); mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg); mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg); mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg); mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg); mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg); mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg); #endif i = 0; i2 = 0; for (l = 0; l < j; l += 8) // move across m { //for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) { /////////////////// Partial Lower 8x8 block trsm of B ptr_l_dup = ptr_l; i4 = i2 + r; //Read current 8 cols of B columns from specified 8x8 current-block of B mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4); mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); //Broadcast A8,0 to A15,0 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); i4 = k >> 3; ptr_l_dup += cs_l; #if GEMM_ACCUM_A //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]); mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]); mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]); mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]); mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]); mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]); mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]); mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]); #endif //Broadcast A21 to A71 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); ptr_l_dup += cs_l; #if GEMM_ACCUM_A //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A8,2 to A15,2 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); ptr_l_dup += cs_l; #if GEMM_ACCUM_A //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A8,3 to A15,3 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); ptr_l_dup += cs_l; #if GEMM_ACCUM_A //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A8,4 to A15,4 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); ptr_l_dup += cs_l; #if GEMM_ACCUM_A //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A8,5 to A15,5 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); ptr_l_dup += cs_l; #if GEMM_ACCUM_A //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A8,6 to A15,6 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); ptr_l_dup += cs_l; #if GEMM_ACCUM_A //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A8,7 to A15,7 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); ptr_l_dup += cs_l; #if GEMM_ACCUM_A //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) #endif //end loop of cols } i2 += cs_b_offset[6]; i += cs_l_offset[6]; } //trsm solve k = 0; //for (i2 = 0; i2 < numCols_b; i2 += 8) { i2 = i1 + r; /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A #if !GEMM_ACCUM_A //Read 8 cols of B columns of Block-to-be-solved mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i2); mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i2)); mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i2)); mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i2)); mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i2)); mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i2)); mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i2)); mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i2)); mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg); mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg); mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg); mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg); mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg); mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg); mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg); mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg); #endif //Broadcast A10 to A70 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); i += cs_l; #if GEMM_ACCUM_A //(Row0): already done #else mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]); #endif #if GEMM_ACCUM_A mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]); mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]); mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]); mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]); mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]); mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]); mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]); //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A21 to A71 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); i += cs_l; //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) //Broadcast A32 to A72 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); i += cs_l; //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) //Broadcast A43 to A73 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); i += cs_l; //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) //Broadcast A54 to A74 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); i += cs_l; //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) //Broadcast A65 to A75 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); i += cs_l; //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) //Broadcast A76 to register mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); //(Row7): FMA operations of b7 with elements of index (7, 0) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) //////////////////////////////////////////////////////////////////////////////// //Store the computed B columns _mm256_storeu_ps((float *)ptr_b_dup + r, mat_b_rearr[0]); _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+r), mat_b_rearr[1]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + r), mat_b_rearr[2]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + r), mat_b_rearr[3]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + r), mat_b_rearr[4]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + r), mat_b_rearr[5]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + r), mat_b_rearr[6]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + r), mat_b_rearr[7]); //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); k++; } } } //numRows of A ///////////////////loop ends ///////////////////// } #else //rel 1.0 intrisic kernels (NOT OPT_CACHE_BLOCKING_L1) static void trsm_XAtB_block_allSmallSizedMatrices(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b) { float ones = 1.0; int i, i1, i2, i3, i4, j, k, l; int cs_b_offset[7]; int cs_l_offset[7]; float *ptr_b_dup; //57 number of ymm(256 bits) registers used __m256 mat_b_col[8]; __m256 mat_b_rearr[16][8]; __m256 mat_a_cols_rearr[8]; __m256 mat_a_blk_elems[64]; __m256 mat_a_diag_inv[8]; __m256 reciprocal_diags[2]; reciprocal_diags[0] = _mm256_broadcast_ss((float const *)(&ones)); // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // //L matrix offsets cs_l_offset[0] = (cs_l << 1); cs_l_offset[1] = cs_l + cs_l_offset[0]; cs_l_offset[2] = (cs_l << 2); cs_l_offset[3] = cs_l + cs_l_offset[2]; cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; cs_l_offset[5] = cs_l + cs_l_offset[4]; cs_l_offset[6] = (cs_l_offset[5] + cs_l); //read diag elems of L 16x16 block mat_a_cols_rearr[0] = _mm256_loadu_ps((float const *)ptr_l); mat_a_cols_rearr[1] = _mm256_loadu_ps((float const *)ptr_l + cs_l); mat_a_cols_rearr[2] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[0]); mat_a_cols_rearr[3] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[1]); mat_a_cols_rearr[4] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[2]); mat_a_cols_rearr[5] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[3]); mat_a_cols_rearr[6] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[4]); mat_a_cols_rearr[7] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[5]); cs_b_offset[0] = (cs_b << 1); cs_b_offset[1] = cs_b + cs_b_offset[0]; cs_b_offset[2] = (cs_b << 2); cs_b_offset[3] = cs_b + cs_b_offset[2]; cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; cs_b_offset[5] = cs_b + cs_b_offset[4]; cs_b_offset[6] = (cs_b_offset[5] + cs_b); reciprocal_diags[1] = reciprocal_diags[0]; //pack first 8 diags together mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[1], 0xAA);//diag 0,1 mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_cols_rearr[2], mat_a_cols_rearr[3], 0xAA);//diag 2,3 mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_cols_rearr[4], mat_a_cols_rearr[5], 0xAA);//diag 4,5 mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_cols_rearr[6], mat_a_cols_rearr[7], 0xAA);//diag 6,7 mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 //reciprocal of diagnal elements 0,1,2,3,4,5,6,7 reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); //Broadcast A10 to A70 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); //Broadcast A21 to A71 to registers mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); //Broadcast A32 to A72 to registers mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); //Broadcast A43 to A73 to registers mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); //Broadcast A54 to A74 to registers mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); //Broadcast A65 to A75 to registers mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); //Broadcast A76 to register mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); //extract diag a00 from a mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); //mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_diag_inv[0], mat_a_diag_inv[0]); //extract diag a11 from a mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); //extract diag a22 from a mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); //extract diag a33 from a mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); //extract diag a44 from a mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); //extract diag a55 from a mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); //extract diag a66 from a mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); //extract diag a77 from a mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); /***************** first set of 8 rows of B processing starts *****************/ ptr_b_dup = ptr_b; i = 0; for (j = 0; j < numCols_b; j += 8) { /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A //read 8x8 block of B into registers mat_b_rearr[0][0] = _mm256_loadu_ps((float const *)ptr_b + i); mat_b_rearr[1][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); mat_b_rearr[2][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); mat_b_rearr[3][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); mat_b_rearr[4][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); mat_b_rearr[5][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); mat_b_rearr[6][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); mat_b_rearr[7][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B mat_b_col[0] = _mm256_mul_ps(mat_b_rearr[0][0], mat_a_diag_inv[0]); //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[1][0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[1][0]);//d = c - (a*b) mat_b_rearr[2][0] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[2][0]);//d = c - (a*b) mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[3][0]);//d = c - (a*b) mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[4][0]);//d = c - (a*b) mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[5][0]);//d = c - (a*b) mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[6][0]);//d = c - (a*b) mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[7][0]);//d = c - (a*b) //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B mat_b_col[1] = _mm256_mul_ps(mat_b_rearr[1][0], mat_a_diag_inv[1]); //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[2][0] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[2][0]);//d = c - (a*b) mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[3][0]);//d = c - (a*b) mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[4][0]);//d = c - (a*b) mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[5][0]);//d = c - (a*b) mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[6][0]);//d = c - (a*b) mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[7][0]);//d = c - (a*b) //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B mat_b_col[2] = _mm256_mul_ps(mat_b_rearr[2][0], mat_a_diag_inv[2]); //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[2], mat_b_rearr[3][0]);//d = c - (a*b) mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[2], mat_b_rearr[4][0]);//d = c - (a*b) mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[2], mat_b_rearr[5][0]);//d = c - (a*b) mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[6][0]);//d = c - (a*b) mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[7][0]);//d = c - (a*b) //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B mat_b_col[3] = _mm256_mul_ps(mat_b_rearr[3][0], mat_a_diag_inv[3]); //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[3], mat_b_rearr[4][0]);//d = c - (a*b) mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[3], mat_b_rearr[5][0]);//d = c - (a*b) mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[3], mat_b_rearr[6][0]);//d = c - (a*b) mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[3], mat_b_rearr[7][0]);//d = c - (a*b) //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B mat_b_col[4] = _mm256_mul_ps(mat_b_rearr[4][0], mat_a_diag_inv[4]); //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[4], mat_b_rearr[5][0]);//d = c - (a*b) mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[4], mat_b_rearr[6][0]);//d = c - (a*b) mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[4], mat_b_rearr[7][0]);//d = c - (a*b) //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B mat_b_col[5] = _mm256_mul_ps(mat_b_rearr[5][0], mat_a_diag_inv[5]); //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[5], mat_b_rearr[6][0]);//d = c - (a*b) mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[5], mat_b_rearr[7][0]);//d = c - (a*b) //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B mat_b_col[6] = _mm256_mul_ps(mat_b_rearr[6][0], mat_a_diag_inv[6]); //(Row7): FMA operations of b7 with elements of index (7, 0) mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[6], mat_b_rearr[7][0]);//d = c - (a*b) //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B mat_b_col[7] = _mm256_mul_ps(mat_b_rearr[7][0], mat_a_diag_inv[7]); //////////////////////////////////////////////////////////////////////////////// //Store the computed B columns _mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]); _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]); //i += cs_b_offset[6]; //ptr_b_dup += cs_b_offset[6]; i += 8; ptr_b_dup += 8; } //c = 0; /***************** first set of 8 cols of B processing done *****************/ ptr_b_dup = ptr_b; i3 = 0; i1 = 0; //Start loop for cols of B to be processed in size of blk_width for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row { ptr_l += 8; //ptr_b += j; //ptr_b_dup += 8; ptr_b_dup += cs_b_offset[6]; i1 += cs_b_offset[6]; //Read next 8x8 block of A to get diag elements i3 += cs_l_offset[6]; mat_a_cols_rearr[0] = _mm256_loadu_ps((float const *)ptr_l + i3); mat_a_cols_rearr[1] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l); mat_a_cols_rearr[2] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[0]); mat_a_cols_rearr[3] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[1]); mat_a_cols_rearr[4] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[2]); mat_a_cols_rearr[5] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[3]); mat_a_cols_rearr[6] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[4]); mat_a_cols_rearr[7] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[5]); //pack 8 diags of A together reciprocal_diags[0] = reciprocal_diags[1]; mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[1], 0xAA);//diag 0,1 mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_cols_rearr[2], mat_a_cols_rearr[3], 0xAA);//diag 2,3 mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_cols_rearr[4], mat_a_cols_rearr[5], 0xAA);//diag 4,5 mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_cols_rearr[6], mat_a_cols_rearr[7], 0xAA);//diag 6,7 mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 //reciprocal of diagnal elements of A :- 0,1,2,3,4,5,6,7 reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); i = 0; i2 = 0; for (k = 0; k < numCols_b; k += 8) { i = i1 + k; //Read 8 cols of B columns of Block-to-be-solved mat_b_rearr[i2][0] = _mm256_loadu_ps((float const *)ptr_b + i); mat_b_rearr[i2][1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); mat_b_rearr[i2][2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); mat_b_rearr[i2][3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); mat_b_rearr[i2][4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); mat_b_rearr[i2][5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); mat_b_rearr[i2][6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); mat_b_rearr[i2][7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); i2++; } i = 0; i2 = 0; for (l = 0; l < j; l += 8) // move across m { //Broadcast A8,0 to A15,0 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); //Broadcast A21 to A71 to registers mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i)); mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 1)); mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 2)); mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 3)); mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 4)); mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 5)); mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 6)); mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 7)); //Broadcast A8,2 to A15,2 to registers mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i)); mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 1)); mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 2)); mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 3)); mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 4)); mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 5)); mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 6)); mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 7)); //Broadcast A8,3 to A15,3 to registers mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i)); mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 1)); mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 2)); mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 3)); mat_a_blk_elems[28] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 4)); mat_a_blk_elems[29] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 5)); mat_a_blk_elems[30] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 6)); mat_a_blk_elems[31] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 7)); // _mm256_permute2f128_ps() //Broadcast A8,4 to A15,4 to registers mat_a_blk_elems[32] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i)); mat_a_blk_elems[33] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 1)); mat_a_blk_elems[34] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 2)); mat_a_blk_elems[35] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 3)); mat_a_blk_elems[36] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 4)); mat_a_blk_elems[37] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 5)); mat_a_blk_elems[38] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 6)); mat_a_blk_elems[39] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 7)); //Broadcast A8,5 to A15,5 to registers mat_a_blk_elems[40] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i)); mat_a_blk_elems[41] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 1)); mat_a_blk_elems[42] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 2)); mat_a_blk_elems[43] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 3)); mat_a_blk_elems[44] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 4)); mat_a_blk_elems[45] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 5)); mat_a_blk_elems[46] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 6)); mat_a_blk_elems[47] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 7)); //Broadcast A8,6 to A15,6 to registers mat_a_blk_elems[48] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i)); mat_a_blk_elems[49] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 1)); mat_a_blk_elems[50] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 2)); mat_a_blk_elems[51] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 3)); mat_a_blk_elems[52] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 4)); mat_a_blk_elems[53] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 5)); mat_a_blk_elems[54] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 6)); mat_a_blk_elems[55] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 7)); //Broadcast A8,7 to A15,7 to registers mat_a_blk_elems[56] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i)); mat_a_blk_elems[57] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 1)); mat_a_blk_elems[58] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 2)); mat_a_blk_elems[59] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 3)); mat_a_blk_elems[60] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 4)); mat_a_blk_elems[61] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 5)); mat_a_blk_elems[62] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 6)); mat_a_blk_elems[63] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 7)); i += cs_l_offset[6]; for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) { /////////////////// Partial Lower 8x8 block trsm of B i4 = i2 + k; //Read current 8 cols of B columns from specified 8x8 current-block of B mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4); mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); i4 = k >> 3; //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[i4][0]);//d = c - (a*b) mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[i4][1]);//d = c - (a*b) mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[i4][2]);//d = c - (a*b) mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[i4][3]);//d = c - (a*b) mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[i4][4]);//d = c - (a*b) mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[i4][5]);//d = c - (a*b) mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[i4][6]);//d = c - (a*b) mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[i4][7]);//d = c - (a*b) //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[i4][0]);//d = c - (a*b) mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[i4][1]);//d = c - (a*b) mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[i4][2]);//d = c - (a*b) mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[i4][3]);//d = c - (a*b) mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[i4][4]);//d = c - (a*b) mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[1], mat_b_rearr[i4][5]);//d = c - (a*b) mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[1], mat_b_rearr[i4][6]);//d = c - (a*b) mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[1], mat_b_rearr[i4][7]);//d = c - (a*b) //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[i4][0]);//d = c - (a*b) mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[i4][1]);//d = c - (a*b) mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[2], mat_b_rearr[i4][2]);//d = c - (a*b) mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[2], mat_b_rearr[i4][3]);//d = c - (a*b) mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[2], mat_b_rearr[i4][4]);//d = c - (a*b) mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[2], mat_b_rearr[i4][5]);//d = c - (a*b) mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[2], mat_b_rearr[i4][6]);//d = c - (a*b) mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[2], mat_b_rearr[i4][7]);//d = c - (a*b) //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[3], mat_b_rearr[i4][0]);//d = c - (a*b) mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[3], mat_b_rearr[i4][1]);//d = c - (a*b) mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[3], mat_b_rearr[i4][2]);//d = c - (a*b) mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[3], mat_b_rearr[i4][3]);//d = c - (a*b) mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[28], mat_b_col[3], mat_b_rearr[i4][4]);//d = c - (a*b) mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[29], mat_b_col[3], mat_b_rearr[i4][5]);//d = c - (a*b) mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[30], mat_b_col[3], mat_b_rearr[i4][6]);//d = c - (a*b) mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[31], mat_b_col[3], mat_b_rearr[i4][7]);//d = c - (a*b) //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[32], mat_b_col[4], mat_b_rearr[i4][0]);//d = c - (a*b) mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[33], mat_b_col[4], mat_b_rearr[i4][1]);//d = c - (a*b) mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[34], mat_b_col[4], mat_b_rearr[i4][2]);//d = c - (a*b) mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[35], mat_b_col[4], mat_b_rearr[i4][3]);//d = c - (a*b) mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[36], mat_b_col[4], mat_b_rearr[i4][4]);//d = c - (a*b) mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[37], mat_b_col[4], mat_b_rearr[i4][5]);//d = c - (a*b) mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[38], mat_b_col[4], mat_b_rearr[i4][6]);//d = c - (a*b) mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[39], mat_b_col[4], mat_b_rearr[i4][7]);//d = c - (a*b) //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[40], mat_b_col[5], mat_b_rearr[i4][0]);//d = c - (a*b) mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[41], mat_b_col[5], mat_b_rearr[i4][1]);//d = c - (a*b) mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[42], mat_b_col[5], mat_b_rearr[i4][2]);//d = c - (a*b) mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[43], mat_b_col[5], mat_b_rearr[i4][3]);//d = c - (a*b) mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[44], mat_b_col[5], mat_b_rearr[i4][4]);//d = c - (a*b) mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[45], mat_b_col[5], mat_b_rearr[i4][5]);//d = c - (a*b) mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[46], mat_b_col[5], mat_b_rearr[i4][6]);//d = c - (a*b) mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[47], mat_b_col[5], mat_b_rearr[i4][7]);//d = c - (a*b) //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[48], mat_b_col[6], mat_b_rearr[i4][0]);//d = c - (a*b) mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[49], mat_b_col[6], mat_b_rearr[i4][1]);//d = c - (a*b) mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[50], mat_b_col[6], mat_b_rearr[i4][2]);//d = c - (a*b) mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[51], mat_b_col[6], mat_b_rearr[i4][3]);//d = c - (a*b) mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[52], mat_b_col[6], mat_b_rearr[i4][4]);//d = c - (a*b) mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[53], mat_b_col[6], mat_b_rearr[i4][5]);//d = c - (a*b) mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[54], mat_b_col[6], mat_b_rearr[i4][6]);//d = c - (a*b) mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[55], mat_b_col[6], mat_b_rearr[i4][7]);//d = c - (a*b) //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[56], mat_b_col[7], mat_b_rearr[i4][0]);//d = c - (a*b) mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[57], mat_b_col[7], mat_b_rearr[i4][1]);//d = c - (a*b) mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[58], mat_b_col[7], mat_b_rearr[i4][2]);//d = c - (a*b) mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[59], mat_b_col[7], mat_b_rearr[i4][3]);//d = c - (a*b) mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[60], mat_b_col[7], mat_b_rearr[i4][4]);//d = c - (a*b) mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[61], mat_b_col[7], mat_b_rearr[i4][5]);//d = c - (a*b) mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[62], mat_b_col[7], mat_b_rearr[i4][6]);//d = c - (a*b) mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[63], mat_b_col[7], mat_b_rearr[i4][7]);//d = c - (a*b) //end loop of cols } i2 += cs_b_offset[6]; } //Broadcast A10 to A70 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); i += cs_l; //extract diag a00 from a mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); //mat_a_diag_inv2[0] = _mm256_unpacklo_ps(mat_a_diag_inv2[0], mat_a_diag_inv2[0]); //Broadcast A21 to A71 to registers mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); i += cs_l; //extract diag a11 from a mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); //Broadcast A32 to A72 to registers mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); i += cs_l; //extract diag a22 from a mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); //Broadcast A43 to A73 to registers mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); i += cs_l; //extract diag a33 from a mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); //Broadcast A54 to A74 to registers mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); i += cs_l; //extract diag a44 from a mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); //Broadcast A65 to A75 to registers mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); i += cs_l; //extract diag a55 from a mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); //Broadcast A76 to register mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); //extract diag a66 from a mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); //extract diag a77 from a mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); k = 0; for (i = 0; i < numCols_b; i+=8) { /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B mat_b_rearr[k][0] = _mm256_mul_ps(mat_b_rearr[k][0], mat_a_diag_inv[0]); //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[k][1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[k][0], mat_b_rearr[k][1]);//d = c - (a*b) mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[k][0], mat_b_rearr[k][2]);//d = c - (a*b) mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[k][0], mat_b_rearr[k][3]);//d = c - (a*b) mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[k][0], mat_b_rearr[k][4]);//d = c - (a*b) mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[k][0], mat_b_rearr[k][5]);//d = c - (a*b) mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[k][0], mat_b_rearr[k][6]);//d = c - (a*b) mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[k][0], mat_b_rearr[k][7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B mat_b_rearr[k][1] = _mm256_mul_ps(mat_b_rearr[k][1], mat_a_diag_inv[1]); //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_rearr[k][1], mat_b_rearr[k][2]);//d = c - (a*b) mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_rearr[k][1], mat_b_rearr[k][3]);//d = c - (a*b) mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_rearr[k][1], mat_b_rearr[k][4]);//d = c - (a*b) mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_rearr[k][1], mat_b_rearr[k][5]);//d = c - (a*b) mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_rearr[k][1], mat_b_rearr[k][6]);//d = c - (a*b) mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_rearr[k][1], mat_b_rearr[k][7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B mat_b_rearr[k][2] = _mm256_mul_ps(mat_b_rearr[k][2], mat_a_diag_inv[2]); //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_rearr[k][2], mat_b_rearr[k][3]);//d = c - (a*b) mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_rearr[k][2], mat_b_rearr[k][4]);//d = c - (a*b) mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_rearr[k][2], mat_b_rearr[k][5]);//d = c - (a*b) mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_rearr[k][2], mat_b_rearr[k][6]);//d = c - (a*b) mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_rearr[k][2], mat_b_rearr[k][7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B mat_b_rearr[k][3] = _mm256_mul_ps(mat_b_rearr[k][3], mat_a_diag_inv[3]); //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_rearr[k][3], mat_b_rearr[k][4]);//d = c - (a*b) mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_rearr[k][3], mat_b_rearr[k][5]);//d = c - (a*b) mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_rearr[k][3], mat_b_rearr[k][6]);//d = c - (a*b) mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_rearr[k][3], mat_b_rearr[k][7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B mat_b_rearr[k][4] = _mm256_mul_ps(mat_b_rearr[k][4], mat_a_diag_inv[4]); //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_rearr[k][4], mat_b_rearr[k][5]);//d = c - (a*b) mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_rearr[k][4], mat_b_rearr[k][6]);//d = c - (a*b) mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_rearr[k][4], mat_b_rearr[k][7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B mat_b_rearr[k][5] = _mm256_mul_ps(mat_b_rearr[k][5], mat_a_diag_inv[5]); //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_rearr[k][5], mat_b_rearr[k][6]);//d = c - (a*b) mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_rearr[k][5], mat_b_rearr[k][7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B mat_b_rearr[k][6] = _mm256_mul_ps(mat_b_rearr[k][6], mat_a_diag_inv[6]); //(Row7): FMA operations of b7 with elements of index (7, 0) mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_rearr[k][6], mat_b_rearr[k][7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B mat_b_rearr[k][7] = _mm256_mul_ps(mat_b_rearr[k][7], mat_a_diag_inv[7]); //////////////////////////////////////////////////////////////////////////////// //Store the computed B columns _mm256_storeu_ps((float *)ptr_b_dup + i, mat_b_rearr[k][0]); _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b) + i), mat_b_rearr[k][1]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i), mat_b_rearr[k][2]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i), mat_b_rearr[k][3]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i), mat_b_rearr[k][4]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i), mat_b_rearr[k][5]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i), mat_b_rearr[k][6]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i), mat_b_rearr[k][7]); //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); k++; } } ///////////////////loop ends ///////////////////// } static void trsm_XAtB_block_allSmallSizedMatrices_alpha(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha) { float ones = 1.0; int i, i1, i2, i3, i4, j, k, l; int cs_b_offset[7]; int cs_l_offset[7]; float *ptr_b_dup; //57 number of ymm(256 bits) registers used __m256 mat_b_col[8]; __m256 mat_b_rearr[16][8]; __m256 mat_a_cols_rearr[8]; __m256 mat_a_blk_elems[64]; __m256 mat_a_diag_inv[8]; __m256 reciprocal_diags[2]; __m256 alphaReg; reciprocal_diags[0] = _mm256_broadcast_ss((float const *)(&ones)); alphaReg = _mm256_broadcast_ss((float const *)&alpha); // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // //L matrix offsets cs_l_offset[0] = (cs_l << 1); cs_l_offset[1] = cs_l + cs_l_offset[0]; cs_l_offset[2] = (cs_l << 2); cs_l_offset[3] = cs_l + cs_l_offset[2]; cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; cs_l_offset[5] = cs_l + cs_l_offset[4]; cs_l_offset[6] = (cs_l_offset[5] + cs_l); //read diag elems of L 16x16 block mat_a_cols_rearr[0] = _mm256_loadu_ps((float const *)ptr_l); mat_a_cols_rearr[1] = _mm256_loadu_ps((float const *)ptr_l + cs_l); mat_a_cols_rearr[2] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[0]); mat_a_cols_rearr[3] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[1]); mat_a_cols_rearr[4] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[2]); mat_a_cols_rearr[5] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[3]); mat_a_cols_rearr[6] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[4]); mat_a_cols_rearr[7] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[5]); cs_b_offset[0] = (cs_b << 1); cs_b_offset[1] = cs_b + cs_b_offset[0]; cs_b_offset[2] = (cs_b << 2); cs_b_offset[3] = cs_b + cs_b_offset[2]; cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; cs_b_offset[5] = cs_b + cs_b_offset[4]; cs_b_offset[6] = (cs_b_offset[5] + cs_b); reciprocal_diags[1] = reciprocal_diags[0]; //pack first 8 diags together mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[1], 0xAA);//diag 0,1 mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_cols_rearr[2], mat_a_cols_rearr[3], 0xAA);//diag 2,3 mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_cols_rearr[4], mat_a_cols_rearr[5], 0xAA);//diag 4,5 mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_cols_rearr[6], mat_a_cols_rearr[7], 0xAA);//diag 6,7 mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 //reciprocal of diagnal elements 0,1,2,3,4,5,6,7 reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); //Broadcast A10 to A70 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); //Broadcast A21 to A71 to registers mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); //Broadcast A32 to A72 to registers mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); //Broadcast A43 to A73 to registers mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); //Broadcast A54 to A74 to registers mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); //Broadcast A65 to A75 to registers mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); //Broadcast A76 to register mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); //extract diag a00 from a mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); //mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_diag_inv[0], mat_a_diag_inv[0]); //extract diag a11 from a mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); //extract diag a22 from a mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); //extract diag a33 from a mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); //extract diag a44 from a mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); //extract diag a55 from a mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); //extract diag a66 from a mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); //extract diag a77 from a mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); /***************** first set of 8 rows of B processing starts *****************/ ptr_b_dup = ptr_b; i = 0; for (j = 0; j < numCols_b; j += 8) { /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A //read 8x8 block of B into registers mat_b_rearr[0][0] = _mm256_loadu_ps((float const *)ptr_b + i); mat_b_rearr[1][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); mat_b_rearr[2][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); mat_b_rearr[3][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); mat_b_rearr[4][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); mat_b_rearr[5][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); mat_b_rearr[6][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); mat_b_rearr[7][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); mat_b_rearr[0][0] = _mm256_mul_ps(mat_b_rearr[0][0], alphaReg); mat_b_rearr[1][0] = _mm256_mul_ps(mat_b_rearr[1][0], alphaReg); mat_b_rearr[2][0] = _mm256_mul_ps(mat_b_rearr[2][0], alphaReg); mat_b_rearr[3][0] = _mm256_mul_ps(mat_b_rearr[3][0], alphaReg); mat_b_rearr[4][0] = _mm256_mul_ps(mat_b_rearr[4][0], alphaReg); mat_b_rearr[5][0] = _mm256_mul_ps(mat_b_rearr[5][0], alphaReg); mat_b_rearr[6][0] = _mm256_mul_ps(mat_b_rearr[6][0], alphaReg); mat_b_rearr[7][0] = _mm256_mul_ps(mat_b_rearr[7][0], alphaReg); //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B mat_b_col[0] = _mm256_mul_ps(mat_b_rearr[0][0], mat_a_diag_inv[0]); //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[1][0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[1][0]);//d = c - (a*b) mat_b_rearr[2][0] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[2][0]);//d = c - (a*b) mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[3][0]);//d = c - (a*b) mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[4][0]);//d = c - (a*b) mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[5][0]);//d = c - (a*b) mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[6][0]);//d = c - (a*b) mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[7][0]);//d = c - (a*b) //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B mat_b_col[1] = _mm256_mul_ps(mat_b_rearr[1][0], mat_a_diag_inv[1]); //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[2][0] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[2][0]);//d = c - (a*b) mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[3][0]);//d = c - (a*b) mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[4][0]);//d = c - (a*b) mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[5][0]);//d = c - (a*b) mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[6][0]);//d = c - (a*b) mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[7][0]);//d = c - (a*b) //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B mat_b_col[2] = _mm256_mul_ps(mat_b_rearr[2][0], mat_a_diag_inv[2]); //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[2], mat_b_rearr[3][0]);//d = c - (a*b) mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[2], mat_b_rearr[4][0]);//d = c - (a*b) mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[2], mat_b_rearr[5][0]);//d = c - (a*b) mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[6][0]);//d = c - (a*b) mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[7][0]);//d = c - (a*b) //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B mat_b_col[3] = _mm256_mul_ps(mat_b_rearr[3][0], mat_a_diag_inv[3]); //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[3], mat_b_rearr[4][0]);//d = c - (a*b) mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[3], mat_b_rearr[5][0]);//d = c - (a*b) mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[3], mat_b_rearr[6][0]);//d = c - (a*b) mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[3], mat_b_rearr[7][0]);//d = c - (a*b) //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B mat_b_col[4] = _mm256_mul_ps(mat_b_rearr[4][0], mat_a_diag_inv[4]); //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[4], mat_b_rearr[5][0]);//d = c - (a*b) mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[4], mat_b_rearr[6][0]);//d = c - (a*b) mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[4], mat_b_rearr[7][0]);//d = c - (a*b) //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B mat_b_col[5] = _mm256_mul_ps(mat_b_rearr[5][0], mat_a_diag_inv[5]); //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[5], mat_b_rearr[6][0]);//d = c - (a*b) mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[5], mat_b_rearr[7][0]);//d = c - (a*b) //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B mat_b_col[6] = _mm256_mul_ps(mat_b_rearr[6][0], mat_a_diag_inv[6]); //(Row7): FMA operations of b7 with elements of index (7, 0) mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[6], mat_b_rearr[7][0]);//d = c - (a*b) //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B mat_b_col[7] = _mm256_mul_ps(mat_b_rearr[7][0], mat_a_diag_inv[7]); //////////////////////////////////////////////////////////////////////////////// //Store the computed B columns _mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]); _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]); //i += cs_b_offset[6]; //ptr_b_dup += cs_b_offset[6]; i += 8; ptr_b_dup += 8; } //c = 0; /***************** first set of 8 cols of B processing done *****************/ ptr_b_dup = ptr_b; i3 = 0; i1 = 0; //Start loop for cols of B to be processed in size of blk_width for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row { ptr_l += 8; //ptr_b += j; //ptr_b_dup += 8; ptr_b_dup += cs_b_offset[6]; i1 += cs_b_offset[6]; //Read next 8x8 block of A to get diag elements i3 += cs_l_offset[6]; mat_a_cols_rearr[0] = _mm256_loadu_ps((float const *)ptr_l + i3); mat_a_cols_rearr[1] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l); mat_a_cols_rearr[2] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[0]); mat_a_cols_rearr[3] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[1]); mat_a_cols_rearr[4] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[2]); mat_a_cols_rearr[5] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[3]); mat_a_cols_rearr[6] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[4]); mat_a_cols_rearr[7] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[5]); //pack 8 diags of A together reciprocal_diags[0] = reciprocal_diags[1]; mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[1], 0xAA);//diag 0,1 mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_cols_rearr[2], mat_a_cols_rearr[3], 0xAA);//diag 2,3 mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_cols_rearr[4], mat_a_cols_rearr[5], 0xAA);//diag 4,5 mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_cols_rearr[6], mat_a_cols_rearr[7], 0xAA);//diag 6,7 mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 //reciprocal of diagnal elements of A :- 0,1,2,3,4,5,6,7 reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); i = 0; i2 = 0; for (k = 0; k < numCols_b; k += 8) { i = i1 + k; //Read 8 cols of B columns of Block-to-be-solved mat_b_rearr[i2][0] = _mm256_loadu_ps((float const *)ptr_b + i); mat_b_rearr[i2][1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); mat_b_rearr[i2][2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); mat_b_rearr[i2][3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); mat_b_rearr[i2][4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); mat_b_rearr[i2][5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); mat_b_rearr[i2][6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); mat_b_rearr[i2][7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); mat_b_rearr[i2][0] = _mm256_mul_ps(mat_b_rearr[i2][0], alphaReg); mat_b_rearr[i2][1] = _mm256_mul_ps(mat_b_rearr[i2][1], alphaReg); mat_b_rearr[i2][2] = _mm256_mul_ps(mat_b_rearr[i2][2], alphaReg); mat_b_rearr[i2][3] = _mm256_mul_ps(mat_b_rearr[i2][3], alphaReg); mat_b_rearr[i2][4] = _mm256_mul_ps(mat_b_rearr[i2][4], alphaReg); mat_b_rearr[i2][5] = _mm256_mul_ps(mat_b_rearr[i2][5], alphaReg); mat_b_rearr[i2][6] = _mm256_mul_ps(mat_b_rearr[i2][6], alphaReg); mat_b_rearr[i2][7] = _mm256_mul_ps(mat_b_rearr[i2][7], alphaReg); i2++; } i = 0; i2 = 0; for (l = 0; l < j; l += 8) // move across m { //Broadcast A8,0 to A15,0 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); //Broadcast A21 to A71 to registers mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i)); mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 1)); mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 2)); mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 3)); mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 4)); mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 5)); mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 6)); mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 7)); //Broadcast A8,2 to A15,2 to registers mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i)); mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 1)); mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 2)); mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 3)); mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 4)); mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 5)); mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 6)); mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 7)); //Broadcast A8,3 to A15,3 to registers mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i)); mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 1)); mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 2)); mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 3)); mat_a_blk_elems[28] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 4)); mat_a_blk_elems[29] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 5)); mat_a_blk_elems[30] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 6)); mat_a_blk_elems[31] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 7)); // _mm256_permute2f128_ps() //Broadcast A8,4 to A15,4 to registers mat_a_blk_elems[32] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i)); mat_a_blk_elems[33] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 1)); mat_a_blk_elems[34] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 2)); mat_a_blk_elems[35] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 3)); mat_a_blk_elems[36] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 4)); mat_a_blk_elems[37] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 5)); mat_a_blk_elems[38] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 6)); mat_a_blk_elems[39] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 7)); //Broadcast A8,5 to A15,5 to registers mat_a_blk_elems[40] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i)); mat_a_blk_elems[41] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 1)); mat_a_blk_elems[42] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 2)); mat_a_blk_elems[43] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 3)); mat_a_blk_elems[44] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 4)); mat_a_blk_elems[45] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 5)); mat_a_blk_elems[46] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 6)); mat_a_blk_elems[47] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 7)); //Broadcast A8,6 to A15,6 to registers mat_a_blk_elems[48] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i)); mat_a_blk_elems[49] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 1)); mat_a_blk_elems[50] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 2)); mat_a_blk_elems[51] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 3)); mat_a_blk_elems[52] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 4)); mat_a_blk_elems[53] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 5)); mat_a_blk_elems[54] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 6)); mat_a_blk_elems[55] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 7)); //Broadcast A8,7 to A15,7 to registers mat_a_blk_elems[56] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i)); mat_a_blk_elems[57] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 1)); mat_a_blk_elems[58] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 2)); mat_a_blk_elems[59] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 3)); mat_a_blk_elems[60] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 4)); mat_a_blk_elems[61] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 5)); mat_a_blk_elems[62] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 6)); mat_a_blk_elems[63] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 7)); i += cs_l_offset[6]; for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) { /////////////////// Partial Lower 8x8 block trsm of B i4 = i2 + k; //Read current 8 cols of B columns from specified 8x8 current-block of B mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4); mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); i4 = k >> 3; //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[i4][0]);//d = c - (a*b) mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[i4][1]);//d = c - (a*b) mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[i4][2]);//d = c - (a*b) mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[i4][3]);//d = c - (a*b) mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[i4][4]);//d = c - (a*b) mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[i4][5]);//d = c - (a*b) mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[i4][6]);//d = c - (a*b) mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[i4][7]);//d = c - (a*b) //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[i4][0]);//d = c - (a*b) mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[i4][1]);//d = c - (a*b) mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[i4][2]);//d = c - (a*b) mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[i4][3]);//d = c - (a*b) mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[i4][4]);//d = c - (a*b) mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[1], mat_b_rearr[i4][5]);//d = c - (a*b) mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[1], mat_b_rearr[i4][6]);//d = c - (a*b) mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[1], mat_b_rearr[i4][7]);//d = c - (a*b) //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[i4][0]);//d = c - (a*b) mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[i4][1]);//d = c - (a*b) mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[2], mat_b_rearr[i4][2]);//d = c - (a*b) mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[2], mat_b_rearr[i4][3]);//d = c - (a*b) mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[2], mat_b_rearr[i4][4]);//d = c - (a*b) mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[2], mat_b_rearr[i4][5]);//d = c - (a*b) mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[2], mat_b_rearr[i4][6]);//d = c - (a*b) mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[2], mat_b_rearr[i4][7]);//d = c - (a*b) //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[3], mat_b_rearr[i4][0]);//d = c - (a*b) mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[3], mat_b_rearr[i4][1]);//d = c - (a*b) mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[3], mat_b_rearr[i4][2]);//d = c - (a*b) mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[3], mat_b_rearr[i4][3]);//d = c - (a*b) mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[28], mat_b_col[3], mat_b_rearr[i4][4]);//d = c - (a*b) mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[29], mat_b_col[3], mat_b_rearr[i4][5]);//d = c - (a*b) mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[30], mat_b_col[3], mat_b_rearr[i4][6]);//d = c - (a*b) mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[31], mat_b_col[3], mat_b_rearr[i4][7]);//d = c - (a*b) //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[32], mat_b_col[4], mat_b_rearr[i4][0]);//d = c - (a*b) mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[33], mat_b_col[4], mat_b_rearr[i4][1]);//d = c - (a*b) mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[34], mat_b_col[4], mat_b_rearr[i4][2]);//d = c - (a*b) mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[35], mat_b_col[4], mat_b_rearr[i4][3]);//d = c - (a*b) mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[36], mat_b_col[4], mat_b_rearr[i4][4]);//d = c - (a*b) mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[37], mat_b_col[4], mat_b_rearr[i4][5]);//d = c - (a*b) mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[38], mat_b_col[4], mat_b_rearr[i4][6]);//d = c - (a*b) mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[39], mat_b_col[4], mat_b_rearr[i4][7]);//d = c - (a*b) //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[40], mat_b_col[5], mat_b_rearr[i4][0]);//d = c - (a*b) mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[41], mat_b_col[5], mat_b_rearr[i4][1]);//d = c - (a*b) mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[42], mat_b_col[5], mat_b_rearr[i4][2]);//d = c - (a*b) mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[43], mat_b_col[5], mat_b_rearr[i4][3]);//d = c - (a*b) mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[44], mat_b_col[5], mat_b_rearr[i4][4]);//d = c - (a*b) mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[45], mat_b_col[5], mat_b_rearr[i4][5]);//d = c - (a*b) mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[46], mat_b_col[5], mat_b_rearr[i4][6]);//d = c - (a*b) mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[47], mat_b_col[5], mat_b_rearr[i4][7]);//d = c - (a*b) //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[48], mat_b_col[6], mat_b_rearr[i4][0]);//d = c - (a*b) mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[49], mat_b_col[6], mat_b_rearr[i4][1]);//d = c - (a*b) mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[50], mat_b_col[6], mat_b_rearr[i4][2]);//d = c - (a*b) mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[51], mat_b_col[6], mat_b_rearr[i4][3]);//d = c - (a*b) mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[52], mat_b_col[6], mat_b_rearr[i4][4]);//d = c - (a*b) mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[53], mat_b_col[6], mat_b_rearr[i4][5]);//d = c - (a*b) mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[54], mat_b_col[6], mat_b_rearr[i4][6]);//d = c - (a*b) mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[55], mat_b_col[6], mat_b_rearr[i4][7]);//d = c - (a*b) //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[56], mat_b_col[7], mat_b_rearr[i4][0]);//d = c - (a*b) mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[57], mat_b_col[7], mat_b_rearr[i4][1]);//d = c - (a*b) mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[58], mat_b_col[7], mat_b_rearr[i4][2]);//d = c - (a*b) mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[59], mat_b_col[7], mat_b_rearr[i4][3]);//d = c - (a*b) mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[60], mat_b_col[7], mat_b_rearr[i4][4]);//d = c - (a*b) mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[61], mat_b_col[7], mat_b_rearr[i4][5]);//d = c - (a*b) mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[62], mat_b_col[7], mat_b_rearr[i4][6]);//d = c - (a*b) mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[63], mat_b_col[7], mat_b_rearr[i4][7]);//d = c - (a*b) //end loop of cols } i2 += cs_b_offset[6]; } //Broadcast A10 to A70 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); i += cs_l; //extract diag a00 from a mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); //mat_a_diag_inv2[0] = _mm256_unpacklo_ps(mat_a_diag_inv2[0], mat_a_diag_inv2[0]); //Broadcast A21 to A71 to registers mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); i += cs_l; //extract diag a11 from a mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); //Broadcast A32 to A72 to registers mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); i += cs_l; //extract diag a22 from a mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); //Broadcast A43 to A73 to registers mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); i += cs_l; //extract diag a33 from a mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); //Broadcast A54 to A74 to registers mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); i += cs_l; //extract diag a44 from a mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); //Broadcast A65 to A75 to registers mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); i += cs_l; //extract diag a55 from a mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); //Broadcast A76 to register mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); //extract diag a66 from a mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); //extract diag a77 from a mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); k = 0; for (i = 0; i < numCols_b; i+=8) { /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B mat_b_rearr[k][0] = _mm256_mul_ps(mat_b_rearr[k][0], mat_a_diag_inv[0]); //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[k][1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[k][0], mat_b_rearr[k][1]);//d = c - (a*b) mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[k][0], mat_b_rearr[k][2]);//d = c - (a*b) mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[k][0], mat_b_rearr[k][3]);//d = c - (a*b) mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[k][0], mat_b_rearr[k][4]);//d = c - (a*b) mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[k][0], mat_b_rearr[k][5]);//d = c - (a*b) mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[k][0], mat_b_rearr[k][6]);//d = c - (a*b) mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[k][0], mat_b_rearr[k][7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B mat_b_rearr[k][1] = _mm256_mul_ps(mat_b_rearr[k][1], mat_a_diag_inv[1]); //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_rearr[k][1], mat_b_rearr[k][2]);//d = c - (a*b) mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_rearr[k][1], mat_b_rearr[k][3]);//d = c - (a*b) mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_rearr[k][1], mat_b_rearr[k][4]);//d = c - (a*b) mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_rearr[k][1], mat_b_rearr[k][5]);//d = c - (a*b) mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_rearr[k][1], mat_b_rearr[k][6]);//d = c - (a*b) mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_rearr[k][1], mat_b_rearr[k][7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B mat_b_rearr[k][2] = _mm256_mul_ps(mat_b_rearr[k][2], mat_a_diag_inv[2]); //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_rearr[k][2], mat_b_rearr[k][3]);//d = c - (a*b) mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_rearr[k][2], mat_b_rearr[k][4]);//d = c - (a*b) mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_rearr[k][2], mat_b_rearr[k][5]);//d = c - (a*b) mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_rearr[k][2], mat_b_rearr[k][6]);//d = c - (a*b) mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_rearr[k][2], mat_b_rearr[k][7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B mat_b_rearr[k][3] = _mm256_mul_ps(mat_b_rearr[k][3], mat_a_diag_inv[3]); //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_rearr[k][3], mat_b_rearr[k][4]);//d = c - (a*b) mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_rearr[k][3], mat_b_rearr[k][5]);//d = c - (a*b) mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_rearr[k][3], mat_b_rearr[k][6]);//d = c - (a*b) mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_rearr[k][3], mat_b_rearr[k][7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B mat_b_rearr[k][4] = _mm256_mul_ps(mat_b_rearr[k][4], mat_a_diag_inv[4]); //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_rearr[k][4], mat_b_rearr[k][5]);//d = c - (a*b) mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_rearr[k][4], mat_b_rearr[k][6]);//d = c - (a*b) mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_rearr[k][4], mat_b_rearr[k][7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B mat_b_rearr[k][5] = _mm256_mul_ps(mat_b_rearr[k][5], mat_a_diag_inv[5]); //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_rearr[k][5], mat_b_rearr[k][6]);//d = c - (a*b) mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_rearr[k][5], mat_b_rearr[k][7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B mat_b_rearr[k][6] = _mm256_mul_ps(mat_b_rearr[k][6], mat_a_diag_inv[6]); //(Row7): FMA operations of b7 with elements of index (7, 0) mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_rearr[k][6], mat_b_rearr[k][7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B mat_b_rearr[k][7] = _mm256_mul_ps(mat_b_rearr[k][7], mat_a_diag_inv[7]); //////////////////////////////////////////////////////////////////////////////// //Store the computed B columns _mm256_storeu_ps((float *)ptr_b_dup + i, mat_b_rearr[k][0]); _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b) + i), mat_b_rearr[k][1]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i), mat_b_rearr[k][2]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i), mat_b_rearr[k][3]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i), mat_b_rearr[k][4]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i), mat_b_rearr[k][5]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i), mat_b_rearr[k][6]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i), mat_b_rearr[k][7]); k++; } } ///////////////////loop ends ///////////////////// } static void trsm_XAtB_block_allSmallSizedMatrices_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b) { //float ones = 1.0; int i, i1, i2, i3, i4, j, k, l; int cs_b_offset[7]; int cs_l_offset[7]; float *ptr_b_dup; //57 number of ymm(256 bits) registers used __m256 mat_b_col[8]; __m256 mat_b_rearr[16][8]; //__m256 mat_a_cols_rearr[8]; __m256 mat_a_blk_elems[64]; //__m256 mat_a_diag_inv[8]; //__m256 reciprocal_diags[2]; // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // //L matrix offsets cs_l_offset[0] = (cs_l << 1); cs_l_offset[1] = cs_l + cs_l_offset[0]; cs_l_offset[2] = (cs_l << 2); cs_l_offset[3] = cs_l + cs_l_offset[2]; cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; cs_l_offset[5] = cs_l + cs_l_offset[4]; cs_l_offset[6] = (cs_l_offset[5] + cs_l); cs_b_offset[0] = (cs_b << 1); cs_b_offset[1] = cs_b + cs_b_offset[0]; cs_b_offset[2] = (cs_b << 2); cs_b_offset[3] = cs_b + cs_b_offset[2]; cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; cs_b_offset[5] = cs_b + cs_b_offset[4]; cs_b_offset[6] = (cs_b_offset[5] + cs_b); //Broadcast A10 to A70 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); //Broadcast A21 to A71 to registers mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); //Broadcast A32 to A72 to registers mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); //Broadcast A43 to A73 to registers mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); //Broadcast A54 to A74 to registers mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); //Broadcast A65 to A75 to registers mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); //Broadcast A76 to register mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); /***************** first set of 8 rows of B processing starts *****************/ ptr_b_dup = ptr_b; i = 0; for (j = 0; j < numCols_b; j += 8) { /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A //read 8x8 block of B into registers mat_b_rearr[0][0] = _mm256_loadu_ps((float const *)ptr_b + i); mat_b_rearr[1][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); mat_b_rearr[2][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); mat_b_rearr[3][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); mat_b_rearr[4][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); mat_b_rearr[5][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); mat_b_rearr[6][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); mat_b_rearr[7][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); //(Row0) mat_b_col[0] = mat_b_rearr[0][0]; //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[1][0]);//d = c - (a*b) mat_b_rearr[2][0] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[2][0]);//d = c - (a*b) mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[3][0]);//d = c - (a*b) mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[4][0]);//d = c - (a*b) mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[5][0]);//d = c - (a*b) mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[6][0]);//d = c - (a*b) mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[7][0]);//d = c - (a*b) //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[2][0]);//d = c - (a*b) mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[3][0]);//d = c - (a*b) mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[4][0]);//d = c - (a*b) mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[5][0]);//d = c - (a*b) mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[6][0]);//d = c - (a*b) mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[7][0]);//d = c - (a*b) //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[2], mat_b_rearr[3][0]);//d = c - (a*b) mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[2], mat_b_rearr[4][0]);//d = c - (a*b) mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[2], mat_b_rearr[5][0]);//d = c - (a*b) mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[6][0]);//d = c - (a*b) mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[7][0]);//d = c - (a*b) //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[3], mat_b_rearr[4][0]);//d = c - (a*b) mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[3], mat_b_rearr[5][0]);//d = c - (a*b) mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[3], mat_b_rearr[6][0]);//d = c - (a*b) mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[3], mat_b_rearr[7][0]);//d = c - (a*b) //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[4], mat_b_rearr[5][0]);//d = c - (a*b) mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[4], mat_b_rearr[6][0]);//d = c - (a*b) mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[4], mat_b_rearr[7][0]);//d = c - (a*b) //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[5], mat_b_rearr[6][0]);//d = c - (a*b) mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[5], mat_b_rearr[7][0]);//d = c - (a*b) //(Row7): FMA operations of b7 with elements of index (7, 0) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[6], mat_b_rearr[7][0]);//d = c - (a*b) //////////////////////////////////////////////////////////////////////////////// //Store the computed B columns _mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]); _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]); //i += cs_b_offset[6]; //ptr_b_dup += cs_b_offset[6]; i += 8; ptr_b_dup += 8; } //c = 0; /***************** first set of 8 cols of B processing done *****************/ ptr_b_dup = ptr_b; i3 = 0; i1 = 0; //Start loop for cols of B to be processed in size of blk_width for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row { ptr_l += 8; //ptr_b += j; //ptr_b_dup += 8; ptr_b_dup += cs_b_offset[6]; i1 += cs_b_offset[6]; i3 += cs_l_offset[6]; i = 0; i2 = 0; for (k = 0; k < numCols_b; k += 8) { i = i1 + k; //Read 8 cols of B columns of Block-to-be-solved mat_b_rearr[i2][0] = _mm256_loadu_ps((float const *)ptr_b + i); mat_b_rearr[i2][1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); mat_b_rearr[i2][2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); mat_b_rearr[i2][3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); mat_b_rearr[i2][4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); mat_b_rearr[i2][5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); mat_b_rearr[i2][6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); mat_b_rearr[i2][7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); i2++; } i = 0; i2 = 0; for (l = 0; l < j; l += 8) // move across m { //Broadcast A8,0 to A15,0 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); //Broadcast A21 to A71 to registers mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i)); mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 1)); mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 2)); mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 3)); mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 4)); mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 5)); mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 6)); mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 7)); //Broadcast A8,2 to A15,2 to registers mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i)); mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 1)); mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 2)); mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 3)); mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 4)); mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 5)); mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 6)); mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 7)); //Broadcast A8,3 to A15,3 to registers mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i)); mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 1)); mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 2)); mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 3)); mat_a_blk_elems[28] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 4)); mat_a_blk_elems[29] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 5)); mat_a_blk_elems[30] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 6)); mat_a_blk_elems[31] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 7)); // _mm256_permute2f128_ps() //Broadcast A8,4 to A15,4 to registers mat_a_blk_elems[32] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i)); mat_a_blk_elems[33] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 1)); mat_a_blk_elems[34] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 2)); mat_a_blk_elems[35] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 3)); mat_a_blk_elems[36] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 4)); mat_a_blk_elems[37] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 5)); mat_a_blk_elems[38] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 6)); mat_a_blk_elems[39] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 7)); //Broadcast A8,5 to A15,5 to registers mat_a_blk_elems[40] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i)); mat_a_blk_elems[41] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 1)); mat_a_blk_elems[42] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 2)); mat_a_blk_elems[43] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 3)); mat_a_blk_elems[44] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 4)); mat_a_blk_elems[45] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 5)); mat_a_blk_elems[46] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 6)); mat_a_blk_elems[47] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 7)); //Broadcast A8,6 to A15,6 to registers mat_a_blk_elems[48] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i)); mat_a_blk_elems[49] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 1)); mat_a_blk_elems[50] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 2)); mat_a_blk_elems[51] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 3)); mat_a_blk_elems[52] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 4)); mat_a_blk_elems[53] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 5)); mat_a_blk_elems[54] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 6)); mat_a_blk_elems[55] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 7)); //Broadcast A8,7 to A15,7 to registers mat_a_blk_elems[56] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i)); mat_a_blk_elems[57] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 1)); mat_a_blk_elems[58] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 2)); mat_a_blk_elems[59] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 3)); mat_a_blk_elems[60] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 4)); mat_a_blk_elems[61] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 5)); mat_a_blk_elems[62] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 6)); mat_a_blk_elems[63] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 7)); i += cs_l_offset[6]; for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) { /////////////////// Partial Lower 8x8 block trsm of B i4 = i2 + k; //Read current 8 cols of B columns from specified 8x8 current-block of B mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4); mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); i4 = k >> 3; //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[i4][0]);//d = c - (a*b) mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[i4][1]);//d = c - (a*b) mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[i4][2]);//d = c - (a*b) mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[i4][3]);//d = c - (a*b) mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[i4][4]);//d = c - (a*b) mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[i4][5]);//d = c - (a*b) mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[i4][6]);//d = c - (a*b) mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[i4][7]);//d = c - (a*b) //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[i4][0]);//d = c - (a*b) mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[i4][1]);//d = c - (a*b) mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[i4][2]);//d = c - (a*b) mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[i4][3]);//d = c - (a*b) mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[i4][4]);//d = c - (a*b) mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[1], mat_b_rearr[i4][5]);//d = c - (a*b) mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[1], mat_b_rearr[i4][6]);//d = c - (a*b) mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[1], mat_b_rearr[i4][7]);//d = c - (a*b) //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[i4][0]);//d = c - (a*b) mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[i4][1]);//d = c - (a*b) mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[2], mat_b_rearr[i4][2]);//d = c - (a*b) mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[2], mat_b_rearr[i4][3]);//d = c - (a*b) mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[2], mat_b_rearr[i4][4]);//d = c - (a*b) mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[2], mat_b_rearr[i4][5]);//d = c - (a*b) mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[2], mat_b_rearr[i4][6]);//d = c - (a*b) mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[2], mat_b_rearr[i4][7]);//d = c - (a*b) //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[3], mat_b_rearr[i4][0]);//d = c - (a*b) mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[3], mat_b_rearr[i4][1]);//d = c - (a*b) mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[3], mat_b_rearr[i4][2]);//d = c - (a*b) mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[3], mat_b_rearr[i4][3]);//d = c - (a*b) mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[28], mat_b_col[3], mat_b_rearr[i4][4]);//d = c - (a*b) mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[29], mat_b_col[3], mat_b_rearr[i4][5]);//d = c - (a*b) mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[30], mat_b_col[3], mat_b_rearr[i4][6]);//d = c - (a*b) mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[31], mat_b_col[3], mat_b_rearr[i4][7]);//d = c - (a*b) //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[32], mat_b_col[4], mat_b_rearr[i4][0]);//d = c - (a*b) mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[33], mat_b_col[4], mat_b_rearr[i4][1]);//d = c - (a*b) mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[34], mat_b_col[4], mat_b_rearr[i4][2]);//d = c - (a*b) mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[35], mat_b_col[4], mat_b_rearr[i4][3]);//d = c - (a*b) mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[36], mat_b_col[4], mat_b_rearr[i4][4]);//d = c - (a*b) mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[37], mat_b_col[4], mat_b_rearr[i4][5]);//d = c - (a*b) mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[38], mat_b_col[4], mat_b_rearr[i4][6]);//d = c - (a*b) mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[39], mat_b_col[4], mat_b_rearr[i4][7]);//d = c - (a*b) //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[40], mat_b_col[5], mat_b_rearr[i4][0]);//d = c - (a*b) mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[41], mat_b_col[5], mat_b_rearr[i4][1]);//d = c - (a*b) mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[42], mat_b_col[5], mat_b_rearr[i4][2]);//d = c - (a*b) mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[43], mat_b_col[5], mat_b_rearr[i4][3]);//d = c - (a*b) mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[44], mat_b_col[5], mat_b_rearr[i4][4]);//d = c - (a*b) mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[45], mat_b_col[5], mat_b_rearr[i4][5]);//d = c - (a*b) mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[46], mat_b_col[5], mat_b_rearr[i4][6]);//d = c - (a*b) mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[47], mat_b_col[5], mat_b_rearr[i4][7]);//d = c - (a*b) //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[48], mat_b_col[6], mat_b_rearr[i4][0]);//d = c - (a*b) mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[49], mat_b_col[6], mat_b_rearr[i4][1]);//d = c - (a*b) mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[50], mat_b_col[6], mat_b_rearr[i4][2]);//d = c - (a*b) mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[51], mat_b_col[6], mat_b_rearr[i4][3]);//d = c - (a*b) mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[52], mat_b_col[6], mat_b_rearr[i4][4]);//d = c - (a*b) mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[53], mat_b_col[6], mat_b_rearr[i4][5]);//d = c - (a*b) mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[54], mat_b_col[6], mat_b_rearr[i4][6]);//d = c - (a*b) mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[55], mat_b_col[6], mat_b_rearr[i4][7]);//d = c - (a*b) //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[56], mat_b_col[7], mat_b_rearr[i4][0]);//d = c - (a*b) mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[57], mat_b_col[7], mat_b_rearr[i4][1]);//d = c - (a*b) mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[58], mat_b_col[7], mat_b_rearr[i4][2]);//d = c - (a*b) mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[59], mat_b_col[7], mat_b_rearr[i4][3]);//d = c - (a*b) mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[60], mat_b_col[7], mat_b_rearr[i4][4]);//d = c - (a*b) mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[61], mat_b_col[7], mat_b_rearr[i4][5]);//d = c - (a*b) mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[62], mat_b_col[7], mat_b_rearr[i4][6]);//d = c - (a*b) mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[63], mat_b_col[7], mat_b_rearr[i4][7]);//d = c - (a*b) //end loop of cols } i2 += cs_b_offset[6]; } //Broadcast A10 to A70 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); i += cs_l; //Broadcast A21 to A71 to registers mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); i += cs_l; //Broadcast A32 to A72 to registers mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); i += cs_l; //Broadcast A43 to A73 to registers mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); i += cs_l; //Broadcast A54 to A74 to registers mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); i += cs_l; //Broadcast A65 to A75 to registers mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); i += cs_l; //Broadcast A76 to register mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); k = 0; for (i = 0; i < numCols_b; i+=8) { /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A //(Row0): already done //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[k][1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[k][0], mat_b_rearr[k][1]);//d = c - (a*b) mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[k][0], mat_b_rearr[k][2]);//d = c - (a*b) mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[k][0], mat_b_rearr[k][3]);//d = c - (a*b) mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[k][0], mat_b_rearr[k][4]);//d = c - (a*b) mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[k][0], mat_b_rearr[k][5]);//d = c - (a*b) mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[k][0], mat_b_rearr[k][6]);//d = c - (a*b) mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[k][0], mat_b_rearr[k][7]);//d = c - (a*b) //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_rearr[k][1], mat_b_rearr[k][2]);//d = c - (a*b) mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_rearr[k][1], mat_b_rearr[k][3]);//d = c - (a*b) mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_rearr[k][1], mat_b_rearr[k][4]);//d = c - (a*b) mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_rearr[k][1], mat_b_rearr[k][5]);//d = c - (a*b) mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_rearr[k][1], mat_b_rearr[k][6]);//d = c - (a*b) mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_rearr[k][1], mat_b_rearr[k][7]);//d = c - (a*b) //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_rearr[k][2], mat_b_rearr[k][3]);//d = c - (a*b) mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_rearr[k][2], mat_b_rearr[k][4]);//d = c - (a*b) mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_rearr[k][2], mat_b_rearr[k][5]);//d = c - (a*b) mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_rearr[k][2], mat_b_rearr[k][6]);//d = c - (a*b) mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_rearr[k][2], mat_b_rearr[k][7]);//d = c - (a*b) //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_rearr[k][3], mat_b_rearr[k][4]);//d = c - (a*b) mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_rearr[k][3], mat_b_rearr[k][5]);//d = c - (a*b) mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_rearr[k][3], mat_b_rearr[k][6]);//d = c - (a*b) mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_rearr[k][3], mat_b_rearr[k][7]);//d = c - (a*b) //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_rearr[k][4], mat_b_rearr[k][5]);//d = c - (a*b) mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_rearr[k][4], mat_b_rearr[k][6]);//d = c - (a*b) mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_rearr[k][4], mat_b_rearr[k][7]);//d = c - (a*b) //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_rearr[k][5], mat_b_rearr[k][6]);//d = c - (a*b) mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_rearr[k][5], mat_b_rearr[k][7]);//d = c - (a*b) //(Row7): FMA operations of b7 with elements of index (7, 0) mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_rearr[k][6], mat_b_rearr[k][7]);//d = c - (a*b) //////////////////////////////////////////////////////////////////////////////// //Store the computed B columns _mm256_storeu_ps((float *)ptr_b_dup + i, mat_b_rearr[k][0]); _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b) + i), mat_b_rearr[k][1]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i), mat_b_rearr[k][2]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i), mat_b_rearr[k][3]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i), mat_b_rearr[k][4]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i), mat_b_rearr[k][5]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i), mat_b_rearr[k][6]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i), mat_b_rearr[k][7]); //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); k++; } } ///////////////////loop ends ///////////////////// } static void trsm_XAtB_block_allSmallSizedMatrices_alpha_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha) { //float ones = 1.0; int i, i1, i2, i3, i4, j, k, l; int cs_b_offset[7]; int cs_l_offset[7]; float *ptr_b_dup; //57 number of ymm(256 bits) registers used __m256 mat_b_col[8]; __m256 mat_b_rearr[16][8]; //__m256 mat_a_cols_rearr[8]; __m256 mat_a_blk_elems[64]; //__m256 mat_a_diag_inv[8]; //__m256 reciprocal_diags[2]; __m256 alphaReg; alphaReg = _mm256_broadcast_ss((float const *)&alpha); // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // //L matrix offsets cs_l_offset[0] = (cs_l << 1); cs_l_offset[1] = cs_l + cs_l_offset[0]; cs_l_offset[2] = (cs_l << 2); cs_l_offset[3] = cs_l + cs_l_offset[2]; cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; cs_l_offset[5] = cs_l + cs_l_offset[4]; cs_l_offset[6] = (cs_l_offset[5] + cs_l); cs_b_offset[0] = (cs_b << 1); cs_b_offset[1] = cs_b + cs_b_offset[0]; cs_b_offset[2] = (cs_b << 2); cs_b_offset[3] = cs_b + cs_b_offset[2]; cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; cs_b_offset[5] = cs_b + cs_b_offset[4]; cs_b_offset[6] = (cs_b_offset[5] + cs_b); //Broadcast A10 to A70 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); //Broadcast A21 to A71 to registers mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); //Broadcast A32 to A72 to registers mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); //Broadcast A43 to A73 to registers mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); //Broadcast A54 to A74 to registers mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); //Broadcast A65 to A75 to registers mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); //Broadcast A76 to register mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); /***************** first set of 8 rows of B processing starts *****************/ ptr_b_dup = ptr_b; i = 0; for (j = 0; j < numCols_b; j += 8) { /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A //read 8x8 block of B into registers mat_b_rearr[0][0] = _mm256_loadu_ps((float const *)ptr_b + i); mat_b_rearr[1][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); mat_b_rearr[2][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); mat_b_rearr[3][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); mat_b_rearr[4][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); mat_b_rearr[5][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); mat_b_rearr[6][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); mat_b_rearr[7][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); mat_b_rearr[0][0] = _mm256_mul_ps(mat_b_rearr[0][0], alphaReg); mat_b_rearr[1][0] = _mm256_mul_ps(mat_b_rearr[1][0], alphaReg); mat_b_rearr[2][0] = _mm256_mul_ps(mat_b_rearr[2][0], alphaReg); mat_b_rearr[3][0] = _mm256_mul_ps(mat_b_rearr[3][0], alphaReg); mat_b_rearr[4][0] = _mm256_mul_ps(mat_b_rearr[4][0], alphaReg); mat_b_rearr[5][0] = _mm256_mul_ps(mat_b_rearr[5][0], alphaReg); mat_b_rearr[6][0] = _mm256_mul_ps(mat_b_rearr[6][0], alphaReg); mat_b_rearr[7][0] = _mm256_mul_ps(mat_b_rearr[7][0], alphaReg); //(Row0) mat_b_col[0] = mat_b_rearr[0][0]; //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[1][0]);//d = c - (a*b) mat_b_rearr[2][0] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[2][0]);//d = c - (a*b) mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[3][0]);//d = c - (a*b) mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[4][0]);//d = c - (a*b) mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[5][0]);//d = c - (a*b) mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[6][0]);//d = c - (a*b) mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[7][0]);//d = c - (a*b) //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[2][0]);//d = c - (a*b) mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[3][0]);//d = c - (a*b) mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[4][0]);//d = c - (a*b) mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[5][0]);//d = c - (a*b) mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[6][0]);//d = c - (a*b) mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[7][0]);//d = c - (a*b) //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[2], mat_b_rearr[3][0]);//d = c - (a*b) mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[2], mat_b_rearr[4][0]);//d = c - (a*b) mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[2], mat_b_rearr[5][0]);//d = c - (a*b) mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[6][0]);//d = c - (a*b) mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[7][0]);//d = c - (a*b) //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[3], mat_b_rearr[4][0]);//d = c - (a*b) mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[3], mat_b_rearr[5][0]);//d = c - (a*b) mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[3], mat_b_rearr[6][0]);//d = c - (a*b) mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[3], mat_b_rearr[7][0]);//d = c - (a*b) //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[4], mat_b_rearr[5][0]);//d = c - (a*b) mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[4], mat_b_rearr[6][0]);//d = c - (a*b) mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[4], mat_b_rearr[7][0]);//d = c - (a*b) //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[5], mat_b_rearr[6][0]);//d = c - (a*b) mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[5], mat_b_rearr[7][0]);//d = c - (a*b) //(Row7): FMA operations of b7 with elements of index (7, 0) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[6], mat_b_rearr[7][0]);//d = c - (a*b) //////////////////////////////////////////////////////////////////////////////// //Store the computed B columns _mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]); _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]); //i += cs_b_offset[6]; //ptr_b_dup += cs_b_offset[6]; i += 8; ptr_b_dup += 8; } //c = 0; /***************** first set of 8 cols of B processing done *****************/ ptr_b_dup = ptr_b; i3 = 0; i1 = 0; //Start loop for cols of B to be processed in size of blk_width for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row { ptr_l += 8; //ptr_b += j; //ptr_b_dup += 8; ptr_b_dup += cs_b_offset[6]; i1 += cs_b_offset[6]; i3 += cs_l_offset[6]; i = 0; i2 = 0; for (k = 0; k < numCols_b; k += 8) { i = i1 + k; //Read 8 cols of B columns of Block-to-be-solved mat_b_rearr[i2][0] = _mm256_loadu_ps((float const *)ptr_b + i); mat_b_rearr[i2][1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); mat_b_rearr[i2][2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); mat_b_rearr[i2][3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); mat_b_rearr[i2][4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); mat_b_rearr[i2][5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); mat_b_rearr[i2][6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); mat_b_rearr[i2][7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); mat_b_rearr[i2][0] = _mm256_mul_ps(mat_b_rearr[i2][0], alphaReg); mat_b_rearr[i2][1] = _mm256_mul_ps(mat_b_rearr[i2][1], alphaReg); mat_b_rearr[i2][2] = _mm256_mul_ps(mat_b_rearr[i2][2], alphaReg); mat_b_rearr[i2][3] = _mm256_mul_ps(mat_b_rearr[i2][3], alphaReg); mat_b_rearr[i2][4] = _mm256_mul_ps(mat_b_rearr[i2][4], alphaReg); mat_b_rearr[i2][5] = _mm256_mul_ps(mat_b_rearr[i2][5], alphaReg); mat_b_rearr[i2][6] = _mm256_mul_ps(mat_b_rearr[i2][6], alphaReg); mat_b_rearr[i2][7] = _mm256_mul_ps(mat_b_rearr[i2][7], alphaReg); i2++; } i = 0; i2 = 0; for (l = 0; l < j; l += 8) // move across m { //Broadcast A8,0 to A15,0 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); //Broadcast A21 to A71 to registers mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i)); mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 1)); mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 2)); mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 3)); mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 4)); mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 5)); mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 6)); mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 7)); //Broadcast A8,2 to A15,2 to registers mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i)); mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 1)); mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 2)); mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 3)); mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 4)); mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 5)); mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 6)); mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 7)); //Broadcast A8,3 to A15,3 to registers mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i)); mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 1)); mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 2)); mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 3)); mat_a_blk_elems[28] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 4)); mat_a_blk_elems[29] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 5)); mat_a_blk_elems[30] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 6)); mat_a_blk_elems[31] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 7)); // _mm256_permute2f128_ps() //Broadcast A8,4 to A15,4 to registers mat_a_blk_elems[32] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i)); mat_a_blk_elems[33] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 1)); mat_a_blk_elems[34] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 2)); mat_a_blk_elems[35] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 3)); mat_a_blk_elems[36] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 4)); mat_a_blk_elems[37] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 5)); mat_a_blk_elems[38] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 6)); mat_a_blk_elems[39] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 7)); //Broadcast A8,5 to A15,5 to registers mat_a_blk_elems[40] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i)); mat_a_blk_elems[41] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 1)); mat_a_blk_elems[42] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 2)); mat_a_blk_elems[43] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 3)); mat_a_blk_elems[44] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 4)); mat_a_blk_elems[45] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 5)); mat_a_blk_elems[46] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 6)); mat_a_blk_elems[47] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 7)); //Broadcast A8,6 to A15,6 to registers mat_a_blk_elems[48] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i)); mat_a_blk_elems[49] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 1)); mat_a_blk_elems[50] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 2)); mat_a_blk_elems[51] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 3)); mat_a_blk_elems[52] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 4)); mat_a_blk_elems[53] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 5)); mat_a_blk_elems[54] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 6)); mat_a_blk_elems[55] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 7)); //Broadcast A8,7 to A15,7 to registers mat_a_blk_elems[56] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i)); mat_a_blk_elems[57] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 1)); mat_a_blk_elems[58] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 2)); mat_a_blk_elems[59] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 3)); mat_a_blk_elems[60] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 4)); mat_a_blk_elems[61] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 5)); mat_a_blk_elems[62] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 6)); mat_a_blk_elems[63] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 7)); i += cs_l_offset[6]; for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) { /////////////////// Partial Lower 8x8 block trsm of B i4 = i2 + k; //Read current 8 cols of B columns from specified 8x8 current-block of B mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4); mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); i4 = k >> 3; //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[i4][0]);//d = c - (a*b) mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[i4][1]);//d = c - (a*b) mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[i4][2]);//d = c - (a*b) mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[i4][3]);//d = c - (a*b) mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[i4][4]);//d = c - (a*b) mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[i4][5]);//d = c - (a*b) mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[i4][6]);//d = c - (a*b) mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[i4][7]);//d = c - (a*b) //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[i4][0]);//d = c - (a*b) mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[i4][1]);//d = c - (a*b) mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[i4][2]);//d = c - (a*b) mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[i4][3]);//d = c - (a*b) mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[i4][4]);//d = c - (a*b) mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[1], mat_b_rearr[i4][5]);//d = c - (a*b) mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[1], mat_b_rearr[i4][6]);//d = c - (a*b) mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[1], mat_b_rearr[i4][7]);//d = c - (a*b) //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[i4][0]);//d = c - (a*b) mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[i4][1]);//d = c - (a*b) mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[2], mat_b_rearr[i4][2]);//d = c - (a*b) mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[2], mat_b_rearr[i4][3]);//d = c - (a*b) mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[2], mat_b_rearr[i4][4]);//d = c - (a*b) mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[2], mat_b_rearr[i4][5]);//d = c - (a*b) mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[2], mat_b_rearr[i4][6]);//d = c - (a*b) mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[2], mat_b_rearr[i4][7]);//d = c - (a*b) //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[3], mat_b_rearr[i4][0]);//d = c - (a*b) mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[3], mat_b_rearr[i4][1]);//d = c - (a*b) mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[3], mat_b_rearr[i4][2]);//d = c - (a*b) mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[3], mat_b_rearr[i4][3]);//d = c - (a*b) mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[28], mat_b_col[3], mat_b_rearr[i4][4]);//d = c - (a*b) mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[29], mat_b_col[3], mat_b_rearr[i4][5]);//d = c - (a*b) mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[30], mat_b_col[3], mat_b_rearr[i4][6]);//d = c - (a*b) mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[31], mat_b_col[3], mat_b_rearr[i4][7]);//d = c - (a*b) //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[32], mat_b_col[4], mat_b_rearr[i4][0]);//d = c - (a*b) mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[33], mat_b_col[4], mat_b_rearr[i4][1]);//d = c - (a*b) mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[34], mat_b_col[4], mat_b_rearr[i4][2]);//d = c - (a*b) mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[35], mat_b_col[4], mat_b_rearr[i4][3]);//d = c - (a*b) mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[36], mat_b_col[4], mat_b_rearr[i4][4]);//d = c - (a*b) mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[37], mat_b_col[4], mat_b_rearr[i4][5]);//d = c - (a*b) mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[38], mat_b_col[4], mat_b_rearr[i4][6]);//d = c - (a*b) mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[39], mat_b_col[4], mat_b_rearr[i4][7]);//d = c - (a*b) //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[40], mat_b_col[5], mat_b_rearr[i4][0]);//d = c - (a*b) mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[41], mat_b_col[5], mat_b_rearr[i4][1]);//d = c - (a*b) mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[42], mat_b_col[5], mat_b_rearr[i4][2]);//d = c - (a*b) mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[43], mat_b_col[5], mat_b_rearr[i4][3]);//d = c - (a*b) mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[44], mat_b_col[5], mat_b_rearr[i4][4]);//d = c - (a*b) mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[45], mat_b_col[5], mat_b_rearr[i4][5]);//d = c - (a*b) mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[46], mat_b_col[5], mat_b_rearr[i4][6]);//d = c - (a*b) mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[47], mat_b_col[5], mat_b_rearr[i4][7]);//d = c - (a*b) //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[48], mat_b_col[6], mat_b_rearr[i4][0]);//d = c - (a*b) mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[49], mat_b_col[6], mat_b_rearr[i4][1]);//d = c - (a*b) mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[50], mat_b_col[6], mat_b_rearr[i4][2]);//d = c - (a*b) mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[51], mat_b_col[6], mat_b_rearr[i4][3]);//d = c - (a*b) mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[52], mat_b_col[6], mat_b_rearr[i4][4]);//d = c - (a*b) mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[53], mat_b_col[6], mat_b_rearr[i4][5]);//d = c - (a*b) mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[54], mat_b_col[6], mat_b_rearr[i4][6]);//d = c - (a*b) mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[55], mat_b_col[6], mat_b_rearr[i4][7]);//d = c - (a*b) //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[56], mat_b_col[7], mat_b_rearr[i4][0]);//d = c - (a*b) mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[57], mat_b_col[7], mat_b_rearr[i4][1]);//d = c - (a*b) mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[58], mat_b_col[7], mat_b_rearr[i4][2]);//d = c - (a*b) mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[59], mat_b_col[7], mat_b_rearr[i4][3]);//d = c - (a*b) mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[60], mat_b_col[7], mat_b_rearr[i4][4]);//d = c - (a*b) mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[61], mat_b_col[7], mat_b_rearr[i4][5]);//d = c - (a*b) mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[62], mat_b_col[7], mat_b_rearr[i4][6]);//d = c - (a*b) mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[63], mat_b_col[7], mat_b_rearr[i4][7]);//d = c - (a*b) //end loop of cols } i2 += cs_b_offset[6]; } //Broadcast A10 to A70 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); i += cs_l; //Broadcast A21 to A71 to registers mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); i += cs_l; //Broadcast A32 to A72 to registers mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); i += cs_l; //Broadcast A43 to A73 to registers mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); i += cs_l; //Broadcast A54 to A74 to registers mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); i += cs_l; //Broadcast A65 to A75 to registers mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); i += cs_l; //Broadcast A76 to register mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); k = 0; for (i = 0; i < numCols_b; i+=8) { /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A //(Row0): already done //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[k][1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[k][0], mat_b_rearr[k][1]);//d = c - (a*b) mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[k][0], mat_b_rearr[k][2]);//d = c - (a*b) mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[k][0], mat_b_rearr[k][3]);//d = c - (a*b) mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[k][0], mat_b_rearr[k][4]);//d = c - (a*b) mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[k][0], mat_b_rearr[k][5]);//d = c - (a*b) mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[k][0], mat_b_rearr[k][6]);//d = c - (a*b) mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[k][0], mat_b_rearr[k][7]);//d = c - (a*b) //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_rearr[k][1], mat_b_rearr[k][2]);//d = c - (a*b) mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_rearr[k][1], mat_b_rearr[k][3]);//d = c - (a*b) mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_rearr[k][1], mat_b_rearr[k][4]);//d = c - (a*b) mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_rearr[k][1], mat_b_rearr[k][5]);//d = c - (a*b) mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_rearr[k][1], mat_b_rearr[k][6]);//d = c - (a*b) mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_rearr[k][1], mat_b_rearr[k][7]);//d = c - (a*b) //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_rearr[k][2], mat_b_rearr[k][3]);//d = c - (a*b) mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_rearr[k][2], mat_b_rearr[k][4]);//d = c - (a*b) mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_rearr[k][2], mat_b_rearr[k][5]);//d = c - (a*b) mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_rearr[k][2], mat_b_rearr[k][6]);//d = c - (a*b) mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_rearr[k][2], mat_b_rearr[k][7]);//d = c - (a*b) //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_rearr[k][3], mat_b_rearr[k][4]);//d = c - (a*b) mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_rearr[k][3], mat_b_rearr[k][5]);//d = c - (a*b) mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_rearr[k][3], mat_b_rearr[k][6]);//d = c - (a*b) mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_rearr[k][3], mat_b_rearr[k][7]);//d = c - (a*b) //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_rearr[k][4], mat_b_rearr[k][5]);//d = c - (a*b) mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_rearr[k][4], mat_b_rearr[k][6]);//d = c - (a*b) mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_rearr[k][4], mat_b_rearr[k][7]);//d = c - (a*b) //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_rearr[k][5], mat_b_rearr[k][6]);//d = c - (a*b) mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_rearr[k][5], mat_b_rearr[k][7]);//d = c - (a*b) //(Row7): FMA operations of b7 with elements of index (7, 0) mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_rearr[k][6], mat_b_rearr[k][7]);//d = c - (a*b) //////////////////////////////////////////////////////////////////////////////// //Store the computed B columns _mm256_storeu_ps((float *)ptr_b_dup + i, mat_b_rearr[k][0]); _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b) + i), mat_b_rearr[k][1]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i), mat_b_rearr[k][2]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i), mat_b_rearr[k][3]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i), mat_b_rearr[k][4]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i), mat_b_rearr[k][5]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i), mat_b_rearr[k][6]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i), mat_b_rearr[k][7]); //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); k++; } } ///////////////////loop ends ///////////////////// } #endif //OPT_CACHE_BLOCKING_L1 //////////////////////////// AutX=B /////////////////////// static void trsm_AutXB_block_allSmallSizedMatrices(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b) { float ones = 1.0; int i, i1, i2, i3, i4, j, k, l, r; int cs_b_offset[7]; int cs_l_offset[7]; float *ptr_b_dup, *ptr_l_dup; //57 number of ymm(256 bits) registers used __m256 mat_b_col[8]; __m256 mat_b_rearr[8]; __m256 mat_a_blk_elems[8]; __m256 mat_a_diag_inv[8]; __m256 reciprocal_diags[2]; reciprocal_diags[0] = _mm256_broadcast_ss((float const *)(&ones)); // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // //L matrix offsets cs_l_offset[0] = (cs_l << 1); cs_l_offset[1] = cs_l + cs_l_offset[0]; cs_l_offset[2] = (cs_l << 2); cs_l_offset[3] = cs_l + cs_l_offset[2]; cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; cs_l_offset[5] = cs_l + cs_l_offset[4]; cs_l_offset[6] = (cs_l_offset[5] + cs_l); //read diag elems of L 16x16 block mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l); mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + cs_l); mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[0]); mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[1]); mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[2]); mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[3]); mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[4]); mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[5]); cs_b_offset[0] = (cs_b << 1); cs_b_offset[1] = cs_b + cs_b_offset[0]; cs_b_offset[2] = (cs_b << 2); cs_b_offset[3] = cs_b + cs_b_offset[2]; cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; cs_b_offset[5] = cs_b + cs_b_offset[4]; cs_b_offset[6] = (cs_b_offset[5] + cs_b); reciprocal_diags[1] = reciprocal_diags[0]; //pack first 8 diags together mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1 mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3 mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5 mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7 mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 //reciprocal of diagnal elements 0,1,2,3,4,5,6,7 reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); #if 0 //Broadcast A10 to A70 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); //Broadcast A21 to A71 to registers mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); //Broadcast A32 to A72 to registers mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); //Broadcast A43 to A73 to registers mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); //Broadcast A54 to A74 to registers mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); //Broadcast A65 to A75 to registers mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); //Broadcast A76 to register mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); #endif //extract diag a00 from a mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); //mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_diag_inv[0], mat_a_diag_inv[0]); //extract diag a11 from a mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); //extract diag a22 from a mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); //extract diag a33 from a mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); //extract diag a44 from a mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); //extract diag a55 from a mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); //extract diag a66 from a mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); //extract diag a77 from a mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); /***************** first set of 8 rows of B processing starts *****************/ ptr_b_dup = ptr_b; i = 0; for (j = 0; j < numCols_b; j += 8) { /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A //read 8x8 block of B into registers mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); /* transpose steps start */ ////unpacklow//// mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); //Rearrange low elements #if REARRANGE_SHFL == 1 mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); #else mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); #endif //Merge rearranged low elements into complete rows mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); ////unpackhigh//// mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); //Rearrange high elements #if REARRANGE_SHFL == 1 mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); #else mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); #endif //Merge rearranged high elements into complete rows mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); /* transpose steps end */ //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], mat_a_diag_inv[0]); mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0])); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3])); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4])); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5])); //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b) mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b) mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b) mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b) mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], mat_a_diag_inv[1]); mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[0])); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[1])); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[2])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[3])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[4])); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[5])); //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b) mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b) mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b) mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], mat_a_diag_inv[2]); mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[1])); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[2])); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[3])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[4])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[5])); //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b) mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b) mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], mat_a_diag_inv[3]); mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[2])); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[3])); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[4])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[5])); //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b) mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], mat_a_diag_inv[4]); mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[3])); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[4])); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[5])); //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], mat_a_diag_inv[5]); mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[4])); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[5])); //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], mat_a_diag_inv[6]); mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 6 + cs_l_offset[5])); //(Row7): FMA operations of b7 with elements of index (7, 0) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], mat_a_diag_inv[7]); //////////////////////////////////////////////////////////////////////////////// /* transpose steps start */ ////unpacklow//// mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); //Rearrange low elements #if REARRANGE_SHFL == 1 mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); #else mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); #endif //Merge rearranged low elements into complete rows mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); ////unpackhigh//// mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); //Rearrange high elements #if REARRANGE_SHFL == 1 mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); #else mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); #endif //Merge rearranged high elements into complete rows mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); /* transpose steps end */ //Store the computed B columns _mm256_storeu_ps((float *)ptr_b_dup, mat_b_rearr[0]); _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_rearr[1]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_rearr[2]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_rearr[3]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_rearr[4]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_rearr[5]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_rearr[6]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_rearr[7]); i += cs_b_offset[6]; ptr_b_dup += cs_b_offset[6]; //i += 8; //ptr_b_dup += 8; } //c = 0; /***************** first set of 8 cols of B processing done *****************/ ptr_b_dup = ptr_b; i3 = 0; i1 = 0; //Start loop for cols of B to be processed in size of blk_width for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row { ptr_l += cs_l_offset[6]; //Read next 8x8 block of A to get diag elements i3 += 8; mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l + i3); mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l); mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[0]); mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[1]); mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[2]); mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[3]); mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[4]); mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[5]); //pack 8 diags of A together reciprocal_diags[0] = reciprocal_diags[1]; mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1 mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3 mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5 mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7 mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 //reciprocal of diagnal elements of A :- 0,1,2,3,4,5,6,7 reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); //ptr_b += j; //ptr_b_dup += 8; ptr_b_dup += 8; i1 += 8; i = i1; i2 = 0; //extract diag a00 from a mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); //mat_a_diag_inv2[0] = _mm256_unpacklo_ps(mat_a_diag_inv2[0], mat_a_diag_inv2[0]); //extract diag a11 from a mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); //extract diag a22 from a mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); //extract diag a33 from a mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); //extract diag a44 from a mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); //extract diag a55 from a mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); //extract diag a66 from a mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); //extract diag a77 from a mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); for (r = 0; r < numCols_b; r += GEMM_BLK_V1) { #if GEMM_ACCUM_A //Read 8 cols of B columns of Block-to-be-solved mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i); mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); /* transpose steps start */ ////unpacklow//// mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); //Rearrange low elements #if REARRANGE_SHFL == 1 mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); #else mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); #endif //Merge rearranged low elements into complete rows mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); ////unpackhigh//// mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); //Rearrange high elements #if REARRANGE_SHFL == 1 mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); #else mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); #endif //Merge rearranged high elements into complete rows mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); /* transpose steps end */ #endif //i = 0; ptr_l_dup = ptr_l; i4 = i2; for (l = 0; l < j; l += 8) // move across m { //for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) //{ /////////////////// Partial Lower 8x8 block trsm of B //Read current 8 cols of B columns from specified 8x8 current-block of B mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_b + i4); mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); /* transpose steps start */ ////unpacklow//// mat_b_col[0] = _mm256_unpacklo_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]); mat_b_col[1] = _mm256_unpacklo_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]); mat_b_col[2] = _mm256_unpacklo_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]); mat_b_col[3] = _mm256_unpacklo_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]); //Rearrange low elements #if REARRANGE_SHFL == 1 mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); #else mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); #endif //Merge rearranged low elements into complete rows mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); ////unpackhigh//// mat_a_blk_elems[0] = _mm256_unpackhi_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]); mat_a_blk_elems[1] = _mm256_unpackhi_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]); mat_a_blk_elems[2] = _mm256_unpackhi_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]); mat_a_blk_elems[3] = _mm256_unpackhi_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]); //Rearrange high elements #if REARRANGE_SHFL == 1 mat_a_blk_elems[4] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x44); mat_a_blk_elems[5] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xEE); mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x44); mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xEE); #else mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x4E); mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x4E); mat_a_blk_elems[4] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[6], 0xCC); mat_a_blk_elems[5] = _mm256_blend_ps(mat_a_blk_elems[1], mat_a_blk_elems[6], 0x33); mat_a_blk_elems[6] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[7], 0xCC); mat_a_blk_elems[7] = _mm256_blend_ps(mat_a_blk_elems[3], mat_a_blk_elems[7], 0x33); #endif //Merge rearranged high elements into complete rows mat_b_col[2] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x20); mat_b_col[6] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x31); mat_b_col[3] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x20); mat_b_col[7] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x31); /* transpose steps end */ //Broadcast A8,0 to A15,0 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); //i4 = k >> 3; ptr_l_dup++; #if GEMM_ACCUM_A //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]); mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]); mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]); mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]); mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]); mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]); mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]); mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]); #endif //Broadcast A21 to A71 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); ptr_l_dup++; #if GEMM_ACCUM_A //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A8,2 to A15,2 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); ptr_l_dup++; #if GEMM_ACCUM_A //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A8,3 to A15,3 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); ptr_l_dup++; #if GEMM_ACCUM_A //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A8,4 to A15,4 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); ptr_l_dup++; #if GEMM_ACCUM_A //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A8,5 to A15,5 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); ptr_l_dup++; #if GEMM_ACCUM_A //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A8,6 to A15,6 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); ptr_l_dup++; #if GEMM_ACCUM_A //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A8,7 to A15,7 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); ptr_l_dup++; #if GEMM_ACCUM_A //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) #endif //end loop of cols //} //i2 += cs_b_offset[6]; i4 += 8; } //trsm solve k = 0; //for (i2 = 0; i2 < numCols_b; i2 += 8) //{ //i2 = i1 + r; /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A #if !GEMM_ACCUM_A //Read 8 cols of B columns of Block-to-be-solved mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); /* transpose steps start */ ////unpacklow//// mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); //Rearrange low elements #if REARRANGE_SHFL == 1 mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); #else mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); #endif //Merge rearranged low elements into complete rows mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); ////unpackhigh//// mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); //Rearrange high elements #if REARRANGE_SHFL == 1 mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); #else mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); #endif //Merge rearranged high elements into complete rows mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); /* transpose steps end */ #endif //Broadcast A10 to A70 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); //i += cs_l; #if GEMM_ACCUM_A //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); #else mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]); mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); #endif #if GEMM_ACCUM_A mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]); mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]); mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]); mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]); mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]); mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]); mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]); //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A21 to A71 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[0])); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[1])); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[2])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[3])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[4])); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[5])); //i += cs_l; //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) //Broadcast A32 to A72 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[1])); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[2])); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[3])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[4])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[5])); //i += cs_l; //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) //Broadcast A43 to A73 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[2])); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[3])); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[4])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[5])); //i += cs_l; //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) //Broadcast A54 to A74 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[3])); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[4])); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[5])); //i += cs_l; //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) //Broadcast A65 to A75 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[4])); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[5])); //i += cs_l; //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) //Broadcast A76 to register mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 6 + cs_l_offset[5])); //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); //(Row7): FMA operations of b7 with elements of index (7, 0) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); //////////////////////////////////////////////////////////////////////////////// /* transpose steps start */ ////unpacklow//// mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); //Rearrange low elements #if REARRANGE_SHFL == 1 mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); #else mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); #endif //Merge rearranged low elements into complete rows mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); ////unpackhigh//// mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); //Rearrange high elements #if REARRANGE_SHFL == 1 mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); #else mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); #endif //Merge rearranged high elements into complete rows mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); /* transpose steps end */ //Store the computed B columns _mm256_storeu_ps((float *)ptr_b_dup + i2, mat_b_col[0]); _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+i2), mat_b_col[1]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i2), mat_b_col[2]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i2), mat_b_col[3]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i2), mat_b_col[4]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i2), mat_b_col[5]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i2), mat_b_col[6]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i2), mat_b_col[7]); //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); k++; //} i += cs_b_offset[6]; i2 += cs_b_offset[6]; } } //numRows of A ///////////////////loop ends ///////////////////// } static void trsm_AutXB_block_allSmallSizedMatrices_alpha(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha) { float ones = 1.0; int i, i1, i2, i3, i4, j, k, l, r; int cs_b_offset[7]; int cs_l_offset[7]; float *ptr_b_dup, *ptr_l_dup; //57 number of ymm(256 bits) registers used __m256 mat_b_col[8]; __m256 mat_b_rearr[8]; __m256 mat_a_blk_elems[8]; __m256 mat_a_diag_inv[8]; __m256 reciprocal_diags[2]; __m256 alphaReg; reciprocal_diags[0] = _mm256_broadcast_ss((float const *)(&ones)); alphaReg = _mm256_broadcast_ss((float const *)&alpha); // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // //L matrix offsets cs_l_offset[0] = (cs_l << 1); cs_l_offset[1] = cs_l + cs_l_offset[0]; cs_l_offset[2] = (cs_l << 2); cs_l_offset[3] = cs_l + cs_l_offset[2]; cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; cs_l_offset[5] = cs_l + cs_l_offset[4]; cs_l_offset[6] = (cs_l_offset[5] + cs_l); //read diag elems of L 16x16 block mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l); mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + cs_l); mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[0]); mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[1]); mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[2]); mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[3]); mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[4]); mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[5]); cs_b_offset[0] = (cs_b << 1); cs_b_offset[1] = cs_b + cs_b_offset[0]; cs_b_offset[2] = (cs_b << 2); cs_b_offset[3] = cs_b + cs_b_offset[2]; cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; cs_b_offset[5] = cs_b + cs_b_offset[4]; cs_b_offset[6] = (cs_b_offset[5] + cs_b); reciprocal_diags[1] = reciprocal_diags[0]; //pack first 8 diags together mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1 mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3 mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5 mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7 mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 //reciprocal of diagnal elements 0,1,2,3,4,5,6,7 reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); #if 0 //Broadcast A10 to A70 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); //Broadcast A21 to A71 to registers mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); //Broadcast A32 to A72 to registers mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); //Broadcast A43 to A73 to registers mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); //Broadcast A54 to A74 to registers mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); //Broadcast A65 to A75 to registers mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); //Broadcast A76 to register mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); #endif //extract diag a00 from a mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); //mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_diag_inv[0], mat_a_diag_inv[0]); //extract diag a11 from a mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); //extract diag a22 from a mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); //extract diag a33 from a mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); //extract diag a44 from a mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); //extract diag a55 from a mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); //extract diag a66 from a mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); //extract diag a77 from a mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); /***************** first set of 8 rows of B processing starts *****************/ ptr_b_dup = ptr_b; i = 0; for (j = 0; j < numCols_b; j += 8) { /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A //read 8x8 block of B into registers mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); /* transpose steps start */ ////unpacklow//// mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); //Rearrange low elements #if REARRANGE_SHFL == 1 mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); #else mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); #endif //Merge rearranged low elements into complete rows mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); ////unpackhigh//// mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); //Rearrange high elements #if REARRANGE_SHFL == 1 mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); #else mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); #endif //Merge rearranged high elements into complete rows mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); /* transpose steps end */ mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg); mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg); mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg); mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg); mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg); mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg); mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg); mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg); //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], mat_a_diag_inv[0]); mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0])); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3])); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4])); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5])); //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b) mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b) mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b) mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b) mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], mat_a_diag_inv[1]); mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[0])); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[1])); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[2])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[3])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[4])); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[5])); //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b) mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b) mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b) mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], mat_a_diag_inv[2]); mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[1])); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[2])); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[3])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[4])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[5])); //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b) mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b) mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], mat_a_diag_inv[3]); mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[2])); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[3])); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[4])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[5])); //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b) mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], mat_a_diag_inv[4]); mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[3])); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[4])); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[5])); //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], mat_a_diag_inv[5]); mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[4])); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[5])); //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], mat_a_diag_inv[6]); mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 6 + cs_l_offset[5])); //(Row7): FMA operations of b7 with elements of index (7, 0) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], mat_a_diag_inv[7]); //////////////////////////////////////////////////////////////////////////////// /* transpose steps start */ ////unpacklow//// mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); //Rearrange low elements #if REARRANGE_SHFL == 1 mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); #else mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); #endif //Merge rearranged low elements into complete rows mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); ////unpackhigh//// mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); //Rearrange high elements #if REARRANGE_SHFL == 1 mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); #else mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); #endif //Merge rearranged high elements into complete rows mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); /* transpose steps end */ //Store the computed B columns _mm256_storeu_ps((float *)ptr_b_dup, mat_b_rearr[0]); _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_rearr[1]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_rearr[2]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_rearr[3]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_rearr[4]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_rearr[5]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_rearr[6]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_rearr[7]); i += cs_b_offset[6]; ptr_b_dup += cs_b_offset[6]; //i += 8; //ptr_b_dup += 8; } //c = 0; /***************** first set of 8 cols of B processing done *****************/ ptr_b_dup = ptr_b; i3 = 0; i1 = 0; //Start loop for cols of B to be processed in size of blk_width for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row { ptr_l += cs_l_offset[6]; //Read next 8x8 block of A to get diag elements i3 += 8; mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l + i3); mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l); mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[0]); mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[1]); mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[2]); mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[3]); mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[4]); mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[5]); //pack 8 diags of A together reciprocal_diags[0] = reciprocal_diags[1]; mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1 mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3 mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5 mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7 mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 //reciprocal of diagnal elements of A :- 0,1,2,3,4,5,6,7 reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); //ptr_b += j; //ptr_b_dup += 8; ptr_b_dup += 8; i1 += 8; i = i1; i2 = 0; //extract diag a00 from a mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); //mat_a_diag_inv2[0] = _mm256_unpacklo_ps(mat_a_diag_inv2[0], mat_a_diag_inv2[0]); //extract diag a11 from a mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); //extract diag a22 from a mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); //extract diag a33 from a mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); //extract diag a44 from a mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); //extract diag a55 from a mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); //extract diag a66 from a mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); //extract diag a77 from a mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); for (r = 0; r < numCols_b; r += GEMM_BLK_V1) { #if GEMM_ACCUM_A //Read 8 cols of B columns of Block-to-be-solved mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i); mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); /* transpose steps start */ ////unpacklow//// mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); //Rearrange low elements #if REARRANGE_SHFL == 1 mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); #else mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); #endif //Merge rearranged low elements into complete rows mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); ////unpackhigh//// mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); //Rearrange high elements #if REARRANGE_SHFL == 1 mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); #else mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); #endif //Merge rearranged high elements into complete rows mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); /* transpose steps end */ mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg); mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg); mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg); mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg); mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg); mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg); mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg); mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg); #endif //i = 0; ptr_l_dup = ptr_l; i4 = i2; for (l = 0; l < j; l += 8) // move across m { //for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) //{ /////////////////// Partial Lower 8x8 block trsm of B //Read current 8 cols of B columns from specified 8x8 current-block of B mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_b + i4); mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); /* transpose steps start */ ////unpacklow//// mat_b_col[0] = _mm256_unpacklo_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]); mat_b_col[1] = _mm256_unpacklo_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]); mat_b_col[2] = _mm256_unpacklo_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]); mat_b_col[3] = _mm256_unpacklo_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]); //Rearrange low elements #if REARRANGE_SHFL == 1 mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); #else mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); #endif //Merge rearranged low elements into complete rows mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); ////unpackhigh//// mat_a_blk_elems[0] = _mm256_unpackhi_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]); mat_a_blk_elems[1] = _mm256_unpackhi_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]); mat_a_blk_elems[2] = _mm256_unpackhi_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]); mat_a_blk_elems[3] = _mm256_unpackhi_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]); //Rearrange high elements #if REARRANGE_SHFL == 1 mat_a_blk_elems[4] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x44); mat_a_blk_elems[5] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xEE); mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x44); mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xEE); #else mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x4E); mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x4E); mat_a_blk_elems[4] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[6], 0xCC); mat_a_blk_elems[5] = _mm256_blend_ps(mat_a_blk_elems[1], mat_a_blk_elems[6], 0x33); mat_a_blk_elems[6] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[7], 0xCC); mat_a_blk_elems[7] = _mm256_blend_ps(mat_a_blk_elems[3], mat_a_blk_elems[7], 0x33); #endif //Merge rearranged high elements into complete rows mat_b_col[2] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x20); mat_b_col[6] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x31); mat_b_col[3] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x20); mat_b_col[7] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x31); /* transpose steps end */ //Broadcast A8,0 to A15,0 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); //i4 = k >> 3; ptr_l_dup++; #if GEMM_ACCUM_A //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]); mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]); mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]); mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]); mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]); mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]); mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]); mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]); #endif //Broadcast A21 to A71 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); ptr_l_dup++; #if GEMM_ACCUM_A //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A8,2 to A15,2 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); ptr_l_dup++; #if GEMM_ACCUM_A //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A8,3 to A15,3 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); ptr_l_dup++; #if GEMM_ACCUM_A //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A8,4 to A15,4 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); ptr_l_dup++; #if GEMM_ACCUM_A //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A8,5 to A15,5 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); ptr_l_dup++; #if GEMM_ACCUM_A //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A8,6 to A15,6 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); ptr_l_dup++; #if GEMM_ACCUM_A //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A8,7 to A15,7 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); ptr_l_dup++; #if GEMM_ACCUM_A //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) #endif //end loop of cols //} //i2 += cs_b_offset[6]; i4 += 8; } //trsm solve k = 0; //for (i2 = 0; i2 < numCols_b; i2 += 8) //{ //i2 = i1 + r; /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A #if !GEMM_ACCUM_A //Read 8 cols of B columns of Block-to-be-solved mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); /* transpose steps start */ ////unpacklow//// mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); //Rearrange low elements #if REARRANGE_SHFL == 1 mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); #else mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); #endif //Merge rearranged low elements into complete rows mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); ////unpackhigh//// mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); //Rearrange high elements #if REARRANGE_SHFL == 1 mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); #else mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); #endif //Merge rearranged high elements into complete rows mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); /* transpose steps end */ mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg); mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg); mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg); mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg); mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg); mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg); mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg); mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg); #endif //Broadcast A10 to A70 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); //i += cs_l; #if GEMM_ACCUM_A //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); #else mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]); mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); #endif #if GEMM_ACCUM_A mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]); mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]); mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]); mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]); mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]); mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]); mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]); //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A21 to A71 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[0])); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[1])); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[2])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[3])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[4])); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[5])); //i += cs_l; //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) //Broadcast A32 to A72 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[1])); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[2])); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[3])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[4])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[5])); //i += cs_l; //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) //Broadcast A43 to A73 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[2])); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[3])); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[4])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[5])); //i += cs_l; //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) //Broadcast A54 to A74 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[3])); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[4])); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[5])); //i += cs_l; //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) //Broadcast A65 to A75 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[4])); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[5])); //i += cs_l; //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) //Broadcast A76 to register mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 6 + cs_l_offset[5])); //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); //(Row7): FMA operations of b7 with elements of index (7, 0) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); //////////////////////////////////////////////////////////////////////////////// /* transpose steps start */ ////unpacklow//// mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); //Rearrange low elements #if REARRANGE_SHFL == 1 mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); #else mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); #endif //Merge rearranged low elements into complete rows mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); ////unpackhigh//// mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); //Rearrange high elements #if REARRANGE_SHFL == 1 mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); #else mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); #endif //Merge rearranged high elements into complete rows mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); /* transpose steps end */ //Store the computed B columns _mm256_storeu_ps((float *)ptr_b_dup + i2, mat_b_col[0]); _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+i2), mat_b_col[1]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i2), mat_b_col[2]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i2), mat_b_col[3]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i2), mat_b_col[4]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i2), mat_b_col[5]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i2), mat_b_col[6]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i2), mat_b_col[7]); //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); k++; //} i += cs_b_offset[6]; i2 += cs_b_offset[6]; } } //numRows of A ///////////////////loop ends ///////////////////// } static void trsm_AutXB_block_allSmallSizedMatrices_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b) { //float ones = 1.0; int i, i1, i2, i4, j, k, l, r; int cs_b_offset[7]; int cs_l_offset[7]; float *ptr_b_dup, *ptr_l_dup; //57 number of ymm(256 bits) registers used __m256 mat_b_col[8]; __m256 mat_b_rearr[8]; __m256 mat_a_blk_elems[8]; //__m256 mat_a_diag_inv[8]; //__m256 reciprocal_diags[2]; // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // //L matrix offsets cs_l_offset[0] = (cs_l << 1); cs_l_offset[1] = cs_l + cs_l_offset[0]; cs_l_offset[2] = (cs_l << 2); cs_l_offset[3] = cs_l + cs_l_offset[2]; cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; cs_l_offset[5] = cs_l + cs_l_offset[4]; cs_l_offset[6] = (cs_l_offset[5] + cs_l); cs_b_offset[0] = (cs_b << 1); cs_b_offset[1] = cs_b + cs_b_offset[0]; cs_b_offset[2] = (cs_b << 2); cs_b_offset[3] = cs_b + cs_b_offset[2]; cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; cs_b_offset[5] = cs_b + cs_b_offset[4]; cs_b_offset[6] = (cs_b_offset[5] + cs_b); #if 0 //Broadcast A10 to A70 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); //Broadcast A21 to A71 to registers mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); //Broadcast A32 to A72 to registers mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); //Broadcast A43 to A73 to registers mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); //Broadcast A54 to A74 to registers mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); //Broadcast A65 to A75 to registers mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); //Broadcast A76 to register mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); #endif /***************** first set of 8 rows of B processing starts *****************/ ptr_b_dup = ptr_b; i = 0; for (j = 0; j < numCols_b; j += 8) { /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A //read 8x8 block of B into registers mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); /* transpose steps start */ ////unpacklow//// mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); //Rearrange low elements #if REARRANGE_SHFL == 1 mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); #else mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); #endif //Merge rearranged low elements into complete rows mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); ////unpackhigh//// mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); //Rearrange high elements #if REARRANGE_SHFL == 1 mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); #else mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); #endif //Merge rearranged high elements into complete rows mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); /* transpose steps end */ //(Row0) mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0])); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3])); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4])); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5])); //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b) mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b) mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b) mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b) mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b) mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[0])); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[1])); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[2])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[3])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[4])); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[5])); //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b) mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b) mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b) mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b) mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[1])); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[2])); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[3])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[4])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[5])); //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b) mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b) mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b) mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[2])); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[3])); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[4])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[5])); //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b) mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b) mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[3])); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[4])); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[5])); //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b) mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[4])); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[5])); //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b) mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 6 + cs_l_offset[5])); //(Row7): FMA operations of b7 with elements of index (7, 0) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b) //////////////////////////////////////////////////////////////////////////////// /* transpose steps start */ ////unpacklow//// mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); //Rearrange low elements #if REARRANGE_SHFL == 1 mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); #else mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); #endif //Merge rearranged low elements into complete rows mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); ////unpackhigh//// mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); //Rearrange high elements #if REARRANGE_SHFL == 1 mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); #else mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); #endif //Merge rearranged high elements into complete rows mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); /* transpose steps end */ //Store the computed B columns _mm256_storeu_ps((float *)ptr_b_dup, mat_b_rearr[0]); _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_rearr[1]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_rearr[2]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_rearr[3]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_rearr[4]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_rearr[5]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_rearr[6]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_rearr[7]); i += cs_b_offset[6]; ptr_b_dup += cs_b_offset[6]; //i += 8; //ptr_b_dup += 8; } //c = 0; /***************** first set of 8 cols of B processing done *****************/ ptr_b_dup = ptr_b; i1 = 0; //Start loop for cols of B to be processed in size of blk_width for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row { ptr_l += cs_l_offset[6]; //ptr_b += j; //ptr_b_dup += 8; ptr_b_dup += 8; i1 += 8; i = i1; i2 = 0; for (r = 0; r < numCols_b; r += GEMM_BLK_V1) { #if GEMM_ACCUM_A //Read 8 cols of B columns of Block-to-be-solved mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i); mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); /* transpose steps start */ ////unpacklow//// mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); //Rearrange low elements #if REARRANGE_SHFL == 1 mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); #else mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); #endif //Merge rearranged low elements into complete rows mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); ////unpackhigh//// mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); //Rearrange high elements #if REARRANGE_SHFL == 1 mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); #else mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); #endif //Merge rearranged high elements into complete rows mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); /* transpose steps end */ #endif //i = 0; ptr_l_dup = ptr_l; i4 = i2; for (l = 0; l < j; l += 8) // move across m { //for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) //{ /////////////////// Partial Lower 8x8 block trsm of B //Read current 8 cols of B columns from specified 8x8 current-block of B mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_b + i4); mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); /* transpose steps start */ ////unpacklow//// mat_b_col[0] = _mm256_unpacklo_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]); mat_b_col[1] = _mm256_unpacklo_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]); mat_b_col[2] = _mm256_unpacklo_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]); mat_b_col[3] = _mm256_unpacklo_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]); //Rearrange low elements #if REARRANGE_SHFL == 1 mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); #else mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); #endif //Merge rearranged low elements into complete rows mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); ////unpackhigh//// mat_a_blk_elems[0] = _mm256_unpackhi_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]); mat_a_blk_elems[1] = _mm256_unpackhi_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]); mat_a_blk_elems[2] = _mm256_unpackhi_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]); mat_a_blk_elems[3] = _mm256_unpackhi_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]); //Rearrange high elements #if REARRANGE_SHFL == 1 mat_a_blk_elems[4] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x44); mat_a_blk_elems[5] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xEE); mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x44); mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xEE); #else mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x4E); mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x4E); mat_a_blk_elems[4] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[6], 0xCC); mat_a_blk_elems[5] = _mm256_blend_ps(mat_a_blk_elems[1], mat_a_blk_elems[6], 0x33); mat_a_blk_elems[6] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[7], 0xCC); mat_a_blk_elems[7] = _mm256_blend_ps(mat_a_blk_elems[3], mat_a_blk_elems[7], 0x33); #endif //Merge rearranged high elements into complete rows mat_b_col[2] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x20); mat_b_col[6] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x31); mat_b_col[3] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x20); mat_b_col[7] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x31); /* transpose steps end */ //Broadcast A8,0 to A15,0 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); //i4 = k >> 3; ptr_l_dup++; #if GEMM_ACCUM_A //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]); mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]); mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]); mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]); mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]); mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]); mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]); mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]); #endif //Broadcast A21 to A71 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); ptr_l_dup++; #if GEMM_ACCUM_A //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A8,2 to A15,2 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); ptr_l_dup++; #if GEMM_ACCUM_A //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A8,3 to A15,3 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); ptr_l_dup++; #if GEMM_ACCUM_A //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A8,4 to A15,4 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); ptr_l_dup++; #if GEMM_ACCUM_A //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A8,5 to A15,5 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); ptr_l_dup++; #if GEMM_ACCUM_A //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A8,6 to A15,6 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); ptr_l_dup++; #if GEMM_ACCUM_A //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A8,7 to A15,7 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); ptr_l_dup++; #if GEMM_ACCUM_A //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) #endif //end loop of cols //} //i2 += cs_b_offset[6]; i4 += 8; } //trsm solve k = 0; //for (i2 = 0; i2 < numCols_b; i2 += 8) //{ //i2 = i1 + r; /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A #if !GEMM_ACCUM_A //Read 8 cols of B columns of Block-to-be-solved mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); /* transpose steps start */ ////unpacklow//// mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); //Rearrange low elements #if REARRANGE_SHFL == 1 mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); #else mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); #endif //Merge rearranged low elements into complete rows mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); ////unpackhigh//// mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); //Rearrange high elements #if REARRANGE_SHFL == 1 mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); #else mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); #endif //Merge rearranged high elements into complete rows mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); /* transpose steps end */ #endif //Broadcast A10 to A70 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); //i += cs_l; #if GEMM_ACCUM_A //(Row0): already done #else mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]); #endif #if GEMM_ACCUM_A mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]); mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]); mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]); mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]); mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]); mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]); mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]); //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A21 to A71 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[0])); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[1])); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[2])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[3])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[4])); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[5])); //i += cs_l; //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) //Broadcast A32 to A72 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[1])); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[2])); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[3])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[4])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[5])); //i += cs_l; //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) //Broadcast A43 to A73 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[2])); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[3])); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[4])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[5])); //i += cs_l; //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) //Broadcast A54 to A74 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[3])); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[4])); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[5])); //i += cs_l; //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) //Broadcast A65 to A75 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[4])); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[5])); //i += cs_l; //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) //Broadcast A76 to register mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 6 + cs_l_offset[5])); //(Row7): FMA operations of b7 with elements of index (7, 0) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) //////////////////////////////////////////////////////////////////////////////// /* transpose steps start */ ////unpacklow//// mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); //Rearrange low elements #if REARRANGE_SHFL == 1 mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); #else mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); #endif //Merge rearranged low elements into complete rows mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); ////unpackhigh//// mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); //Rearrange high elements #if REARRANGE_SHFL == 1 mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); #else mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); #endif //Merge rearranged high elements into complete rows mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); /* transpose steps end */ //Store the computed B columns _mm256_storeu_ps((float *)ptr_b_dup + i2, mat_b_col[0]); _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+i2), mat_b_col[1]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i2), mat_b_col[2]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i2), mat_b_col[3]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i2), mat_b_col[4]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i2), mat_b_col[5]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i2), mat_b_col[6]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i2), mat_b_col[7]); //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); k++; //} i += cs_b_offset[6]; i2 += cs_b_offset[6]; } } //numRows of A ///////////////////loop ends ///////////////////// } static void trsm_AutXB_block_allSmallSizedMatrices_alpha_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha) { //float ones = 1.0; int i, i1, i2, i4, j, k, l, r; int cs_b_offset[7]; int cs_l_offset[7]; float *ptr_b_dup, *ptr_l_dup; //57 number of ymm(256 bits) registers used __m256 mat_b_col[8]; __m256 mat_b_rearr[8]; __m256 mat_a_blk_elems[8]; //__m256 mat_a_diag_inv[8]; //__m256 reciprocal_diags[2]; __m256 alphaReg; alphaReg = _mm256_broadcast_ss((float const *)&alpha); // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // //L matrix offsets cs_l_offset[0] = (cs_l << 1); cs_l_offset[1] = cs_l + cs_l_offset[0]; cs_l_offset[2] = (cs_l << 2); cs_l_offset[3] = cs_l + cs_l_offset[2]; cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; cs_l_offset[5] = cs_l + cs_l_offset[4]; cs_l_offset[6] = (cs_l_offset[5] + cs_l); cs_b_offset[0] = (cs_b << 1); cs_b_offset[1] = cs_b + cs_b_offset[0]; cs_b_offset[2] = (cs_b << 2); cs_b_offset[3] = cs_b + cs_b_offset[2]; cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; cs_b_offset[5] = cs_b + cs_b_offset[4]; cs_b_offset[6] = (cs_b_offset[5] + cs_b); #if 0 //Broadcast A10 to A70 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); //Broadcast A21 to A71 to registers mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); //Broadcast A32 to A72 to registers mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); //Broadcast A43 to A73 to registers mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); //Broadcast A54 to A74 to registers mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); //Broadcast A65 to A75 to registers mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); //Broadcast A76 to register mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); #endif /***************** first set of 8 rows of B processing starts *****************/ ptr_b_dup = ptr_b; i = 0; for (j = 0; j < numCols_b; j += 8) { /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A //read 8x8 block of B into registers mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); /* transpose steps start */ ////unpacklow//// mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); //Rearrange low elements #if REARRANGE_SHFL == 1 mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); #else mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); #endif //Merge rearranged low elements into complete rows mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); ////unpackhigh//// mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); //Rearrange high elements #if REARRANGE_SHFL == 1 mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); #else mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); #endif //Merge rearranged high elements into complete rows mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); /* transpose steps end */ mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg); mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg); mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg); mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg); mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg); mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg); mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg); mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg); //(Row0) mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0])); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3])); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4])); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5])); //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b) mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b) mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b) mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b) mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b) mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[0])); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[1])); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[2])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[3])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[4])); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[5])); //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b) mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b) mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b) mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b) mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[1])); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[2])); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[3])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[4])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[5])); //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b) mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b) mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b) mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[2])); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[3])); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[4])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[5])); //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b) mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b) mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[3])); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[4])); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[5])); //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b) mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[4])); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[5])); //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b) mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 6 + cs_l_offset[5])); //(Row7): FMA operations of b7 with elements of index (7, 0) mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b) //////////////////////////////////////////////////////////////////////////////// /* transpose steps start */ ////unpacklow//// mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); //Rearrange low elements #if REARRANGE_SHFL == 1 mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); #else mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); #endif //Merge rearranged low elements into complete rows mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); ////unpackhigh//// mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); //Rearrange high elements #if REARRANGE_SHFL == 1 mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); #else mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); #endif //Merge rearranged high elements into complete rows mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); /* transpose steps end */ //Store the computed B columns _mm256_storeu_ps((float *)ptr_b_dup, mat_b_rearr[0]); _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_rearr[1]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_rearr[2]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_rearr[3]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_rearr[4]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_rearr[5]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_rearr[6]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_rearr[7]); i += cs_b_offset[6]; ptr_b_dup += cs_b_offset[6]; //i += 8; //ptr_b_dup += 8; } //c = 0; /***************** first set of 8 cols of B processing done *****************/ ptr_b_dup = ptr_b; i1 = 0; //Start loop for cols of B to be processed in size of blk_width for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row { ptr_l += cs_l_offset[6]; //ptr_b += j; //ptr_b_dup += 8; ptr_b_dup += 8; i1 += 8; i = i1; i2 = 0; for (r = 0; r < numCols_b; r += GEMM_BLK_V1) { #if GEMM_ACCUM_A //Read 8 cols of B columns of Block-to-be-solved mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i); mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); /* transpose steps start */ ////unpacklow//// mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); //Rearrange low elements #if REARRANGE_SHFL == 1 mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); #else mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); #endif //Merge rearranged low elements into complete rows mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); ////unpackhigh//// mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); //Rearrange high elements #if REARRANGE_SHFL == 1 mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); #else mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); #endif //Merge rearranged high elements into complete rows mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); /* transpose steps end */ mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg); mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg); mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg); mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg); mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg); mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg); mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg); mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg); #endif //i = 0; ptr_l_dup = ptr_l; i4 = i2; for (l = 0; l < j; l += 8) // move across m { //for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) //{ /////////////////// Partial Lower 8x8 block trsm of B //Read current 8 cols of B columns from specified 8x8 current-block of B mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_b + i4); mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); /* transpose steps start */ ////unpacklow//// mat_b_col[0] = _mm256_unpacklo_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]); mat_b_col[1] = _mm256_unpacklo_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]); mat_b_col[2] = _mm256_unpacklo_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]); mat_b_col[3] = _mm256_unpacklo_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]); //Rearrange low elements #if REARRANGE_SHFL == 1 mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); #else mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); #endif //Merge rearranged low elements into complete rows mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); ////unpackhigh//// mat_a_blk_elems[0] = _mm256_unpackhi_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]); mat_a_blk_elems[1] = _mm256_unpackhi_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]); mat_a_blk_elems[2] = _mm256_unpackhi_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]); mat_a_blk_elems[3] = _mm256_unpackhi_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]); //Rearrange high elements #if REARRANGE_SHFL == 1 mat_a_blk_elems[4] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x44); mat_a_blk_elems[5] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xEE); mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x44); mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xEE); #else mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x4E); mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x4E); mat_a_blk_elems[4] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[6], 0xCC); mat_a_blk_elems[5] = _mm256_blend_ps(mat_a_blk_elems[1], mat_a_blk_elems[6], 0x33); mat_a_blk_elems[6] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[7], 0xCC); mat_a_blk_elems[7] = _mm256_blend_ps(mat_a_blk_elems[3], mat_a_blk_elems[7], 0x33); #endif //Merge rearranged high elements into complete rows mat_b_col[2] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x20); mat_b_col[6] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x31); mat_b_col[3] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x20); mat_b_col[7] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x31); /* transpose steps end */ //Broadcast A8,0 to A15,0 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); //i4 = k >> 3; ptr_l_dup++; #if GEMM_ACCUM_A //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]); mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]); mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]); mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]); mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]); mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]); mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]); mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]); #endif //Broadcast A21 to A71 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); ptr_l_dup++; #if GEMM_ACCUM_A //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A8,2 to A15,2 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); ptr_l_dup++; #if GEMM_ACCUM_A //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A8,3 to A15,3 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); ptr_l_dup++; #if GEMM_ACCUM_A //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A8,4 to A15,4 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); ptr_l_dup++; #if GEMM_ACCUM_A //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A8,5 to A15,5 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); ptr_l_dup++; #if GEMM_ACCUM_A //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A8,6 to A15,6 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); ptr_l_dup++; #if GEMM_ACCUM_A //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A8,7 to A15,7 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); ptr_l_dup++; #if GEMM_ACCUM_A //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) #endif //end loop of cols //} //i2 += cs_b_offset[6]; i4 += 8; } //trsm solve k = 0; //for (i2 = 0; i2 < numCols_b; i2 += 8) //{ //i2 = i1 + r; /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A #if !GEMM_ACCUM_A //Read 8 cols of B columns of Block-to-be-solved mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); /* transpose steps start */ ////unpacklow//// mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); //Rearrange low elements #if REARRANGE_SHFL == 1 mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); #else mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); #endif //Merge rearranged low elements into complete rows mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); ////unpackhigh//// mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); //Rearrange high elements #if REARRANGE_SHFL == 1 mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); #else mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); #endif //Merge rearranged high elements into complete rows mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); /* transpose steps end */ mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg); mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg); mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg); mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg); mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg); mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg); mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg); mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg); #endif //Broadcast A10 to A70 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); //i += cs_l; #if GEMM_ACCUM_A //(Row0): already done #else mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]); #endif #if GEMM_ACCUM_A mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) #else mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]); mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]); mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]); mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]); mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]); mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]); mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]); //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) #endif //Broadcast A21 to A71 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[0])); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[1])); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[2])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[3])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[4])); mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[5])); //i += cs_l; //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) //Broadcast A32 to A72 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[1])); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[2])); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[3])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[4])); mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[5])); //i += cs_l; //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) //Broadcast A43 to A73 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[2])); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[3])); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[4])); mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[5])); //i += cs_l; //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) //Broadcast A54 to A74 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[3])); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[4])); mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[5])); //i += cs_l; //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) //Broadcast A65 to A75 to registers mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[4])); mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[5])); //i += cs_l; //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) //Broadcast A76 to register mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 6 + cs_l_offset[5])); //(Row7): FMA operations of b7 with elements of index (7, 0) mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) //////////////////////////////////////////////////////////////////////////////// /* transpose steps start */ ////unpacklow//// mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); //Rearrange low elements #if REARRANGE_SHFL == 1 mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); #else mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); #endif //Merge rearranged low elements into complete rows mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); ////unpackhigh//// mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); //Rearrange high elements #if REARRANGE_SHFL == 1 mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); #else mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); #endif //Merge rearranged high elements into complete rows mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); /* transpose steps end */ //Store the computed B columns _mm256_storeu_ps((float *)ptr_b_dup + i2, mat_b_col[0]); _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+i2), mat_b_col[1]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i2), mat_b_col[2]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i2), mat_b_col[3]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i2), mat_b_col[4]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i2), mat_b_col[5]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i2), mat_b_col[6]); _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i2), mat_b_col[7]); //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); k++; //} i += cs_b_offset[6]; i2 += cs_b_offset[6]; } } //numRows of A ///////////////////loop ends ///////////////////// } #endif