{%- include "structs.wgsl" -%} struct Block { data: array<{{ elem_type }}> }; // X (input) @group(0) @binding(0) var input_0: Block; // Scale @group(0) @binding(1) var input_1: Array; // B (bias) @group(0) @binding(2) var input_2: Array; // Input mean @group(0) @binding(3) var input_3: Array; // Input variance @group(1) @binding(0) var input_4: Array; // Y (Output) @group(1) @binding(1) var output_0: Block; @compute @workgroup_size(1) fn main(@builtin(global_invocation_id) global_id: vec3) { let channel = global_id.y; let batch = global_id.z; let index = global_id.x + batch * {{ batch_size }}u + channel * {{ channel_size }}u; // Y = (X - input_mean) / sqrt(input_var + epsilon) * scale + B let x = input_0.data[index]; let channel_scale = input_1.data[channel]; let channel_bias = input_2.data[channel]; let channel_mean = input_3.data[channel]; let channel_var = input_4.data[channel]; output_0.data[index] = (x - channel_mean) / sqrt(channel_var + {{ epsilon }}) * channel_scale + channel_bias; }