//! Demonstrates how to build a custom [nn::Module] without using tuples use dfdx::{ nn::modules::{Linear, Module, ModuleVisitor, ReLU, TensorCollection}, prelude::BuildModule, shapes::{Dtype, Rank1, Rank2}, tensor::{AutoDevice, SampleTensor, Tape, Tensor, Trace}, tensor_ops::Device, }; /// Custom model struct /// This case is trivial and should be done with a tuple of linears and relus, /// but it demonstrates how to build models with custom behavior struct Mlp> { l1: Linear, l2: Linear, relu: ReLU, } // TensorCollection lets you do several operations on Modules, including constructing them with // randomized parameters, and iterating through or mutating all tensors in a model. impl> TensorCollection for Mlp where E: Dtype + num_traits::Float + rand_distr::uniform::SampleUniform, { // Type alias that specifies the how Mlp's type changes when using a different dtype and/or // device. type To> = Mlp; fn iter_tensors>( visitor: &mut V, ) -> Result>, V::Err> { visitor.visit_fields( ( // Define name of each field and how to access it, using ModuleField for Modules, // and TensorField for Tensors. Self::module("l1", |s| &s.l1, |s| &mut s.l1), Self::module("l2", |s| &s.l2, |s| &mut s.l2), ), // Define how to construct the collection given its fields in the order they are given // above. This conversion is done using the ModuleFields trait. |(l1, l2)| Mlp { l1, l2, relu: Default::default(), }, ) } } // impl Module for single item impl> Module, E, D>> for Mlp { type Output = Tensor, E, D>; type Error = D::Err; fn try_forward(&self, x: Tensor, E, D>) -> Result { let x = self.l1.try_forward(x)?; let x = self.relu.try_forward(x)?; self.l2.try_forward(x) } } // impl Module for batch of items impl< const BATCH: usize, const IN: usize, const INNER: usize, const OUT: usize, E: Dtype, D: Device, T: Tape, > Module, E, D, T>> for Mlp { type Output = Tensor, E, D, T>; type Error = D::Err; fn try_forward(&self, x: Tensor, E, D, T>) -> Result { let x = self.l1.try_forward(x)?; let x = self.relu.try_forward(x)?; self.l2.try_forward(x) } } fn main() { // Rng for generating model's params let dev = AutoDevice::default(); // Construct model let model = Mlp::<10, 512, 20, f32, AutoDevice>::build(&dev); // Forward pass with a single sample let item: Tensor, f32, _> = dev.sample_normal(); let _: Tensor, f32, _> = model.forward(item); // Forward pass with a batch of samples let batch: Tensor, f32, _> = dev.sample_normal(); let _: Tensor, f32, _, _> = model.forward(batch.leaky_trace()); }