// Copyright (c) Facebook, Inc. and its affiliates // SPDX-License-Identifier: MIT OR Apache-2.0 #![cfg(feature = "arrayfire")] use arrayfire as af; use gad::prelude::*; struct TestNet { dims: af::Dim4, weights: af::Array, marker: std::marker::PhantomData, } impl TestNet where T: Float, { pub fn new(dims: af::Dim4, weights: af::Array) -> Self { Self { dims, weights, marker: std::marker::PhantomData, } } } impl Net for TestNet where T: Float, A: AfAlgebra, { type Input = af::Array; type Output = >::Value; type Weights = af::Array; type GradientInfo = <>::Value as HasGradientId>::GradientId; fn eval_with_gradient_info( &self, g: &mut A, input: Self::Input, ) -> Result<(Self::Output, Self::GradientInfo)> { assert_eq!(input.dims(), self.dims); let input = g.constant(input); let weights = g.variable(self.weights.clone()); let output = g.matmul_nn(&input, &weights)?; let id = weights.gid()?; Ok((output, id)) } fn get_weights(&self) -> Self::Weights { self.weights.clone() } fn update_weights(&mut self, delta: Self::Weights) -> Result<()> { check_equal_dimensions(func_name!(), &[&delta.dims(), &self.weights.dims()])?; self.weights += delta; Ok(()) } fn set_weights(&mut self, weights: Self::Weights) -> Result<()> { check_equal_dimensions(func_name!(), &[&weights.dims(), &self.weights.dims()])?; self.weights = weights; Ok(()) } fn read_weight_gradients( &self, info: Self::GradientInfo, reader: &::GradientReader, ) -> Result { Ok(reader .read(info) .ok_or_else(|| Error::missing_gradient(func_name!()))? .clone()) } } fn make_net( n: u64, ) -> impl Net, Output = >::Value, Weights = impl WeightOps> where T: Float, A: AfAlgebra, { let input = InputData::, A>::new(af::dim4!(n, n)); let weight = WeightData::new(af::randn!(T; n, n)); input.using(weight).map(|g, (i, w)| g.matmul_nn(&i, &w)) } #[test] fn test_testnet() -> anyhow::Result<()> { let mut train = TestNet::new(af::dim4!(3, 3), af::randn!(f32; 3, 3)).add_square_loss(); let a = af::Array::::new( &[1.0, 2.0, 1.0, 1.0, 0.0, 1.0, 0.0, -2.0, -1.0], af::dim4!(3, 3), ); let i = af::identity(af::dim4!(3, 3)); let samples = vec![(a.clone(), i.clone())]; loop { let loss = train.apply_gradient_step(-0.01, samples.clone())?; assert!(loss.is_finite()); if loss < 0.000001 { break; } } let mut net = TestNet::new(af::dim4!(3, 3), af::randn!(f32; 3, 3)); net.set_weights(train.get_weights())?; let i2 = net.evaluate(a)?; testing::assert_almost_all_equal(&i, &i2, 0.01); Ok(()) } #[test] fn test_make_net() -> anyhow::Result<()> { let mut train = make_net(3).add_square_loss(); let a = af::Array::::new( &[1.0, 2.0, 1.0, 1.0, 0.0, 1.0, 0.0, -2.0, -1.0], af::dim4!(3, 3), ); let i = af::identity(af::dim4!(3, 3)); let samples = vec![(a.clone(), i.clone())]; loop { let loss = train.apply_gradient_step(-0.01, samples.clone())?; assert!(loss.is_finite()); if loss < 0.000001 { break; } } // Because we used `impl WeightOps` in the return type of `make_net`, // calls with different parameters `A` generates incomparable types. let bytes = bincode::serialize(&train.get_weights())?; // Note that the type of `weights` is inferred and cannot be written down in Rust // at the moment. let weights = bincode::deserialize(&bytes)?; let mut net = make_net(3); net.set_weights(weights)?; let i2 = net.evaluate(a.clone())?; testing::assert_almost_all_equal(&i, &i2, 0.01); // Check dimensions. let weights = bincode::deserialize(&bytes)?; let mut net = make_net(3); net.set_weights(weights)?; let i2 = net.check(a)?; assert_eq!(i.dims(), i2); Ok(()) }