extern crate gbdt; use gbdt::config::Config; use gbdt::decision_tree::{DataVec, PredVec}; use gbdt::fitness::almost_equal_thrs; use gbdt::gradient_boost::GBDT; use gbdt::input::{load, InputFormat}; fn main() { let mut cfg = Config::new(); cfg.set_feature_size(4); cfg.set_max_depth(4); cfg.set_iterations(100); cfg.set_shrinkage(0.1); cfg.set_loss("LAD"); cfg.set_debug(true); cfg.set_training_optimization_level(2); // load data let train_file = "dataset/iris/train.txt"; let test_file = "dataset/iris/test.txt"; let mut input_format = InputFormat::csv_format(); input_format.set_feature_size(4); input_format.set_label_index(4); let mut train_dv: DataVec = load(train_file, input_format).expect("failed to load training data"); let test_dv: DataVec = load(test_file, input_format).expect("failed to load test data"); // train and save the model let mut gbdt = GBDT::new(&cfg); gbdt.fit(&mut train_dv); gbdt.save_model("gbdt.model") .expect("failed to save the model"); // load the model and do inference let model = GBDT::load_model("gbdt.model").expect("failed to load the model"); let predicted: PredVec = model.predict(&test_dv); assert_eq!(predicted.len(), test_dv.len()); let mut correct = 0; let mut wrong = 0; for i in 0..predicted.len() { if almost_equal_thrs(test_dv[i].label, predicted[i], 0.0001) { correct += 1; } else { wrong += 1; }; println!("[{}] {} {}", i, test_dv[i].label, predicted[i]); } println!("correct: {}", correct); println!("wrong: {}", wrong); assert!(wrong <= 2); }