use cubecl_core::ir::{Operation, Operator}; use cubecl_core::{ ir::{self as core, BinaryOperator, UnaryOperator}, ExecutionMode, }; use rspirv::spirv::{Capability, MemorySemantics, Scope, Word}; use crate::{ item::{Elem, Item}, lookups::Slice, variable::{ConstVal, IndexedVariable}, SpirvCompiler, SpirvTarget, }; impl SpirvCompiler { pub fn compile_operation(&mut self, op: Operation) { match op { Operation::Operator(operator) => self.compile_operator(operator), Operation::Branch(branch) => self.compile_branch(branch), Operation::Metadata(meta) => self.compile_meta(meta), Operation::Subcube(subcube) => self.compile_subcube(subcube), Operation::Synchronization(sync) => self.compile_sync(sync), Operation::CoopMma(cmma) => self.compile_cmma(cmma), } } pub fn compile_operator(&mut self, op: Operator) { match op { Operator::Index(op) => { let is_atomic = op.lhs.item().elem.is_atomic(); let value = self.compile_variable(op.lhs); let index = self.compile_variable(op.rhs); let out = self.compile_variable(op.out); if is_atomic { let checked = matches!(self.mode, ExecutionMode::Checked) && value.has_len(); let ptr = match self.index(&value, &index, !checked) { IndexedVariable::Pointer(ptr, _) => ptr, _ => unreachable!("Atomic is always pointer"), }; let out_id = out.as_binding().unwrap(); // This isn't great but atomics can't currently be constructed so should be fine self.merge_binding(out_id, ptr); } else { let out_id = self.read_indexed(&out, &value, &index); self.write(&out, out_id); } } Operator::IndexAssign(op) => { let index = self.compile_variable(op.lhs); let value = self.compile_variable(op.rhs); let out = self.compile_variable(op.out); let value_id = self.read_as(&value, &out.indexed_item()); self.write_indexed(&out, &index, value_id); } Operator::UncheckedIndex(op) => { let value = self.compile_variable(op.lhs); let index = self.compile_variable(op.rhs); let out = self.compile_variable(op.out); let out_id = self.read_indexed_unchecked(&out, &value, &index); self.write(&out, out_id); } Operator::UncheckedIndexAssign(op) => { let index = self.compile_variable(op.lhs); let value = self.compile_variable(op.rhs); let out = self.compile_variable(op.out); let value_id = self.read_as(&value, &out.indexed_item()); self.write_indexed_unchecked(&out, &index, value_id); } Operator::Slice(op) => { let item = self.compile_item(op.input.item()); let input = self.compile_variable(op.input); let start = self.compile_variable(op.start); let end = self.compile_variable(op.end); let out = match op.out { core::Variable::Slice { id, depth, .. } => (id, depth), _ => unreachable!(), }; let start_id = self.read(&start); let end_id = self.read(&end); let const_len = match (start.as_const(), end.as_const()) { (Some(start), Some(end)) => { let len = end.as_u32() - start.as_u32(); Some(len) } _ => None, }; self.state.slices.insert( out, Slice { ptr: input, offset: start_id, end: end_id, const_len, item, }, ); } Operator::Assign(op) => { let input = self.compile_variable(op.input); let out = self.compile_variable(op.out); let ty = out.item().id(self); let in_id = self.read(&input); let out_id = self.write_id(&out); let is_cast = input.item() != out.item(); if is_cast { if let Some(as_const) = input.as_const() { let cast = self.static_cast(as_const, &input.elem(), &out.item()); self.copy_object(ty, Some(out_id), cast).unwrap(); } else { input.item().cast_to(self, Some(out_id), in_id, &out.item()); } } else { self.copy_object(ty, Some(out_id), in_id).unwrap(); } self.write(&out, out_id); } Operator::Equal(op) => { self.compile_binary_op_bool(op, |b, lhs_ty, ty, lhs, rhs, out| { match lhs_ty.elem() { Elem::Bool => b.logical_equal(ty, Some(out), lhs, rhs), Elem::Int(_, _) => b.i_equal(ty, Some(out), lhs, rhs), Elem::Float(_) => b.f_ord_equal(ty, Some(out), lhs, rhs), Elem::Void => unreachable!(), } .unwrap(); }); } Operator::NotEqual(op) => { self.compile_binary_op_bool(op, |b, lhs_ty, ty, lhs, rhs, out| { match lhs_ty.elem() { Elem::Bool => b.logical_not_equal(ty, Some(out), lhs, rhs), Elem::Int(_, _) => b.i_not_equal(ty, Some(out), lhs, rhs), Elem::Float(_) => b.f_ord_not_equal(ty, Some(out), lhs, rhs), Elem::Void => unreachable!(), } .unwrap(); }); } Operator::Lower(op) => { self.compile_binary_op_bool(op, |b, lhs_ty, ty, lhs, rhs, out| { match lhs_ty.elem() { Elem::Int(_, false) => b.u_less_than(ty, Some(out), lhs, rhs), Elem::Int(_, true) => b.s_less_than(ty, Some(out), lhs, rhs), Elem::Float(_) => b.f_ord_less_than(ty, Some(out), lhs, rhs), _ => unreachable!(), } .unwrap(); }); } Operator::LowerEqual(op) => { self.compile_binary_op_bool(op, |b, lhs_ty, ty, lhs, rhs, out| { match lhs_ty.elem() { Elem::Int(_, false) => b.u_less_than_equal(ty, Some(out), lhs, rhs), Elem::Int(_, true) => b.s_less_than_equal(ty, Some(out), lhs, rhs), Elem::Float(_) => b.f_ord_less_than_equal(ty, Some(out), lhs, rhs), _ => unreachable!(), } .unwrap(); }); } Operator::Greater(op) => { self.compile_binary_op_bool(op, |b, lhs_ty, ty, lhs, rhs, out| { match lhs_ty.elem() { Elem::Int(_, false) => b.u_greater_than(ty, Some(out), lhs, rhs), Elem::Int(_, true) => b.s_greater_than(ty, Some(out), lhs, rhs), Elem::Float(_) => b.f_ord_greater_than(ty, Some(out), lhs, rhs), _ => unreachable!(), } .unwrap(); }); } Operator::GreaterEqual(op) => { self.compile_binary_op_bool(op, |b, lhs_ty, ty, lhs, rhs, out| { match lhs_ty.elem() { Elem::Int(_, false) => b.u_greater_than_equal(ty, Some(out), lhs, rhs), Elem::Int(_, true) => b.s_greater_than_equal(ty, Some(out), lhs, rhs), Elem::Float(_) => b.f_ord_greater_than_equal(ty, Some(out), lhs, rhs), _ => unreachable!(), } .unwrap(); }); } Operator::Add(op) => { self.compile_binary_op(op, |b, out_ty, ty, lhs, rhs, out| { match out_ty.elem() { Elem::Int(_, _) => b.i_add(ty, Some(out), lhs, rhs).unwrap(), Elem::Float(_) => b.f_add(ty, Some(out), lhs, rhs).unwrap(), _ => unreachable!(), }; }); } Operator::Sub(op) => { self.compile_binary_op(op, |b, out_ty, ty, lhs, rhs, out| { match out_ty.elem() { Elem::Int(_, _) => b.i_sub(ty, Some(out), lhs, rhs).unwrap(), Elem::Float(_) => b.f_sub(ty, Some(out), lhs, rhs).unwrap(), _ => unreachable!(), }; }); } Operator::Mul(op) => { self.compile_binary_op(op, |b, out_ty, ty, lhs, rhs, out| { match out_ty.elem() { Elem::Int(_, _) => b.i_mul(ty, Some(out), lhs, rhs).unwrap(), Elem::Float(_) => b.f_mul(ty, Some(out), lhs, rhs).unwrap(), _ => unreachable!(), }; }); } Operator::Div(op) => { self.compile_binary_op(op, |b, out_ty, ty, lhs, rhs, out| { match out_ty.elem() { Elem::Int(_, false) => b.u_div(ty, Some(out), lhs, rhs).unwrap(), Elem::Int(_, true) => b.s_div(ty, Some(out), lhs, rhs).unwrap(), Elem::Float(_) => b.f_div(ty, Some(out), lhs, rhs).unwrap(), _ => unreachable!(), }; }); } Operator::Remainder(op) => { self.compile_binary_op(op, |b, out_ty, ty, lhs, rhs, out| { match out_ty.elem() { Elem::Int(_, false) => b.u_mod(ty, Some(out), lhs, rhs).unwrap(), Elem::Int(_, true) => b.s_mod(ty, Some(out), lhs, rhs).unwrap(), Elem::Float(_) => b.f_mod(ty, Some(out), lhs, rhs).unwrap(), _ => unreachable!(), }; }); } Operator::Modulo(op) => { self.compile_binary_op(op, |b, out_ty, ty, lhs, rhs, out| { match out_ty.elem() { Elem::Int(_, false) => b.u_mod(ty, Some(out), lhs, rhs).unwrap(), Elem::Int(_, true) => b.s_rem(ty, Some(out), lhs, rhs).unwrap(), Elem::Float(_) => b.f_rem(ty, Some(out), lhs, rhs).unwrap(), _ => unreachable!(), }; }); } Operator::Dot(op) => { if op.lhs.item().vectorization.map(|it| it.get()).unwrap_or(1) == 1 { self.compile_binary_op(op, |b, out_ty, ty, lhs, rhs, out| { match out_ty.elem() { Elem::Int(_, _) => b.i_mul(ty, Some(out), lhs, rhs).unwrap(), Elem::Float(_) => b.f_mul(ty, Some(out), lhs, rhs).unwrap(), _ => unreachable!(), }; }); } else { let lhs = self.compile_variable(op.lhs); let rhs = self.compile_variable(op.rhs); let out = self.compile_variable(op.out); let ty = out.item().id(self); let lhs_id = self.read(&lhs); let rhs_id = self.read(&rhs); let out_id = self.write_id(&out); if matches!(lhs.elem(), Elem::Int(_, _)) { self.capabilities.insert(Capability::DotProduct); } match (lhs.elem(), rhs.elem()) { (Elem::Int(_, false), Elem::Int(_, false)) => { self.u_dot(ty, Some(out_id), lhs_id, rhs_id, None) } (Elem::Int(_, true), Elem::Int(_, false)) => { self.su_dot(ty, Some(out_id), lhs_id, rhs_id, None) } (Elem::Int(_, false), Elem::Int(_, true)) => { self.su_dot(ty, Some(out_id), rhs_id, lhs_id, None) } (Elem::Int(_, true), Elem::Int(_, true)) => { self.s_dot(ty, Some(out_id), lhs_id, rhs_id, None) } (Elem::Float(_), Elem::Float(_)) => { self.dot(ty, Some(out_id), lhs_id, rhs_id) } _ => unreachable!(), } .unwrap(); self.write(&out, out_id); } } Operator::Fma(op) => { let a = self.compile_variable(op.a); let b = self.compile_variable(op.b); let c = self.compile_variable(op.c); let out = self.compile_variable(op.out); let out_ty = out.item(); let a_id = self.read_as(&a, &out_ty); let b_id = self.read_as(&b, &out_ty); let c_id = self.read_as(&c, &out_ty); let out_id = self.write_id(&out); let ty = out_ty.id(self); let mul = self.f_mul(ty, None, a_id, b_id).unwrap(); self.f_add(ty, Some(out_id), mul, c_id).unwrap(); self.write(&out, out_id); } Operator::Recip(op) => { self.compile_unary_op_cast(op, |b, out_ty, ty, input, out| { let one = b.static_cast(ConstVal::Bit32(1), &Elem::Int(32, false), &out_ty); b.f_div(ty, Some(out), one, input).unwrap(); }); } Operator::And(op) => { self.compile_binary_op(op, |b, _, ty, lhs, rhs, out| { b.logical_and(ty, Some(out), lhs, rhs).unwrap(); }); } Operator::Or(op) => { self.compile_binary_op(op, |b, _, ty, lhs, rhs, out| { b.logical_or(ty, Some(out), lhs, rhs).unwrap(); }); } Operator::Not(op) => { self.compile_unary_op_cast(op, |b, _, ty, input, out| { b.logical_not(ty, Some(out), input).unwrap(); }); } Operator::Neg(op) => { self.compile_unary_op_cast(op, |b, out_ty, ty, input, out| { match out_ty.elem() { Elem::Int(_, true) => b.s_negate(ty, Some(out), input).unwrap(), Elem::Float(_) => b.f_negate(ty, Some(out), input).unwrap(), _ => unreachable!(), }; }); } Operator::BitwiseAnd(op) => self.compile_binary_op(op, |b, _, ty, lhs, rhs, out| { b.bitwise_and(ty, Some(out), lhs, rhs).unwrap(); }), Operator::BitwiseOr(op) => self.compile_binary_op(op, |b, _, ty, lhs, rhs, out| { b.bitwise_or(ty, Some(out), lhs, rhs).unwrap(); }), Operator::BitwiseXor(op) => self.compile_binary_op(op, |b, _, ty, lhs, rhs, out| { b.bitwise_xor(ty, Some(out), lhs, rhs).unwrap(); }), Operator::ShiftLeft(op) => self.compile_binary_op(op, |b, _, ty, lhs, rhs, out| { b.shift_left_logical(ty, Some(out), lhs, rhs).unwrap(); }), Operator::ShiftRight(op) => self.compile_binary_op(op, |b, _, ty, lhs, rhs, out| { b.shift_right_logical(ty, Some(out), lhs, rhs).unwrap(); }), Operator::Bitcast(op) => self.compile_unary_op(op, |b, _, ty, input, out| { b.bitcast(ty, Some(out), input).unwrap(); }), Operator::Erf(op) => self.compile_unary_op_cast(op, |b, out_ty, ty, input, out| { b.compile_erf(out_ty, ty, input, out); }), // Extension functions Operator::Normalize(op) => { self.compile_unary_op(op, |b, _, ty, input, out| { T::normalize(b, ty, input, out); }); } Operator::Magnitude(op) => { self.compile_unary_op(op, |b, _, ty, input, out| { T::magnitude(b, ty, input, out); }); } Operator::Abs(op) => { self.compile_unary_op_cast(op, |b, out_ty, ty, input, out| match out_ty.elem() { Elem::Int(_, _) => T::s_abs(b, ty, input, out), Elem::Float(_) => T::f_abs(b, ty, input, out), _ => unreachable!(), }); } Operator::Exp(op) => { self.compile_unary_op_cast(op, |b, _, ty, input, out| T::exp(b, ty, input, out)); } Operator::Log(op) => { self.compile_unary_op_cast(op, |b, _, ty, input, out| T::log(b, ty, input, out)) } Operator::Log1p(op) => { self.compile_unary_op_cast(op, |b, out_ty, ty, input, out| { let one = b.static_cast(ConstVal::Bit32(1), &Elem::Int(32, false), &out_ty); let add = match out_ty.elem() { Elem::Int(_, _) => b.i_add(ty, None, input, one).unwrap(), Elem::Float(_) => b.f_add(ty, None, input, one).unwrap(), _ => unreachable!(), }; T::log(b, ty, add, out) }); } Operator::Cos(op) => { self.compile_unary_op_cast(op, |b, _, ty, input, out| T::cos(b, ty, input, out)) } Operator::Sin(op) => { self.compile_unary_op_cast(op, |b, _, ty, input, out| T::sin(b, ty, input, out)) } Operator::Tanh(op) => { self.compile_unary_op_cast(op, |b, _, ty, input, out| T::tanh(b, ty, input, out)) } Operator::Powf(op) => self.compile_binary_op(op, |b, out_ty, ty, lhs, rhs, out| { let bool = match out_ty { Item::Scalar(_) => Elem::Bool.id(b), Item::Vector(_, factor) => Item::Vector(Elem::Bool, factor).id(b), _ => unreachable!(), }; let zero = out_ty.const_u32(b, 0); let one = out_ty.const_u32(b, 1); let two = out_ty.const_u32(b, 2); let modulo = b.f_rem(ty, None, rhs, two).unwrap(); let is_zero = b.f_ord_equal(bool, None, modulo, zero).unwrap(); let abs = b.id(); T::f_abs(b, ty, lhs, abs); let even = b.id(); T::pow(b, ty, abs, rhs, even); let cond2_0 = b.f_ord_equal(bool, None, modulo, one).unwrap(); let cond2_1 = b.f_ord_less_than(bool, None, lhs, zero).unwrap(); let cond2 = b.logical_and(bool, None, cond2_0, cond2_1).unwrap(); let neg_lhs = b.f_negate(ty, None, lhs).unwrap(); let pow2 = b.id(); T::pow(b, ty, neg_lhs, rhs, pow2); let pow2_neg = b.f_negate(ty, None, pow2).unwrap(); let default = b.id(); T::pow(b, ty, lhs, rhs, default); let sel1 = b.select(ty, None, cond2, pow2_neg, default).unwrap(); b.select(ty, Some(out), is_zero, even, sel1).unwrap(); }), Operator::Sqrt(op) => { self.compile_unary_op_cast(op, |b, _, ty, input, out| T::sqrt(b, ty, input, out)) } Operator::Round(op) => { self.compile_unary_op_cast(op, |b, _, ty, input, out| T::round(b, ty, input, out)) } Operator::Floor(op) => { self.compile_unary_op_cast(op, |b, _, ty, input, out| T::floor(b, ty, input, out)) } Operator::Ceil(op) => { self.compile_unary_op_cast(op, |b, _, ty, input, out| T::ceil(b, ty, input, out)) } Operator::Clamp(op) => { let input = self.compile_variable(op.input); let min = self.compile_variable(op.min_value); let max = self.compile_variable(op.max_value); let out = self.compile_variable(op.out); let out_ty = out.item(); let input = self.read_as(&input, &out_ty); let min = self.read_as(&min, &out_ty); let max = self.read_as(&max, &out_ty); let out_id = self.write_id(&out); let ty = out_ty.id(self); match out_ty.elem() { Elem::Int(_, false) => T::u_clamp(self, ty, input, min, max, out_id), Elem::Int(_, true) => T::s_clamp(self, ty, input, min, max, out_id), Elem::Float(_) => T::f_clamp(self, ty, input, min, max, out_id), _ => unreachable!(), } self.write(&out, out_id); } Operator::Max(op) => { self.compile_binary_op(op, |b, out_ty, ty, lhs, rhs, out| match out_ty.elem() { Elem::Int(_, false) => T::u_max(b, ty, lhs, rhs, out), Elem::Int(_, true) => T::s_max(b, ty, lhs, rhs, out), Elem::Float(_) => T::f_max(b, ty, lhs, rhs, out), _ => unreachable!(), }) } Operator::Min(op) => { self.compile_binary_op(op, |b, out_ty, ty, lhs, rhs, out| match out_ty.elem() { Elem::Int(_, false) => T::u_min(b, ty, lhs, rhs, out), Elem::Int(_, true) => T::s_min(b, ty, lhs, rhs, out), Elem::Float(_) => T::f_min(b, ty, lhs, rhs, out), _ => unreachable!(), }) } // Atomic ops Operator::AtomicLoad(op) => { let input = self.compile_variable(op.input); let out = self.compile_variable(op.out); let out_ty = out.item(); let input_id = input.id(self); let out_id = self.write_id(&out); let ty = out_ty.id(self); let memory = self.const_u32(Scope::Device as u32); let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits()); self.atomic_load(ty, Some(out_id), input_id, memory, semantics) .unwrap(); self.write(&out, out_id); } Operator::AtomicStore(op) => { let input = self.compile_variable(op.input); let out = self.compile_variable(op.out); let input_id = self.read(&input); let out_id = out.id(self); let memory = self.const_u32(Scope::Device as u32); let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits()); self.atomic_store(out_id, memory, semantics, input_id) .unwrap(); } Operator::AtomicSwap(op) => { let lhs = self.compile_variable(op.lhs); let rhs = self.compile_variable(op.rhs); let out = self.compile_variable(op.out); let out_ty = out.item(); let lhs_id = lhs.id(self); let rhs_id = self.read(&rhs); let out_id = self.write_id(&out); let ty = out_ty.id(self); let memory = self.const_u32(Scope::Device as u32); let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits()); self.atomic_exchange(ty, Some(out_id), lhs_id, memory, semantics, rhs_id) .unwrap(); self.write(&out, out_id); } Operator::AtomicCompareAndSwap(op) => { let atomic = self.compile_variable(op.input); let cmp = self.compile_variable(op.cmp); let val = self.compile_variable(op.val); let out = self.compile_variable(op.out); let out_ty = out.item(); let atomic_id = atomic.id(self); let cmp_id = self.read(&cmp); let val_id = self.read(&val); let out_id = self.write_id(&out); let ty = out_ty.id(self); let memory = self.const_u32(Scope::Device as u32); let semantics_success = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits()); let semantics_failure = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits()); self.atomic_compare_exchange( ty, Some(out_id), atomic_id, memory, semantics_success, semantics_failure, val_id, cmp_id, ) .unwrap(); self.write(&out, out_id); } Operator::AtomicAdd(op) => { let lhs = self.compile_variable(op.lhs); let rhs = self.compile_variable(op.rhs); let out = self.compile_variable(op.out); let out_ty = out.item(); let lhs_id = lhs.id(self); let rhs_id = self.read(&rhs); let out_id = self.write_id(&out); let ty = out_ty.id(self); let memory = self.const_u32(Scope::Device as u32); let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits()); self.atomic_i_add(ty, Some(out_id), lhs_id, memory, semantics, rhs_id) .unwrap(); self.write(&out, out_id); } Operator::AtomicSub(op) => { let lhs = self.compile_variable(op.lhs); let rhs = self.compile_variable(op.rhs); let out = self.compile_variable(op.out); let out_ty = out.item(); let lhs_id = lhs.id(self); let rhs_id = self.read(&rhs); let out_id = self.write_id(&out); let ty = out_ty.id(self); let memory = self.const_u32(Scope::Device as u32); let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits()); self.atomic_i_sub(ty, Some(out_id), lhs_id, memory, semantics, rhs_id) .unwrap(); self.write(&out, out_id); } Operator::AtomicMax(op) => { let lhs = self.compile_variable(op.lhs); let rhs = self.compile_variable(op.rhs); let out = self.compile_variable(op.out); let out_ty = out.item(); let lhs_id = lhs.id(self); let rhs_id = self.read(&rhs); let out_id = self.write_id(&out); let ty = out_ty.id(self); let memory = self.const_u32(Scope::Device as u32); let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits()); match out_ty.elem() { Elem::Int(_, false) => self .atomic_u_max(ty, Some(out_id), lhs_id, memory, semantics, rhs_id) .unwrap(), Elem::Int(_, true) => self .atomic_s_max(ty, Some(out_id), lhs_id, memory, semantics, rhs_id) .unwrap(), _ => unreachable!(), }; self.write(&out, out_id); } Operator::AtomicMin(op) => { let lhs = self.compile_variable(op.lhs); let rhs = self.compile_variable(op.rhs); let out = self.compile_variable(op.out); let out_ty = out.item(); let lhs_id = lhs.id(self); let rhs_id = self.read(&rhs); let out_id = self.write_id(&out); let ty = out_ty.id(self); let memory = self.const_u32(Scope::Device as u32); let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits()); match out_ty.elem() { Elem::Int(_, false) => self .atomic_u_min(ty, Some(out_id), lhs_id, memory, semantics, rhs_id) .unwrap(), Elem::Int(_, true) => self .atomic_s_min(ty, Some(out_id), lhs_id, memory, semantics, rhs_id) .unwrap(), _ => unreachable!(), }; self.write(&out, out_id); } Operator::AtomicAnd(op) => { let lhs = self.compile_variable(op.lhs); let rhs = self.compile_variable(op.rhs); let out = self.compile_variable(op.out); let out_ty = out.item(); let lhs_id = lhs.id(self); let rhs_id = self.read(&rhs); let out_id = self.write_id(&out); let ty = out_ty.id(self); let memory = self.const_u32(Scope::Device as u32); let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits()); self.atomic_and(ty, Some(out_id), lhs_id, memory, semantics, rhs_id) .unwrap(); self.write(&out, out_id); } Operator::AtomicOr(op) => { let lhs = self.compile_variable(op.lhs); let rhs = self.compile_variable(op.rhs); let out = self.compile_variable(op.out); let out_ty = out.item(); let lhs_id = lhs.id(self); let rhs_id = self.read(&rhs); let out_id = self.write_id(&out); let ty = out_ty.id(self); let memory = self.const_u32(Scope::Device as u32); let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits()); self.atomic_or(ty, Some(out_id), lhs_id, memory, semantics, rhs_id) .unwrap(); self.write(&out, out_id); } Operator::AtomicXor(op) => { let lhs = self.compile_variable(op.lhs); let rhs = self.compile_variable(op.rhs); let out = self.compile_variable(op.out); let out_ty = out.item(); let lhs_id = lhs.id(self); let rhs_id = self.read(&rhs); let out_id = self.write_id(&out); let ty = out_ty.id(self); let memory = self.const_u32(Scope::Device as u32); let semantics = self.const_u32(MemorySemantics::UNIFORM_MEMORY.bits()); self.atomic_xor(ty, Some(out_id), lhs_id, memory, semantics, rhs_id) .unwrap(); self.write(&out, out_id); } Operator::InitLine(op) => { let values = op .inputs .into_iter() .map(|input| self.compile_variable(input)) .collect::>() .into_iter() .map(|it| self.read(&it)) .collect::>(); let out = self.compile_variable(op.out); let out_id = self.write_id(&out); let item = self.compile_item(op.out.item()); let ty = item.id(self); self.composite_construct(ty, Some(out_id), values).unwrap(); self.write(&out, out_id); } Operator::Copy(op) => { let input = self.compile_variable(op.input); let in_index = self.compile_variable(op.in_index); let out = self.compile_variable(op.out); let out_index = self.compile_variable(op.out_index); let in_ptr = self.index_ptr(&input, &in_index); let out_ptr = self.index_ptr(&out, &out_index); let checked = matches!(self.mode, ExecutionMode::Checked) && input.has_len() && out.has_len(); if checked { let in_index = self.read(&in_index); let out_index = self.read(&out_index); self.compile_copy_bound(&input, &out, in_index, out_index, None, |b| { b.copy_memory(out_ptr, in_ptr, None, None, vec![]).unwrap(); }); } else { self.copy_memory(out_ptr, in_ptr, None, None, vec![]) .unwrap(); } } Operator::CopyBulk(op) => { self.capabilities.insert(Capability::Addresses); let input = self.compile_variable(op.input); let in_index = self.compile_variable(op.in_index); let out = self.compile_variable(op.out); let out_index = self.compile_variable(op.out_index); let source = self.index_ptr(&input, &in_index); let target = self.index_ptr(&out, &out_index); let size = self.const_u32(op.len * out.item().size()); let checked = matches!(self.mode, ExecutionMode::Checked) && input.has_len() && out.has_len(); if checked { let in_index = self.read(&in_index); let out_index = self.read(&out_index); self.compile_copy_bound(&input, &out, in_index, out_index, Some(size), |b| { b.copy_memory_sized(target, source, size, None, None, vec![]) .unwrap(); }); } else { self.copy_memory_sized(target, source, size, None, None, vec![]) .unwrap(); } } } } pub fn compile_unary_op_cast( &mut self, op: UnaryOperator, exec: impl FnOnce(&mut Self, Item, Word, Word, Word), ) { let input = self.compile_variable(op.input); let out = self.compile_variable(op.out); let out_ty = out.item(); let input_id = self.read_as(&input, &out_ty); let out_id = self.write_id(&out); let ty = out_ty.id(self); exec(self, out_ty, ty, input_id, out_id); self.write(&out, out_id); } pub fn compile_unary_op( &mut self, op: UnaryOperator, exec: impl FnOnce(&mut Self, Item, Word, Word, Word), ) { let input = self.compile_variable(op.input); let out = self.compile_variable(op.out); let out_ty = out.item(); let input_id = self.read(&input); let out_id = self.write_id(&out); let ty = out_ty.id(self); exec(self, out_ty, ty, input_id, out_id); self.write(&out, out_id); } pub fn compile_unary_op_bool( &mut self, op: UnaryOperator, exec: impl FnOnce(&mut Self, Item, Word, Word, Word), ) { let input = self.compile_variable(op.input); let out = self.compile_variable(op.out); let in_ty = input.item(); let input_id = self.read(&input); let out_id = self.write_id(&out); let ty = out.item().id(self); exec(self, in_ty, ty, input_id, out_id); self.write(&out, out_id); } pub fn compile_binary_op( &mut self, op: BinaryOperator, exec: impl FnOnce(&mut Self, Item, Word, Word, Word, Word), ) { let lhs = self.compile_variable(op.lhs); let rhs = self.compile_variable(op.rhs); let out = self.compile_variable(op.out); let out_ty = out.item(); let lhs_id = self.read_as(&lhs, &out_ty); let rhs_id = self.read_as(&rhs, &out_ty); let out_id = self.write_id(&out); let ty = out_ty.id(self); exec(self, out_ty, ty, lhs_id, rhs_id, out_id); self.write(&out, out_id); } pub fn compile_binary_op_no_cast( &mut self, op: BinaryOperator, exec: impl FnOnce(&mut Self, Item, Word, Word, Word, Word), ) { let lhs = self.compile_variable(op.lhs); let rhs = self.compile_variable(op.rhs); let out = self.compile_variable(op.out); let out_ty = out.item(); let lhs_id = self.read(&lhs); let rhs_id = self.read(&rhs); let out_id = self.write_id(&out); let ty = out_ty.id(self); exec(self, out_ty, ty, lhs_id, rhs_id, out_id); self.write(&out, out_id); } pub fn compile_binary_op_bool( &mut self, op: BinaryOperator, exec: impl FnOnce(&mut Self, Item, Word, Word, Word, Word), ) { let lhs = self.compile_variable(op.lhs); let rhs = self.compile_variable(op.rhs); let out = self.compile_variable(op.out); let lhs_ty = lhs.item(); let lhs_id = self.read(&lhs); let rhs_id = self.read_as(&rhs, &lhs_ty); let out_id = self.write_id(&out); let ty = out.item().id(self); exec(self, lhs_ty, ty, lhs_id, rhs_id, out_id); self.write(&out, out_id); } fn compile_erf(&mut self, out_ty: Item, ty: Word, input: Word, out: Word) { let bool = match out_ty { Item::Scalar(_) => Item::Scalar(Elem::Bool), Item::Vector(_, factor) => Item::Vector(Elem::Bool, factor), _ => unreachable!(), } .id(self); let mut cast = |val: f64| self.static_cast(ConstVal::Bit64(val.to_bits()), &Elem::Float(64), &out_ty); let p = cast(0.3275911); let a1 = cast(0.254829592); let a2 = cast(-0.284496736); let a3 = cast(1.421413741); let a4 = cast(-1.453152027); let a5 = cast(1.061405429); let one = cast(1.0); let zero = cast(0.0); let mul = |b: &mut Self, lhs: Word, rhs: Word| b.f_mul(ty, None, lhs, rhs).unwrap(); let add = |b: &mut Self, lhs: Word, rhs: Word| b.f_add(ty, None, lhs, rhs).unwrap(); let erf = |b: &mut Self, input: Word| { let abs = b.id(); T::f_abs(b, ty, input, abs); let t_0 = mul(b, p, abs); let t_1 = add(b, t_0, one); let t = b.f_div(ty, None, one, t_1).unwrap(); let tmp_1 = mul(b, a5, t); let tmp_2 = add(b, tmp_1, a4); let tmp_3 = mul(b, tmp_2, t); let tmp_4 = add(b, tmp_3, a3); let tmp_5 = mul(b, tmp_4, t); let tmp_6 = add(b, tmp_5, a2); let tmp_7 = mul(b, tmp_6, t); let tmp = add(b, tmp_7, a1); let ret_0 = b.f_negate(ty, None, input).unwrap(); let ret_1 = mul(b, ret_0, input); let ret_2 = b.id(); T::exp(b, ty, ret_1, ret_2); let ret_3 = mul(b, tmp, t); let ret_4 = mul(b, ret_2, ret_3); b.f_sub(ty, None, one, ret_4).unwrap() }; let cond = self.f_ord_less_than(bool, None, input, zero).unwrap(); let neg = { let neg_in = self.f_negate(ty, None, input).unwrap(); let res = erf(self, neg_in); self.f_negate(ty, None, res).unwrap() }; let pos = erf(self, input); self.select(ty, Some(out), cond, neg, pos).unwrap(); } }