// vim: ft=arm .text .align 4 /* Z: 64x32 tile. each Z reg is f16x32 Z0 Z2 ... Z62 Z1 Z3 S63 */ .global {{G}}apple_amx_mmm_f16_64x32_{{suffix}} {{G}}apple_amx_mmm_f16_64x32_{{suffix}}: {{ AMX_SET }} // set x1 to a 128 bytes aligned block for loads mov x1, sp lsr x1, x1, #7 lsl x1, x1, #7 sub x1, x1, 128 {% include "dispatcher.tmpliq" %} .leaky_relu: .q_scale: .q_shl: .q_shr: b .unsupported .add_mat_mul: ldr x2, [x0, #24] // b ldp x3, x4, [x0, #8] // k, a cmp x3, #0 beq .non_linear_loop orr x4, x4, {{0|setting:62}} // load pairs (A) eor x5, x5, x5 // top left orr x7, x5, {{ 0|setting:20 }} orr x7, x7, {{ 0|setting:6 }} // bottom left .packed_packed_loop_1: {% amx ldx x2 %} {% amx ldy x4 %} add x2, x2, 64 add x4, x4, 128 {% amx fma16 x5 %} {% amx fma16 x7 %} subs x3, x3, #1 bne .packed_packed_loop_1 b .non_linear_loop .clear: // top left eor x2, x2, x2 orr x2, x2, {{ 0|setting:27 }} orr x2, x2, {{ 0|setting:28 }} orr x2, x2, {{ 0|setting:29 }} // Z = 0 {% amx fma32 x2 %} // top right orr x2, x2, {{ 0|setting:20 }} // Z row = 1 {% amx fma32 x2 %} // bottom right orr x2, x2, {{ 0|setting:21 }} // Z row = 3 {% amx fma32 x2 %} // bottom left eor x2, x2, {{ 0|setting:20 }} // Z row = 2 {% amx fma32 x2 %} mov x3, #16 str x3, [x1] b .non_linear_loop .per_col_sub: // performs a unary neg on Z eor x2, x2, x2 // X[0] = Z[0] mov x4, {{ 0|setting:63 }} // vector mode orr x4, x4, {{ 0|setting:28 }} orr x4, x4, {{ 0|setting:27 }} // Z=-X mov x6, 64 .per_col_sub_loop: {% amx extrx x2 %} {% amx fms16 x4 %} add x2, x2, {{0|setting:20}} // next Z row add x4, x4, {{0|setting:20}} // next Z row subs x6, x6, 1 bne .per_col_sub_loop // continue .per_col_add: ldr x2, [x0, #8] ld1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x2] st1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x1] {% amx ldx x1 %} mov x2, {{ 0|setting:28 }} // z += y // top left {% amx fma16 x2 %} // bottom left orr x2, x2, {{ 0|setting:20 }} // Z row = 2 {% amx fma16 x2 %} b .non_linear_loop .per_col_sub_flipped: ldr x2, [x0, #8] ld1 { v0.8h, v1.8h, v2.8h, v3.8h }, [x2] st1 { v0.8h, v1.8h, v2.8h, v3.8h }, [x1] {% amx ldx x1 %} mov x2, {{ 0|setting:28 }} // z += y {% amx fms16 x2 %} orr x2, x2, {{ 0|setting:20 }} // Z row = 1 {% amx fms16 x2 %} b .non_linear_loop .per_row_sub_flipped: ldr x2, [x0, #8] ld1 { v0.8h, v1.8h, v2.8h, v3.8h }, [x2], #64 st1 { v0.8h, v1.8h, v2.8h, v3.8h }, [x1], #64 ld1 { v0.8h, v1.8h, v2.8h, v3.8h }, [x2] st1 { v0.8h, v1.8h, v2.8h, v3.8h }, [x1] sub x1, x1, #64 orr x2, x1, {{ 0|setting:62 }} // load a pair {% amx ldy x2 %} mov x2, {{ 0|setting:29 }} // z += y // top left {% amx fms16 x2 %} // bottom right orr x2, x2, {{ 0|setting:20 }} // Z row = 3 orr x2, x2, {{ 0|setting:6 }} // Y offset {% amx fms16 x2 %} b .non_linear_loop .per_row_sub: // performs a unary neg on Z eor x2, x2, x2 // X[0] = Z[0] mov x4, {{ 0|setting:63 }} // vector mode orr x4, x4, {{ 0|setting:28 }} orr x4, x4, {{ 0|setting:27 }} // Z=-X mov x6, 64 .per_row_sub_loop: {% amx extrx x2 %} {% amx fms16 x4 %} add x2, x2, {{0|setting:20}} // next Z row add x4, x4, {{0|setting:20}} // next Z row subs x6, x6, 1 bne .per_row_sub_loop // continue .per_row_add: ldr x2, [x0, #8] ld1 { v0.8h, v1.8h, v2.8h, v3.8h }, [x2], #64 st1 { v0.8h, v1.8h, v2.8h, v3.8h }, [x1], #64 ld1 { v0.8h, v1.8h, v2.8h, v3.8h }, [x2] st1 { v0.8h, v1.8h, v2.8h, v3.8h }, [x1] sub x1, x1, #64 orr x2, x1, {{ 0|setting:62 }} // load a pair {% amx ldy x2 %} mov x2, {{ 0|setting:29 }} // z += y // top left {% amx fma16 x2 %} // bottom right orr x2, x2, {{ 0|setting:20 }} // Z row = 1 orr x2, x2, {{ 0|setting:6 }} // Y offset {% amx fma16 x2 %} b .non_linear_loop .per_row_min: mov x2, 5 b .per_row_min_max .per_row_max: mov x2, 7 .per_row_min_max: ldr x5, [x0, #8] add x6, x5, 64 lsl x2, x2, 47 // max(x,z) (or min) orr x2, x2, {{ 0|setting:43 }} // f16 orr x8, x2, {{ 0|setting:20 }} // bottom left mov x4, 32 .loop_per_row_max: // top half ld1 { v0.h }[0], [x5], #2 dup v0.8h, v0.h[0] dup v1.8h, v0.h[0] dup v2.8h, v0.h[0] dup v3.8h, v0.h[0] st1 { v0.8h, v1.8h, v2.8h, v3.8h }, [x1] {% amx ldx x1 %} {% amx vecfp x2 %} add x2, x2, {{ 0|setting:21 }} // bottom half ld1 { v0.h }[0], [x6], #2 dup v0.8h, v0.h[0] dup v1.8h, v0.h[0] dup v2.8h, v0.h[0] dup v3.8h, v0.h[0] st1 { v0.8h, v1.8h, v2.8h, v3.8h }, [x1] {% amx ldx x1 %} {% amx vecfp x8 %} add x8, x8, {{ 0|setting:21 }} subs x4, x4, 1 bne .loop_per_row_max b .non_linear_loop .per_col_min: mov x2, 5 b .per_col_min_max .per_col_max: mov x2, 7 .per_col_min_max: ldr x4, [x0, #8] ld1 { v0.8h, v1.8h, v2.8h, v3.8h }, [x4] st1 { v0.8h, v1.8h, v2.8h, v3.8h }, [x1] {% amx ldx x1 %} lsl x2, x2, 47 // max(x,z) (or min) orr x2, x2, {{ 0|setting:43 }} // f16 mov x4, 64 .loop_per_col_max: {% amx vecfp x2 %} add x2, x2, {{ 0|setting:20 }} subs x4, x4, 1 bne .loop_per_col_max b .non_linear_loop .per_col_mul: ldr x4, [x0, #8] ld1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x4] st1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x1] {% amx ldy x1 %} eor x2, x2, x2 // X[0] = Z[0] mov x4, {{ 0|setting:63 }} // vector mode orr x4, x4, {{ 0|setting:27 }} // Z=X*Y mov x6, 64 .loop_per_col_mul: {% amx extrx x2 %} {% amx fma16 x4 %} add x2, x2, {{0|setting:20}} add x4, x4, {{0|setting:20}} subs x6, x6, 1 bne .loop_per_col_mul b .non_linear_loop .per_row_mul: ldr x14, [x0, #8] add x15, x14, 64 // extrx eor x2, x2, x2 // X[0] = Z[0] (top left) eor x4, x4, x4 orr x4, x4, {{0|setting:20}} // X[0] = Z[1] (bottom left) // fma16 eor x6, x6, x6 orr x6, x6, {{0|setting:63}} // vector mode orr x6, x6, {{0|setting:27}} // Z=X*Y Z[0]=X[0]*Y[0] orr x8, x6, {{0|setting:20}} // Z[1] mov x10, 32 .loop_per_row_mul: // top ld1 { v0.h }[0], [x14], #2 dup v0.8h, v0.h[0] dup v1.8h, v0.h[0] dup v2.8h, v0.h[0] dup v3.8h, v0.h[0] st1 { v0.8h, v1.8h, v2.8h, v3.8h }, [x1] {% amx ldy x1 %} {% amx extrx x2 %} {% amx fma16 x6 %} add x2, x2, {{ 0|setting:21 }} add x6, x6, {{ 0|setting:21 }} // bottom ld1 { v0.h }[0], [x15], #2 dup v0.8h, v0.h[0] dup v1.8h, v0.h[0] dup v2.8h, v0.h[0] dup v3.8h, v0.h[0] st1 { v0.8h, v1.8h, v2.8h, v3.8h }, [x1] {% amx ldy x1 %} {% amx extrx x4 %} {% amx fma16 x8 %} add x4, x4, {{ 0|setting:21 }} add x8, x8, {{ 0|setting:21 }} subs x10, x10, 1 bne .loop_per_row_mul b .non_linear_loop .scalar_sub: // performs a unary neg on Z, then go to scalar_add eor x2, x2, x2 // X[0] = Z[0] mov x4, {{ 0|setting:63 }} // vector mode orr x4, x4, {{ 0|setting:28 }} orr x4, x4, {{ 0|setting:27 }} // Z=-X mov x6, 64 .scalar_sub_loop: {% amx extrx x2 %} {% amx fms16 x4 %} add x2, x2, {{0|setting:20}} // next Z row add x4, x4, {{0|setting:20}} // next Z row subs x6, x6, 1 bne .scalar_sub_loop // continue on purpose .scalar_add: ldr w5, [x0, #8] fmov h0, w5 dup v0.8h, v0.h[0] dup v1.8h, v0.h[0] dup v2.8h, v0.h[0] dup v3.8h, v0.h[0] st1 { v0.8h, v1.8h, v2.8h, v3.8h }, [x1] {% amx ldx x1 %} // load 16 values mov x2, {{ 0|setting:28 }} // Z+=X {% amx fma16 x2 %} add x2, x2, {{0|setting:20}} // Z1 {% amx fma16 x2 %} b .non_linear_loop .scalar_sub_flipped: ldr w5, [x0, #8] fmov h0, w5 dup v0.8h, v0.h[0] dup v1.8h, v0.h[0] dup v2.8h, v0.h[0] dup v3.8h, v0.h[0] st1 { v0.8h, v1.8h, v2.8h, v3.8h }, [x1] {% amx ldx x1 %} // load 32 values mov x2, {{ 0|setting:28 }} // Z-=X {% amx fms16 x2 %} add x2, x2, {{0|setting:20}} // next Z row {% amx fms16 x2 %} b .non_linear_loop .scalar_mul: ldr w5, [x0, #8] fmov h0, w5 dup v0.8h, v0.h[0] dup v1.8h, v0.h[0] dup v2.8h, v0.h[0] dup v3.8h, v0.h[0] st1 { v0.8h, v1.8h, v2.8h, v3.8h }, [x1] {% amx ldy x1 %} // load 32 values eor x2, x2, x2 // X[0] = Z[0] mov x4, {{ 0|setting:63 }} // vector mode orr x4, x4, {{ 0|setting:27 }} // Z=X*Y mov x6, 64 .scalar_mul_loop: {% amx extrx x2 %} {% amx fma16 x4 %} add x2, x2, {{0|setting:20}} // next Z row add x4, x4, {{0|setting:20}} // next Z row subs x6, x6, 1 bne .scalar_mul_loop b .non_linear_loop .scalar_min: mov x2, 5 b .scalar_min_max .scalar_max: mov x2, 7 .scalar_min_max: ldr w5, [x0, #8] fmov h0, w5 dup v0.8h, v0.h[0] dup v1.8h, v0.h[0] dup v2.8h, v0.h[0] dup v3.8h, v0.h[0] st1 { v0.8h, v1.8h, v2.8h, v3.8h }, [x1] {% amx ldx x1 %} // load 16 values lsl x2, x2, 47 orr x2, x2, {{ 0|setting:43 }} // f32 mov x3, 64 .loop_scalar_max: add x2, x2, {{ 0|setting:20}} // next Z {% amx vecfp x2 %} subs x3, x3, 1 bne .loop_scalar_max b .non_linear_loop .add_unicast: ldp x5, x6, [x0, #8] // c base ptr, rsc ldp x7, x8, [x0, #24] // csc, item_size mov x3, 0 // x3 is the row .loop_load: // z reg is (row % 32) * 2 + (row / 32) and x9, x3, 0x1f lsl x9, x9, 1 lsr x10, x3, 5 add x9, x9, x10 mov x4, x5 {% for neon in (0..3) %} {% for lane in (0..7) %} ld1 { v{{neon}}.h }[{{lane}}], [x4], x7 {% endfor %} {% endfor %} st1 { v0.8h, v1.8h, v2.8h, v3.8h }, [x1] {% amx ldy x1 %} lsl x2, x9, 20 // Z register to update orr x2, x2, {{ 0|setting:63 }} // vector mode orr x2, x2, {{ 0|setting:29 }} // perform Z+=Y {% amx fma16 x2 %} add x5, x5, x6 add x3, x3, 1 cmp x3, 64 bne .loop_load /* mov x3, 0 // x3 is the row .loop_load: and x9, x3, 0xf // x9 = row % 16 lsl x9, x9, 2 // x9 = (row % 16) * 4 lsr x10, x3, 4 // x10 = row / 16 lsl x10, x10, 1 // x10 = (row / 16) * 2 add x9, x9, x10 // x9 = x9 + x10 mov x4, x5 {% for neon in (0..3) %} {% for lane in (0..3) %} ld1 { v{{neon}}.s }[{{lane}}], [x4], x7 {% endfor %} {% endfor %} st1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x1] {% for neon in (0..3) %} {% for lane in (0..3) %} ld1 { v{{neon}}.s }[{{lane}}], [x4], x7 {% endfor %} {% endfor %} st1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x8] mov x2, x1 orr x2, x2, {{ 0|setting:62 }} // load 32 values {% amx ldy x2 %} lsl x2, x9, 20 // left Z register to update orr x2, x2, {{ 0|setting:63 }} // vector mode orr x2, x2, {{ 0|setting:29 }} // perform Z+=Y {% amx fma32 x2 %} add x2, x2, {{0|setting:20}} orr x2, x2, 64 // offset Y by 16 values {% amx fma32 x2 %} add x5, x5, x6 add x3, x3, 1 cmp x3, 32 bne .loop_load */ b .non_linear_loop .add_row_col_products: ldp x5, x6, [x0, #8] // a base ptr, b base ptr add x8, x1, 64 ld1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x5], #64 st1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x1], #64 ld1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x5] st1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x1] sub x1, x1, #64 orr x2, x1, {{ 0|setting:62 }} // load a pair {% amx ldy x2 %} ld1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x6] st1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x1] {% amx ldx x1 %} // top eor x2, x2, x2 {% amx fma16 x2 %} // bottom right orr x2, x2, {{ 0|setting:20 }} // Z row = 1 orr x2, x2, {{ 0|setting:6 }} // Y offset {% amx fma16 x2 %} b .non_linear_loop .store: ldp x5, x6, [x0, #8] // c base ptr, rsc ldp x7, x8, [x0, #24] // csc, item_size cmp x7, 2 bne .store_generic ands x8, x5, 0x7f bne .store_generic ands x8, x6, 0x7f bne .store_generic lsl x8, x6, 5 add x8, x8, x5 // x8 = 32*rsc orr x8, x8, {{ 0|setting:56 }} // first to x8 is z1 mov x4, {{0|setting:57}} // Zreg += 2 add x4, x4, x6 // +rsc mov x3, 32 .loop_store_direct: {% amx stz x5 %} {% amx stz x8 %} add x5, x5, x4 add x8, x8, x4 subs x3, x3, 1 bne .loop_store_direct b .non_linear_loop .store_generic: mov x3, 0 // row id .loop_store: // z reg is (row % 32) * 2 + (row / 32) and x9, x3, 0x1f lsl x9, x9, 1 lsr x10, x3, 5 add x9, x9, x10 lsl x2, x9, 56 orr x2, x2, x1 {% amx stz x2 %} // f16 x 32 ld1 { v0.8h, v1.8h, v2.8h, v3.8h }, [x1] mov x4, x5 {% for neon in (0..3) %} {% for lane in (0..7) %} st1 { v{{neon}}.h }[{{lane}}], [x4], x7 {% endfor %} {% endfor %} add x5, x5, x6 add x3, x3, 1 cmp x3, 64 bne .loop_store b .non_linear_loop .return: {{ AMX_CLR }} ret