{%- include "structs.wgsl" -%} @group(0) @binding(0) var input_0: Array; {% for output in o_lens %} @group({{ loop.index / 4 | int }}) @binding({{ loop.index % 4}}) var output_{{ loop.index0 }}: Array; {% endfor %} @compute @workgroup_size(256, 1, 1) fn main(@builtin(global_invocation_id) global_id: vec3) { let gidx = global_id.x; if (gidx < {{ i_lens[0] }}u) { var rest = gidx; {%- for chunks in i_chunks[0] -%} {% if loop.last %} let d_{{ loop.index0 }} = rest; {% else %} let d_{{ loop.index0 }} = rest / {{ chunks }}u; rest = gidx % {{ chunks }}u; {% endif %} {%- endfor -%} {% for output in o_lens %} {%- if loop.first %} if (d_{{ axis }} < {{ split | first }}u) { let index = {%- for chunk in o_chunks | first -%} {%- if not loop.first %} + {%- endif -%} d_{{ loop.index0 }} * {{ chunk }}u {%- endfor -%} ; output_{{ loop.index0 }}.data[index] = input_0.data[gidx]; } {%- else %} {% set split_output = split | nth(n=loop.index0 -1) %} if ((d_{{ axis }} >= {{ split_output }}u) && (d_{{ axis }} < {{ split | nth(n=loop.index0)}}u)) { let index = {%- for chunk in o_chunks | nth(n=loop.index0) -%} {%- if not loop.first %} + {%- endif -%} {%- if loop.index0 == axis %} (d_{{ loop.index0 }} - {{ split_output }}u) * {{ chunk }}u {% else %} d_{{ loop.index0 }} * {{ chunk }}u {%- endif -%} {%- endfor -%} ; output_{{ loop.index0 }}.data[index] = input_0.data[gidx]; } {% endif %} {% endfor %} } }