| Crates.io | sklears-tree |
| lib.rs | sklears-tree |
| version | 0.1.0-beta.1 |
| created_at | 2025-10-13 12:14:48.027393+00 |
| updated_at | 2026-01-01 21:31:24.74156+00 |
| description | Decision tree algorithms for sklears: CART, ID3, C4.5 |
| homepage | https://github.com/cool-japan/sklears |
| repository | https://github.com/cool-japan/sklears |
| max_upload_size | |
| id | 1880452 |
| size | 1,186,589 |
State-of-the-art tree-based algorithms for Rust with 5-20x performance improvements over scikit-learn. Features advanced algorithms like BART, soft trees, and LightGBM-style optimizations.
Latest release:
0.1.0-beta.1(January 1, 2026). See the workspace release notes for highlights and upgrade guidance.
sklears-tree provides comprehensive tree-based ML algorithms:
use sklears_tree::{DecisionTreeClassifier, RandomForestClassifier, GradientBoostingRegressor};
use ndarray::array;
// Decision Tree
let tree = DecisionTreeClassifier::builder()
.max_depth(5)
.min_samples_split(2)
.criterion(Criterion::Entropy)
.build();
// Random Forest with parallel training
let rf = RandomForestClassifier::builder()
.n_estimators(100)
.max_features(MaxFeatures::Sqrt)
.n_jobs(4)
.build();
// Gradient Boosting with early stopping
let gb = GradientBoostingRegressor::builder()
.n_estimators(1000)
.learning_rate(0.1)
.early_stopping(true)
.validation_fraction(0.2)
.build();
// Train and predict
let X = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
let y = array![0, 1, 0];
let fitted = tree.fit(&X, &y)?;
let predictions = fitted.predict(&X)?;
use sklears_tree::BART;
let bart = BART::builder()
.n_trees(200)
.n_chains(4)
.n_samples(1000)
.build();
let fitted = bart.fit(&X, &y)?;
let (predictions, lower, upper) = fitted.predict_with_uncertainty(&X_test, 0.95)?;
use sklears_tree::SoftDecisionTree;
let soft_tree = SoftDecisionTree::builder()
.temperature(0.5)
.learning_rate(0.01)
.use_batch_norm(true)
.build();
use sklears_tree::{HistogramGradientBoosting, GOSS, EFB};
let lgb = HistogramGradientBoosting::builder()
.max_bins(255)
.use_goss(true) // Gradient-based One-Side Sampling
.use_efb(true) // Exclusive Feature Bundling
.leaf_wise_growth(true)
.build();
use sklears_tree::{TreeSHAP, LIME, PartialDependence};
// SHAP values for tree models
let shap = TreeSHAP::new(&fitted_model);
let shap_values = shap.explain(&X)?;
// LIME local explanations
let lime = LIME::builder()
.n_samples(1000)
.kernel_width(0.75)
.build();
let explanation = lime.explain(&fitted_model, &instance)?;
// Partial dependence plots
let pd = PartialDependence::new(&fitted_model);
let pd_values = pd.compute(&X, &[0, 1])?; // Features 0 and 1
let rf = RandomForestClassifier::builder()
.n_estimators(1000)
.n_jobs(-1) // Use all cores
.parallel_predict(true)
.build();
use sklears_tree::MemoryMappedForest;
// Save large models to disk
let mmap_forest = MemoryMappedForest::from_forest(&rf)?;
mmap_forest.save_to_file("model.mmap")?;
// Load and use without loading into RAM
let loaded = MemoryMappedForest::load("model.mmap")?;
let predictions = loaded.predict(&X)?;
use sklears_tree::{HoeffdingTree, StreamingGradientBoosting};
// Hoeffding tree for streaming data
let mut hoeffding = HoeffdingTree::builder()
.grace_period(200)
.split_confidence(0.95)
.build();
for batch in data_stream {
hoeffding.partial_fit(&batch.X, &batch.y)?;
}
use sklears_tree::{FairDecisionTree, FairnessConstraint};
let fair_tree = FairDecisionTree::builder()
.protected_attribute(2) // Column index
.constraint(FairnessConstraint::DemographicParity)
.fairness_threshold(0.8)
.build();
use sklears_tree::{MultiOutputTree, MultiLabelRandomForest};
// Multi-output regression
let mo_tree = MultiOutputTree::builder()
.strategy(MultiOutputStrategy::Chained)
.build();
// Multi-label classification
let ml_rf = MultiLabelRandomForest::builder()
.n_estimators(100)
.label_correlation(true)
.build();
use sklears_tree::{TemporalRandomForest, SpatialDecisionTree};
// Time series with seasonal patterns
let temporal_rf = TemporalRandomForest::builder()
.seasonal_period(12)
.trend_detection(true)
.build();
// Geospatial data
let spatial_tree = SpatialDecisionTree::builder()
.coordinate_system(CoordinateSystem::Geographic)
.spatial_index(SpatialIndex::QuadTree)
.build();
Performance on standard datasets:
| Algorithm | scikit-learn | sklears-tree | Speedup |
|---|---|---|---|
| Decision Tree | 5.2ms | 0.8ms | 6.5x |
| Random Forest | 125ms | 12ms | 10.4x |
| Gradient Boosting | 850ms | 95ms | 8.9x |
| Extra Trees | 110ms | 8ms | 13.8x |
With upcoming GPU support:
sklears-tree/
├── core/ # Base tree structures
├── ensemble/ # Forest algorithms
├── boosting/ # Gradient boosting variants
├── interpretability/ # SHAP, LIME, PDP
├── streaming/ # Online algorithms
├── distributed/ # Distributed training
├── specialized/ # BART, soft trees, etc.
└── gpu/ # GPU kernels (WIP)
We welcome contributions! Priority areas:
See CONTRIBUTING.md for guidelines.
Licensed under either of:
@software{sklears_tree,
title = {sklears-tree: High-Performance Tree Algorithms for Rust},
author = {COOLJAPAN OU (Team KitaSan)},
year = {2026},
url = {https://github.com/cool-japan/sklears}
}