Crates.io | randomforest |
lib.rs | randomforest |
version | 0.1.6 |
source | src |
created_at | 2020-09-13 10:14:11.655883 |
updated_at | 2020-09-18 10:39:01.225583 |
description | Random forest regressor and classifier |
homepage | https://github.com/sile/randomforest |
repository | https://github.com/sile/randomforest |
max_upload_size | |
id | 288135 |
size | 48,305 |
A random forest implementation in Rust.
use randomforest::criterion::Mse;
use randomforest::RandomForestRegressorOptions;
use randomforest::table::TableBuilder;
let features = [
&[0.0, 2.0, 1.0, 0.0][..],
&[0.0, 2.0, 1.0, 1.0][..],
&[1.0, 2.0, 1.0, 0.0][..],
&[2.0, 1.0, 1.0, 0.0][..],
&[2.0, 0.0, 0.0, 0.0][..],
&[2.0, 0.0, 0.0, 1.0][..],
&[1.0, 0.0, 0.0, 1.0][..],
&[0.0, 1.0, 1.0, 0.0][..],
&[0.0, 0.0, 0.0, 0.0][..],
&[2.0, 1.0, 0.0, 0.0][..],
&[0.0, 1.0, 0.0, 1.0][..],
&[1.0, 1.0, 1.0, 1.0][..],
];
let target = [
25.0, 30.0, 46.0, 45.0, 52.0, 23.0, 43.0, 35.0, 38.0, 46.0, 48.0, 52.0
];
let mut table_builder = TableBuilder::new();
for (xs, y) in features.iter().zip(target.iter()) {
table_builder.add_row(xs, *y)?;
}
let table = table_builder.build()?;
let regressor = RandomForestRegressorOptions::new()
.seed(0)
.fit(Mse, table);
assert_eq!(regressor.predict(&[1.0, 2.0, 0.0, 0.0]), 41.9785);