// vim: ft=arm .text .align 4 /* Z: 32x32 z0[0] .. z0[15] z1[0] .. z1[15] z4[0] .. z4[15] z5[0] .. z5[15] .. z60[0] .. z60[15] z61[0] .. z61[15] z2[0] .. z2[15] z3[0] .. z3[15] z5[0] .. z5[15] z6[0] .. z6[15] .. z62[0] .. z62[15] z63[0] .. z63[15] */ .global {{G}}apple_amx_mmm_f32_32x32_{{suffix}} {{G}}apple_amx_mmm_f32_32x32_{{suffix}}: {{ AMX_SET }} // set x1 to a 128 bytes aligned block for loads mov x1, sp lsr x1, x1, #7 lsl x1, x1, #7 sub x1, x1, 128 {% include "dispatcher.tmpliq" %} .leaky_relu: .q_scale: .q_shl: .q_shr: b .unsupported .add_mat_mul: ldr x2, [x0, #24] // b ldp x3, x4, [x0, #8] // k, a cmp x3, #0 beq .non_linear_loop orr x4, x4, {{0|setting:62}} // load pairs (A) orr x2, x2, {{0|setting:62}} // load pairs (B) eor x5, x5, x5 // top left orr x6, x5, {{ 0|setting:20 }} // Z row = 1 orr x6, x6, {{ 0|setting:16 }} // top right orr x7, x5, {{ 0|setting:21 }} orr x7, x7, {{ 0|setting:6 }} // bottom left orr x8, x7, x6 // bottom right .packed_packed_loop_1: {% amx ldx x2 %} {% amx ldy x4 %} add x2, x2, 128 add x4, x4, 128 {% amx fma32 x5 %} {% amx fma32 x6 %} {% amx fma32 x7 %} {% amx fma32 x8 %} subs x3, x3, #1 bne .packed_packed_loop_1 b .non_linear_loop .clear: // top left eor x2, x2, x2 orr x2, x2, {{ 0|setting:27 }} orr x2, x2, {{ 0|setting:28 }} orr x2, x2, {{ 0|setting:29 }} // Z = 0 {% amx fma32 x2 %} // top right orr x2, x2, {{ 0|setting:20 }} // Z row = 1 {% amx fma32 x2 %} // bottom right orr x2, x2, {{ 0|setting:21 }} // Z row = 3 {% amx fma32 x2 %} // bottom left eor x2, x2, {{ 0|setting:20 }} // Z row = 2 {% amx fma32 x2 %} b .non_linear_loop .per_col_sub: // performs a unary neg on Z eor x2, x2, x2 // X[0] = Z[0] mov x4, {{ 0|setting:63 }} // vector mode orr x4, x4, {{ 0|setting:28 }} orr x4, x4, {{ 0|setting:27 }} // Z=-X mov x6, 64 .per_col_sub_loop: {% amx extrx x2 %} {% amx fms32 x4 %} add x2, x2, {{0|setting:20}} // next Z row add x4, x4, {{0|setting:20}} // next Z row subs x6, x6, 1 bne .per_col_sub_loop // continue .per_col_add: ldr x2, [x0, #8] ld1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x2], #64 st1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x1], #64 ld1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x2] st1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x1] sub x1, x1, #64 orr x1, x1, {{ 0|setting:62 }} // load a pair {% amx ldx x1 %} mov x2, {{ 0|setting:28 }} // z += y // top left {% amx fma32 x2 %} // bottom left orr x2, x2, {{ 0|setting:21 }} // Z row = 2 {% amx fma32 x2 %} // bottom right orr x2, x2, {{ 0|setting:16 }} // X offset orr x2, x2, {{ 0|setting:20 }} // Z row = 3 {% amx fma32 x2 %} // top right eor x2, x2, {{ 0|setting:21 }} // Z row = 1 {% amx fma32 x2 %} b .non_linear_loop .per_col_sub_flipped: ldr x2, [x0, #8] ld1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x2], #64 st1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x1], #64 ld1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x2] st1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x1] sub x1, x1, #64 orr x1, x1, {{ 0|setting:62 }} // load a pair {% amx ldx x1 %} mov x2, {{ 0|setting:28 }} // z += y // top left {% amx fms32 x2 %} // bottom left orr x2, x2, {{ 0|setting:21 }} // Z row = 2 {% amx fms32 x2 %} // bottom right orr x2, x2, {{ 0|setting:16 }} // X offset orr x2, x2, {{ 0|setting:20 }} // Z row = 3 {% amx fms32 x2 %} // top right eor x2, x2, {{ 0|setting:21 }} // Z row = 1 {% amx fms32 x2 %} b .non_linear_loop .per_row_sub_flipped: ldr x2, [x0, #8] ld1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x2], #64 st1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x1], #64 ld1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x2] st1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x1] sub x1, x1, #64 orr x2, x1, {{ 0|setting:62 }} // load a pair {% amx ldy x2 %} mov x2, {{ 0|setting:29 }} // z += y // top left {% amx fms32 x2 %} // top right orr x2, x2, {{ 0|setting:20 }} // Z row = 1 {% amx fms32 x2 %} // bottom right orr x2, x2, {{ 0|setting:21 }} // Z row = 3 orr x2, x2, {{ 0|setting:6 }} // Y offset {% amx fms32 x2 %} // bottom left eor x2, x2, {{ 0|setting:20 }} // Z row = 2 {% amx fms32 x2 %} b .non_linear_loop .per_row_sub: // performs a unary neg on Z eor x2, x2, x2 // X[0] = Z[0] mov x4, {{ 0|setting:63 }} // vector mode orr x4, x4, {{ 0|setting:28 }} orr x4, x4, {{ 0|setting:27 }} // Z=-X mov x6, 64 .per_row_sub_loop: {% amx extrx x2 %} {% amx fms32 x4 %} add x2, x2, {{0|setting:20}} // next Z row add x4, x4, {{0|setting:20}} // next Z row subs x6, x6, 1 bne .per_row_sub_loop // continue .per_row_add: ldr x2, [x0, #8] ld1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x2], #64 st1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x1], #64 ld1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x2] st1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x1] sub x1, x1, #64 orr x2, x1, {{ 0|setting:62 }} // load a pair {% amx ldy x2 %} mov x2, {{ 0|setting:29 }} // z += y // top left {% amx fma32 x2 %} // top right orr x2, x2, {{ 0|setting:20 }} // Z row = 1 {% amx fma32 x2 %} // bottom right orr x2, x2, {{ 0|setting:21 }} // Z row = 3 orr x2, x2, {{ 0|setting:6 }} // Y offset {% amx fma32 x2 %} // bottom left eor x2, x2, {{ 0|setting:20 }} // Z row = 2 {% amx fma32 x2 %} b .non_linear_loop .per_row_min: mov x2, 5 b .per_row_min_max .per_row_max: mov x2, 7 .per_row_min_max: ldr x5, [x0, #8] add x6, x5, 64 lsl x2, x2, 47 // max(x,z) (or min) orr x2, x2, {{ 0|setting:44 }} // f32 orr x3, x2, {{ 0|setting:20 }} // right half: z offset orr x8, x2, {{ 0|setting:21 }} // bottom left orr x9, x3, {{ 0|setting:21 }} // bottom right mov x4, 16 .loop_per_row_max: // top half ld1 { v0.s }[0], [x5], #4 dup v0.4s, v0.s[0] dup v1.4s, v0.s[0] dup v2.4s, v0.s[0] dup v3.4s, v0.s[0] st1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x1] {% amx ldx x1 %} {% amx vecfp x2 %} {% amx vecfp x3 %} add x2, x2, {{ 0|setting:22 }} add x3, x3, {{ 0|setting:22 }} // bottom half ld1 { v0.s }[0], [x6], #4 dup v0.4s, v0.s[0] dup v1.4s, v0.s[0] dup v2.4s, v0.s[0] dup v3.4s, v0.s[0] st1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x1] {% amx ldx x1 %} {% amx vecfp x8 %} {% amx vecfp x9 %} add x8, x8, {{ 0|setting:22 }} add x9, x9, {{ 0|setting:22 }} subs x4, x4, 1 bne .loop_per_row_max b .non_linear_loop .per_col_min: mov x2, 5 b .per_col_min_max .per_col_max: mov x2, 7 .per_col_min_max: ldr x4, [x0, #8] ld1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x4], #64 st1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x1], #64 ld1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x4] st1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x1] sub x1, x1, #64 orr x3, x1, {{ 0|setting:62 }} // load a pair {% amx ldx x3 %} lsl x2, x2, 47 // max(x,z) (or min) orr x2, x2, {{ 0|setting:44 }} // f32 orr x3, x2, {{ 0|setting:16 }} // right half: x offset orr x3, x3, {{ 0|setting:20 }} // right half: z offset mov x4, 32 .loop_per_col_max: {% amx vecfp x2 %} {% amx vecfp x3 %} add x2, x2, {{ 0|setting:21 }} add x3, x3, {{ 0|setting:21 }} subs x4, x4, 1 bne .loop_per_col_max b .non_linear_loop .per_col_mul: ldr x4, [x0, #8] ld1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x4], #64 st1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x1], #64 ld1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x4] st1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x1] sub x1, x1, #64 orr x2, x1, {{ 0|setting:62 }} // load a pair {% amx ldy x2 %} eor x2, x2, x2 // X[0] = Z[0] eor x3, x3, x3 orr x3, x3, {{0|setting:20 }} // Z[1] orr x3, x3, {{0|setting:16 }} // X[1] mov x4, {{ 0|setting:63 }} // vector mode orr x4, x4, {{ 0|setting:27 }} // Z=X*Y mov x5, {{ 0|setting:63 }} // vector mode orr x5, x5, {{ 0|setting:27 }} // Z=X*Y orr x5, x5, {{ 0|setting:20 }} // Z right orr x5, x5, {{ 0|setting:16 }} // X[1] (right) orr x5, x5, {{ 0|setting:6 }} // Y[1] (right) mov x6, 32 .loop_per_col_mul: {% amx extrx x2 %} {% amx extrx x3 %} {% amx fma32 x4 %} {% amx fma32 x5 %} add x2, x2, {{0|setting:21}} add x3, x3, {{0|setting:21}} add x4, x4, {{0|setting:21}} add x5, x5, {{0|setting:21}} subs x6, x6, 1 bne .loop_per_col_mul b .non_linear_loop .per_row_mul: ldr x14, [x0, #8] add x15, x14, 64 // extrx eor x2, x2, x2 // X[0] = Z[0] (top left) eor x3, x3, x3 orr x3, x3, {{0|setting:20 }} // Z[1] orr x3, x3, {{0|setting:16 }} // X[1] = Z[1] (top right) eor x4, x4, x4 orr x4, x4, {{0|setting:21}} // X[0] = Z[2] (bottom left) orr x5, x4, {{0|setting:20}} orr x5, x5, {{0|setting:16}} // X[1] = Z[3] (bottom right) // fma32 eor x6, x6, x6 orr x6, x6, {{0|setting:63}} // vector mode orr x6, x6, {{0|setting:27}} // Z=X*Y Z[0]=X[0]*Y[0] orr x7, x6, {{0|setting:20}} // Z[1] orr x7, x7, {{0|setting:16}} // X[1] Z[1] = X[1]*Y[0] orr x8, x6, {{0|setting:21}} // Z[2] orr x8, x8, {{0|setting:21}} // Z[2] orr x9, x8, {{0|setting:20}} // Z[3] orr x9, x9, {{0|setting:16}} // X[1] mov x10, 16 .loop_per_row_mul: // top ld1 { v0.s }[0], [x14], #4 dup v0.4s, v0.s[0] dup v1.4s, v0.s[0] dup v2.4s, v0.s[0] dup v3.4s, v0.s[0] st1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x1] {% amx ldy x1 %} {% amx extrx x2 %} {% amx extrx x3 %} {% amx fma32 x6 %} {% amx fma32 x7 %} add x2, x2, {{ 0|setting:22 }} add x3, x3, {{ 0|setting:22 }} add x6, x6, {{ 0|setting:22 }} add x7, x7, {{ 0|setting:22 }} // bottom ld1 { v0.s }[0], [x15], #4 dup v0.4s, v0.s[0] dup v1.4s, v0.s[0] dup v2.4s, v0.s[0] dup v3.4s, v0.s[0] st1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x1] {% amx ldy x1 %} {% amx extrx x4 %} {% amx extrx x5 %} {% amx fma32 x8 %} {% amx fma32 x9 %} add x4, x4, {{ 0|setting:22 }} add x5, x5, {{ 0|setting:22 }} add x8, x8, {{ 0|setting:22 }} add x9, x9, {{ 0|setting:22 }} subs x10, x10, 1 bne .loop_per_row_mul b .non_linear_loop .scalar_sub: // performs a unary neg on Z, then go to scalar_add eor x2, x2, x2 // X[0] = Z[0] mov x4, {{ 0|setting:63 }} // vector mode orr x4, x4, {{ 0|setting:28 }} orr x4, x4, {{ 0|setting:27 }} // Z=-X mov x6, 64 .scalar_sub_loop: {% amx extrx x2 %} {% amx fms32 x4 %} add x2, x2, {{0|setting:20}} // next Z row add x4, x4, {{0|setting:20}} // next Z row subs x6, x6, 1 bne .scalar_sub_loop // continue on purpose .scalar_add: ldr w5, [x0, #8] fmov s0, w5 dup v0.4s, v0.s[0] dup v1.4s, v0.s[0] dup v2.4s, v0.s[0] dup v3.4s, v0.s[0] st1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x1] {% amx ldx x1 %} // load 16 values mov x2, {{ 0|setting:28 }} // Z+=X {% for chunk in (0..3) %} {% amx fma32 x2 %} add x2, x2, {{0|setting:20}} // next Z row {% endfor %} b .non_linear_loop .scalar_sub_flipped: ldr w5, [x0, #8] fmov s0, w5 dup v0.4s, v0.s[0] dup v1.4s, v0.s[0] dup v2.4s, v0.s[0] dup v3.4s, v0.s[0] st1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x1] {% amx ldx x1 %} // load 16 values mov x2, {{ 0|setting:28 }} // Z-=X {% for chunk in (0..3) %} {% amx fms32 x2 %} add x2, x2, {{0|setting:20}} // next Z row {% endfor %} b .non_linear_loop .scalar_mul: ldr w5, [x0, #8] fmov s0, w5 dup v0.4s, v0.s[0] dup v1.4s, v0.s[0] dup v2.4s, v0.s[0] dup v3.4s, v0.s[0] st1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x1] {% amx ldy x1 %} // load 16 values eor x2, x2, x2 // X[0] = Z[0] mov x4, {{ 0|setting:63 }} // vector mode orr x4, x4, {{ 0|setting:27 }} // Z=X*Y mov x6, 64 .scalar_mul_loop: {% amx extrx x2 %} {% amx fma32 x4 %} add x2, x2, {{0|setting:20}} // next Z row add x4, x4, {{0|setting:20}} // next Z row subs x6, x6, 1 bne .scalar_mul_loop b .non_linear_loop .scalar_min: mov x2, 5 b .scalar_min_max .scalar_max: mov x2, 7 .scalar_min_max: ldr w5, [x0, #8] fmov s0, w5 dup v0.4s, v0.s[0] dup v1.4s, v0.s[0] dup v2.4s, v0.s[0] dup v3.4s, v0.s[0] st1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x1] {% amx ldx x1 %} // load 16 values lsl x2, x2, 47 orr x2, x2, {{ 0|setting:44 }} // f32 mov x3, 64 .loop_scalar_max: add x2, x2, {{ 0|setting:20}} // next Z {% amx vecfp x2 %} subs x3, x3, 1 bne .loop_scalar_max b .non_linear_loop .add_unicast: ldp x5, x6, [x0, #8] // c base ptr, rsc ldp x7, x8, [x0, #24] // csc, item_size add x8, x1, 64 mov x3, 0 // x3 is the row .loop_load: and x9, x3, 0xf // x9 = row % 16 lsl x9, x9, 2 // x9 = (row % 16) * 4 lsr x10, x3, 4 // x10 = row / 16 lsl x10, x10, 1 // x10 = (row / 16) * 2 add x9, x9, x10 // x9 = x9 + x10 mov x4, x5 {% for neon in (0..3) %} {% for lane in (0..3) %} ld1 { v{{neon}}.s }[{{lane}}], [x4], x7 {% endfor %} {% endfor %} st1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x1] {% for neon in (0..3) %} {% for lane in (0..3) %} ld1 { v{{neon}}.s }[{{lane}}], [x4], x7 {% endfor %} {% endfor %} st1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x8] mov x2, x1 orr x2, x2, {{ 0|setting:62 }} // load 32 values {% amx ldy x2 %} lsl x2, x9, 20 // left Z register to update orr x2, x2, {{ 0|setting:63 }} // vector mode orr x2, x2, {{ 0|setting:29 }} // perform Z+=Y {% amx fma32 x2 %} add x2, x2, {{0|setting:20}} orr x2, x2, 64 // offset Y by 16 values {% amx fma32 x2 %} add x5, x5, x6 add x3, x3, 1 cmp x3, 32 bne .loop_load b .non_linear_loop .add_row_col_products: ldp x5, x6, [x0, #8] // a base ptr, b base ptr add x8, x1, 64 ld1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x5], #64 st1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x1], #64 ld1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x5] st1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x1] sub x1, x1, #64 orr x2, x1, {{ 0|setting:62 }} // load a pair {% amx ldy x2 %} ld1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x6], #64 st1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x1], #64 ld1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x6] st1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x1] sub x1, x1, #64 orr x2, x1, {{ 0|setting:62 }} // load a pair {% amx ldx x2 %} // top left eor x2, x2, x2 {% amx fma32 x2 %} // top right orr x2, x2, {{ 0|setting:20 }} // Z row = 1 orr x2, x2, {{ 0|setting:16 }} // X offset {% amx fma32 x2 %} // bottom right orr x2, x2, {{ 0|setting:21 }} // Z row = 3 orr x2, x2, {{ 0|setting:6 }} // Y offset {% amx fma32 x2 %} // bottom left eor x2, x2, {{ 0|setting:20 }} // Z row = 2 eor x2, x2, {{ 0|setting:16 }} // X offset <- {% amx fma32 x2 %} b .non_linear_loop .store: ldp x5, x6, [x0, #8] // c base ptr, rsc ldp x7, x8, [x0, #24] // csc, item_size cmp x7, 4 bne .store_generic ands x8, x5, 0x7f bne .store_generic ands x8, x6, 0x7f bne .store_generic orr x5, x5, {{ 0|setting:62 }} // pair lsl x8, x6, 4 add x8, x8, x5 // x8 = 16*rsc orr x8, x8, {{ 0|setting:57 }} // first to x8 is z2 mov x4, {{0|setting:58}} // Zreg += 4 add x4, x4, x6 // +rsc mov x3, 16 .loop_store_direct: {% amx stz x5 %} {% amx stz x8 %} add x5, x5, x4 add x8, x8, x4 subs x3, x3, 1 bne .loop_store_direct b .non_linear_loop .store_generic: add x8, x1, 64 mov x3, 0 // row id .loop_store: and x9, x3, 0xf // x9 = row % 16 lsl x9, x9, 2 // x9 = (row % 16) * 4 lsr x10, x3, 4 // x10 = row / 16 lsl x10, x10, 1 // x10 = (row / 16) * 2 add x9, x9, x10 // x9 = x9 + x10 lsl x2, x9, 56 orr x2, x2, {{ 0|setting:62 }} orr x2, x2, x1 {% amx stz x2 %} ld1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x1] mov x4, x5 {% for neon in (0..3) %} {% for lane in (0..3) %} st1 { v{{neon}}.s }[{{lane}}], [x4], x7 {% endfor %} {% endfor %} ld1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x8] {% for neon in (0..3) %} {% for lane in (0..3) %} st1 { v{{neon}}.s }[{{lane}}], [x4], x7 {% endfor %} {% endfor %} add x5, x5, x6 add x3, x3, 1 cmp x3, 32 bne .loop_store b .non_linear_loop .return: {{ AMX_CLR }} ret