// Copyright (c) Facebook, Inc. and its affiliates // SPDX-License-Identifier: MIT OR Apache-2.0 use gad::prelude::*; pub trait UserAlgebra { fn square(&mut self, v: &Value) -> Result; } #[cfg(feature = "arrayfire")] mod af_arith { use super::*; use arrayfire as af; impl UserAlgebra> for Eval where T: af::HasAfEnum + af::ImplicitPromote + af::ConstGenerator + num::Zero, { #[inline] fn square(&mut self, v: &af::Array) -> Result> { Ok(v * v) } } impl UserAlgebra for Check { #[inline] fn square(&mut self, v: &af::Dim4) -> Result { Ok(v.dims()) } } } // Sadly, we cannot quantify over T: Number until negative traits are available in Rust: // https://github.com/rust-lang/rust/issues/68318 is fixed. macro_rules! impl_eval { ($T:ident) => { impl UserAlgebra<$T> for Eval { #[inline] fn square(&mut self, v: &$T) -> Result<$T> { Ok((*v) * (*v)) } } }; } impl_eval!(i32); impl_eval!(i64); impl_eval!(f32); impl_eval!(f64); impl UserAlgebra<()> for Check { #[inline] fn square(&mut self, _v: &()) -> Result<()> { Ok(()) } } macro_rules! impl_graph { ($config:ident) => { impl UserAlgebra> for Graph<$config> where E: Default + Clone + CoreAlgebra + UserAlgebra + ArithAlgebra + LinkedAlgebra, D>, D: HasDims + Clone + 'static + Send + Sync, Dims: PartialEq + std::fmt::Debug + Clone + 'static + Send + Sync, { fn square(&mut self, v: &Value) -> Result> { let result = self.eval().square(v.data())?; let value = self.make_node(result, vec![v.input()], { let v = v.clone(); move |graph, store, gradient| { if let Some(id) = v.id() { let c = graph.link(&v); let grad1 = graph.mul(&gradient, c)?; let grad2 = graph.mul(c, &gradient)?; store.add_gradient(graph, id, &grad1)?; store.add_gradient(graph, id, &grad2)?; } Ok(()) } }); Ok(value) } } }; } impl_graph!(Config1); impl_graph!(ConfigN); #[test] fn test_square() -> Result<()> { let mut g = Graph1::new(); let a = g.variable(3i32); let b = g.square(&a)?; assert_eq!(*b.data(), 9); let gradients = g.evaluate_gradients_once(b.gid()?, 1)?; assert_eq!(*gradients.get(a.gid()?).unwrap(), 6); Ok(()) }