mod data; mod model; use crate::data::MnistBatcher; use burn::{ backend::{Autodiff, Wgpu}, data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset}, optim::AdamWConfig, prelude::*, record::{CompactRecorder, NoStdTrainingRecorder}, tensor::backend::AutodiffBackend, train::{ metric::{ store::{Aggregate, Direction, Split}, AccuracyMetric, CpuMemory, CpuTemperature, CpuUse, LossMetric, }, LearnerBuilder, MetricEarlyStoppingStrategy, StoppingCondition, }, }; use burn_efficient_kan::KanOptions; static ARTIFACT_DIR: &str = "/tmp/burn-example-mnist"; #[derive(Config)] pub struct KanTrainingConfig { #[config(default = 42)] pub seed: u64, #[config(default = 10)] pub num_epochs: usize, #[config(default = 64)] pub batch_size: usize, #[config(default = 4)] pub num_workers: usize, pub optimizer: AdamWConfig, pub kan_options: KanOptions, pub lr_scheduler: f64, } fn create_artifact_dir(artifact_dir: &str) { // Remove existing artifacts before to get an accurate learner summary std::fs::remove_dir_all(artifact_dir).ok(); std::fs::create_dir_all(artifact_dir).ok(); } pub fn run(device: B::Device) where B::FloatElem: ndarray_linalg::Scalar + ndarray_linalg::Lapack, { create_artifact_dir(ARTIFACT_DIR); // Config let config_optimizer = burn::optim::AdamWConfig::new().with_weight_decay(1e-4); let config = KanTrainingConfig::new(config_optimizer, KanOptions::new([784, 64, 10]), 1e-4); B::seed(config.seed); // Data let batcher_train = MnistBatcher::::new(device.clone()); let batcher_valid = MnistBatcher::::new(device.clone()); let dataloader_train = DataLoaderBuilder::new(batcher_train) .batch_size(config.batch_size) .shuffle(config.seed) .num_workers(config.num_workers) .build(MnistDataset::train()); let dataloader_test = DataLoaderBuilder::new(batcher_valid) .batch_size(config.batch_size) .shuffle(config.seed) .num_workers(config.num_workers) .build(MnistDataset::test()); // Model let learner = LearnerBuilder::new(ARTIFACT_DIR) .metric_train_numeric(AccuracyMetric::new()) .metric_valid_numeric(AccuracyMetric::new()) .metric_train_numeric(CpuUse::new()) .metric_valid_numeric(CpuUse::new()) .metric_train_numeric(CpuMemory::new()) .metric_valid_numeric(CpuMemory::new()) .metric_train_numeric(CpuTemperature::new()) .metric_valid_numeric(CpuTemperature::new()) .metric_train_numeric(LossMetric::new()) .metric_valid_numeric(LossMetric::new()) .with_file_checkpointer(CompactRecorder::new()) .early_stopping(MetricEarlyStoppingStrategy::new::>( Aggregate::Mean, Direction::Lowest, Split::Valid, StoppingCondition::NoImprovementSince { n_epochs: 1 }, )) .devices(vec![device.clone()]) .num_epochs(config.num_epochs) .summary() .build( model::Kan::new(&config.kan_options, &device), config.optimizer.init(), config.lr_scheduler, ); let model_trained = learner.fit(dataloader_train, dataloader_test); config .save(format!("{ARTIFACT_DIR}/config.json").as_str()) .unwrap(); model_trained .save_file( format!("{ARTIFACT_DIR}/model"), &NoStdTrainingRecorder::new(), ) .expect("Failed to save trained model"); } fn main() { let device = ::Device::default(); run::>(device); }