rstorch

Crates.iorstorch
lib.rsrstorch
version0.2.0
sourcesrc
created_at2023-05-17 19:51:29.922826
updated_at2023-11-25 19:32:38.745592
descriptionImplementation from scratch of a neural network framework in Rust inspired by PyTorch
homepagehttps://github.com/ferranSanchezLlado/rstorch.git
repositoryhttps://github.com/ferranSanchezLlado/rstorch.git
max_upload_size
id867175
size91,293
Ferran Sanchez Llado (ferranSanchezLlado)

documentation

https://docs.rs/rstorch

README

RsTorch

Implementation from scratch of a deep learning framework in Rust with a PyTorch-like API. The project is still in its early stages and is not ready for production use. Therefore, the API is not stable and may change at any time.

Currently, the project achieved the Minimum Viable Product allow the user to train a sequential model. Furthermore, it also provides the MNIST dataset that will download automatically from the internet.

Installation

Add the following to your Cargo.toml:

[dependencies]
rstorch = "0.2.0"

Or if you want to use the latest version from the master branch:

[dependencies]
rstorch = { git = "https://github.com/ferranSanchezLlado/rstorch.git" }

Usage

Small example on how to use the library to train a model with the MNIST dataset:

use rstorch::data::{DataLoader, SequentialSampler};
use rstorch::hub::MNIST;
use rstorch::prelude::*;
use rstorch::utils::{accuracy, flatten, normalize_zero_one, one_hot};
use rstorch::{CrossEntropyLoss, Identity, Linear, ReLU, Sequential, SGD};
use std::fs;
use std::path::PathBuf;

const BATCH_SIZE: usize = 32;
const EPOCHS: usize = 5;

fn main() {
    // Path that gets deleted by tests
    let path: PathBuf = ["data", "mnist"].iter().collect();

    let train_data = MNIST::new(path, true, true)
        .transform(|(x, y)| (flatten(normalize_zero_one(x)), one_hot(y, 10)));
    let sampler = SequentialSampler::new(train_data.len());
    let mut data_loader = DataLoader::new(train_data, BATCH_SIZE, true, sampler);

    let mut model = sequential!(
        Identity(),
        Linear(784, 100),
        ReLU(),
        Linear(100, 100),
        ReLU(),
        Linear(100, 10),
    );
    let mut loss = CrossEntropyLoss::new();
    let mut optim = SGD::new(0.01);

    for i in 0..EPOCHS {
        let n = data_loader.len() as f64;
        let mut total_loss = 0.0;
        let mut total_acc = 0.0;

        for (x, y) in data_loader.iter_array() {
            let pred = model.forward(x);
            let l = loss.forward(pred.clone(), y.clone());
            let acc = accuracy(pred, y);

            total_loss += l;
            total_acc += acc;

            model.backward(loss.backward());
            optim.step(&mut model);
        }

        let avg_loss = total_loss / n;
        let avg_acc = total_acc / n;
        println!("EPOCH {i}: Avarage loss {avg_loss} - Avarage accuracy {avg_acc}");
    }
}

License

This project is licensed under the MIT License or Apache License, Version 2.0 at your option.

Commit count: 41

cargo fmt