hextral

Crates.iohextral
lib.rshextral
version0.7.1
created_at2024-03-14 22:21:06.743007+00
updated_at2025-09-26 00:58:17.184075+00
descriptionComprehensive neural network library with batch normalization, 9 activation functions, 5 loss functions, multiple optimizers, regularization, and clean async-first API
homepage
repositoryhttps://github.com/xStFtx/hextral
max_upload_size
id1174213
size130,088
Philip (xStFtx)

documentation

README

Hextral

A high-performance neural network library for Rust with clean async-first API, advanced activation functions, multiple optimizers, early stopping, and checkpointing capabilities.

Crates.io Documentation

Features

Core Architecture

  • Multi-layer perceptrons with configurable hidden layers
  • Batch normalization for improved training stability and convergence
  • Xavier weight initialization for stable gradient flow
  • Flexible network topology - specify any number of hidden layers and neurons
  • Clean async-first API with intelligent yielding for non-blocking operations

Activation Functions (9 Available)

  • ReLU - Rectified Linear Unit (good for most cases)
  • Sigmoid - Smooth activation for binary classification
  • Tanh - Hyperbolic tangent for centered outputs
  • Leaky ReLU - Prevents dying ReLU problem
  • ELU - Exponential Linear Unit for smoother gradients
  • Linear - For regression output layers
  • Swish - Modern activation with smooth derivatives
  • GELU - Gaussian Error Linear Unit used in transformers
  • Mish - Self-regularizing activation function
  • Quaternion - Quaternion-based normalization for 4D data

Loss Functions (5 Available)

  • Mean Squared Error (MSE) - Standard regression loss
  • Mean Absolute Error (MAE) - Robust to outliers
  • Binary Cross-Entropy - Binary classification
  • Categorical Cross-Entropy - Multi-class classification
  • Huber Loss - Robust hybrid of MSE and MAE

Optimization Algorithms (12 Available)

  • Adam - Adaptive moment estimation (recommended for most cases)
  • AdamW - Adam with decoupled weight decay
  • NAdam - Nesterov-accelerated Adam
  • AdaBelief - Adapting stepsizes by belief in observed gradients
  • Lion - Evolved sign momentum optimizer
  • SGD - Stochastic Gradient Descent (simple and reliable)
  • SGD with Momentum - Accelerated gradient descent
  • RMSprop - Root mean square propagation
  • AdaGrad - Adaptive gradient algorithm
  • AdaDelta - Extension of AdaGrad
  • LBFGS - Limited-memory BFGS (quasi-Newton method)
  • Ranger - Combination of RAdam and LookAhead

Advanced Training Features

  • Early Stopping - Automatic training termination based on validation loss
  • Checkpointing - Save and restore model weights with bincode serialization
  • Regularization - L1/L2 regularization and dropout support
  • Batch Training - Configurable batch sizes for memory efficiency
  • Training Progress Tracking - Loss history and validation monitoring
  • Dual sync/async API for both blocking and non-blocking operations

Async/Concurrent Processing

  • Async training methods with cooperative multitasking
  • Parallel batch prediction using futures
  • Intelligent yielding - only yields for large workloads (>1000 elements)
  • Concurrent activation function processing
  • Performance-optimized async implementation alongside synchronous methods

Quick Start

Add this to your Cargo.toml:

[dependencies]
hextral = "0.7.0"
nalgebra = "0.33"
tokio = { version = "1.0", features = ["full"] }  # For async features

Basic Async Usage (Recommended)

