Crates.io | candle-optimisers |
lib.rs | candle-optimisers |
version | 0.8.0 |
source | src |
created_at | 2023-12-07 15:59:15.624109 |
updated_at | 2024-11-18 10:56:50.015231 |
description | Optimisers for use with candle, the minimalist ML framework |
homepage | |
repository | https://github.com/KGrewal1/optimisers |
max_upload_size | |
id | 1061081 |
size | 373,607 |
A crate for optimisers for use with candle, the minimalist ML framework
Optimisers implemented are:
SGD (including momentum and weight decay)
RMSprop
Adaptive methods:
AdaDelta
AdaGrad
AdaMax
Adam
AdamW (included with Adam as decoupled_weight_decay
)
NAdam
RAdam
These are all checked against their pytorch implementation (see pytorch_test.ipynb) and should implement the same functionality (though without some input checking).
Additionally all of the adaptive mehods listed and SGD implement decoupled weight decay as described in Decoupled Weight Decay Regularization, in addition to the standard weight decay as implemented in pytorch.
Pseudosecond order methods:
This is not implemented equivalent to pytorch, but is checked on the 2D rosenbrock function
There is an mnist toy program along with a simple example of adagrad. Whilst the parameters of each method aren't tuned (all default with user input learning rate), the following converges quite nicely:
cargo r -r --example mnist mlp --optim r-adam --epochs 2000 --learning-rate 0.025
For even faster training try:
cargo r -r --features cuda --example mnist mlp --optim r-adam --epochs 2000 --learning-rate 0.025
to use the cuda backend.
cargo add --git https://github.com/KGrewal1/optimisers.git candle-optimisers
Documentation is available on the rust docs site https://docs.rs/candle-optimisers