Crates.io | gad |
lib.rs | gad |
version | 0.2.0 |
source | src |
created_at | 2021-04-20 19:16:12.761606 |
updated_at | 2021-04-27 18:15:17.866565 |
description | Generic automatic differentiation for Rust |
homepage | |
repository | https://github.com/facebookresearch/gad |
max_upload_size | |
id | 387271 |
size | 205,076 |
This library provides automatic differentiation by backward propagation (aka. "autograd") in Rust. It was designed to allow first-class user extensions (e.g. with new array types or new operators) and to support multiples modes of execution with minimal overhead.
The following modes of execution are currently supported for all library-defined operators:
The core of this library consists of a tape-based implementation of automatic differentiation in reverse mode.
We have chosen to prioritize idiomatic Rust in order to make this library as re-usable as possible:
The core differentiation algorithm does not use unsafe Rust features or interior
mutability (e.g. RefCell
). All differentiable expressions explicitly mutate a tape
when they are constructed. (Below, the tape variable is noted graph
or g
.)
Fallible operations never panic and always return a Result
type. For instance, the
sum of two arrays x
and y
may be written g.add(&x, &y)?
.
All structures and values implement Send
and Sync
to support concurrent programming.
Generic programming is encouraged so that user formulas can be interpreted in different modes of execution (forward evaluation, dimension checking, etc) with minimal overhead. (See the section below for a code example.)
While this library is primarily motivated by machine learning applications, it is meant to cover other use cases of automatic differentiation in reverse mode. In the sections below, we show how a user may define new operators and add new modes of execution while retaining automatic differentiability.
Currently, the usual syntax of operators +
, -
, *
, etc is not available for
differentiable values. All operations are method calls of the form g.op(x1, .. xN)
(or typically g.op(x1, .. xN)?
for fallible operations).
Because of a current limitation of
the Rust borrow checker, expressions cannot be nested: g.add(&x, &g.mul(&y, &z)?)?
must
be written let v = g.mul(&y, &z)?; g.add(&x, &v)?
.
We believe that this state of affairs could be improved in the future using Rust
macros. Alternatively, future extensions of the library could define a new category of
differentiable values that contain an implicit RefCell
reference to a common tape
and provide (implicitly fallible, thread unsafe) operator traits for these values.
To compute gradients, we first build an expression using operations provided by a
fresh tape g
of type Graph1
. Successive algebraic operations modify the internal
state of g
to track all relevant computations and enables future backward
propagation passes.
We then call g.evaluate_gradients(..)
to run a backward propagation algorithm from the
desired starting point and using an initial gradient value direction
.
Unless a one-time optimized variant g.evaluate_gradients_once(..)
is used, backward
propagation with g.evaluate_gradients(..)
does not modify g
. This allows
successive (or concurrent) backward propagations to be run from different starting
points or with different gradient values.
// A new tape supporting first-order differentials (aka gradients)
let mut g = Graph1::new();
// Compute forward values.
let a = g.variable(1f32);
let b = g.variable(2f32);
let c = g.mul(&a, &b)?;
// Compute the derivatives of `c` relative to `a` and `b`
let gradients = g.evaluate_gradients(c.gid()?, 1f32)?;
// Read the `dc/da` component.
assert_eq!(*gradients.get(a.gid()?).unwrap(), 2.0);
Because Graph1
, the type of g
, provides algebraic operations as methods, below we
refer to such a type as an "algebra". GAD uses particular Rust traits to represent the
set of operations supported by a given algebra.
The default array operations of the library are currently based on Arrayfire, a portable array library supporting GPUs and JIT-compilation.
use arrayfire as af;
// A new tape supporting first-order differentials (aka gradients)
let mut g = Graph1::new();
// Compute forward values using Arrayfire arrays
let dims = af::Dim4::new(&[4, 3, 1, 1]);
let a = g.variable(af::randu::<f32>(dims));
let b = g.variable(af::randu::<f32>(dims));
let c = g.mul(&a, &b)?;
// Compute gradient of c
let direction = af::constant(1f32, dims);
let gradients = g.evaluate_gradients_once(c.gid()?, direction)?;
After installing the arrayfire library on your system, make sure to
select the package feature "arrayfire" in your build file Cargo.toml
(e.g. gad = { version = "XX", features = ["arrayfire"]}
),
run cargo
with the environment variable AF_PATH
set appropriately (e.g. after
export AF_PATH=/usr/local
).
The algebra Graph1
used in the examples above is one choice amongst several
"default" algebras offered by the library:
We also provide a special algebra Eval
for forward evaluation, that is, running
only primitive operations and dimension checks (no tape, no gradients);
Similarly, using the algebra Check
will check dimensions without
evaluating or allocating any array data;
Finally, differentiation is obtained using Graph1
for first-order differentials,
and GraphN
for higher-order differentials.
Users are encouraged to program formulas in a generic way so that any of the default algebras can be chosen.
The following example illustrates such a programming style in the case of array operations:
use arrayfire as af;
fn get_value<A>(g: &mut A) -> Result<<A as AfAlgebra<f32>>::Value>
where A : AfAlgebra<f32>
{
let dims = af::Dim4::new(&[4, 3, 1, 1]);
let a = g.variable(af::randu::<f32>(dims));
let b = g.variable(af::randu::<f32>(dims));
g.mul(&a, &b)
}
// Direct evaluation. The result type is a primitive (non-differentiable) value.
let mut g = Eval::default();
let c : af::Array<f32> = get_value(&mut g)?;
// Fast dimension-checking. The result type is a dimension.
let mut g = Check::default();
let d : af::Dim4 = get_value(&mut g)?;
assert_eq!(c.dims(), d);
Higher-order differentials are computed using the algebra GraphN
. In this case, gradients
are values whose computations is also tracked.
// A new tape supporting differentials of any order.
let mut g = GraphN::new();
// Compute forward values using floats.
let x = g.variable(1.0f32);
let y = g.variable(0.4f32);
// z = x * y^2
let z = {
let h = g.mul(&x, &y)?;
g.mul(&h, &y)?
};
// Use short names for gradient ids.
let (x, y, z) = (x.gid()?, y.gid()?, z.gid()?);
// Compute gradient.
let dz = g.constant(1f32);
let dz_d = g.compute_gradients(z, dz)?;
let dz_dx = dz_d.get(x).unwrap();
// Compute some 2nd-order differentials.
let ddz = g.constant(1f32);
let ddz_dxd = g.compute_gradients(dz_dx.gid()?, ddz)?;
let ddz_dxdy = ddz_dxd.get(y).unwrap().data();
assert_eq!(*ddz_dxdy, 0.8); // 2y
// Compute some 3rd-order differentials.
let dddz = g.constant(1f32);
let dddz_dxdyd = g.compute_gradients(ddz_dxd.get(y).unwrap().gid()?, dddz)?;
let dddz_dxdydy = dddz_dxdyd.get(y).unwrap().data();
assert_eq!(*dddz_dxdydy, 2.0);
The default algebras Eval
, Check
, Graph1
, GraphN
are meant to provide
interchangeable sets of operations in each of the default modes of operation
(respectively, evaluation, dimension-checking, first-order differentiation, and
higher-order differentiation).
Default operations are grouped into several traits named *Algebra
and implemented by
each of the four default algebras above.
The special trait CoreAlgebra<Data>
defines the mapping from underlying data (e.g.
array) to differentiable values. In particular, the method fn variable(&mut self, data: &Data) -> Self::Value
creates differentiable variables x
whose gradient value can be
referred to later by an id written x.gid()?
(assuming the algebra is Graph1
or
GraphN
).
Other traits are parameterized over one or several value types. E.g.
ArithAlgebra<Value>
provides pointwise negation, multiplication, subtraction, etc
over Value
.
The motivation for using several *Algebra
traits is twofold:
Users may define their own operations (see next paragraph).
Certain operations are more broadly applicable than others.
The following example illustrates gradient computations over integers:
let mut g = Graph1::new();
let a = g.variable(1i32);
let b = g.variable(2i32);
let c = g.sub(&a, &b)?;
assert_eq!(*c.data(), -1);
let gradients = g.evaluate_gradients_once(c.gid()?, 1)?;
assert_eq!(*gradients.get(a.gid()?).unwrap(), 1);
assert_eq!(*gradients.get(b.gid()?).unwrap(), -1);
Users may define new differentiable operations by defining their own *Algebra
trait
and providing implementations to the default algebras Eval
, Check
, Graph1
,
GraphN
.
In the following example, we define a new operation square
over integers and
af-arrays and add support for first-order differentials:
use arrayfire as af;
pub trait UserAlgebra<Value> {
fn square(&mut self, v: &Value) -> Result<Value>;
}
impl UserAlgebra<i32> for Eval
{
fn square(&mut self, v: &i32) -> Result<i32> { Ok(v * v) }
}
impl<T> UserAlgebra<af::Array<T>> for Eval
where
T: af::HasAfEnum + af::ImplicitPromote<T, Output = T>
{
fn square(&mut self, v: &af::Array<T>) -> Result<af::Array<T>> { Ok(v * v) }
}
impl<D> UserAlgebra<Value<D>> for Graph1
where
Eval: CoreAlgebra<D, Value = D>
+ UserAlgebra<D>
+ ArithAlgebra<D>
+ LinkedAlgebra<Value<D>, D>,
D: HasDims + Clone + 'static + Send + Sync,
D::Dims: PartialEq + std::fmt::Debug + Clone + 'static + Send + Sync,
{
fn square(&mut self, v: &Value<D>) -> Result<Value<D>> {
let result = self.eval().square(v.data())?;
let value = self.make_node(result, vec![v.input()], {
let v = v.clone();
move |graph, store, gradient| {
if let Some(id) = v.id() {
let c = graph.link(&v);
let grad1 = graph.mul(&gradient, c)?;
let grad2 = graph.mul(c, &gradient)?;
let grad = graph.add(&grad1, &grad2)?;
store.add_gradient(graph, id, &grad)?;
}
Ok(())
}
});
Ok(value)
}
}
fn main() -> Result<()> {
let mut g = Graph1::new();
let a = g.variable(3i32);
let b = g.square(&a)?;
assert_eq!(*b.data(), 9);
let gradients = g.evaluate_gradients_once(b.gid()?, 1)?;
assert_eq!(*gradients.get(a.gid()?).unwrap(), 6);
Ok(())
}
The implementation for GraphN
would be identical to Graph1
. We have omitted
dimension-checking for simplicity. We refer readers to the test files of the library
for a more complete example.
Users may define new "evaluation" algebras (similar to Eval
) by implementing a
subset of operation traits that includes CoreAlgebra<Data, Value=Data>
for each
supported Data
types.
An evaluation-only algebra can be turned into algebras supporting differentiation
(similar to Graph1
and GraphN
) using the Graph
construction provided by the
library.
The following example illustrates how to define a new evaluation algebra SymEval
then deduce its counterpart SymGraph1
:
/// A custom algebra for forward-only symbolic evaluation.
#[derive(Clone, Default)]
struct SymEval;
/// Symbolic expressions of type T.
#[derive(Debug, PartialEq)]
enum Exp_<T> {
Num(T),
Neg(Exp<T>),
Add(Exp<T>, Exp<T>),
Mul(Exp<T>, Exp<T>),
// ...
}
type Exp<T> = Arc<Exp_<T>>;
impl<T> CoreAlgebra<Exp<T>> for SymEval {
type Value = Exp<T>;
fn variable(&mut self, data: Exp<T>) -> Self::Value {
data
}
fn constant(&mut self, data: Exp<T>) -> Self::Value {
data
}
fn add(&mut self, v1: &Self::Value, v2: &Self::Value) -> Result<Self::Value> {
Ok(Arc::new(Exp_::Add(v1.clone(), v2.clone())))
}
}
impl<T> ArithAlgebra<Exp<T>> for SymEval {
fn neg(&mut self, v: &Exp<T>) -> Exp<T> {
Arc::new(Exp_::Neg(v.clone()))
}
fn sub(&mut self, v1: &Exp<T>, v2: &Exp<T>) -> Result<Exp<T>> {
let v2 = self.neg(v2);
Ok(Arc::new(Exp_::Add(v1.clone(), v2)))
}
fn mul(&mut self, v1: &Exp<T>, v2: &Exp<T>) -> Result<Exp<T>> {
Ok(Arc::new(Exp_::Mul(v1.clone(), v2.clone())))
}
// ...
}
// No dimension checks.
impl<T> HasDims for Exp_<T> {
type Dims = ();
fn dims(&self) {}
}
impl<T: std::fmt::Display> std::fmt::Display for Exp_<T> {
// ...
}
/// Apply `graph` module to Derive an algebra supporting gradients.
type SymGraph1 = Graph<Config1<SymEval>>;
fn main() -> Result<()> {
let mut g = SymGraph1::new();
let a = g.variable(Arc::new(Exp_::Num("a")));
let b = g.variable(Arc::new(Exp_::Num("b")));
let c = g.mul(&a, &b)?;
let d = g.mul(&a, &c)?;
assert_eq!(format!("{}", d.data()), "aab");
let gradients = g.evaluate_gradients_once(d.gid()?, Arc::new(Exp_::Num("1")))?;
assert_eq!(format!("{}", gradients.get(a.gid()?).unwrap()), "(1ab+a1b)");
assert_eq!(format!("{}", gradients.get(b.gid()?).unwrap()), "aa1");
Ok(())
}
See the CONTRIBUTING file for how to help out.
This project is available under the terms of either the Apache 2.0 license or the MIT license.