| Crates.io | hextral |
| lib.rs | hextral |
| version | 0.8.0 |
| created_at | 2024-03-14 22:21:06.743007+00 |
| updated_at | 2025-09-26 19:33:05.063613+00 |
| description | Comprehensive neural network library with dataset loading, batch normalization, 9 activation functions, 5 loss functions, multiple optimizers, regularization, and clean async-first API |
| homepage | |
| repository | https://github.com/xStFtx/hextral |
| max_upload_size | |
| id | 1174213 |
| size | 242,466 |
A high-performance neural network library for Rust with clean async-first API, comprehensive dataset loading, advanced preprocessing, multiple optimizers, early stopping, and checkpointing capabilities.
Add this to your Cargo.toml:
[dependencies]
hextral = { version = "0.8.0", features = ["datasets"] }
nalgebra = "0.33"
tokio = { version = "1.0", features = ["full"] } # For async features
use hextral::{
Hextral, ActivationFunction, Optimizer,
dataset::{
csv::CsvLoader,
image::{ImageLoader, LabelStrategy},
preprocessing::Preprocessor,
}
};
use nalgebra::DVector;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Load CSV data
let csv_loader = CsvLoader::new()
.with_headers(true)
.with_target_columns_by_name(vec!["species".to_string()]);
let mut dataset = csv_loader.from_file("iris.csv").await?;
// Apply preprocessing
let preprocessor = Preprocessor::new()
.standardize(None) // Standardize all features
.one_hot_encode(vec![4]); // One-hot encode target column
let stats = preprocessor.fit_transform(&mut dataset).await?;
// Split data (80% train, 20% test)
let split_index = (dataset.features.len() as f64 * 0.8) as usize;
let (train_features, test_features) = dataset.features.split_at(split_index);
let (train_targets, test_targets) = dataset.targets.as_ref().unwrap().split_at(split_index);
// Create and train neural network
let mut nn = Hextral::new(
dataset.metadata.feature_count,
&[8, 6], // Hidden layers
3, // Output classes
ActivationFunction::ReLU,
Optimizer::adam(0.001),
);
let (train_history, _) = nn.train(
train_features,
train_targets,
0.01, // Learning rate
100, // Epochs
None, // Batch size
None, None, None, None, // Validation, early stopping, checkpoints
).await?;
println!("Training completed! Final loss: {:.4}", train_history.last().unwrap());
// Evaluate on test set
let test_loss = nn.evaluate(test_features, test_targets).await;
println!("Test loss: {:.4}", test_loss);
Ok(())
}
use hextral::{Hextral, ActivationFunction, Optimizer, EarlyStopping, CheckpointConfig};
use nalgebra::DVector;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Create a neural network: 2 inputs -> [4, 3] hidden -> 1 output
let mut nn = Hextral::new(
2, // Input features
&[4, 3], // Hidden layer sizes
1, // Output size
ActivationFunction::ReLU, // Activation function
Optimizer::adam(0.01), // Modern Adam optimizer
);
// Training data for XOR problem
let train_inputs = vec![
DVector::from_vec(vec![0.0, 0.0]),
DVector::from_vec(vec![0.0, 1.0]),
DVector::from_vec(vec![1.0, 0.0]),
DVector::from_vec(vec![1.0, 1.0]),
];
let train_targets = vec![
DVector::from_vec(vec![0.0]),
DVector::from_vec(vec![1.0]),
DVector::from_vec(vec![1.0]),
DVector::from_vec(vec![0.0]),
];
// Validation data (can be same as training for demo)
let val_inputs = train_inputs.clone();
let val_targets = train_targets.clone();
// Configure early stopping and checkpointing
let early_stopping = EarlyStopping::new(10, 0.001, true);
let checkpoint_config = CheckpointConfig::new("best_model".to_string());
// Train the network with advanced features
println!("Training network with early stopping...");
let (train_history, val_history) = nn.train(
&train_inputs,
&train_targets,
0.1, // Learning rate
1000, // Max epochs
Some(2), // Batch size
Some(&val_inputs), // Validation inputs
Some(&val_targets), // Validation targets
Some(early_stopping), // Early stopping
Some(checkpoint_config), // Checkpointing
).await?;
println!("Training completed after {} epochs", train_history.len());
println!("Final validation loss: {:.6}", val_history.last().unwrap_or(&0.0));
// Make predictions
println!("\nPredictions:");
for (input, expected) in train_inputs.iter().zip(train_targets.iter()) {
let prediction = nn.predict(input).await;
println!("Input: {:?} | Expected: {:.1} | Predicted: {:.3}",
input.data.as_vec(), expected[0], prediction[0]);
}
// Batch prediction (efficient for multiple inputs)
let batch_predictions = nn.predict_batch(&train_inputs).await;
// Evaluate performance
let final_loss = nn.evaluate(&train_inputs, &train_targets).await;
println!("Final loss: {:.6}", final_loss);
Ok(())
}
use hextral::*;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Create network with advanced activation function
let mut nn = Hextral::new(
4, &[8, 6], 2,
ActivationFunction::Swish { beta: 1.0 }, // Modern Swish activation
Optimizer::adamw(0.001, 0.01), // AdamW with weight decay
);
// Enable batch normalization for better training stability
nn.enable_batch_norm();
nn.set_training_mode(true);
// Configure regularization
nn.set_regularization(Regularization::L2(0.001));
let inputs = vec![/* your training data */];
let targets = vec![/* your target data */];
// Advanced training with all features
let early_stop = EarlyStopping::new(
15, // Patience: stop if no improvement for 15 epochs
0.0001, // Minimum improvement threshold
true, // Restore best weights when stopping
);
let checkpoint = CheckpointConfig::new("model_checkpoint".to_string())
.save_every(10); // Save every 10 epochs
let (train_losses, val_losses) = nn.train(
&inputs, &targets,
0.01, // Learning rate
500, // Max epochs
Some(32), // Batch size
Some(&inputs), // Validation inputs
Some(&targets), // Validation targets
Some(early_stop), // Early stopping
Some(checkpoint), // Checkpointing
).await?;
// Switch to inference mode
nn.set_training_mode(false);
Ok(())
}
Configure different loss functions for your specific task:
use hextral::{Hextral, LossFunction, ActivationFunction, Optimizer};
let mut nn = Hextral::new(2, &[4], 1, ActivationFunction::ReLU, Optimizer::default());
// For regression tasks
nn.set_loss_function(LossFunction::MeanSquaredError);
nn.set_loss_function(LossFunction::MeanAbsoluteError);
nn.set_loss_function(LossFunction::Huber { delta: 1.0 });
// For classification tasks
nn.set_loss_function(LossFunction::BinaryCrossEntropy);
nn.set_loss_function(LossFunction::CategoricalCrossEntropy);
Enable batch normalization for improved training stability:
use hextral::{Hextral, ActivationFunction, Optimizer};
let mut nn = Hextral::new(10, &[64, 32], 1, ActivationFunction::ReLU, Optimizer::default());
// Enable batch normalization
nn.enable_batch_norm();
// Set training mode
nn.set_training_mode(true);
// Train your network...
let loss_history = nn.train(&inputs, &targets, 0.01, 100);
// Switch to inference mode
nn.set_training_mode(false);
// Make predictions...
let prediction = nn.predict(&input);
Use state-of-the-art activation functions:
use hextral::{Hextral, ActivationFunction, Optimizer};
// Swish activation (used in EfficientNet)
let mut nn = Hextral::new(2, &[4], 1,
ActivationFunction::Swish { beta: 1.0 }, Optimizer::default());
// GELU activation (used in BERT, GPT)
let mut nn = Hextral::new(2, &[4], 1,
ActivationFunction::GELU, Optimizer::default());
// Mish activation (self-regularizing)
let mut nn = Hextral::new(2, &[4], 1,
ActivationFunction::Mish, Optimizer::default());
Prevent overfitting with built-in regularization techniques:
use hextral::{Hextral, Regularization, ActivationFunction, Optimizer};
let mut nn = Hextral::new(3, &[16, 8], 1, ActivationFunction::ReLU,
Optimizer::Adam { learning_rate: 0.01 });
// L2 regularization (Ridge)
nn.set_regularization(Regularization::L2(0.01));
// L1 regularization (Lasso)
nn.set_regularization(Regularization::L1(0.005));
// Dropout regularization
nn.set_regularization(Regularization::Dropout(0.3));
Choose the optimizer that works best for your problem:
// Adam: Good default choice, adaptive learning rates
let optimizer = Optimizer::Adam { learning_rate: 0.001 };
// SGD: Simple and interpretable
let optimizer = Optimizer::SGD { learning_rate: 0.1 };
// SGD with Momentum: Accelerated convergence
let optimizer = Optimizer::SGDMomentum {
learning_rate: 0.1,
momentum: 0.9
};
Get insights into your network:
// Network architecture
println!("Architecture: {:?}", nn.architecture()); // [2, 4, 3, 1]
// Parameter count
println!("Total parameters: {}", nn.parameter_count()); // 25
// Save/load weights
let weights = nn.get_weights();
nn.set_weights(weights);
Hextral provides comprehensive dataset loading and preprocessing capabilities.
use hextral::dataset::csv::{CsvLoader, TargetColumns};
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Load CSV with automatic type inference
let csv_loader = CsvLoader::new()
.with_headers(true)
.with_delimiter(b',')
.with_target_columns_by_name(vec!["target".to_string()])
.with_max_rows(Some(1000));
let dataset = csv_loader.from_file("data.csv").await?;
println!("Loaded {} samples with {} features",
dataset.metadata.sample_count,
dataset.metadata.feature_count);
Ok(())
}
use hextral::dataset::image::{ImageLoader, LabelStrategy, AugmentationConfig};
use image::imageops::FilterType;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Configure image preprocessing and augmentation
let augmentation = AugmentationConfig::new()
.with_horizontal_flip(0.5)
.with_rotation(15.0)
.with_brightness(0.8, 1.2)
.with_contrast(0.8, 1.2)
.with_noise(0.1);
let image_loader = ImageLoader::new()
.with_target_size(224, 224)
.with_normalization(true)
.with_grayscale(false)
.with_label_strategy(LabelStrategy::FromDirectory)
.with_augmentation(augmentation)
.with_extensions(vec!["jpg".to_string(), "png".to_string()]);
let dataset = image_loader.from_directory("./images/").await?;
println!("Loaded {} images", dataset.metadata.sample_count);
if let Some(ref class_names) = dataset.target_names {
println!("Classes: {:?}", class_names);
}
Ok(())
}
// Extract labels from directory structure
let strategy = LabelStrategy::FromDirectory;
// Extract labels from filename patterns
let strategy = LabelStrategy::FromFilename("digit".to_string()); // Extract first digit
let strategy = LabelStrategy::FromFilename("number".to_string()); // Extract first number
let strategy = LabelStrategy::FromFilename("split:_".to_string()); // Split by underscore
let strategy = LabelStrategy::FromFilename("prefix:3".to_string()); // First 3 characters
// Use manual label mapping
let mut mapping = std::collections::HashMap::new();
mapping.insert("cat_image".to_string(), 0);
mapping.insert("dog_image".to_string(), 1);
let strategy = LabelStrategy::Manual(mapping);
// Load labels from separate file
let strategy = LabelStrategy::FromFile(PathBuf::from("labels.txt"));
use hextral::dataset::{
preprocessing::{Preprocessor, PreprocessingUtils},
FillStrategy
};
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let mut dataset = /* load your dataset */;
// Create preprocessing pipeline
let preprocessor = Preprocessor::new()
.normalize(None) // Normalize all features to [0,1]
.standardize(Some(vec![0, 1, 2])) // Standardize specific features
.fill_missing(FillStrategy::Mean) // Fill missing values with mean
.remove_outliers(3.0) // Remove outliers beyond 3 std devs
.one_hot_encode(vec![3]) // One-hot encode categorical features
.apply_polynomial_features(2); // Add polynomial features (degree 2)
// Fit preprocessor and transform data
let stats = preprocessor.fit_transform(&mut dataset).await?;
// Split dataset
let (train_set, val_set, test_set) = PreprocessingUtils::train_val_test_split(
&dataset, 0.7, 0.2 // 70% train, 20% val, 10% test
).await?;
// Shuffle dataset
PreprocessingUtils::shuffle(&mut dataset).await?;
// Calculate correlation matrix
let correlation = PreprocessingUtils::correlation_matrix(&dataset).await?;
Ok(())
}
// Apply PCA for dimensionality reduction
let preprocessor = Preprocessor::new()
.standardize(None) // Always standardize before PCA
.apply_pca(10); // Reduce to 10 principal components
let stats = preprocessor.fit_transform(&mut dataset).await?;
// Features are now transformed to principal components
println!("Reduced from {} to {} dimensions",
stats.feature_means.len(),
dataset.metadata.feature_count);
use hextral::dataset::FillStrategy;
// Different strategies for handling missing values
let preprocessor = Preprocessor::new()
.fill_missing(FillStrategy::Mean) // Use column mean
.fill_missing(FillStrategy::Median) // Use column median
.fill_missing(FillStrategy::Mode) // Use most frequent value
.fill_missing(FillStrategy::Constant(0.0)) // Fill with constant
.fill_missing(FillStrategy::ForwardFill) // Use previous valid value
.fill_missing(FillStrategy::BackwardFill); // Use next valid value
Hextral - Main neural network struct with async-first APIActivationFunction - Enum for activation functions (9 available)Optimizer - Enum for optimization algorithms (12 available)Regularization - Enum for regularization techniquesEarlyStopping - Configuration for automatic training terminationCheckpointConfig - Configuration for model checkpointingLossFunction - Enum for loss functions (5 available)// Network creation
Hextral::new(inputs, hidden_layers, outputs, activation, optimizer) -> Hextral
// Training with full feature set
async fn train(
&mut self,
train_inputs: &[DVector<f64>],
train_targets: &[DVector<f64>],
learning_rate: f64,
epochs: usize,
batch_size: Option<usize>,
val_inputs: Option<&[DVector<f64>]>,
val_targets: Option<&[DVector<f64>]>,
early_stopping: Option<EarlyStopping>,
checkpoint_config: Option<CheckpointConfig>,
) -> Result<(Vec<f64>, Vec<f64>), Box<dyn std::error::Error>>
// Predictions
async fn predict(&self, input: &DVector<f64>) -> DVector<f64>
async fn predict_batch(&self, inputs: &[DVector<f64>]) -> Vec<DVector<f64>>
// Evaluation
async fn evaluate(&self, inputs: &[DVector<f64>], targets: &[DVector<f64>]) -> f64
// Forward pass
async fn forward(&self, input: &DVector<f64>) -> DVector<f64>
// Batch normalization
fn enable_batch_norm(&mut self)
fn disable_batch_norm(&mut self)
fn set_training_mode(&mut self, training: bool)
// Regularization
fn set_regularization(&mut self, reg: Regularization)
// Loss function
fn set_loss_function(&mut self, loss: LossFunction)
// Weight management
fn get_weights(&self) -> Vec<(DMatrix<f64>, DVector<f64>)>
fn set_weights(&mut self, weights: Vec<(DMatrix<f64>, DVector<f64>)>)
fn parameter_count(&self) -> usize
// Early stopping configuration
let early_stop = EarlyStopping::new(
patience: usize, // Epochs to wait for improvement
min_delta: f64, // Minimum improvement threshold
restore_best_weights: bool // Whether to restore best weights
);
// Checkpoint configuration
let checkpoint = CheckpointConfig::new("model_path".to_string())
.save_every(10) // Save every N epochs
.save_best(true); // Save best model based on validation loss
We welcome contributions! Please feel free to:
Contributions are welcome! Please feel free to submit a Pull Request. For major changes, please open an issue first to discuss what you would like to change.
This project is licensed under the MIT OR Apache-2.0 license.