candle-optimisers

Crates.iocandle-optimisers
lib.rscandle-optimisers
version0.8.0
sourcesrc
created_at2023-12-07 15:59:15.624109
updated_at2024-11-18 10:56:50.015231
descriptionOptimisers for use with candle, the minimalist ML framework
homepage
repositoryhttps://github.com/KGrewal1/optimisers
max_upload_size
id1061081
size373,607
Kirpal Grewal (KGrewal1)

documentation

README

Candle Optimisers

License: MIT codecov Tests Tests Latest version Documentation

A crate for optimisers for use with candle, the minimalist ML framework

Optimisers implemented are:

  • SGD (including momentum and weight decay)

  • RMSprop

Adaptive methods:

  • AdaDelta

  • AdaGrad

  • AdaMax

  • Adam

  • AdamW (included with Adam as decoupled_weight_decay)

  • NAdam

  • RAdam

These are all checked against their pytorch implementation (see pytorch_test.ipynb) and should implement the same functionality (though without some input checking).

Additionally all of the adaptive mehods listed and SGD implement decoupled weight decay as described in Decoupled Weight Decay Regularization, in addition to the standard weight decay as implemented in pytorch.

Pseudosecond order methods:

  • LBFGS

This is not implemented equivalent to pytorch, but is checked on the 2D rosenbrock function

Examples

There is an mnist toy program along with a simple example of adagrad. Whilst the parameters of each method aren't tuned (all default with user input learning rate), the following converges quite nicely:

cargo r -r --example mnist mlp --optim r-adam --epochs 2000 --learning-rate 0.025

For even faster training try:

cargo r -r --features cuda --example mnist mlp --optim r-adam --epochs 2000 --learning-rate 0.025

to use the cuda backend.

Usage

cargo add --git https://github.com/KGrewal1/optimisers.git candle-optimisers

Documentation

Documentation is available on the rust docs site https://docs.rs/candle-optimisers

Commit count: 176

cargo fmt