use burn::{ module::Module, nn::loss::CrossEntropyLossConfig, tensor::{ backend::{AutodiffBackend, Backend}, Tensor, }, train::{ClassificationOutput, TrainOutput, TrainStep, ValidStep}, }; use burn_efficient_kan::{Kan as EfficientKan, KanOptions}; use crate::data::MnistBatch; #[derive(Module, Debug)] pub struct Kan { kan: EfficientKan, } impl Kan { pub fn new(options: &KanOptions, device: &B::Device) -> Self where B::FloatElem: ndarray_linalg::Scalar + ndarray_linalg::Lapack, { let kan = EfficientKan::new(options, device); Self { kan } } } impl Kan { pub fn forward(&self, input: Tensor) -> Tensor { let [batch_size, height, width] = input.dims(); let x = input.reshape([batch_size, height * width]); self.kan.forward(x) } pub fn forward_classification(&self, item: MnistBatch) -> ClassificationOutput { let targets = item.targets; let output = self.forward(item.images); let loss = CrossEntropyLossConfig::new() .init(&output.device()) .forward(output.clone(), targets.clone()); ClassificationOutput { loss, output, targets, } } } impl TrainStep, ClassificationOutput> for Kan { fn step(&self, item: MnistBatch) -> TrainOutput> { let item = self.forward_classification(item); TrainOutput::new(self, item.loss.backward(), item) } } impl ValidStep, ClassificationOutput> for Kan { fn step(&self, batch: MnistBatch) -> ClassificationOutput { self.forward_classification(batch) } }