use cubecl_core::ir as core; use cubecl_opt::{BasicBlock, NodeIndex, Optimizer}; use std::{ collections::HashSet, env, fmt::Debug, mem::take, ops::{Deref, DerefMut}, }; use cubecl_core::{ ir::{HybridAllocator, KernelDefinition, LocalAllocator}, Compiler, ExecutionMode, }; use rspirv::{ dr::{Builder, InsertPoint, Instruction, Module, Operand}, spirv::{BuiltIn, Capability, Decoration, FunctionControl, Op, StorageClass, Word}, }; use crate::{ item::Item, lookups::LookupTables, target::{GLCompute, SpirvTarget}, SpirvKernel, }; pub struct SpirvCompiler { pub target: Target, builder: Builder, pub mode: ExecutionMode, pub debug: bool, global_invocation_id: Word, num_workgroups: Word, pub setup_block: usize, pub opt: Optimizer, pub current_block: Option, pub visited: HashSet, pub capabilities: HashSet, pub state: LookupTables, } unsafe impl Send for SpirvCompiler {} unsafe impl Sync for SpirvCompiler {} impl Clone for SpirvCompiler { fn clone(&self) -> Self { Self { target: self.target.clone(), builder: Builder::new_from_module(self.module_ref().clone()), mode: self.mode, global_invocation_id: self.global_invocation_id, num_workgroups: self.num_workgroups, setup_block: self.setup_block, opt: self.opt.clone(), current_block: self.current_block, capabilities: self.capabilities.clone(), state: self.state.clone(), debug: self.debug, visited: self.visited.clone(), } } } impl Default for SpirvCompiler { fn default() -> Self { Self { target: Default::default(), builder: Builder::new(), mode: Default::default(), global_invocation_id: Default::default(), num_workgroups: Default::default(), capabilities: Default::default(), state: Default::default(), setup_block: Default::default(), opt: Default::default(), current_block: Default::default(), debug: env::var("CUBECL_DEBUG_LOG").is_ok(), visited: Default::default(), } } } impl Deref for SpirvCompiler { type Target = Builder; fn deref(&self) -> &Self::Target { &self.builder } } impl DerefMut for SpirvCompiler { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.builder } } impl Compiler for SpirvCompiler { type Representation = SpirvKernel; fn compile(kernel: KernelDefinition, mode: ExecutionMode) -> Self::Representation { let num_bindings = kernel.inputs.len() + kernel.outputs.len() + kernel.named.len(); let (module, optimizer) = Self { mode, ..Default::default() } .compile_kernel(kernel); SpirvKernel { module, optimizer, num_bindings, } } fn elem_size(elem: core::Elem) -> usize { elem.size() } fn local_allocator() -> impl LocalAllocator { HybridAllocator::default() } fn max_shared_memory_size() -> usize { 32768 } } impl Debug for SpirvCompiler { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "spirv<{:?}>", self.target) } } impl SpirvCompiler { pub fn compile_kernel(&mut self, kernel: KernelDefinition) -> (Module, Optimizer) { self.set_version(1, 6); self.init_state(kernel.clone()); let cube_dims = vec![kernel.cube_dim.x, kernel.cube_dim.y, kernel.cube_dim.z]; let mut target = self.target.clone(); let extensions = target.extensions(self); self.state.extensions = extensions; let void = self.type_void(); let voidf = self.type_function(void, vec![]); let main = self .begin_function(void, None, FunctionControl::NONE, voidf) .unwrap(); let setup = self.id(); self.debug_name(setup, "setup"); self.opt = Optimizer::new(kernel.body, kernel.cube_dim, self.mode); let entry = self.opt.entry(); let body = self.label(entry); let setup_block = self.setup(setup); self.setup_block = setup_block; self.compile_block(entry); let ret = self.opt.ret; self.compile_block(ret); if self.selected_block().is_some() { let label = self.label(ret); self.branch(label).unwrap(); } self.select_block(Some(setup_block)).unwrap(); self.branch(body).unwrap(); self.end_function().unwrap(); self.declare_shared_memories(); let builtins = self .state .used_builtins .clone() .into_iter() .map(|(builtin, (id, item))| { let ty = Item::Pointer(StorageClass::Input, Box::new(item)).id(self); self.variable(ty, Some(id), StorageClass::Input, None); self.decorate(id, Decoration::BuiltIn, vec![builtin.into()]); id }) .collect::>(); target.set_modes(self, main, builtins, cube_dims); let module = take(&mut self.builder).module(); (module, self.opt.clone()) } fn setup(&mut self, label: Word) -> usize { self.begin_block(Some(label)).unwrap(); let setup_block = self.selected_block().unwrap(); self.select_block(None).unwrap(); setup_block } #[track_caller] pub fn current_block(&self) -> &BasicBlock { self.opt.block(self.current_block.unwrap()) } pub fn builtin(&mut self, builtin: BuiltIn, item: Item) -> Word { if let Some(existing) = self.state.used_builtins.get(&builtin) { existing.0 } else { let id = self.id(); self.state.used_builtins.insert(builtin, (id, item)); id } } pub fn compile_block(&mut self, block: NodeIndex) { if self.visited.contains(&block) { return; } self.visited.insert(block); self.current_block = Some(block); let label = self.label(block); self.begin_block(Some(label)).unwrap(); let block_id = self.selected_block().unwrap(); let operations = self.current_block().ops.borrow().clone(); for (_, operation) in operations { self.compile_operation(operation); } let control_flow = self.current_block().control_flow.borrow().clone(); self.compile_control_flow(control_flow); let current = self.selected_block(); self.select_block(Some(block_id)).unwrap(); let phi = { self.opt.block(block).phi_nodes.borrow().clone() }; for phi in phi { let out = self.compile_variable(phi.out); let ty = out.item().id(self); let out_id = self.write_id(&out); let entries: Vec<_> = phi .entries .into_iter() .map(|it| { let label = self.end_label(it.block); let value = self.compile_variable(it.value); let value = self.read(&value); (value, label) }) .collect(); self.insert_phi(InsertPoint::Begin, ty, Some(out_id), entries) .unwrap(); } self.select_block(current).unwrap(); } // Declare variable in the first block of the function pub fn declare_function_variable(&mut self, ty: Word) -> Word { let setup = self.setup_block; let id = self.id(); let var = Instruction::new( Op::Variable, Some(ty), Some(id), vec![Operand::StorageClass(StorageClass::Function)], ); let current_block = self.selected_block(); self.select_block(Some(setup)).unwrap(); self.insert_into_block(InsertPoint::Begin, var).unwrap(); self.select_block(current_block).unwrap(); id } fn declare_shared_memories(&mut self) { let shared_memories = self.state.shared_memories.clone(); for (id, memory) in shared_memories { let arr_ty = Item::Array(Box::new(memory.item), memory.len); let ptr_ty = Item::Pointer(StorageClass::Workgroup, Box::new(arr_ty)).id(self); self.debug_name(memory.id, format!("shared({id})")); self.variable(ptr_ty, Some(memory.id), StorageClass::Workgroup, None); } } pub fn debug_name(&mut self, var: Word, name: impl Into) { if self.debug { self.name(var, name); } } }