| Crates.io | tensorlogic-train |
| lib.rs | tensorlogic-train |
| version | 0.1.0-alpha.2 |
| created_at | 2025-11-07 22:46:00.044588+00 |
| updated_at | 2026-01-03 21:08:00.23507+00 |
| description | Training loops, loss composition, and optimization schedules for TensorLogic |
| homepage | https://github.com/cool-japan/tensorlogic |
| repository | https://github.com/cool-japan/tensorlogic |
| max_upload_size | |
| id | 1922305 |
| size | 1,522,082 |
Training scaffolds for Tensorlogic: loss composition, optimizers, schedulers, and callbacks.
tensorlogic-train provides comprehensive training infrastructure for Tensorlogic models, combining standard ML training components with logic-specific loss functions for constraint satisfaction and rule adherence.
Add to your Cargo.toml:
[dependencies]
tensorlogic-train = { path = "../tensorlogic-train" }
use tensorlogic_train::{
Trainer, TrainerConfig, MseLoss, AdamOptimizer, OptimizerConfig,
EpochCallback, CallbackList, MetricTracker, Accuracy,
};
use scirs2_core::ndarray::Array2;
use std::collections::HashMap;
// Create loss function
let loss = Box::new(MseLoss);
// Create optimizer
let optimizer_config = OptimizerConfig {
learning_rate: 0.001,
..Default::default()
};
let optimizer = Box::new(AdamOptimizer::new(optimizer_config));
// Create trainer
let config = TrainerConfig {
num_epochs: 10,
..Default::default()
};
let mut trainer = Trainer::new(config, loss, optimizer);
// Add callbacks
let mut callbacks = CallbackList::new();
callbacks.add(Box::new(EpochCallback::new(true)));
trainer = trainer.with_callbacks(callbacks);
// Add metrics
let mut metrics = MetricTracker::new();
metrics.add(Box::new(Accuracy::default()));
trainer = trainer.with_metrics(metrics);
// Prepare data
let train_data = Array2::zeros((100, 10));
let train_targets = Array2::zeros((100, 2));
let val_data = Array2::zeros((20, 10));
let val_targets = Array2::zeros((20, 2));
// Train model
let mut parameters = HashMap::new();
parameters.insert("weights".to_string(), Array2::zeros((10, 2)));
let history = trainer.train(
&train_data.view(),
&train_targets.view(),
Some(&val_data.view()),
Some(&val_targets.view()),
&mut parameters,
).unwrap();
// Access training history
println!("Training losses: {:?}", history.train_loss);
println!("Validation losses: {:?}", history.val_loss);
if let Some((best_epoch, best_loss)) = history.best_val_loss() {
println!("Best validation loss: {} at epoch {}", best_loss, best_epoch);
}
Combine supervised learning with logical constraints:
use tensorlogic_train::{
LogicalLoss, LossConfig, CrossEntropyLoss,
RuleSatisfactionLoss, ConstraintViolationLoss,
};
// Configure loss weights
let config = LossConfig {
supervised_weight: 1.0,
constraint_weight: 10.0, // Heavily penalize constraint violations
rule_weight: 5.0,
temperature: 1.0,
};
// Create logical loss
let logical_loss = LogicalLoss::new(
config,
Box::new(CrossEntropyLoss::default()),
vec![Box::new(RuleSatisfactionLoss::default())],
vec![Box::new(ConstraintViolationLoss::default())],
);
// Compute total loss
let total_loss = logical_loss.compute_total(
&predictions.view(),
&targets.view(),
&rule_values,
&constraint_values,
)?;
Stop training automatically when validation stops improving:
use tensorlogic_train::{CallbackList, EarlyStoppingCallback};
let mut callbacks = CallbackList::new();
callbacks.add(Box::new(EarlyStoppingCallback::new(
5, // patience: Wait 5 epochs without improvement
0.001, // min_delta: Minimum improvement threshold
)));
trainer = trainer.with_callbacks(callbacks);
// Training will stop automatically if validation doesn't improve for 5 epochs
Save model checkpoints during training:
use tensorlogic_train::{CallbackList, CheckpointCallback};
use std::path::PathBuf;
let mut callbacks = CallbackList::new();
callbacks.add(Box::new(CheckpointCallback::new(
PathBuf::from("/tmp/checkpoints"),
1, // save_frequency: Save every epoch
true, // save_best_only: Only save when validation improves
)));
trainer = trainer.with_callbacks(callbacks);
Adjust learning rate during training:
use tensorlogic_train::{CosineAnnealingLrScheduler, LrScheduler};
let scheduler = Box::new(CosineAnnealingLrScheduler::new(
0.001, // initial_lr
0.00001, // min_lr
100, // t_max: Total epochs
));
trainer = trainer.with_scheduler(scheduler);
Use L2 norm clipping for stable training of deep networks:
use tensorlogic_train::{AdamOptimizer, OptimizerConfig, GradClipMode};
let optimizer = Box::new(AdamOptimizer::new(OptimizerConfig {
learning_rate: 0.001,
grad_clip: Some(5.0), // Clip if global L2 norm > 5.0
grad_clip_mode: GradClipMode::Norm, // Use L2 norm clipping
..Default::default()
}));
// Global L2 norm is computed across all parameters:
// norm = sqrt(sum(g_i^2 for all gradients g_i))
// If norm > 5.0, all gradients are scaled by (5.0 / norm)
use tensorlogic_train::ConfusionMatrix;
let cm = ConfusionMatrix::compute(&predictions.view(), &targets.view())?;
// Pretty print the confusion matrix
println!("{}", cm);
// Output:
// Confusion Matrix:
// 0 1 2
// 0| 45 2 1
// 1| 1 38 3
// 2| 0 2 48
// Get per-class metrics
let precision = cm.precision_per_class();
let recall = cm.recall_per_class();
let f1 = cm.f1_per_class();
// Get overall accuracy
println!("Accuracy: {:.4}", cm.accuracy());
use tensorlogic_train::RocCurve;
// Binary classification example
let predictions = vec![0.9, 0.8, 0.3, 0.1];
let targets = vec![true, true, false, false];
let roc = RocCurve::compute(&predictions, &targets)?;
// Compute AUC
println!("AUC: {:.4}", roc.auc());
// Access ROC curve points
for (fpr, tpr, threshold) in izip!(
&roc.fpr,
&roc.tpr,
&roc.thresholds
) {
println!("FPR: {:.4}, TPR: {:.4}, Threshold: {:.4}",
fpr, tpr, threshold);
}
use tensorlogic_train::PerClassMetrics;
let metrics = PerClassMetrics::compute(&predictions.view(), &targets.view())?;
// Pretty print comprehensive report
println!("{}", metrics);
// Output:
// Per-Class Metrics:
// Class Precision Recall F1-Score Support
// ----- --------- ------ -------- -------
// 0 0.9583 0.9200 0.9388 50
// 1 0.9048 0.9048 0.9048 42
// 2 0.9600 0.9600 0.9600 50
// ----- --------- ------ -------- -------
// Macro 0.9410 0.9283 0.9345 142
Implement the Model trait for your own architectures:
use tensorlogic_train::{Model, TrainResult};
use scirs2_core::ndarray::{Array, ArrayView, Ix2};
use std::collections::HashMap;
struct TwoLayerNet {
parameters: HashMap<String, Array<f64, Ix2>>,
hidden_size: usize,
}
impl TwoLayerNet {
pub fn new(input_size: usize, hidden_size: usize, output_size: usize) -> Self {
let mut parameters = HashMap::new();
// Initialize weights (simplified - use proper initialization)
parameters.insert(
"W1".to_string(),
Array::zeros((input_size, hidden_size))
);
parameters.insert(
"b1".to_string(),
Array::zeros((1, hidden_size))
);
parameters.insert(
"W2".to_string(),
Array::zeros((hidden_size, output_size))
);
parameters.insert(
"b2".to_string(),
Array::zeros((1, output_size))
);
Self { parameters, hidden_size }
}
}
impl Model for TwoLayerNet {
fn forward(&self, input: &ArrayView<f64, Ix2>) -> TrainResult<Array<f64, Ix2>> {
let w1 = self.parameters.get("W1").unwrap();
let b1 = self.parameters.get("b1").unwrap();
let w2 = self.parameters.get("W2").unwrap();
let b2 = self.parameters.get("b2").unwrap();
// Forward pass: hidden = ReLU(X @ W1 + b1)
let hidden = (input.dot(w1) + b1).mapv(|x| x.max(0.0));
// Output: Y = hidden @ W2 + b2
let output = hidden.dot(w2) + b2;
Ok(output)
}
fn backward(
&self,
input: &ArrayView<f64, Ix2>,
grad_output: &ArrayView<f64, Ix2>,
) -> TrainResult<HashMap<String, Array<f64, Ix2>>> {
// Implement backpropagation
// (Simplified - in practice, cache activations from forward pass)
let mut gradients = HashMap::new();
// Compute gradients for W2, b2, W1, b1
// ...
Ok(gradients)
}
fn parameters(&self) -> &HashMap<String, Array<f64, Ix2>> {
&self.parameters
}
fn parameters_mut(&mut self) -> &mut HashMap<String, Array<f64, Ix2>> {
&mut self.parameters
}
fn set_parameters(&mut self, parameters: HashMap<String, Array<f64, Ix2>>) {
self.parameters = parameters;
}
}
Prevent overfitting with L1, L2, or Elastic Net regularization:
use tensorlogic_train::{L2Regularization, Regularizer};
use scirs2_core::ndarray::Array2;
use std::collections::HashMap;
// Create L2 regularization (weight decay)
let regularizer = L2Regularization::new(0.01); // lambda = 0.01
// Compute regularization penalty
let mut parameters = HashMap::new();
parameters.insert("weights".to_string(), Array2::ones((10, 5)));
let penalty = regularizer.compute_penalty(¶meters)?;
let gradients = regularizer.compute_gradient(¶meters)?;
// Add penalty to loss and gradients to parameter updates
total_loss += penalty;
use tensorlogic_train::ElasticNetRegularization;
// Combine L1 (sparsity) and L2 (smoothness)
let regularizer = ElasticNetRegularization::new(
0.01, // l1_lambda
0.01, // l2_lambda
);
Apply on-the-fly data augmentation during training:
use tensorlogic_train::{NoiseAugmenter, ScaleAugmenter, MixupAugmenter, DataAugmenter};
use scirs2_core::ndarray::Array2;
// Gaussian noise augmentation
let noise_aug = NoiseAugmenter::new(0.0, 0.1); // mean=0, std=0.1
let augmented = noise_aug.augment(&data.view())?;
// Scale augmentation
let scale_aug = ScaleAugmenter::new(0.8, 1.2); // scale between 0.8x and 1.2x
let scaled = scale_aug.augment(&data.view())?;
// Mixup augmentation (Zhang et al., ICLR 2018)
let mixup = MixupAugmenter::new(1.0); // alpha = 1.0 (uniform mixing)
let (mixed_data, mixed_targets) = mixup.mixup(
&data.view(),
&targets.view(),
0.3, // lambda: mixing coefficient
)?;
use tensorlogic_train::CompositeAugmenter;
let mut pipeline = CompositeAugmenter::new();
pipeline.add(Box::new(NoiseAugmenter::new(0.0, 0.05)));
pipeline.add(Box::new(ScaleAugmenter::new(0.9, 1.1)));
// Apply all augmentations in sequence
let augmented = pipeline.augment(&data.view())?;
Track training progress with multiple logging backends:
use tensorlogic_train::{ConsoleLogger, FileLogger, MetricsLogger, LoggingBackend};
use std::path::PathBuf;
// Console logging with timestamps
let console = ConsoleLogger::new(true); // with_timestamp = true
console.log_epoch(1, 10, 0.532, Some(0.612))?;
// Output: [2025-11-06 10:30:15] Epoch 1/10 - Loss: 0.5320 - Val Loss: 0.6120
// File logging
let file_logger = FileLogger::new(
PathBuf::from("/tmp/training.log"),
true, // append mode
)?;
file_logger.log_batch(1, 100, 0.425)?;
// Aggregate metrics across backends
let mut metrics_logger = MetricsLogger::new();
metrics_logger.add_backend(Box::new(console));
metrics_logger.add_backend(Box::new(file_logger));
// Log to all backends
metrics_logger.log_metric("accuracy", 0.95)?;
metrics_logger.log_epoch(5, 20, 0.234, Some(0.287))?;
tensorlogic-train/
├── src/
│ ├── lib.rs # Public API exports
│ ├── error.rs # Error types
│ ├── loss.rs # 14 loss functions
│ ├── optimizer.rs # 9 optimizers
│ ├── scheduler.rs # Learning rate schedulers
│ ├── batch.rs # Batch management
│ ├── trainer.rs # Main training loop
│ ├── callbacks.rs # Training callbacks
│ ├── metrics.rs # Evaluation metrics
│ ├── model.rs # Model trait interface
│ ├── regularization.rs # L1, L2, Elastic Net
│ ├── augmentation.rs # Data augmentation
│ └── logging.rs # Logging backends
Model: Forward/backward passes and parameter managementAutodiffModel: Automatic differentiation integration (trait extension)DynamicModel: Variable-sized input supportLoss: Compute loss and gradientsOptimizer: Update parameters with gradientsLrScheduler: Adjust learning rateCallback: Hook into training eventsMetric: Evaluate model performanceRegularizer: Compute regularization penalties and gradientsDataAugmenter: Apply data transformationsLoggingBackend: Log training metrics and eventsThis crate strictly follows the SciRS2 integration policy:
// ✅ Correct: Use SciRS2 types
use scirs2_core::ndarray::{Array, Array2};
use scirs2_autograd::Variable;
// ❌ Wrong: Never use these directly
// use ndarray::Array2; // Never!
// use rand::thread_rng; // Never!
All tensor operations use scirs2_core::ndarray, ready for seamless integration with scirs2-autograd for automatic differentiation.
All modules have comprehensive unit tests:
| Module | Tests | Coverage |
|---|---|---|
loss.rs |
13 | All 14 loss functions (CE, MSE, Focal, Huber, Dice, Tversky, BCE, Contrastive, Triplet, Hinge, KL, logical) |
optimizer.rs |
18 | All 13 optimizers (SGD, Adam, AdamW, RMSprop, Adagrad, NAdam, LAMB, AdaMax, Lookahead, AdaBelief, RAdam, LARS, SAM + clipping) |
scheduler.rs |
11 | LR scheduling (Step, Exp, Cosine, OneCycle, Cyclic, Polynomial, Warmup, WarmupCosine, Noam, MultiStep, ReduceLROnPlateau) |
batch.rs |
5 | Batch iteration & sampling |
trainer.rs |
3 | Training loop |
callbacks.rs |
8 | 13+ callbacks (checkpointing, early stopping, Model EMA, Grad Accum, SWA, LR finder, profiling) |
metrics.rs |
15 | Metrics, confusion matrix, ROC/AUC, per-class analysis |
model.rs |
6 | Model interface & implementations |
regularization.rs |
8 | L1, L2, Elastic Net, Composite regularization |
augmentation.rs |
12 | Noise, Scale, Rotation, Mixup augmentations |
logging.rs |
11 | Console, File, TensorBoard loggers + metrics aggregation |
| Total | 172 | 100% |
Run tests with:
cargo nextest run -p tensorlogic-train --no-fail-fast
See TODO.md for the complete roadmap, including:
The crate includes 5 comprehensive examples demonstrating all features:
Run any example with:
cargo run --example 01_basic_training
See examples/README.md for detailed descriptions and usage patterns.
Comprehensive guides are available in the docs/ directory:
Loss Function Selection Guide - Choose the right loss for your task
Hyperparameter Tuning Guide - Optimize training performance
Performance benchmarks are available in the benches/ directory:
cargo bench -p tensorlogic-train
Benchmarks cover:
Apache-2.0
See CONTRIBUTING.md for guidelines.
Status: ✅ Production Ready (Phase 6.3+ - 100% complete) **Last Updated: 2025-12-16 Version: 0.1.0-alpha.2 Test Coverage: 172/172 tests passing (100%) Code Quality: Zero warnings, clippy clean Features: 14 losses, 13 optimizers, 11 schedulers, 13+ callbacks, regularization, augmentation, logging, curriculum, transfer, ensembling Examples: 5 comprehensive training examples
New in this update: