use rand::prelude::*; use auto_diff::var::Var; use auto_diff::optim::{SGD}; use auto_diff::op::Linear; use auto_diff::op::OpCall; use auto_diff::err::AutoDiffError; extern crate openblas_src; fn main() { fn func(input: &Var) -> Result { let input = input.clone(); input.set_grad(false); let result = input.matmul(&Var::new(&vec![2., 3.], &vec![2, 1]))? + Var::new(&vec![1.], &vec![1]); result.set_grad(true); Ok(result) } let n = 15; let mut rng = StdRng::seed_from_u64(671); let data = Var::normal(&mut rng, &vec![n, 2], 0., 2.); let label = func(&data).unwrap(); let mut op1 = Linear::new(Some(2), Some(1), true); op1.set_weight(Var::normal(&mut rng, &[2, 1], 0., 2.)); op1.set_bias(Var::normal(&mut rng, &[1, ], 0., 2.)); let output = op1.call(&[&data]).unwrap().pop().unwrap(); let loss = output.mse_loss(&label).unwrap(); let mut opt = SGD::new(3.); for i in 0..200 { println!("index: {}, loss: {:?}", i, loss); loss.rerun().unwrap(); loss.bp().unwrap(); loss.step(&mut opt).unwrap(); let weight = op1.weight(); let bias = op1.bias(); println!("{:?}, {:?}", weight, bias); } }