Crates.io | gbdt |
lib.rs | gbdt |
version | 0.1.3 |
source | src |
created_at | 2018-12-19 00:23:10.830495 |
updated_at | 2024-01-24 05:45:54.845241 |
description | A implementation of Gradient Boosting Regression Tree in Rust programming language |
homepage | https://github.com/mesalock-linux/gbdt-rs |
repository | https://github.com/mesalock-linux/gbdt-rs |
max_upload_size | |
id | 102615 |
size | 210,613 |
MesaTEE GBDT-RS is a gradient boost decision tree library written in Safe Rust. There is no unsafe rust code in the library.
MesaTEE GBDT-RS provides the training and inference capabilities. And it can use the models trained by xgboost to do inference tasks.
New! The MesaTEE GBDT-RS paper has been accepted by IEEE S&P'19!
At this time, MesaTEE GBDT-RS support to use model trained by xgboost to do inference. The model should be trained by xgboost with following configruation:
We have tested that MesaTEE GBDT-RS is compatible with xgboost 0.81 and 0.82
use gbdt::config::Config;
use gbdt::decision_tree::{DataVec, PredVec};
use gbdt::gradient_boost::GBDT;
use gbdt::input::{InputFormat, load};
let mut cfg = Config::new();
cfg.set_feature_size(22);
cfg.set_max_depth(3);
cfg.set_iterations(50);
cfg.set_shrinkage(0.1);
cfg.set_loss("LogLikelyhood");
cfg.set_debug(true);
cfg.set_data_sample_ratio(1.0);
cfg.set_feature_sample_ratio(1.0);
cfg.set_training_optimization_level(2);
// load data
let train_file = "dataset/agaricus-lepiota/train.txt";
let test_file = "dataset/agaricus-lepiota/test.txt";
let mut input_format = InputFormat::csv_format();
input_format.set_feature_size(22);
input_format.set_label_index(22);
let mut train_dv: DataVec = load(train_file, input_format).expect("failed to load training data");
let test_dv: DataVec = load(test_file, input_format).expect("failed to load test data");
// train and save model
let mut gbdt = GBDT::new(&cfg);
gbdt.fit(&mut train_dv);
gbdt.save_model("gbdt.model").expect("failed to save the model");
// load model and do inference
let model = GBDT::load_model("gbdt.model").expect("failed to load the model");
let predicted: PredVec = model.predict(&test_dv);
At this time, training in MesaTEE GBDT-RS is single-threaded.
The related inference functions are single-threaded. But they are thread-safe. We provide an inference example using multi threads in example/test-multithreads.rs
Because MesaTEE GBDT-RS is written in pure rust, with the help of rust-sgx-sdk, it can be used in sgx enclave easily as:
gbdt_sgx = { git = "https://github.com/mesalock-linux/gbdt-rs" }
This would import a crate named gbdt_sgx
. If you prefer gbdt
as normal:
gbdt = { package = "gbdt_sgx", git = "https://github.com/mesalock-linux/gbdt-rs" }
For more information and concret examples, please look at directory sgx/gbdt-sgx-test
.
Apache 2.0
Tianyi Li @n0b0dyCN n0b0dypku@gmail.com
Tongxin Li @litongxin1991 litongxin1991@gmail.com
Yu Ding @dingelish dingelish@gmail.com
Tao Wei, Yulong Zhang
Thanks to @qiyiping for his/her great previous work gbdt. We read his/her code before starting this project.