// Copyright (c) Facebook, Inc. and its affiliates // SPDX-License-Identifier: MIT OR Apache-2.0 #![allow(clippy::many_single_char_names)] use gad::prelude::*; use std::sync::Arc; /// Symbolic evaluation. #[derive(Clone, Default)] struct SymEval; /// Symbolic expression of type T (unboxed). #[derive(Debug, PartialEq)] enum Exp_ { Zero, One, Num(T), Neg(Exp), Add(Exp, Exp), Mul(Exp, Exp), } /// Symbolic expression of type T (boxed.) type Exp = Arc>; impl Exp_ { fn num(x: T) -> Exp { Arc::new(Exp_::Num(x)) } } impl HasDims for Exp_ { type Dims = (); #[inline] fn dims(&self) {} } impl CoreAlgebra> for SymEval { type Value = Exp; fn variable(&mut self, data: Exp) -> Self::Value { data } fn constant(&mut self, data: Exp) -> Self::Value { data } fn add(&mut self, v1: &Self::Value, v2: &Self::Value) -> Result { Ok(Arc::new(Exp_::Add(v1.clone(), v2.clone()))) } } impl ArithAlgebra> for SymEval { fn zeros(&mut self, _v: &Exp) -> Exp { Arc::new(Exp_::Zero) } fn ones(&mut self, _v: &Exp) -> Exp { Arc::new(Exp_::One) } fn neg(&mut self, v: &Exp) -> Exp { Arc::new(Exp_::Neg(v.clone())) } fn sub(&mut self, v1: &Exp, v2: &Exp) -> Result> { let v2 = self.neg(v2); Ok(Arc::new(Exp_::Add(v1.clone(), v2))) } fn mul(&mut self, v1: &Exp, v2: &Exp) -> Result> { Ok(Arc::new(Exp_::Mul(v1.clone(), v2.clone()))) } } impl std::fmt::Display for Exp_ { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { use Exp_::*; match self { Zero => write!(f, "0"), One => write!(f, "1"), Num(x) => write!(f, "{}", x), Neg(e) => write!(f, "(-{})", *e), Add(e1, e2) => write!(f, "({}+{})", *e1, *e2), Mul(e1, e2) => write!(f, "{}{}", *e1, *e2), } } } type SymGraph1 = Graph>; // type SymGraphN = Graph>; #[test] fn test_symgraph1() -> Result<()> { let mut g = SymGraph1::new(); let a = CoreAlgebra::variable(&mut g, Exp_::num("a")); let b = g.variable(Exp_::num("b")); let c = g.mul(&a, &b)?; let d = g.mul(&a, &c)?; assert_eq!(format!("{}", d.data()), "aab"); let gradients = g.evaluate_gradients_once(d.gid()?, Exp_::num("1"))?; assert_eq!(format!("{}", gradients.get(a.gid()?).unwrap()), "(1ab+a1b)"); assert_eq!(format!("{}", gradients.get(b.gid()?).unwrap()), "aa1"); Ok(()) }