use hextral::{Hextral, ActivationFunction, Optimizer, EarlyStopping, CheckpointConfig};
use nalgebra::DVector;

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
    // Create a neural network: 2 inputs -> [4, 3] hidden -> 1 output
    let mut nn = Hextral::new(
        2,                                    // Input features
        &[4, 3],                             // Hidden layer sizes  
        1,                                    // Output size
        ActivationFunction::ReLU,             // Activation function
        Optimizer::adam(0.01),                // Modern Adam optimizer
    );

    // Training data for XOR problem
    let train_inputs = vec![
        DVector::from_vec(vec![0.0, 0.0]),
        DVector::from_vec(vec![0.0, 1.0]),
        DVector::from_vec(vec![1.0, 0.0]),
        DVector::from_vec(vec![1.0, 1.0]),
    ];
    let train_targets = vec![
        DVector::from_vec(vec![0.0]),
        DVector::from_vec(vec![1.0]),
        DVector::from_vec(vec![1.0]),
        DVector::from_vec(vec![0.0]),
    ];

    // Validation data (can be same as training for demo)
    let val_inputs = train_inputs.clone();
    let val_targets = train_targets.clone();

    // Configure early stopping and checkpointing
    let early_stopping = EarlyStopping::new(10, 0.001, true);
    let checkpoint_config = CheckpointConfig::new("best_model".to_string());

    // Train the network with advanced features
    println!("Training network with early stopping...");
    let (train_history, val_history) = nn.train(
        &train_inputs,
        &train_targets,
        0.1,                           // Learning rate
        1000,                          // Max epochs
        Some(2),                       // Batch size
        Some(&val_inputs),             // Validation inputs
        Some(&val_targets),            // Validation targets
        Some(early_stopping),          // Early stopping
        Some(checkpoint_config),       // Checkpointing
    ).await?;

    println!("Training completed after {} epochs", train_history.len());
    println!("Final validation loss: {:.6}", val_history.last().unwrap_or(&0.0));

    // Make predictions
    println!("\nPredictions:");
    for (input, expected) in train_inputs.iter().zip(train_targets.iter()) {
        let prediction = nn.predict(input).await;
        println!("Input: {:?} | Expected: {:.1} | Predicted: {:.3}", 
                 input.data.as_vec(), expected[0], prediction[0]);
    }

    // Batch prediction (efficient for multiple inputs)
    let batch_predictions = nn.predict_batch(&train_inputs).await;
    
    // Evaluate performance
    let final_loss = nn.evaluate(&train_inputs, &train_targets).await;
    println!("Final loss: {:.6}", final_loss);

    Ok(())
}

Advanced Features

use hextral::*;

#[tokio::main] 
async fn main() -> Result<(), Box<dyn std::error::Error>> {
    // Create network with advanced activation function
    let mut nn = Hextral::new(
        4, &[8, 6], 2,
        ActivationFunction::Swish { beta: 1.0 },  // Modern Swish activation
        Optimizer::adamw(0.001, 0.01),            // AdamW with weight decay
    );

    // Enable batch normalization for better training stability
    nn.enable_batch_norm();
    nn.set_training_mode(true);

    // Configure regularization
    nn.set_regularization(Regularization::L2(0.001));

    let inputs = vec![/* your training data */];
    let targets = vec![/* your target data */];

    // Advanced training with all features
    let early_stop = EarlyStopping::new(
        15,      // Patience: stop if no improvement for 15 epochs
        0.0001,  // Minimum improvement threshold
        true,    // Restore best weights when stopping
    );

    let checkpoint = CheckpointConfig::new("model_checkpoint".to_string())
        .save_every(10);  // Save every 10 epochs

    let (train_losses, val_losses) = nn.train(
        &inputs, &targets,
        0.01,               // Learning rate
        500,                // Max epochs
        Some(32),           // Batch size
        Some(&inputs),      // Validation inputs
        Some(&targets),     // Validation targets
        Some(early_stop),   // Early stopping
        Some(checkpoint),   // Checkpointing
    ).await?;

    // Switch to inference mode
    nn.set_training_mode(false);
    
    Ok(())
}
  • Scalable architecture - Ideal for web services and concurrent applications
  • Parallel batch processing - Multiple predictions processed concurrently using futures

Loss Functions

Configure different loss functions for your specific task:

use hextral::{Hextral, LossFunction, ActivationFunction, Optimizer};

let mut nn = Hextral::new(2, &[4], 1, ActivationFunction::ReLU, Optimizer::default());

// For regression tasks
nn.set_loss_function(LossFunction::MeanSquaredError);
nn.set_loss_function(LossFunction::MeanAbsoluteError);
nn.set_loss_function(LossFunction::Huber { delta: 1.0 });

