#[cfg(feature = "autotune-persistent-cache")] use rand::{distributions::Alphanumeric, Rng}; use std::sync::Arc; #[cfg(feature = "autotune-persistent-cache")] use burn_compute::tune::compute_checksum; use burn_compute::{ server::Handle, tune::{AutotuneOperation, AutotuneOperationSet}, }; use crate::dummy::{ CacheTestFastOn3, CacheTestSlowOn3, DummyClient, DummyElementwiseAddition, DummyElementwiseMultiplication, DummyElementwiseMultiplicationSlowWrong, DummyServer, OneKernelAutotuneOperation, }; use super::DummyElementwiseAdditionSlowWrong; pub struct AdditionAutotuneOperationSet { client: DummyClient, key: String, shapes: Vec>, handles: Vec>, } impl AdditionAutotuneOperationSet { #[allow(dead_code)] pub fn new( client: DummyClient, shapes: Vec>, handles: Vec>, ) -> Self { Self { client, key: format!("{}-{}", "add", log_shape_input_key(&shapes)), shapes, handles, } } } impl AutotuneOperationSet for AdditionAutotuneOperationSet { fn key(&self) -> String { self.key.clone() } fn autotunables(&self) -> Vec> { vec![ Box::new(OneKernelAutotuneOperation::new( Arc::new(DummyElementwiseAddition), self.client.clone(), self.shapes.clone(), self.handles.clone(), )), Box::new(OneKernelAutotuneOperation::new( Arc::new(DummyElementwiseAdditionSlowWrong), self.client.clone(), self.shapes.clone(), self.handles.clone(), )), ] } fn fastest(self: Box, fastest_index: usize) -> Box { self.autotunables()[fastest_index].clone() } } pub struct MultiplicationAutotuneOperationSet { client: DummyClient, key: String, shapes: Vec>, handles: Vec>, } impl MultiplicationAutotuneOperationSet { #[allow(dead_code)] pub fn new( client: DummyClient, shapes: Vec>, handles: Vec>, ) -> Self { Self { client, key: format!("{}-{}", "mul", log_shape_input_key(&shapes)), shapes, handles, } } } impl AutotuneOperationSet for MultiplicationAutotuneOperationSet { fn key(&self) -> String { self.key.clone() } fn autotunables(&self) -> Vec> { vec![ Box::new(OneKernelAutotuneOperation::new( Arc::new(DummyElementwiseMultiplicationSlowWrong), self.client.clone(), self.shapes.clone(), self.handles.clone(), )), Box::new(OneKernelAutotuneOperation::new( Arc::new(DummyElementwiseMultiplication), self.client.clone(), self.shapes.clone(), self.handles.clone(), )), ] } fn fastest(self: Box, fastest_index: usize) -> Box { self.autotunables()[fastest_index].clone() } } pub struct CacheTestAutotuneOperationSet { client: DummyClient, key: String, shapes: Vec>, handles: Vec>, pub generate_random_checksum: bool, } impl CacheTestAutotuneOperationSet { #[allow(dead_code)] pub fn new( client: DummyClient, shapes: Vec>, handles: Vec>, ) -> Self { Self { client, key: format!("{}-{}", "cache_test", log_shape_input_key(&shapes)), shapes, handles, generate_random_checksum: false, } } } impl AutotuneOperationSet for CacheTestAutotuneOperationSet { fn key(&self) -> String { self.key.clone() } fn autotunables(&self) -> Vec> { vec![ Box::new(OneKernelAutotuneOperation::new( Arc::new(CacheTestFastOn3), self.client.clone(), self.shapes.clone(), self.handles.clone(), )), Box::new(OneKernelAutotuneOperation::new( Arc::new(CacheTestSlowOn3), self.client.clone(), self.shapes.clone(), self.handles.clone(), )), ] } fn fastest(self: Box, fastest_index: usize) -> Box { self.autotunables()[fastest_index].clone() } #[cfg(feature = "std")] fn compute_checksum(&self) -> String { if self.generate_random_checksum { let rand_string: String = rand::thread_rng() .sample_iter(&Alphanumeric) .take(16) .map(char::from) .collect(); rand_string } else { compute_checksum(&self.autotunables()) } } } pub fn log_shape_input_key(shapes: &[Vec]) -> String { let mut hash = String::new(); let lhs = &shapes[0]; for size in lhs { let exp = f32::ceil(f32::log2(*size as f32)) as u32; hash.push_str(2_u32.pow(exp).to_string().as_str()); hash.push(','); } hash }