| Crates.io | trustformers-training |
| lib.rs | trustformers-training |
| version | 0.1.0-alpha.1 |
| created_at | 2025-11-09 10:18:14.698035+00 |
| updated_at | 2025-11-09 10:18:14.698035+00 |
| description | Training infrastructure for TrustformeRS |
| homepage | |
| repository | https://github.com/cool-japan/trustformers |
| max_upload_size | |
| id | 1923940 |
| size | 1,953,049 |
Comprehensive training infrastructure for transformer models with distributed training, hyperparameter optimization, and advanced training techniques.
This crate provides production-ready training capabilities including distributed training across multiple nodes, mixed precision training, hyperparameter tuning, and quantization-aware training. The design closely follows HuggingFace Transformers' Trainer API for familiarity.
use trustformers_training::{
Trainer, TrainingArguments,
optimizers::{AdamW, AdamWConfig},
schedulers::{LinearScheduler, SchedulerConfig},
};
// Configure training
let args = TrainingArguments {
output_dir: "output".to_string(),
num_train_epochs: 3,
per_device_train_batch_size: 32,
learning_rate: 5e-5,
warmup_steps: 500,
logging_steps: 100,
save_steps: 1000,
fp16: true,
gradient_accumulation_steps: 4,
..Default::default()
};
// Create optimizer
let optimizer = AdamW::new(AdamWConfig {
lr: args.learning_rate,
weight_decay: 0.01,
..Default::default()
})?;
// Create trainer
let trainer = Trainer::new(
model,
args,
train_dataset,
eval_dataset,
optimizer,
)?;
// Train
trainer.train()?;
use trustformers_training::distributed::{
DistributedTrainer,
ProcessGroup,
ZeroStage,
};
// Initialize distributed environment
let process_group = ProcessGroup::new_from_env()?;
// Configure ZeRO
let trainer = DistributedTrainer::new(
model,
args,
process_group,
ZeroStage::Three, // Full parameter sharding
)?;
// Train across multiple GPUs/nodes
trainer.train()?;
trustformers-training/
├── src/
│ ├── trainer.rs # Main trainer implementation
│ ├── args.rs # Training arguments
│ ├── distributed/ # Distributed training
│ │ ├── data_parallel.rs
│ │ ├── zero.rs # ZeRO optimizer
│ │ └── process_group.rs
│ ├── mixed_precision/ # AMP implementation
│ ├── hyperparameter/ # HP optimization
│ ├── loss/ # Loss functions
│ ├── metrics/ # Evaluation metrics
│ ├── callbacks/ # Callback system
│ └── schedulers/ # LR schedulers
| Configuration | GPUs | Model | Throughput | Speedup |
|---|---|---|---|---|
| Single GPU | 1 | BERT-Large | 250 samples/s | 1.0x |
| Data Parallel | 8 | BERT-Large | 1,920 samples/s | 7.7x |
| ZeRO Stage 2 | 8 | GPT-2 1.5B | 450 samples/s | - |
| ZeRO Stage 3 | 16 | LLaMA 7B | 320 samples/s | - |
Benchmarks on NVIDIA A100 GPUs with NVLink
TrainerCallback traitMetric traitPlanned additions:
MIT OR Apache-2.0