// For classification tasks
nn.set_loss_function(LossFunction::BinaryCrossEntropy);
nn.set_loss_function(LossFunction::CategoricalCrossEntropy);

Batch Normalization

Enable batch normalization for improved training stability:

use hextral::{Hextral, ActivationFunction, Optimizer};

let mut nn = Hextral::new(10, &[64, 32], 1, ActivationFunction::ReLU, Optimizer::default());

// Enable batch normalization
nn.enable_batch_norm();

// Set training mode
nn.set_training_mode(true);

// Train your network...
let loss_history = nn.train(&inputs, &targets, 0.01, 100);

// Switch to inference mode
nn.set_training_mode(false);

// Make predictions...
let prediction = nn.predict(&input);

Modern Activation Functions

Use state-of-the-art activation functions:

use hextral::{Hextral, ActivationFunction, Optimizer};

// Swish activation (used in EfficientNet)
let mut nn = Hextral::new(2, &[4], 1, 
    ActivationFunction::Swish { beta: 1.0 }, Optimizer::default());

// GELU activation (used in BERT, GPT)
let mut nn = Hextral::new(2, &[4], 1, 
    ActivationFunction::GELU, Optimizer::default());

// Mish activation (self-regularizing)
let mut nn = Hextral::new(2, &[4], 1, 
    ActivationFunction::Mish, Optimizer::default());

Regularization

Prevent overfitting with built-in regularization techniques:

use hextral::{Hextral, Regularization, ActivationFunction, Optimizer};

let mut nn = Hextral::new(3, &[16, 8], 1, ActivationFunction::ReLU, 
                          Optimizer::Adam { learning_rate: 0.01 });

// L2 regularization (Ridge)
nn.set_regularization(Regularization::L2(0.01));

// L1 regularization (Lasso) 
nn.set_regularization(Regularization::L1(0.005));

// Dropout regularization
nn.set_regularization(Regularization::Dropout(0.3));

Different Optimizers

Choose the optimizer that works best for your problem:

// Adam: Good default choice, adaptive learning rates
let optimizer = Optimizer::Adam { learning_rate: 0.001 };

// SGD: Simple and interpretable
let optimizer = Optimizer::SGD { learning_rate: 0.1 };

// SGD with Momentum: Accelerated convergence
let optimizer = Optimizer::SGDMomentum { 
    learning_rate: 0.1, 
    momentum: 0.9 
};

Network Introspection

Get insights into your network:

// Network architecture
println!("Architecture: {:?}", nn.architecture()); // [2, 4, 3, 1]

// Parameter count  
println!("Total parameters: {}", nn.parameter_count()); // 25

// Save/load weights
let weights = nn.get_weights();
nn.set_weights(weights);

API Reference

Core Types

  • Hextral - Main neural network struct with async-first API
  • ActivationFunction - Enum for activation functions (9 available)
  • Optimizer - Enum for optimization algorithms (12 available)
  • Regularization - Enum for regularization techniques
  • EarlyStopping - Configuration for automatic training termination
  • CheckpointConfig - Configuration for model checkpointing
  • LossFunction - Enum for loss functions (5 available)

Primary Methods (All Async)

// Network creation
Hextral::new(inputs, hidden_layers, outputs, activation, optimizer) -> Hextral

// Training with full feature set
async fn train(
    &mut self,
    train_inputs: &[DVector<f64>],
    train_targets: &[DVector<f64>],
    learning_rate: f64,
    epochs: usize,
    batch_size: Option<usize>,
    val_inputs: Option<&[DVector<f64>]>,
    val_targets: Option<&[DVector<f64>]>,
    early_stopping: Option<EarlyStopping>,
    checkpoint_config: Option<CheckpointConfig>,
) -> Result<(Vec<f64>, Vec<f64>), Box<dyn std::error::Error>>

// Predictions
async fn predict(&self, input: &DVector<f64>) -> DVector<f64>
async fn predict_batch(&self, inputs: &[DVector<f64>]) -> Vec<DVector<f64>>

// Evaluation
async fn evaluate(&self, inputs: &[DVector<f64>], targets: &[DVector<f64>]) -> f64

