use candle_core::test_utils::{to_vec0_round, to_vec2_round}; use anyhow::Result; use candle_core::{Device, Tensor, Var}; use candle_nn::{Linear, Module, Optimizer}; use candle_optimisers::rmsprop::{ParamsRMSprop, RMSprop}; /* The results of this test have been checked against the following PyTorch code. import torch from torch import optim w_gen = torch.tensor([[3., 1.]]) b_gen = torch.tensor([-2.]) sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]]) sample_ys = sample_xs.matmul(w_gen.t()) + b_gen m = torch.nn.Linear(2, 1) with torch.no_grad(): m.weight.zero_() m.bias.zero_() optimiser = optim.RMSprop(m.parameters()) # optimiser.zero_grad() for _step in range(100): optimiser.zero_grad() ys = m(sample_xs) loss = ((ys - sample_ys)**2).sum() loss.backward() optimiser.step() # print("Optimizer state begin") # print(optimiser.state) # print("Optimizer state end") print(m.weight) print(m.bias) */ #[test] fn rmsprop_test() -> Result<()> { // Generate some linear data, y = 3.x1 + x2 - 2. let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; let b_gen = Tensor::new(-2f32, &Device::Cpu)?; let gen = Linear::new(w_gen, Some(b_gen)); let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; let sample_ys = gen.forward(&sample_xs)?; let params = ParamsRMSprop::default(); // Now use backprop to run a linear regression between samples and get the coefficients back. let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; let b = Var::new(0f32, &Device::Cpu)?; let mut n_sgd = RMSprop::new(vec![w.clone(), b.clone()], params)?; let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); for _step in 0..100 { let ys = lin.forward(&sample_xs)?; let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; n_sgd.backward_step(&loss)?; } assert_eq!(to_vec2_round(&w, 4)?, &[[1.6650, 0.7867]]); assert_eq!(to_vec0_round(&b, 4)?, 1.3012); Ok(()) } /* The results of this test have been checked against the following PyTorch code. import torch from torch import optim w_gen = torch.tensor([[3., 1.]]) b_gen = torch.tensor([-2.]) sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]]) sample_ys = sample_xs.matmul(w_gen.t()) + b_gen m = torch.nn.Linear(2, 1) with torch.no_grad(): m.weight.zero_() m.bias.zero_() optimiser = optim.RMSprop(m.parameters(), weight_decay = 0.4) # optimiser.zero_grad() for _step in range(100): optimiser.zero_grad() ys = m(sample_xs) loss = ((ys - sample_ys)**2).sum() loss.backward() optimiser.step() # print("Optimizer state begin") # print(optimiser.state) # print("Optimizer state end") print(m.weight) print(m.bias) */ #[test] fn rmsprop_weight_decay_test() -> Result<()> { // Generate some linear data, y = 3.x1 + x2 - 2. let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; let b_gen = Tensor::new(-2f32, &Device::Cpu)?; let gen = Linear::new(w_gen, Some(b_gen)); let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; let sample_ys = gen.forward(&sample_xs)?; let params = ParamsRMSprop { weight_decay: Some(0.4), ..Default::default() }; // Now use backprop to run a linear regression between samples and get the coefficients back. let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; let b = Var::new(0f32, &Device::Cpu)?; let mut n_sgd = RMSprop::new(vec![w.clone(), b.clone()], params)?; let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); for _step in 0..100 { let ys = lin.forward(&sample_xs)?; let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; n_sgd.backward_step(&loss)?; } assert_eq!(to_vec2_round(&w, 4)?, &[[1.6643, 0.7867]]); assert_eq!(to_vec0_round(&b, 4)?, 1.2926); Ok(()) } /* The results of this test have been checked against the following PyTorch code. import torch from torch import optim w_gen = torch.tensor([[3., 1.]]) b_gen = torch.tensor([-2.]) sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]]) sample_ys = sample_xs.matmul(w_gen.t()) + b_gen m = torch.nn.Linear(2, 1) with torch.no_grad(): m.weight.zero_() m.bias.zero_() optimiser = optim.RMSprop(m.parameters(), centered = True) # optimiser.zero_grad() for _step in range(100): optimiser.zero_grad() ys = m(sample_xs) loss = ((ys - sample_ys)**2).sum() loss.backward() optimiser.step() # print("Optimizer state begin") # print(optimiser.state) # print("Optimizer state end") print(m.weight) print(m.bias) */ #[test] fn rmsprop_centered_test() -> Result<()> { // Generate some linear data, y = 3.x1 + x2 - 2. let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; let b_gen = Tensor::new(-2f32, &Device::Cpu)?; let gen = Linear::new(w_gen, Some(b_gen)); let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; let sample_ys = gen.forward(&sample_xs)?; let params = ParamsRMSprop { centered: true, ..Default::default() }; // Now use backprop to run a linear regression between samples and get the coefficients back. let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; let b = Var::new(0f32, &Device::Cpu)?; let mut n_sgd = RMSprop::new(vec![w.clone(), b.clone()], params)?; let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); for _step in 0..100 { let ys = lin.forward(&sample_xs)?; let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; n_sgd.backward_step(&loss)?; } assert_eq!(to_vec2_round(&w, 4)?, &[[1.8892, 0.7617]]); assert_eq!(to_vec0_round(&b, 4)?, 1.3688); Ok(()) } /* The results of this test have been checked against the following PyTorch code. import torch from torch import optim w_gen = torch.tensor([[3., 1.]]) b_gen = torch.tensor([-2.]) sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]]) sample_ys = sample_xs.matmul(w_gen.t()) + b_gen m = torch.nn.Linear(2, 1) with torch.no_grad(): m.weight.zero_() m.bias.zero_() optimiser = optim.RMSprop(m.parameters(), centered = True) # optimiser.zero_grad() for _step in range(100): optimiser.zero_grad() ys = m(sample_xs) loss = ((ys - sample_ys)**2).sum() loss.backward() optimiser.step() # print("Optimizer state begin") # print(optimiser.state) # print("Optimizer state end") print(m.weight) print(m.bias) */ #[test] fn rmsprop_centered_decay_test() -> Result<()> { // Generate some linear data, y = 3.x1 + x2 - 2. let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; let b_gen = Tensor::new(-2f32, &Device::Cpu)?; let gen = Linear::new(w_gen, Some(b_gen)); let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; let sample_ys = gen.forward(&sample_xs)?; let params = ParamsRMSprop { centered: true, weight_decay: Some(0.4), ..Default::default() }; // Now use backprop to run a linear regression between samples and get the coefficients back. let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; let b = Var::new(0f32, &Device::Cpu)?; let mut n_sgd = RMSprop::new(vec![w.clone(), b.clone()], params)?; let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); for _step in 0..100 { let ys = lin.forward(&sample_xs)?; let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; n_sgd.backward_step(&loss)?; } assert_eq!(to_vec2_round(&w, 4)?, &[[1.8883, 0.7621]]); assert_eq!(to_vec0_round(&b, 4)?, 1.3558); Ok(()) } /* The results of this test have been checked against the following PyTorch code. import torch from torch import optim w_gen = torch.tensor([[3., 1.]]) b_gen = torch.tensor([-2.]) sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]]) sample_ys = sample_xs.matmul(w_gen.t()) + b_gen m = torch.nn.Linear(2, 1) with torch.no_grad(): m.weight.zero_() m.bias.zero_() optimiser = optim.RMSprop(m.parameters(), momentum = 0.4) # optimiser.zero_grad() for _step in range(100): optimiser.zero_grad() ys = m(sample_xs) loss = ((ys - sample_ys)**2).sum() loss.backward() optimiser.step() # print("Optimizer state begin") # print(optimiser.state) # print("Optimizer state end") print(m.weight) print(m.bias) */ #[test] fn rmsprop_momentum_test() -> Result<()> { // Generate some linear data, y = 3.x1 + x2 - 2. let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; let b_gen = Tensor::new(-2f32, &Device::Cpu)?; let gen = Linear::new(w_gen, Some(b_gen)); let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; let sample_ys = gen.forward(&sample_xs)?; let params = ParamsRMSprop { momentum: Some(0.4), ..Default::default() }; // Now use backprop to run a linear regression between samples and get the coefficients back. let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; let b = Var::new(0f32, &Device::Cpu)?; let mut n_sgd = RMSprop::new(vec![w.clone(), b.clone()], params)?; let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); for _step in 0..100 { let ys = lin.forward(&sample_xs)?; let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; n_sgd.backward_step(&loss)?; } assert_eq!(to_vec2_round(&w, 4)?, &[[2.3042, 0.6835]]); assert_eq!(to_vec0_round(&b, 4)?, 1.5441); Ok(()) } /* The results of this test have been checked against the following PyTorch code. import torch from torch import optim w_gen = torch.tensor([[3., 1.]]) b_gen = torch.tensor([-2.]) sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]]) sample_ys = sample_xs.matmul(w_gen.t()) + b_gen m = torch.nn.Linear(2, 1) with torch.no_grad(): m.weight.zero_() m.bias.zero_() optimiser = optim.RMSprop(m.parameters(), momentum = 0.4, weight_decay = 0.4) # optimiser.zero_grad() for _step in range(100): optimiser.zero_grad() ys = m(sample_xs) loss = ((ys - sample_ys)**2).sum() loss.backward() optimiser.step() # print("Optimizer state begin") # print(optimiser.state) # print("Optimizer state end") print(m.weight) print(m.bias) */ #[test] fn rmsprop_momentum_decay_test() -> Result<()> { // Generate some linear data, y = 3.x1 + x2 - 2. let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; let b_gen = Tensor::new(-2f32, &Device::Cpu)?; let gen = Linear::new(w_gen, Some(b_gen)); let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; let sample_ys = gen.forward(&sample_xs)?; let params = ParamsRMSprop { momentum: Some(0.4), weight_decay: Some(0.4), ..Default::default() }; // Now use backprop to run a linear regression between samples and get the coefficients back. let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; let b = Var::new(0f32, &Device::Cpu)?; let mut n_sgd = RMSprop::new(vec![w.clone(), b.clone()], params)?; let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); for _step in 0..100 { let ys = lin.forward(&sample_xs)?; let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; n_sgd.backward_step(&loss)?; } assert_eq!(to_vec2_round(&w, 4)?, &[[2.3028, 0.6858]]); assert_eq!(to_vec0_round(&b, 4)?, 1.5149); Ok(()) } /* The results of this test have been checked against the following PyTorch code. import torch from torch import optim w_gen = torch.tensor([[3., 1.]]) b_gen = torch.tensor([-2.]) sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]]) sample_ys = sample_xs.matmul(w_gen.t()) + b_gen m = torch.nn.Linear(2, 1) with torch.no_grad(): m.weight.zero_() m.bias.zero_() optimiser = optim.RMSprop(m.parameters(), centered = True, momentum = 0.4) # optimiser.zero_grad() for _step in range(100): optimiser.zero_grad() ys = m(sample_xs) loss = ((ys - sample_ys)**2).sum() loss.backward() optimiser.step() # print("Optimizer state begin") # print(optimiser.state) # print("Optimizer state end") print(m.weight) print(m.bias) */ #[test] fn rmsprop_centered_momentum_test() -> Result<()> { // Generate some linear data, y = 3.x1 + x2 - 2. let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; let b_gen = Tensor::new(-2f32, &Device::Cpu)?; let gen = Linear::new(w_gen, Some(b_gen)); let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; let sample_ys = gen.forward(&sample_xs)?; let params = ParamsRMSprop { centered: true, momentum: Some(0.4), ..Default::default() }; // Now use backprop to run a linear regression between samples and get the coefficients back. let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; let b = Var::new(0f32, &Device::Cpu)?; let mut n_sgd = RMSprop::new(vec![w.clone(), b.clone()], params)?; let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); for _step in 0..100 { let ys = lin.forward(&sample_xs)?; let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; n_sgd.backward_step(&loss)?; } assert_eq!(to_vec2_round(&w, 4)?, &[[2.4486, 0.6715]]); assert_eq!(to_vec0_round(&b, 4)?, 1.5045); Ok(()) } /* The results of this test have been checked against the following PyTorch code. import torch from torch import optim w_gen = torch.tensor([[3., 1.]]) b_gen = torch.tensor([-2.]) sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]]) sample_ys = sample_xs.matmul(w_gen.t()) + b_gen m = torch.nn.Linear(2, 1) with torch.no_grad(): m.weight.zero_() m.bias.zero_() optimiser = optim.RMSprop(m.parameters(), centered = True, momentum = 0.4, weight_decay = 0.4) # optimiser.zero_grad() for _step in range(100): optimiser.zero_grad() ys = m(sample_xs) loss = ((ys - sample_ys)**2).sum() loss.backward() optimiser.step() # print("Optimizer state begin") # print(optimiser.state) # print("Optimizer state end") print(m.weight) print(m.bias) */ #[test] fn rmsprop_centered_momentum_decay_test() -> Result<()> { // Generate some linear data, y = 3.x1 + x2 - 2. let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?; let b_gen = Tensor::new(-2f32, &Device::Cpu)?; let gen = Linear::new(w_gen, Some(b_gen)); let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?; let sample_ys = gen.forward(&sample_xs)?; let params = ParamsRMSprop { centered: true, momentum: Some(0.4), weight_decay: Some(0.4), ..Default::default() }; // Now use backprop to run a linear regression between samples and get the coefficients back. let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?; let b = Var::new(0f32, &Device::Cpu)?; let mut n_sgd = RMSprop::new(vec![w.clone(), b.clone()], params)?; let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone())); for _step in 0..100 { let ys = lin.forward(&sample_xs)?; let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?; n_sgd.backward_step(&loss)?; } assert_eq!(to_vec2_round(&w, 4)?, &[[2.4468, 0.6744]]); assert_eq!(to_vec0_round(&b, 4)?, 1.4695); Ok(()) }