use candle_core::{Result, Tensor}; use candle_nn::{Linear, Module, VarBuilder}; const IMAGE_DIM: usize = 784; const LABELS: usize = 10; fn linear_z(in_dim: usize, out_dim: usize, vs: &VarBuilder) -> Result { let ws = vs.get_with_hints((out_dim, in_dim), "weight", candle_nn::init::ZERO)?; let bs = vs.get_with_hints(out_dim, "bias", candle_nn::init::ZERO)?; Ok(Linear::new(ws, Some(bs))) } pub trait Model: Sized { fn new(vs: VarBuilder) -> Result; fn forward(&self, xs: &Tensor) -> Result; } pub struct LinearModel { linear: Linear, } impl Model for LinearModel { fn new(vs: VarBuilder) -> Result { let linear = linear_z(IMAGE_DIM, LABELS, &vs)?; Ok(Self { linear }) } fn forward(&self, xs: &Tensor) -> Result { self.linear.forward(xs) } } pub struct Mlp { ln1: Linear, ln2: Linear, } impl Model for Mlp { fn new(vs: VarBuilder) -> Result { let ln1 = candle_nn::linear(IMAGE_DIM, 100, vs.pp("ln1"))?; let ln2 = candle_nn::linear(100, LABELS, vs.pp("ln2"))?; Ok(Self { ln1, ln2 }) } fn forward(&self, xs: &Tensor) -> Result { let xs = self.ln1.forward(xs)?; let xs = xs.relu()?; self.ln2.forward(&xs) } }