runnt

Crates.iorunnt
lib.rsrunnt
version0.7.0
sourcesrc
created_at2023-02-21 17:05:56.329267
updated_at2023-05-01 20:06:22.616586
descriptionEasy Neural Network for machine learning
homepage
repositoryhttps://github.com/griccardos/runnt/
max_upload_size
id790886
size130,341
(griccardos)

documentation

README

runnt (rust neural net)

Very simple fully connected neural network.
For when you just want to throw something together with minimal dependencies, and few lines of code.
Aim is to create a fully connected network, run it on data, and get results in about 10 lines of code
This library was created due to being unable to find a nice rust library which didn't have external dependencies, and was easy to use.

You are welcome to raise an issue or PR if you identify any errors or optimisations.

Functionality:

  • fully connected neural network
  • minimal dependencies
  • no external static libraries/dlls required
  • regression and classfication
  • able to define layers sizes
  • able to define activation types
  • can save/load model
  • Stochastic, mini batch, gradient descent
  • Regularisation
  • Dataset manager
    • csv
    • onehot encoding
    • normalization
  • Reporting

How to use

Simple example

All you need is NN and data

   //XOR
    use runnt::{nn::NN,activation::ActivationType};
    let inputs = [[0., 0.], [0., 1.], [1., 0.], [1., 1.]];
    let outputs = [[0.], [1.], [1.], [0.]];

    let mut nn = NN::new(&[2, 8, 1])
        .with_learning_rate(0.2)
        .with_hidden_type(ActivationType::Tanh)
        .with_output_type(ActivationType::Linear);

    for i in 0..5000 {
        nn.fit_one(&inputs[i % 4], &outputs[i % 4]);
    }

Simple example with Dataset and reporting

Dataset makes loading and transforming data a bit easier
train makes running epochs and reporting easy
Complete neural net with reporting in < 10 lines

let set = Dataset::builder()
    .read_csv("examples/data/iris.csv")
    .add_input_columns(&[0, 1, 2, 3], Conversion::NormaliseMean)
    .add_target_columns(&[4], Conversion::OneHot)
    .allocate_to_test_data(0.2)
    .build();

    let mut net = NN::new(&[set.input_size(), 32, set.target_size()]).with_learning_rate(0.15);
    net.train(&set, 1000, 8, 100, ReportAccuracy::CorrectClassification);

With Dataset and reporting and save:

let set = Dataset::builder()
        .read_csv(r"/temp/diamonds.csv")
        .allocate_to_test_data(0.2)
        .add_input_columns(&[0, 4, 5, 7, 8, 9], Conversion::NormaliseMean)
        .add_input_columns(&[1, 2, 3], Conversion::OneHot)
        .add_target_columns(
            &[6],
            Conversion::Function(|f| f.parse::<f32>().unwrap_or_default() / 1_000.),
        )
        .build();

    let save_path = r"network.txt";
    let mut net = if std::path::PathBuf::from_str(save_path).unwrap().exists() {
        NN::load(save_path)
    } else {
        NN::new(&[set.input_size(), 32, set.target_size()])
    };
    //run for 100 epochs, with batch size 32 and report every 10 epochs
    net.train(&set,  100, 32, 10, ReportAccuracy::RSquared);
    net.save(save_path);
Commit count: 22

cargo fmt