use candle_core::Result as CResult; use candle_datasets::vision::Dataset; use candle_optimisers::{ adadelta::Adadelta, adagrad::Adagrad, adam::Adam, adamax::Adamax, esgd::SGD, nadam::NAdam, radam::RAdam, rmsprop::RMSprop, }; use criterion::{criterion_group, criterion_main, Criterion}; use training::Mlp; // mod models; // mod optim; mod training; fn load_data() -> CResult { candle_datasets::vision::mnist::load() } #[allow(clippy::missing_panics_doc)] pub fn criterion_benchmark_std(c: &mut Criterion) { let mut group = c.benchmark_group("std-optimisers"); let m = &load_data().expect("Failed to load data"); // let m = Rc::new(m); group.significance_level(0.1).sample_size(100); group.bench_function("adadelta", |b| { b.iter(|| { training::run_training::(m).expect("Failed to setup training"); }); }); group.bench_function("adagrad", |b| { b.iter(|| { training::run_training::(m).expect("Failed to setup training"); }); }); group.bench_function("adam", |b| { b.iter(|| training::run_training::(m).expect("Failed to setup training")); }); group.bench_function("adamax", |b| { b.iter(|| { training::run_training::(m).expect("Failed to setup training"); }); }); group.bench_function("esgd", |b| { b.iter(|| { training::run_training::(m).expect("Failed to setup training"); }); }); group.bench_function("nadam", |b| { b.iter(|| { training::run_training::(m).expect("Failed to setup training"); }); }); group.bench_function("radam", |b| { b.iter(|| { training::run_training::(m).expect("Failed to setup training"); }); }); group.bench_function("rmsprop", |b| { b.iter(|| { training::run_training::(m).expect("Failed to setup training"); }); }); group.finish(); } #[allow(clippy::missing_panics_doc)] pub fn criterion_benchmark_lbfgs(c: &mut Criterion) { let mut group = c.benchmark_group("lbfgs-optimser"); let m = load_data().expect("Failed to load data"); // let m = Rc::new(m); group.significance_level(0.1).sample_size(10); group.bench_function("lbfgs", |b| { b.iter(|| training::run_lbfgs_training::(&m).expect("Failed to setup training")); }); group.finish(); } criterion_group!(benches, criterion_benchmark_std, criterion_benchmark_lbfgs); criterion_main!(benches);