// vim: ft=arm .text .align 4 /* Z: 32x1 z0[0] .. z0[15] z1[0] .. z1[15] */ .global {{G}}apple_amx_mmm_f16_64x1_{{suffix}} {{G}}apple_amx_mmm_f16_64x1_{{suffix}}: {{ AMX_SET }} // set x1 to a 128 bytes aligned block for loads mov x1, sp lsr x1, x1, #7 lsl x1, x1, #7 sub x1, x1, 128 {% include "dispatcher.tmpliq" %} .leaky_relu: .q_scale: .q_shl: .q_shr: b .unsupported .add_mat_mul: ldr x2, [x0, #24] // b ldp x3, x4, [x0, #8] // k, a cmp x3, #0 beq .non_linear_loop orr x4, x4, {{ 0|setting:62 }} // load a pair of A mov x5, {{ 0|setting:43 }} // f16 orr x5, x5, {{ 0|setting:38 }} // Broadcast Y orr x6, x5, {{ 0|setting:20 }} // z offset orr x6, x6, {{ 0|setting:16 }} // x offset cmp x3, #32 blt .packed_packed_loop_1 mov x9, {{0|setting:32}} // Y broadcast offset += 1 .packed_packed_loop_32: mov x7, x5 mov x8, x6 {% amx ldy x2 %} {% for k in (0..31) %} {% amx ldx x4 %} add x4, x4, 128 {% amx vecfp x7 %} {% amx vecfp x8 %} add x7, x7, x9 add x8, x8, x9 {% endfor %} add x2, x2, #64 sub x3, x3, #32 cmp x3, #32 bge .packed_packed_loop_32 cmp x3, #0 beq .non_linear_loop .packed_packed_loop_1: ldr w7, [x2], #2 str w7, [x1] {% amx ldx x4 %} {% amx ldy x1 %} {% amx vecfp x5 %} {% amx vecfp x6 %} add x4, x4, 128 subs x3, x3, #1 bne .packed_packed_loop_1 b .non_linear_loop .clear: // top left eor x2, x2, x2 orr x2, x2, {{ 0|setting:27 }} orr x2, x2, {{ 0|setting:28 }} orr x2, x2, {{ 0|setting:29 }} // Z = 0 {% amx fma32 x2 %} // top right orr x2, x2, {{ 0|setting:20 }} // Z row = 1 {% amx fma32 x2 %} // bottom right orr x2, x2, {{ 0|setting:21 }} // Z row = 3 {% amx fma32 x2 %} // bottom left eor x2, x2, {{ 0|setting:20 }} // Z row = 2 {% amx fma32 x2 %} b .non_linear_loop .per_col_sub: // performs a unary neg on Z eor x2, x2, x2 // X[0] = Z[0] // extr[hxyz] is suport confusing mov x4, {{ 0|setting:63 }} // vector mode orr x4, x4, {{ 0|setting:28 }} orr x4, x4, {{ 0|setting:27 }} // Z=-X {% amx extrx x2 %} {% amx fms16 x4 %} add x2, x2, {{0|setting:20}} // next Z row add x4, x4, {{0|setting:20}} // next Z row {% amx extrx x2 %} // extr[hxyz] is confusing {% amx fms16 x4 %} // continue .per_col_add: ldr x2, [x0, #8] // broadcast value to x0 ld1 { v0.h }[0], [x2] dup v0.8h, v0.h[0] st1 { v0.8h }, [x1], #16 st1 { v0.8h }, [x1], #16 st1 { v0.8h }, [x1], #16 st1 { v0.8h }, [x1], #16 sub x1, x1, #64 {% amx ldx x1 %} // load into x0 by default mov x2, {{ 0|setting:28 }} // z += y {% amx fma16 x2 %} orr x2, x2, {{ 0|setting:20 }} // target is now z1 {% amx fma16 x2 %} b .non_linear_loop .per_col_sub_flipped: ldr x2, [x0, #8] // broadcast value to x0 ld1 { v0.h }[0], [x2] dup v0.8h, v0.h[0] st1 { v0.8h }, [x1], #16 st1 { v0.8h }, [x1], #16 st1 { v0.8h }, [x1], #16 st1 { v0.8h }, [x1], #16 sub x1, x1, #64 {% amx ldx x1 %} // load into x0 by default mov x2, {{ 0|setting:28 }} // z += y {% amx fms16 x2 %} orr x2, x2, {{ 0|setting:20 }} // target is now z1 {% amx fms16 x2 %} b .non_linear_loop .per_row_sub_flipped: ldr x2, [x0, #8] ld1 { v0.8h, v1.8h, v2.8h, v3.8h }, [x2], #64 st1 { v0.8h, v1.8h, v2.8h, v3.8h }, [x1], #64 ld1 { v0.8h, v1.8h, v2.8h, v3.8h }, [x2] st1 { v0.8h, v1.8h, v2.8h, v3.8h }, [x1] sub x1, x1, #64 orr x2, x1, {{ 0|setting:62 }} // load a pair {% amx ldy x2 %} mov x2, {{ 0|setting:63 }} // vector mode orr x2, x2, {{ 0|setting:29 }} // z -= y // top left {% amx fms16 x2 %} // bottom left orr x2, x2, {{ 0|setting:20 }} // Z row = 1 orr x2, x2, {{ 0|setting:6 }} // Y offset {% amx fms16 x2 %} b .non_linear_loop .per_row_sub: // performs a unary neg on Z eor x2, x2, x2 // X[0] = Z[0] mov x4, {{ 0|setting:63 }} // vector mode orr x4, x4, {{ 0|setting:28 }} orr x4, x4, {{ 0|setting:27 }} // Z=-X {% amx extrx x2 %} {% amx fms16 x4 %} add x2, x2, {{0|setting:20}} // next Z row add x4, x4, {{0|setting:20}} // next Z row {% amx extrx x2 %} {% amx fms16 x4 %} // continue .per_row_add: ldr x2, [x0, #8] ld1 { v0.8h, v1.8h, v2.8h, v3.8h }, [x2], #64 st1 { v0.8h, v1.8h, v2.8h, v3.8h }, [x1], #64 ld1 { v0.8h, v1.8h, v2.8h, v3.8h }, [x2] st1 { v0.8h, v1.8h, v2.8h, v3.8h }, [x1] sub x1, x1, #64 orr x2, x1, {{ 0|setting:62 }} // load a pair {% amx ldy x2 %} mov x2, {{ 0|setting:63 }} // vector mode orr x2, x2, {{ 0|setting:29 }} // z += y // top left {% amx fma16 x2 %} // bottom left orr x2, x2, {{ 0|setting:20 }} // Z row = 1 orr x2, x2, {{ 0|setting:6 }} // Y offset {% amx fma16 x2 %} b .non_linear_loop .per_row_min: mov x2, 5 b .per_row_min_max .per_row_max: mov x2, 7 .per_row_min_max: ldr x5, [x0, #8] ld1 { v0.8h, v1.8h, v2.8h, v3.8h }, [x5], #64 st1 { v0.8h, v1.8h, v2.8h, v3.8h }, [x1], #64 ld1 { v0.8h, v1.8h, v2.8h, v3.8h }, [x5] st1 { v0.8h, v1.8h, v2.8h, v3.8h }, [x1] sub x1, x1, #64 orr x5, x1, {{ 0|setting:62 }} // load a pair {% amx ldx x5 %} lsl x2, x2, 47 // max(x,z) (or min) orr x2, x2, {{ 0|setting:44 }} // f32 {% amx vecfp x2 %} orr x2, x2, {{ 0|setting:16 }} // x1 orr x2, x2, {{ 0|setting:20 }} // z1 {% amx vecfp x2 %} b .non_linear_loop .per_col_min: mov x2, 5 b .per_col_min_max .per_col_max: mov x2, 7 .per_col_min_max: ldr x4, [x0, #8] // broadcast value to x0 ld1 { v0.h }[0], [x4] dup v0.8h, v0.h[0] st1 { v0.8h }, [x1], #16 st1 { v0.8h }, [x1], #16 st1 { v0.8h }, [x1], #16 st1 { v0.8h }, [x1], #16 sub x1, x1, #64 {% amx ldx x1 %} lsl x2, x2, 47 // max(x,z) (or min) orr x2, x2, {{ 0|setting:43 }} // f32 {% amx vecfp x2 %} orr x2, x2, {{ 0|setting:20 }} // z offset {% amx vecfp x2 %} b .non_linear_loop .per_col_mul: ldr x4, [x0, #8] // broadcast value to y0 ld1 { v0.h }[0], [x4] dup v0.8h, v0.h[0] st1 { v0.8h }, [x1], #16 st1 { v0.8h }, [x1], #16 st1 { v0.8h }, [x1], #16 st1 { v0.8h }, [x1], #16 sub x1, x1, #64 {% amx ldy x1 %} eor x2, x2, x2 // X[0] = Z[0] {% amx extrx x2 %} mov x4, {{ 0|setting:63 }} // vector mode orr x4, x4, {{ 0|setting:27 }} // Z=X*Y {% amx fma16 x4 %} orr x2, x2, {{ 0|setting:20 }} // Z1 {% amx extrx x2 %} orr x4, x4, {{ 0|setting:20 }} // Z1 {% amx fma16 x4 %} b .non_linear_loop .per_row_mul: ldr x2, [x0, #8] ld1 { v0.8h, v1.8h, v2.8h, v3.8h }, [x2], #64 st1 { v0.8h, v1.8h, v2.8h, v3.8h }, [x1], #64 ld1 { v0.8h, v1.8h, v2.8h, v3.8h }, [x2] st1 { v0.8h, v1.8h, v2.8h, v3.8h }, [x1] sub x1, x1, #64 orr x2, x1, {{ 0|setting:62 }} // pair {% amx ldy x2 %} eor x2, x2, x2 // X[0] = Z[0] {% amx extrx x2 %} mov x4, {{ 0|setting:63 }} // vector mode orr x4, x4, {{ 0|setting:27 }} // Z=X*Y {% amx fma16 x4 %} orr x2, x2, {{ 0|setting:20 }} // Z1 {% amx extrx x2 %} orr x4, x4, {{ 0|setting:20 }} // Z1 orr x4, x4, {{ 0|setting:6 }} // Y1 {% amx fma16 x4 %} b .non_linear_loop .scalar_sub: // performs a unary neg on Z, then go to scalar_add eor x2, x2, x2 // X[0] = Z[0] mov x4, {{ 0|setting:63 }} // vector mode orr x4, x4, {{ 0|setting:28 }} orr x4, x4, {{ 0|setting:27 }} // Z=-X {% amx extrx x2 %} {% amx fms16 x4 %} add x2, x2, {{0|setting:20}} // next Z row add x4, x4, {{0|setting:20}} // next Z row {% amx extrx x2 %} {% amx fms16 x4 %} // continue on purpose .scalar_add: ldr w5, [x0, #8] fmov h0, w5 dup v0.8h, v0.h[0] dup v1.8h, v0.h[0] dup v2.8h, v0.h[0] dup v3.8h, v0.h[0] st1 { v0.8h, v1.8h, v2.8h, v3.8h }, [x1] {% amx ldx x1 %} // load 16 values mov x2, {{ 0|setting:28 }} // Z+=X {% amx fma16 x2 %} add x2, x2, {{0|setting:20}} // next Z row {% amx fma16 x2 %} b .non_linear_loop .scalar_sub_flipped: ldr w5, [x0, #8] fmov s0, w5 dup v0.8h, v0.h[0] dup v1.8h, v0.h[0] dup v2.8h, v0.h[0] dup v3.8h, v0.h[0] st1 { v0.8h, v1.8h, v2.8h, v3.8h }, [x1] {% amx ldx x1 %} // load 16 values mov x2, {{ 0|setting:28 }} // Z-=X {% amx fms16 x2 %} add x2, x2, {{0|setting:20}} // next Z row {% amx fms16 x2 %} b .non_linear_loop .scalar_mul: ldr w5, [x0, #8] fmov h0, w5 dup v0.8h, v0.h[0] dup v1.8h, v0.h[0] dup v2.8h, v0.h[0] dup v3.8h, v0.h[0] st1 { v0.8h, v1.8h, v2.8h, v3.8h }, [x1] {% amx ldy x1 %} eor x2, x2, x2 // X[0] = Z[0] mov x4, {{ 0|setting:63 }} // vector mode orr x4, x4, {{ 0|setting:27 }} // Z=X*Y {% amx extrx x2 %} {% amx fma16 x4 %} add x2, x2, {{0|setting:20}} // next Z row add x4, x4, {{0|setting:20}} // next Z row {% amx extrx x2 %} {% amx fma16 x4 %} b .non_linear_loop .scalar_min: mov x2, 5 b .scalar_min_max .scalar_max: mov x2, 7 .scalar_min_max: ldr w5, [x0, #8] fmov h0, w5 dup v0.8h, v0.h[0] dup v1.8h, v0.h[0] dup v2.8h, v0.h[0] dup v3.8h, v0.h[0] st1 { v0.8h, v1.8h, v2.8h, v3.8h }, [x1] {% amx ldx x1 %} // load 16 values lsl x2, x2, 47 orr x2, x2, {{ 0|setting:43 }} // f16 {% amx vecfp x2 %} add x2, x2, {{ 0|setting:20}} // next Z {% amx vecfp x2 %} b .non_linear_loop .add_unicast: ldp x5, x6, [x0, #8] // c base ptr, rsc ldp x7, x8, [x0, #24] // csc, item_size {% for neon in (0..7) %} {% for lane in (0..7) %} ld1 { v{{neon}}.h }[{{lane}}], [x5], x6 {% endfor %} {% endfor %} mov x8, x1 st1 { v0.8h, v1.8h, v2.8h, v3.8h }, [x8], #64 st1 { v4.8h, v5.8h, v6.8h, v7.8h }, [x8], #64 orr x8, x1, {{ 0|setting:62 }} // pair {% amx ldy x8 %} eor x2, x2, x2 orr x2, x2, {{ 0|setting:63 }} // vector mode orr x2, x2, {{ 0|setting:29 }} // perform Z0+=Y0 {% amx fma16 x2 %} orr x2, x2, {{ 0|setting:20 }} // Z1 orr x2, x2, 64 // offset Y {% amx fma16 x2 %} b .non_linear_loop .add_row_col_products: ldp x5, x6, [x0, #8] // a base ptr, b base ptr ld1 { v0.h }[0], [x6] st1 { v0.h }[0], [x1] {% amx ldy x1 %} ld1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x5], #64 st1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x1], #64 ld1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x5] st1 { v0.4s, v1.4s, v2.4s, v3.4s }, [x1] sub x1, x1, #64 orr x2, x1, {{ 0|setting:62 }} // load a pair {% amx ldx x2 %} mov x2, {{ 0|setting:43 }} // f16 orr x2, x2, {{ 0|setting:38 }} // Broadcast Y {% amx vecfp x2 %} orr x2, x2, {{ 0|setting:20 }} // Z row = 1 orr x2, x2, {{ 0|setting:16 }} // X offset {% amx vecfp x2 %} b .non_linear_loop .store: ldp x5, x6, [x0, #8] // c base ptr, rsc ldp x7, x8, [x0, #24] // csc, item_size ands x8, x5, 0x7f bne .store_generic cmp x6, 4 bne .store_generic cmp x7, 4 bne .store_generic orr x5, x5, {{ 0|setting:62 }} // pair {% amx stz x5 %} b .non_linear_loop .store_generic: orr x8, x1, {{ 0|setting:62 }} // pair {% amx stz x8 %} mov x8, x1 ld1 { v0.8h, v1.8h, v2.8h, v3.8h }, [x8], #64 ld1 { v4.8h, v5.8h, v6.8h, v7.8h }, [x8], #64 {% for neon in (0..7) %} {% for lane in (0..7) %} st1 { v{{neon}}.h }[{{lane}}], [x5], x6 {% endfor %} {% endfor %} b .non_linear_loop .return: {{ AMX_CLR }} ret