aegir

Crates.ioaegir
lib.rsaegir
version2.0.0
sourcesrc
created_at2021-01-02 19:29:37.747857
updated_at2022-12-29 19:21:29.861317
descriptionStrongly-typed, reverse-mode autodiff library in Rust
homepage
repositoryhttps://github.com/tspooner/aegir
max_upload_size
id330719
size227,008
Thomas Spooner (tspooner)

documentation

https://docs.rs/aegir

README

aegir

Crates.io Build Status

Overview

Strongly-typed, compile-time autodifferentiation in Rust.

aegir is an experimental autodifferentiation framework designed to leverage the powerful type-system in Rust and avoid runtime as much as humanly possible. The approach taken resembles that of expression templates, as commonly used in linear-algebra libraries written in C++.

Key Features

  • Built-in arithmetic, linear-algebraic, trigonometric and special operators.
  • Infinitely differentiable: Jacobian, Hessian, etc...
  • Custom DSL for operator expansion.
  • Decoupled/generic tensor type.

Installation

[dependencies]
aegir = "2.0"

Example

#[macro_use]
extern crate aegir;
extern crate rand;

use aegir::{Differentiable, Function, Identifier, Node, ids::{X, Y, W}};

ctx!(Ctx { x: X, y: Y, w: W });

fn main() {
    let mut rng = rand::thread_rng();
    let mut ctx = Ctx {
        x: [0.0; N],
        y: 0.0,
        w: [0.0; N],
    };

    let x = X.into_var();
    let y = Y.into_var();
    let w = W.into_var();

    let model = x.dot(w);

    // Using standard method calls...
    let sse = model.sub(y).squared();
    let adj = sse.adjoint(W);

    // ...or using aegir! macro
    let sse = aegir!((model - y) ^ 2);
    let adj = sse.adjoint(W);

    for _ in 0..1_000_00 {
        // Independent variables:
        ctx.x = rng.gen();

        // Dependent variable:
        ctx.y = ctx.x[0] * 2.0 - ctx.x[1] * 4.0;

        // Evaluate gradient:
        let g: [f64; N] = adj.evaluate(&ctx).unwrap();

        // Update weights:
        ctx.w[0] -= 0.01 * g[0];
        ctx.w[1] -= 0.01 * g[1];
    }

    println!("{:?}", ctx.w.to_vec());
}
Commit count: 122

cargo fmt