use candle_core::{Result, Tensor, Var}; use candle_nn::Optimizer; use candle_optimisers::{ adadelta::{Adadelta, ParamsAdaDelta}, adagrad::{Adagrad, ParamsAdaGrad}, adam::{Adam, ParamsAdam}, adamax::{Adamax, ParamsAdaMax}, esgd::{ParamsSGD, SGD}, nadam::{NAdam, ParamsNAdam}, radam::{ParamsRAdam, RAdam}, rmsprop::{ParamsRMSprop, RMSprop}, }; pub trait Optim: Sized { fn new(vars: Vec, lr: f64) -> Result; fn back_step(&mut self, loss: &Tensor) -> Result<()>; } impl Optim for Adadelta { fn new(vars: Vec, lr: f64) -> Result { ::new( vars, ParamsAdaDelta { lr, ..Default::default() }, ) } fn back_step(&mut self, loss: &Tensor) -> Result<()> { self.backward_step(loss) } } impl Optim for Adagrad { fn new(vars: Vec, lr: f64) -> Result { ::new( vars, ParamsAdaGrad { lr, ..Default::default() }, ) } fn back_step(&mut self, loss: &Tensor) -> Result<()> { self.backward_step(loss) } } impl Optim for Adamax { fn new(vars: Vec, lr: f64) -> Result { ::new( vars, ParamsAdaMax { lr, ..Default::default() }, ) } fn back_step(&mut self, loss: &Tensor) -> Result<()> { self.backward_step(loss) } } impl Optim for SGD { fn new(vars: Vec, lr: f64) -> Result { ::new( vars, ParamsSGD { lr, ..Default::default() }, ) } fn back_step(&mut self, loss: &Tensor) -> Result<()> { self.backward_step(loss) } } impl Optim for NAdam { fn new(vars: Vec, lr: f64) -> Result { ::new( vars, ParamsNAdam { lr, ..Default::default() }, ) } fn back_step(&mut self, loss: &Tensor) -> Result<()> { self.backward_step(loss) } } impl Optim for RAdam { fn new(vars: Vec, lr: f64) -> Result { ::new( vars, ParamsRAdam { lr, ..Default::default() }, ) } fn back_step(&mut self, loss: &Tensor) -> Result<()> { self.backward_step(loss) } } impl Optim for RMSprop { fn new(vars: Vec, lr: f64) -> Result { ::new( vars, ParamsRMSprop { lr, ..Default::default() }, ) } fn back_step(&mut self, loss: &Tensor) -> Result<()> { self.backward_step(loss) } } impl Optim for Adam { fn new(vars: Vec, lr: f64) -> Result { ::new( vars, ParamsAdam { lr, ..Default::default() }, ) } fn back_step(&mut self, loss: &Tensor) -> Result<()> { self.backward_step(loss) } }