| Crates.io | gbrt-rs |
| lib.rs | gbrt-rs |
| version | 0.2.0 |
| created_at | 2025-12-10 09:15:19.797715+00 |
| updated_at | 2025-12-10 09:15:19.797715+00 |
| description | Gradient Boosted Regression Trees in Rust |
| homepage | |
| repository | https://github.com/Guap-Codes/gbrt-rs |
| max_upload_size | |
| id | 1977716 |
| size | 838,298 |
A high-performance, production-ready implementation of Gradient Boosted Regression Trees (GBRT) in Rust, engineered for speed, memory safety, and reproducible machine learning workflows. This library provides both a comprehensive CLI tool and a programmatic API, supporting regression and binary classification tasks with advanced features including intelligent categorical encoding with automatic metadata capture, early stopping with patience-based termination, k-fold cross-validation with proper metric calculation, feature importance analysis, and robust model serialization.
cargo install gbrt-rs
git clone https://github.com/Guap-Codes/gbrt-rs.git
cd gbrt-rs
cargo build --release
The optimized binary will be available at target/release/gbrt-rs.
Train a model on housing data:
# Train with validation and categorical features
./gbrt-rs train \
--data housing.csv \
--target price \
--categorical-columns neighborhood,zip_code \
--validation validation.csv \
--early-stopping-rounds 15 \
--output home_model \
--verbose
# Cross-validate to assess generalization
./gbrt-rs cross-validate \
--data housing.csv \
--target price \
--categorical-columns neighborhood,zip_code \
--folds 5
# Evaluate on test set
./gbrt-rs evaluate \
--model model.json \
--data test.csv \
--target price
# Make predictions
./gbrt-rs predict \
--model model.json \
--data new_data.csv \
--output predictions.csv
# Analyze feature importance
./gbrt-rs feature-importance \
--model model.json \
--top-k 10 \
--output importance.json
use gbrt_rs::{GBRTModel, Dataset, DataLoader, ModelMetrics, Result};
fn main() -> Result<()> {
// Load data with categorical encoding
let data_loader = DataLoader::new()?;
let categorical_cols = vec![4, 5]; // indices for neighborhood, zip_code
let dataset = data_loader.load_csv_with_categorical(
"housing.csv",
"price",
Some(&categorical_cols),
0.1
)?;
// Train model
let mut model = GBRTModel::new()?;
model.set_feature_names(vec![
"square_feet".to_string(),
"bedrooms".to_string(),
"neighborhood".to_string(),
"zip_code".to_string(),
]);
model.fit(&dataset)?;
// Evaluate
let predictions = model.predict(dataset.features())?;
let metrics = ModelMetrics::regression_metrics(
dataset.targets().as_slice().unwrap(),
&predictions
)?;
println!("R²: {:.4}", metrics.r2);
// Save model
let model_io = ModelIO::new()?;
model_io.save_model(&model.booster(), "model.json", "housing_model")?;
Ok(())
}
train - Train a new model./gbrt-rs train [OPTIONS] --data <FILE> --target <COLUMN>
Required:
-d, --data <FILE> Training data (CSV)
-t, --target <COLUMN> Target column name
Options:
-o, --output <FILE> Save trained model
-m, --model-type <TYPE> regression|classification [default: regression]
--n-estimators <N> Boosting rounds [default: 100]
--learning-rate <RATE> Step size [default: 0.1]
--max-depth <DEPTH> Tree depth [default: 6]
--subsample <RATIO> Row sampling [default: 1.0]
-V, --validation <FILE> Separate validation set
--test-split <RATIO> Internal split ratio [default: 0.2]
--early-stopping-rounds <N> Patience in rounds
--early-stopping-tolerance Minimum improvement [default: 0.001]
--categorical-columns <COLS> Comma-separated column names
--categorical-threshold <T> Auto-detect ratio [default: 0.1]
--feature-names <NAMES> Override feature names
-v, --verbose Show training progress
-h, --help Print help
cross-validate - K-fold cross-validation./gbrt-rs cross-validate [OPTIONS] --data <FILE> --target <COLUMN>
Options:
-f, --folds <N> Number of folds [default: 5]
--categorical-columns <COLS>
--categorical-threshold <T>
-v, --verbose
Outputs per-fold and mean R² (or accuracy for classification).
evaluate - Compute metrics on test data./gbrt-rs evaluate --model <MODEL> --data <FILE> --target <COLUMN> [--verbose]
Outputs RMSE, MAE, R², and sample count.
predict - Generate predictions./gbrt-rs predict --model <MODEL> --data <FILE> [--output <FILE>] [--format csv|json]
Writes predictions to CSV/JSON or stdout.
feature-importance - Analyze model./gbrt-rs feature-importance --model <MODEL> [--top-k <N>] [--output <FILE>]
Shows top features by importance score.
info - Model metadata./gbrt-rs info --model <MODEL> [--verbose]
Displays model type, features, trees, parameters, and training state.
GBRTModel: Main model struct
new(): Regression modelnew_classifier(): Classification modelwith_config(config): Custom configurationfit(&dataset): Train on datafit_with_validation(&train, &val): Train with early stoppingpredict(&features) -> Vec<f64>: Predictionsfeature_importance() -> Vec<f64>: Importance scorestraining_history() -> Option<&TrainingState>: Training metricsDataset: Data container
features() -> &FeatureMatrixtargets() -> &Arrayn_samples(), n_features()validate(): Check data integrityDataLoader: CSV loading with encoding
load_csv(path): Load basic CSVload_csv_with_categorical(path, target, cat_indices, threshold): With encodingcross_validation_splits(dataset, k): Generate CV foldsuse gbrt_rs::core::{GBRTConfig, TreeConfig, LossFunction};
let config = GBRTConfig {
n_estimators: 200,
learning_rate: 0.1,
subsample: 0.8,
loss: LossFunction::MSE,
tree_config: TreeConfig {
max_depth: 6,
min_samples_split: 2,
min_samples_leaf: 1,
lambda: 1.0, // L2 regularization
..Default::default()
},
early_stopping_rounds: Some(15),
early_stopping_tolerance: 0.001,
..Default::default()
};
let model = GBRTModel::with_config(config)?;
--targetExplicit Specification (Recommended):
--categorical-columns neighborhood,zip_code,sector
Auto-Detection (Use with caution):
Why Explicit is Better:
src/
├── lib.rs # Public API & convenience functions
├── main.rs # CLI entry point & command handlers
├── core/ # Core types and configurations
│ ├── GBRTConfig # Main booster configuration
│ ├── TreeConfig # Tree-specific parameters
│ └── LossFunction # Objective enum
├── boosting/ # Gradient boosting engine
│ ├── GradientBooster # Core boosting algorithm
│ ├── BoosterFactory # Factory methods
│ └── TrainingState # Metrics tracking
├── tree/ # Decision tree components
│ ├── TreeBuilder # Builder pattern implementation
│ ├── BestSplitter # Split finding with binning
│ └── MSECriterion # Split quality computation
├── data/ # Data structures and preprocessing
│ ├── Dataset # Container for X and y
│ ├── FeatureMatrix # Feature storage with encoding
│ └── preprocessing # Categorical encoding logic
├── objective/ # Loss functions
│ ├── MSEObjective # Regression
│ └── LogLossObjective # Classification
└── io/ # I/O operations
├── DataLoader # CSV reading with encoding
├── ModelIO # JSON serialization
└── SaveOptions # Metadata storage
TreeBuilder allows custom splitters/criteriaSplitter, Criterion enable experimentationResult<T, GBRTError> for robust error handlingFeatureMatrix avoids unnecessary allocations# Faster, less accurate
--n-estimators 100 --max-depth 4 --learning-rate 0.3
# Balanced (recommended)
--n-estimators 200 --max-depth 6 --learning-rate 0.1
# Slower, more accurate
--n-estimators 500 --max-depth 8 --learning-rate 0.05
--subsample 0.8 to reduce memory footprint./gbrt-rs train \
--data housing.csv \
--target price \
--categorical-columns neighborhood,zip_code \
--n-estimators 300 \
--max-depth 6 \
--learning-rate 0.05 \
--early-stopping-rounds 20 \
--output housing_model.json
./gbrt-rs train \
--data stocks.csv \
--target direction \
--model-type classification \
--categorical-columns sector \
--n-estimators 200 \
--max-depth 5 \
--subsample 0.8 \
--output stock_classifier.json
./gbrt-rs train \
--data large_dataset.csv \
--target target \
--n-estimators 1000 \
--max-depth 4 \
--learning-rate 0.01 \
--subsample 0.5 \
--early-stopping-rounds 50 \
--output large_model.json
A: Common causes:
--n-estimators) or too shallow (--max-depth)Debug steps:
# Check model info
./gbrt-rs info --model model.json --verbose
# Ensure encoded features count matches expectations
# Should see: Features: 8 (original) + N (encoded categories)
# Run cross-validation to detect overfitting/underfitting
./gbrt-rs cross-validate --data train.csv --target target --folds 5
A: This is the categorical encoding mismatch bug. The evaluation data wasn't encoded the same way as training data. Update to latest version where this is fixed.
--validation and --test-split?A:
--validation: Uses a separate file you provide (recommended for reproducibility)--test-split: Splits your training data internally (convenient but less control)A: Current version loads all data into memory. For large datasets:
--subsample 0.5 to train on subset--max-depth to limit tree sizeA: No! Only for:
Not needed for:
A: GBRT is not ideal for time series because:
prophet or statsforecast insteadA: Importance = total reduction in loss contributed by splits on that feature. Important notes:
We welcome contributions! Please see CONTRIBUTING.md.
# Fork and clone
git clone https://github.com/Guap-Codes/gbrt-rs.git
cd gbrt-rs
# Install Rust tools
rustup component add rustfmt clippy
# Setup pre-commit hooks
pip install pre-commit
pre-commit install
# Create feature branch
git checkout -b feature/my-feature
# Make changes ensuring:
cargo fmt --all # Format code
cargo clippy --all -- -D warnings # Fix warnings
cargo test --all # Run tests
# Submit PR
Performance on a 100,000 sample, 50 feature dataset:
| Library | Train Time | Predict Time | Memory | R² |
|---|---|---|---|---|
| gbrt-rs | 2.3s | 12ms | 85 MB | 0.847 |
| XGBoost | 2.8s | 18ms | 120 MB | 0.852 |
| LightGBM | 1.9s | 15ms | 95 MB | 0.849 |
| sklearn | 8.7s | 45ms | 280 MB | 0.831 |
Measured on: AMD Ryzen 9 5900X, 32GB RAM, single-threaded
This project is licensed under the MIT License - see LICENSE for details.
If you use GBRT-RS in research, please cite:
@software{gbrt_rs_2025,
title={GBRT-RS: High-Performance Gradient Boosting in Rust},
author={Guap-Codes},
year={2025},
url={https://github.com/Guap-Codes/gbrt-rs},
version={0.2.0}
}
ndarray, serde, clap, csv, thiserrorHappy Boosting! 🚀
Made with ❤️ in Rust