use cubecl_core as cubecl; use cubecl_core::prelude::*; /// Traits used in Cube kernels must expose an _expand variant /// for all their methods. However, one does not need to provide its /// implementation, see examples below. #[cube] pub trait Strategy { fn operation(input_1: T, input_2: T) -> T; } struct AddStrategy; #[cube] /// The actual implementation of AddStrategy's operation /// Automatically generated an _expand variant pub fn add_strategy_operation(input_1: T, input_2: T) -> T { input_1 + input_2 } #[cube] impl Strategy for AddStrategy { fn operation(input_1: T, input_2: T) -> T { add_strategy_operation::(input_1, input_2) } } struct SubStrategy; #[cube] impl Strategy for SubStrategy { fn operation(input_1: T, input_2: T) -> T { input_1 - input_2 } } #[cube] pub fn with_strategy_trait, T: Numeric>(x: T, y: T) -> T { S::operation(x, y) } #[cube] pub fn two_strategy_traits, S2: Strategy, F: Float>(x: F, y: F) -> F { let z = S1::operation(x, y); S2::operation(z, y) } pub trait MethodTypedStrategy { fn operation(input_1: T, input_2: T) -> T; fn __expand_operation( _context: &mut CubeContext, input_1: ::ExpandType, input_2: ::ExpandType, ) -> ::ExpandType; } impl MethodTypedStrategy for AddStrategy { fn operation(input_1: T, input_2: T) -> T { add_strategy_operation(input_1, input_2) } fn __expand_operation( context: &mut CubeContext, input_1: ::ExpandType, input_2: ::ExpandType, ) -> ::ExpandType { add_strategy_operation::expand::(context, input_1, input_2) } } #[cube] pub fn with_trait_generic_method(x: T, y: T) -> T { S::operation::(x, y) } mod tests { use super::*; use cubecl_core::{ cpa, ir::{Item, Variable}, }; use pretty_assertions::assert_eq; type ElemType = f32; #[test] fn cube_strategy_trait_add_test() { let mut context = CubeContext::default(); let x = context.create_local_binding(Item::new(ElemType::as_elem())); let y = context.create_local_binding(Item::new(ElemType::as_elem())); with_strategy_trait::expand::(&mut context, x.into(), y.into()); let scope = context.into_scope(); assert_eq!( format!("{:#?}", scope.operations), inline_macro_ref_one(true) ); } #[test] fn cube_strategy_trait_sub_test() { let mut context = CubeContext::default(); let x = context.create_local_binding(Item::new(ElemType::as_elem())); let y = context.create_local_binding(Item::new(ElemType::as_elem())); with_strategy_trait::expand::(&mut context, x.into(), y.into()); let scope = context.into_scope(); assert_eq!( format!("{:#?}", scope.operations), inline_macro_ref_one(false) ); } #[test] fn cube_two_strategy_traits_test() { let mut context = CubeContext::default(); let x = context.create_local_binding(Item::new(ElemType::as_elem())); let y = context.create_local_binding(Item::new(ElemType::as_elem())); two_strategy_traits::expand::( &mut context, x.into(), y.into(), ); let scope = context.into_scope(); assert_eq!(format!("{:#?}", scope.operations), inline_macro_ref_two()); } #[test] fn cube_trait_generic_method_test() { let mut context = CubeContext::default(); let x = context.create_local_binding(Item::new(ElemType::as_elem())); let y = context.create_local_binding(Item::new(ElemType::as_elem())); with_trait_generic_method::expand::( &mut context, x.into(), y.into(), ); let scope = context.into_scope(); assert_eq!( format!("{:#?}", scope.operations), inline_macro_ref_one(true) ); } fn inline_macro_ref_one(is_add_strategy: bool) -> String { let mut context = CubeContext::default(); let item = Item::new(ElemType::as_elem()); let x = context.create_local_binding(item); let y = context.create_local_binding(item); let mut scope = context.into_scope(); let x: Variable = x.into(); let y: Variable = y.into(); match is_add_strategy { true => cpa!(scope, y = x + y), false => cpa!(scope, y = x - y), } format!("{:#?}", scope.operations) } fn inline_macro_ref_two() -> String { let mut context = CubeContext::default(); let item = Item::new(ElemType::as_elem()); let x = context.create_local_binding(item); let y = context.create_local_binding(item); let mut scope = context.into_scope(); let x: Variable = x.into(); let y: Variable = y.into(); cpa!(scope, x = x - y); cpa!(scope, y = x + y); format!("{:#?}", scope.operations) } }