// Forward pass
async fn forward(&self, input: &DVector<f64>) -> DVector<f64>

Configuration Methods

// Batch normalization
fn enable_batch_norm(&mut self)
fn disable_batch_norm(&mut self)
fn set_training_mode(&mut self, training: bool)

// Regularization
fn set_regularization(&mut self, reg: Regularization)

// Loss function
fn set_loss_function(&mut self, loss: LossFunction)

// Weight management
fn get_weights(&self) -> Vec<(DMatrix<f64>, DVector<f64>)>
fn set_weights(&mut self, weights: Vec<(DMatrix<f64>, DVector<f64>)>)
fn parameter_count(&self) -> usize

Early Stopping & Checkpointing

// Early stopping configuration
let early_stop = EarlyStopping::new(
    patience: usize,           // Epochs to wait for improvement
    min_delta: f64,           // Minimum improvement threshold
    restore_best_weights: bool // Whether to restore best weights
);

// Checkpoint configuration  
let checkpoint = CheckpointConfig::new("model_path".to_string())
    .save_every(10)           // Save every N epochs
    .save_best(true);         // Save best model based on validation loss

Performance Tips

  1. Use ReLU activation for hidden layers in most cases
  2. Start with Adam optimizer - it adapts learning rates automatically
  3. Apply L2 regularization if you see overfitting (test loss > train loss)
  4. Use dropout for large networks to prevent co-adaptation
  5. Normalize your input data to [0,1] or [-1,1] range for better training stability

Architecture Decisions

  • Built on nalgebra for efficient linear algebra operations
  • Xavier initialization for stable gradient flow from the start
  • Proper error handling throughout the API
  • Modular design allowing easy extension of activation functions and optimizers
  • Zero-copy predictions where possible for performance

Contributing

We welcome contributions! Please feel free to:

  • Report bugs by opening an issue
  • Suggest new features or improvements
  • Submit pull requests with enhancements
  • Improve documentation
  • Add more test cases

Changelog

Changelog

v0.7.0 (Latest)

  • Removed Redundancy: Eliminated confusing duplicate methods and verbose naming patterns
  • Better Performance: Streamlined async implementation with intelligent yielding
  • Updated Documentation: All examples now use clean, consistent API
  • All Tests Updated: Comprehensive test suite updated for new API patterns

v0.6.0

  • Full Async/Await Support: Complete async API alongside synchronous methods
  • Intelligent Yielding: Performance-optimized async with yielding only for large workloads (>1000 elements)
  • Concurrent Processing: Parallel batch predictions using futures and join_all
  • Async Training: Non-blocking training with cooperative multitasking
  • Code Optimization: Removed verbose AI-generated patterns, cleaner professional code
  • Performance Improvements: Smart async yielding prevents unnecessary overhead
  • Enhanced Documentation: Updated examples and API documentation

v0.5.1

  • Improved Documentation: Enhanced README with comprehensive examples of all new features
  • Better Crates.io Presentation: Updated documentation to properly showcase library capabilities

v0.5.0

  • Major Feature Expansion: Added comprehensive loss functions, batch normalization, and modern activation functions
  • 5 Loss Functions: MSE, MAE, Binary Cross-Entropy, Categorical Cross-Entropy, Huber Loss
  • Batch Normalization: Full implementation with training/inference modes
  • 3 New Activation Functions: Swish, GELU, Mish (total of 9 activation functions)
  • Code Organization: Separated tests into dedicated files for cleaner library structure
  • Enhanced API: Flexible loss function configuration and batch normalization controls

v0.4.0

  • Complete rewrite with proper error handling and fixed implementations
  • Implemented all documented features - train(), predict(), evaluate() methods
  • Fixed critical bugs in batch normalization and backward pass
  • Added regularization support - L1, L2, and Dropout
  • Improved documentation with usage examples and API reference

Contributing

Contributions are welcome! Please feel free to submit a Pull Request. For major changes, please open an issue first to discuss what you would like to change.

License

This project is licensed under the MIT OR Apache-2.0 license.

Commit count: 44

cargo fmt