use burn::{ module::Param, prelude::*, tensor::{activation::relu, backend::AutodiffBackend}, }; use nn::{Linear, LinearConfig}; use rl::algo::dqn::DQNModel; #[derive(Module, Debug)] pub struct Model { fc1: Linear, fc2: Linear, fc3: Linear, } #[derive(Config, Debug)] pub struct ModelConfig { fc1_out: usize, fc2_out: usize, } impl ModelConfig { pub fn init(&self, device: &B::Device) -> Model { Model { fc1: LinearConfig::new(4, self.fc1_out).init(device), fc2: LinearConfig::new(self.fc1_out, self.fc2_out).init(device), fc3: LinearConfig::new(self.fc2_out, 2).init(device), } } } impl DQNModel for Model { fn forward(&self, input: Tensor) -> Tensor { let x = relu(self.fc1.forward(input)); let x = relu(self.fc2.forward(x)); self.fc3.forward(x) } fn soft_update(self, other: &Self, tau: f32) -> Self { Self { fc1: soft_update_linear(self.fc1, &other.fc1, tau), fc2: soft_update_linear(self.fc2, &other.fc2, tau), fc3: soft_update_linear(self.fc3, &other.fc3, tau), } } } fn soft_update_tensor( this: Param>, that: &Param>, tau: f32, ) -> Param> { this.map(|tensor| tensor * (1.0 - tau) + that.val() * tau) } fn soft_update_linear(mut this: Linear, that: &Linear, tau: f32) -> Linear { this.weight = soft_update_tensor(this.weight, &that.weight, tau); this.bias = match (this.bias, &that.bias) { (Some(b1), Some(b2)) => Some(soft_update_tensor(b1, b2, tau)), _ => None, }; this }