//! Intro to dfdx::nn use dfdx::{ nn::builders::*, shapes::{Const, Rank1, Rank2}, tensor::{AsArray, AutoDevice, SampleTensor, Tensor, ZerosTensor}, }; fn main() { let dev = AutoDevice::default(); // `nn::builders` exposes many different neural network types, like the Linear layer! type Model = Linear<4, 2>; // you can use DeviceBuildExt::build_module to construct an initialized model. // the type of this is actually one of the structures under `nn::modules`. let mut m = dev.build_module::(); // `ResetParams::reset_params` also allows you to re-randomize the weights m.reset_params(); // Modules act on tensors using either: // 1. `Module::forward`, which does not mutate the module let _: Tensor, f32, _> = m.forward(dev.zeros::>()); // 2. `ModuleMut::forward_mut()`, which may mutate the module let _: Tensor, f32, _> = m.forward_mut(dev.zeros::>()); // most of them can also act on many different shapes of tensors // here we see that Linear can also accept batches of inputs // Note: the Rank2 with a batch size of 10 in the input AND the output let _: Tensor, f32, _> = m.forward(dev.zeros::>()); // Even dynamic size is supported; let batch_size = 3; let _: Tensor<(usize, Const<2>), f32, _> = m.forward(dev.zeros_like(&(batch_size, Const))); // you can also combine multiple modules with tuples type Mlp = (Linear<4, 2>, ReLU, Linear<2, 1>); let mlp = Mlp::build_on_device(&dev); // and of course forward passes the input through each module sequentially: let x: Tensor, f32, _> = dev.sample_normal(); let a = mlp.forward(x.clone()); let b = mlp.2.forward(mlp.1.forward(mlp.0.forward(x))); assert_eq!(a.array(), b.array()); // There are actually two ways to specify nn types, depending on whether // you want device/dtype agnostic types, or not. // For device agnostic structure, you can use the `nn::builders` api { use dfdx::nn::builders::*; type Model = (Linear<5, 2>, ReLU, Tanh, Linear<2, 3>); let _ = dev.build_module::(); } // The other way is to specify modules with device & dtype generics // using `nn::modules`. The structures in modules are actually // what the nn::builders turn into! { use dfdx::nn::modules::*; type Model = (Linear<5, 2, E, D>, ReLU, Tanh, Linear<2, 3, E, D>); let _: Model = BuildModule::build(&dev); } }