Crates.io | rust-lstm |
lib.rs | rust-lstm |
version | 0.5.0 |
created_at | 2024-06-09 21:53:52.36335+00 |
updated_at | 2025-08-15 00:07:05.671078+00 |
description | A complete LSTM neural network library with training capabilities, multiple optimizers, and peephole variants. |
homepage | https://github.com/SyntaxSpirits/rust-lstm |
repository | https://github.com/SyntaxSpirits/rust-lstm |
max_upload_size | |
id | 1266641 |
size | 374,698 |
A comprehensive LSTM (Long Short-Term Memory) neural network library implemented in Rust with complete training capabilities, multiple optimizers, and advanced regularization.
graph TD
A["Input Sequence<br/>(x₁, x₂, ..., xₜ)"] --> B["LSTM Layer 1"]
B --> C["LSTM Layer 2"]
C --> D["Output Layer"]
D --> E["Predictions<br/>(y₁, y₂, ..., yₜ)"]
F["Hidden State h₀"] --> B
G["Cell State c₀"] --> B
B --> H["Hidden State h₁"]
B --> I["Cell State c₁"]
H --> C
I --> C
style A fill:#e1f5fe
style E fill:#e8f5e8
style B fill:#fff3e0
style C fill:#fff3e0
Add to your Cargo.toml
:
[dependencies]
rust-lstm = "0.5.0"
use ndarray::Array2;
use rust_lstm::models::lstm_network::LSTMNetwork;
fn main() {
// Create LSTM network
let mut network = LSTMNetwork::new(3, 10, 2); // input_size, hidden_size, num_layers
// Create input data
let input = Array2::from_shape_vec((3, 1), vec![0.5, 0.1, -0.3]).unwrap();
let hx = Array2::zeros((10, 1));
let cx = Array2::zeros((10, 1));
// Forward pass
let (output, _) = network.forward(&input, &hx, &cx);
println!("Output: {:?}", output);
}
use rust_lstm::{LSTMNetwork, create_basic_trainer, TrainingConfig};
fn main() {
// Create network with dropout
let network = LSTMNetwork::new(1, 10, 2)
.with_input_dropout(0.2, true)
.with_recurrent_dropout(0.3, true);
// Setup trainer (uses SGD optimizer and MSE loss by default)
let mut trainer = create_basic_trainer(network, 0.001)
.with_config(TrainingConfig {
epochs: 100,
clip_gradient: Some(1.0),
..Default::default()
});
// Train (train_data is slice of (input_sequence, target_sequence) tuples)
// Each input_sequence and target_sequence is Vec<Array2<f64>>
trainer.train(&train_data, Some(&validation_data));
}
use rust_lstm::layers::bilstm_network::{BiLSTMNetwork, CombineMode};
// BiLSTM with concatenated outputs (output_size = 2 * hidden_size)
let mut bilstm = BiLSTMNetwork::new_concat(input_size, hidden_size, num_layers);
// Process sequence with both past and future context
let outputs = bilstm.forward_sequence(&sequence);
graph TD
A["Input Sequence<br/>(x₁, x₂, x₃, x₄)"] --> B["Forward LSTM"]
A --> C["Backward LSTM"]
B --> D["Forward Hidden States<br/>(h₁→, h₂→, h₃→, h₄→)"]
C --> E["Backward Hidden States<br/>(h₁←, h₂←, h₃←, h₄←)"]
D --> F["Combine Layer<br/>(Concat/Sum/Average)"]
E --> F
F --> G["BiLSTM Output<br/>(combined representations)"]
style A fill:#e1f5fe
style B fill:#fff3e0
style C fill:#fff3e0
style F fill:#f3e5f5
style G fill:#e8f5e8
use rust_lstm::models::gru_network::GRUNetwork;
// Create GRU network (alternative to LSTM)
let mut gru = GRUNetwork::new(input_size, hidden_size, num_layers)
.with_input_dropout(0.2, true)
.with_recurrent_dropout(0.3, true);
// Forward pass
let (output, _) = gru.forward(&input, &hidden_state);
graph LR
subgraph "LSTM Cell"
A1["Input xₜ"] --> B1["Forget Gate<br/>fₜ = σ(Wf·[hₜ₋₁,xₜ] + bf)"]
A1 --> C1["Input Gate<br/>iₜ = σ(Wi·[hₜ₋₁,xₜ] + bi)"]
A1 --> D1["Candidate Values<br/>C̃ₜ = tanh(WC·[hₜ₋₁,xₜ] + bC)"]
A1 --> E1["Output Gate<br/>oₜ = σ(Wo·[hₜ₋₁,xₜ] + bo)"]
B1 --> F1["Cell State<br/>Cₜ = fₜ * Cₜ₋₁ + iₜ * C̃ₜ"]
C1 --> F1
D1 --> F1
F1 --> G1["Hidden State<br/>hₜ = oₜ * tanh(Cₜ)"]
E1 --> G1
end
subgraph "GRU Cell"
A2["Input xₜ"] --> B2["Reset Gate<br/>rₜ = σ(Wr·[hₜ₋₁,xₜ])"]
A2 --> C2["Update Gate<br/>zₜ = σ(Wz·[hₜ₋₁,xₜ])"]
A2 --> D2["Candidate State<br/>h̃ₜ = tanh(W·[rₜ*hₜ₋₁,xₜ])"]
B2 --> D2
C2 --> E2["Hidden State<br/>hₜ = (1-zₜ)*hₜ₋₁ + zₜ*h̃ₜ"]
D2 --> E2
end
style B1 fill:#ffcdd2
style C1 fill:#c8e6c9
style D1 fill:#fff3e0
style E1 fill:#e1f5fe
style B2 fill:#ffcdd2
style C2 fill:#c8e6c9
style D2 fill:#fff3e0
The library includes 12 different learning rate schedulers with visualization capabilities:
use rust_lstm::{
LSTMNetwork, create_step_lr_trainer, create_one_cycle_trainer, create_cosine_annealing_trainer,
ScheduledOptimizer, PolynomialLR, CyclicalLR, WarmupScheduler,
LRScheduleVisualizer, Adam
};
// Create a network
let network = LSTMNetwork::new(1, 10, 2);
// Step decay: reduce LR by 50% every 10 epochs
let mut trainer = create_step_lr_trainer(network, 0.01, 10, 0.5);
// OneCycle policy for modern deep learning
let mut trainer = create_one_cycle_trainer(network.clone(), 0.1, 100);
// Cosine annealing with warm restarts
let mut trainer = create_cosine_annealing_trainer(network.clone(), 0.01, 20, 1e-6);
// Advanced combinations - Warmup + Cyclical scheduling
let base_scheduler = CyclicalLR::new(0.001, 0.01, 10);
let warmup_scheduler = WarmupScheduler::new(5, base_scheduler, 0.0001);
let optimizer = ScheduledOptimizer::new(Adam::new(0.01), warmup_scheduler, 0.01);
// Polynomial decay with visualization
let poly_scheduler = PolynomialLR::new(100, 2.0, 0.001);
LRScheduleVisualizer::print_schedule(poly_scheduler, 0.01, 100, 60, 10);
layers
: LSTM and GRU cells (standard, peephole, bidirectional) with dropoutmodels
: High-level network architectures (LSTM, BiLSTM, GRU)training
: Training utilities with automatic train/eval mode switchingoptimizers
: SGD, Adam, RMSprop with schedulingloss
: MSE, MAE, Cross-entropy loss functionsschedulers
: Learning rate scheduling algorithmsRun examples to see the library in action:
# Basic usage and training
cargo run --example basic_usage
cargo run --example training_example
cargo run --example multi_layer_lstm
cargo run --example time_series_prediction
# Advanced architectures
cargo run --example gru_example # GRU vs LSTM comparison
cargo run --example bilstm_example # Bidirectional LSTM
cargo run --example dropout_example # Dropout demo
# Learning and scheduling
cargo run --example learning_rate_scheduling # Basic schedulers
cargo run --example advanced_lr_scheduling # Advanced schedulers with visualization
# Real-world applications
cargo run --example stock_prediction
cargo run --example weather_prediction
cargo run --example text_classification_bilstm
cargo run --example text_generation_advanced
cargo run --example real_data_example
# Analysis and debugging
cargo run --example model_inspection
cargo test
The library includes comprehensive examples that demonstrate its capabilities:
Run the learning rate scheduling examples to see different scheduler behaviors:
cargo run --example learning_rate_scheduling # Compare basic schedulers
cargo run --example advanced_lr_scheduling # Advanced schedulers with ASCII visualization
Compare LSTM vs GRU performance:
cargo run --example gru_example
Test the library with practical examples:
cargo run --example stock_prediction # Stock price predictions
cargo run --example weather_prediction # Weather forecasting
cargo run --example text_classification_bilstm # Classification accuracy
The examples output training metrics, loss values, and predictions that you can analyze or plot with external tools.
Contributions are welcome! Please submit issues, feature requests, or pull requests.
MIT License - see the LICENSE file for details.