Crates.io | hextral |
lib.rs | hextral |
version | 0.7.1 |
created_at | 2024-03-14 22:21:06.743007+00 |
updated_at | 2025-09-26 00:58:17.184075+00 |
description | Comprehensive neural network library with batch normalization, 9 activation functions, 5 loss functions, multiple optimizers, regularization, and clean async-first API |
homepage | |
repository | https://github.com/xStFtx/hextral |
max_upload_size | |
id | 1174213 |
size | 130,088 |
A high-performance neural network library for Rust with clean async-first API, advanced activation functions, multiple optimizers, early stopping, and checkpointing capabilities.
Add this to your Cargo.toml
:
[dependencies]
hextral = "0.7.0"
nalgebra = "0.33"
tokio = { version = "1.0", features = ["full"] } # For async features
use hextral::{Hextral, ActivationFunction, Optimizer, EarlyStopping, CheckpointConfig};
use nalgebra::DVector;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Create a neural network: 2 inputs -> [4, 3] hidden -> 1 output
let mut nn = Hextral::new(
2, // Input features
&[4, 3], // Hidden layer sizes
1, // Output size
ActivationFunction::ReLU, // Activation function
Optimizer::adam(0.01), // Modern Adam optimizer
);
// Training data for XOR problem
let train_inputs = vec![
DVector::from_vec(vec![0.0, 0.0]),
DVector::from_vec(vec![0.0, 1.0]),
DVector::from_vec(vec![1.0, 0.0]),
DVector::from_vec(vec![1.0, 1.0]),
];
let train_targets = vec![
DVector::from_vec(vec![0.0]),
DVector::from_vec(vec![1.0]),
DVector::from_vec(vec![1.0]),
DVector::from_vec(vec![0.0]),
];
// Validation data (can be same as training for demo)
let val_inputs = train_inputs.clone();
let val_targets = train_targets.clone();
// Configure early stopping and checkpointing
let early_stopping = EarlyStopping::new(10, 0.001, true);
let checkpoint_config = CheckpointConfig::new("best_model".to_string());
// Train the network with advanced features
println!("Training network with early stopping...");
let (train_history, val_history) = nn.train(
&train_inputs,
&train_targets,
0.1, // Learning rate
1000, // Max epochs
Some(2), // Batch size
Some(&val_inputs), // Validation inputs
Some(&val_targets), // Validation targets
Some(early_stopping), // Early stopping
Some(checkpoint_config), // Checkpointing
).await?;
println!("Training completed after {} epochs", train_history.len());
println!("Final validation loss: {:.6}", val_history.last().unwrap_or(&0.0));
// Make predictions
println!("\nPredictions:");
for (input, expected) in train_inputs.iter().zip(train_targets.iter()) {
let prediction = nn.predict(input).await;
println!("Input: {:?} | Expected: {:.1} | Predicted: {:.3}",
input.data.as_vec(), expected[0], prediction[0]);
}
// Batch prediction (efficient for multiple inputs)
let batch_predictions = nn.predict_batch(&train_inputs).await;
// Evaluate performance
let final_loss = nn.evaluate(&train_inputs, &train_targets).await;
println!("Final loss: {:.6}", final_loss);
Ok(())
}
use hextral::*;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// Create network with advanced activation function
let mut nn = Hextral::new(
4, &[8, 6], 2,
ActivationFunction::Swish { beta: 1.0 }, // Modern Swish activation
Optimizer::adamw(0.001, 0.01), // AdamW with weight decay
);
// Enable batch normalization for better training stability
nn.enable_batch_norm();
nn.set_training_mode(true);
// Configure regularization
nn.set_regularization(Regularization::L2(0.001));
let inputs = vec![/* your training data */];
let targets = vec![/* your target data */];
// Advanced training with all features
let early_stop = EarlyStopping::new(
15, // Patience: stop if no improvement for 15 epochs
0.0001, // Minimum improvement threshold
true, // Restore best weights when stopping
);
let checkpoint = CheckpointConfig::new("model_checkpoint".to_string())
.save_every(10); // Save every 10 epochs
let (train_losses, val_losses) = nn.train(
&inputs, &targets,
0.01, // Learning rate
500, // Max epochs
Some(32), // Batch size
Some(&inputs), // Validation inputs
Some(&targets), // Validation targets
Some(early_stop), // Early stopping
Some(checkpoint), // Checkpointing
).await?;
// Switch to inference mode
nn.set_training_mode(false);
Ok(())
}
Configure different loss functions for your specific task:
use hextral::{Hextral, LossFunction, ActivationFunction, Optimizer};
let mut nn = Hextral::new(2, &[4], 1, ActivationFunction::ReLU, Optimizer::default());
// For regression tasks
nn.set_loss_function(LossFunction::MeanSquaredError);
nn.set_loss_function(LossFunction::MeanAbsoluteError);
nn.set_loss_function(LossFunction::Huber { delta: 1.0 });
// For classification tasks
nn.set_loss_function(LossFunction::BinaryCrossEntropy);
nn.set_loss_function(LossFunction::CategoricalCrossEntropy);
Enable batch normalization for improved training stability:
use hextral::{Hextral, ActivationFunction, Optimizer};
let mut nn = Hextral::new(10, &[64, 32], 1, ActivationFunction::ReLU, Optimizer::default());
// Enable batch normalization
nn.enable_batch_norm();
// Set training mode
nn.set_training_mode(true);
// Train your network...
let loss_history = nn.train(&inputs, &targets, 0.01, 100);
// Switch to inference mode
nn.set_training_mode(false);
// Make predictions...
let prediction = nn.predict(&input);
Use state-of-the-art activation functions:
use hextral::{Hextral, ActivationFunction, Optimizer};
// Swish activation (used in EfficientNet)
let mut nn = Hextral::new(2, &[4], 1,
ActivationFunction::Swish { beta: 1.0 }, Optimizer::default());
// GELU activation (used in BERT, GPT)
let mut nn = Hextral::new(2, &[4], 1,
ActivationFunction::GELU, Optimizer::default());
// Mish activation (self-regularizing)
let mut nn = Hextral::new(2, &[4], 1,
ActivationFunction::Mish, Optimizer::default());
Prevent overfitting with built-in regularization techniques:
use hextral::{Hextral, Regularization, ActivationFunction, Optimizer};
let mut nn = Hextral::new(3, &[16, 8], 1, ActivationFunction::ReLU,
Optimizer::Adam { learning_rate: 0.01 });
// L2 regularization (Ridge)
nn.set_regularization(Regularization::L2(0.01));
// L1 regularization (Lasso)
nn.set_regularization(Regularization::L1(0.005));
// Dropout regularization
nn.set_regularization(Regularization::Dropout(0.3));
Choose the optimizer that works best for your problem:
// Adam: Good default choice, adaptive learning rates
let optimizer = Optimizer::Adam { learning_rate: 0.001 };
// SGD: Simple and interpretable
let optimizer = Optimizer::SGD { learning_rate: 0.1 };
// SGD with Momentum: Accelerated convergence
let optimizer = Optimizer::SGDMomentum {
learning_rate: 0.1,
momentum: 0.9
};
Get insights into your network:
// Network architecture
println!("Architecture: {:?}", nn.architecture()); // [2, 4, 3, 1]
// Parameter count
println!("Total parameters: {}", nn.parameter_count()); // 25
// Save/load weights
let weights = nn.get_weights();
nn.set_weights(weights);
Hextral
- Main neural network struct with async-first APIActivationFunction
- Enum for activation functions (9 available)Optimizer
- Enum for optimization algorithms (12 available)Regularization
- Enum for regularization techniquesEarlyStopping
- Configuration for automatic training terminationCheckpointConfig
- Configuration for model checkpointingLossFunction
- Enum for loss functions (5 available)// Network creation
Hextral::new(inputs, hidden_layers, outputs, activation, optimizer) -> Hextral
// Training with full feature set
async fn train(
&mut self,
train_inputs: &[DVector<f64>],
train_targets: &[DVector<f64>],
learning_rate: f64,
epochs: usize,
batch_size: Option<usize>,
val_inputs: Option<&[DVector<f64>]>,
val_targets: Option<&[DVector<f64>]>,
early_stopping: Option<EarlyStopping>,
checkpoint_config: Option<CheckpointConfig>,
) -> Result<(Vec<f64>, Vec<f64>), Box<dyn std::error::Error>>
// Predictions
async fn predict(&self, input: &DVector<f64>) -> DVector<f64>
async fn predict_batch(&self, inputs: &[DVector<f64>]) -> Vec<DVector<f64>>
// Evaluation
async fn evaluate(&self, inputs: &[DVector<f64>], targets: &[DVector<f64>]) -> f64
// Forward pass
async fn forward(&self, input: &DVector<f64>) -> DVector<f64>
// Batch normalization
fn enable_batch_norm(&mut self)
fn disable_batch_norm(&mut self)
fn set_training_mode(&mut self, training: bool)
// Regularization
fn set_regularization(&mut self, reg: Regularization)
// Loss function
fn set_loss_function(&mut self, loss: LossFunction)
// Weight management
fn get_weights(&self) -> Vec<(DMatrix<f64>, DVector<f64>)>
fn set_weights(&mut self, weights: Vec<(DMatrix<f64>, DVector<f64>)>)
fn parameter_count(&self) -> usize
// Early stopping configuration
let early_stop = EarlyStopping::new(
patience: usize, // Epochs to wait for improvement
min_delta: f64, // Minimum improvement threshold
restore_best_weights: bool // Whether to restore best weights
);
// Checkpoint configuration
let checkpoint = CheckpointConfig::new("model_path".to_string())
.save_every(10) // Save every N epochs
.save_best(true); // Save best model based on validation loss
We welcome contributions! Please feel free to:
Contributions are welcome! Please feel free to submit a Pull Request. For major changes, please open an issue first to discuss what you would like to change.
This project is licensed under the MIT OR Apache-2.0 license.