#version 450 layout(local_size_x = 16, local_size_y = 16) in; layout(set = 0, binding = 0) readonly buffer InputA { float a[]; }; layout(set = 0, binding = 1) readonly buffer InputB { float b[]; }; layout(set = 0, binding = 2) buffer Output { float result[]; }; layout(push_constant) uniform PushConstants { uint M; uint N; uint K; } dims; shared float sharedA[16][16]; shared float sharedB[16][16]; void main() { uint row = gl_GlobalInvocationID.y; uint col = gl_GlobalInvocationID.x; uint localRow = gl_LocalInvocationID.y; uint localCol = gl_LocalInvocationID.x; float sum = 0.0; // Compute in tiles for (uint t = 0; t < (dims.K + 15) / 16; t++) { // Load tile to shared memory if (row < dims.M && t * 16 + localCol < dims.K) { sharedA[localRow][localCol] = a[row * dims.K + t * 16 + localCol]; } else { sharedA[localRow][localCol] = 0.0; } if (t * 16 + localRow < dims.K && col < dims.N) { sharedB[localRow][localCol] = b[(t * 16 + localRow) * dims.N + col]; } else { sharedB[localRow][localCol] = 0.0; } barrier(); memoryBarrierShared(); // Compute tile product for (uint k = 0; k < 16; k++) { sum += sharedA[localRow][k] * sharedB[k][localCol]; } barrier(); memoryBarrierShared(); } // Write result if (row < dims.M && col < dims.N) { result[row * dims.N + col] = sum; } }