| Crates.io | nanogpt |
| lib.rs | nanogpt |
| version | 0.1.0 |
| created_at | 2025-12-26 06:39:51.911499+00 |
| updated_at | 2025-12-26 06:39:51.911499+00 |
| description | Nanochat in Rust |
| homepage | https://github.com/tuned-org-uk/nanogpt-rs |
| repository | https://github.com/tuned-org-uk/nanogpt-rs |
| max_upload_size | |
| id | 2005326 |
| size | 242,685 |
A GPT implementation in Rust using the Burn deep learning framework. This is a high-performance Rust port inspired by Andrej Karpathy's nanochat, featuring modern transformer architecture with advanced optimizations.
Karpathy's original repo and license here
cargo run --release --features wgpu
cargo run --release --features cuda
Test model capabilities:
cargo run --example check_features --features wgpu -- --nocapture
nanochat-rs implements a decoder-only transformer with state-of-the-art features for efficient text generation. Built on Burn 0.18, it leverages Rust's performance and safety guarantees while providing GPU acceleration through multiple backends (WGPU, CUDA, CPU).
Model Architecture (gpt.rs)
Multi-layer Transformer: N stacked decoder blocks with pre-norm residual connections
Rotary Position Embeddings (RoPE): Replaces learned positional encodings with rotary embeddings for better length generalization
Multi-Query Attention (MQA): Reduces KV cache size by sharing key/value heads across query heads
RMSNorm: Parameter-free normalization for stability (instead of LayerNorm)
QK-norm: Normalizes queries and keys before attention to prevent numerical instability
ReLUΒ² MLP: Uses ReLU(x)Β² activation for better gradient flow on GPUs
Softcap Logits: Bounds output logits using tanh(x/15)*15 to prevent extreme values
# CPU only (fast compile)
cargo build --release
# GPU with WGPU
cargo build --release --features wgpu
# NVIDIA with CUDA
cargo build --release --features cuda
# Unit tests
cargo test
# Integration test with demo
cargo run --release --bin main
Attention Mechanism:
Forward Pass Flow:
Input IDs [B,T]
β Embedding [B,T,C]
β N Γ Block(RMSNorm β Attn+Residual β RMSNorm β MLP+Residual)
β Final RMSNorm [B,T,C]
β LM Head [B,T,V]
β Softcap + Clamp
β Logits [B,T,V]
Decode Flow (with cache):
Last Token [B,1]
β Embed [B,1,C]
β N Γ Block(decode with cache update)
β Final RMSNorm [B,1,C]
β LM Head [B,1,V]
β Sample next token
M2: Sampling Policies
M3: Multi-Block GPT
M4: KV Cache & Streaming
M5: RoPE (Rotary Position Embeddings)
M6: RMSNorm & QK-norm
M7: Multi-Query Attention
M8: Advanced Sampling
M9: Logits Softcap
M10: Checkpoint I/O
Separate config (JSON) and weights (MessagePack) serialization
Uses NamedMpkFileRecorder for cross-backend compatibility
Clean save/load API via checkpoint module
src/
βββ lib.rs # Public API exports
βββ gpt.rs # Core GPT model implementation
βββ config.rs # Model hyperparameters
βββ engine.rs # KV cache and streaming interface
βββ sampling.rs # Sampling strategies
βββ checkpoint.rs # Model serialization
βββ backend.rs # Multi-backend support (WGPU/CUDA/Metal)
βββ tokenizer.rs # BPE tokenizer (compatible with rustbpe)
use burn::tensor::{Int, Tensor};
use nanochat::{
backend::{get_device, AutoBackend},
config::nanochatConfig,
gpt::GptModel,
sampling::{sample_with_policy, SamplingPolicy},
};
// Configure model
let cfg = nanochatConfig {
vocab_size: 65536,
n_layer: 12,
n_head: 8,
n_kv_head: 2, // MQA: 2 KV heads shared across 8 Q heads
n_embd: 768,
sequence_len: 2048,
block_size: 2048,
dropout: 0.0,
};
let device = get_device();
let model = GptModel::<AutoBackend>::new(&cfg, &device);
// Encode input (token IDs from tokenizer)
let input_ids = vec![1, 2, 3, 4, 5];
let input = Tensor::<AutoBackend, 1, Int>::from_ints(&input_ids, &device)
.reshape([1, input_ids.len()]);
// Generate with temperature sampling
let output = model.generate(input, 50);
use nanochat::engine::{Engine, KVCache};
let engine = Engine::new(model, device);
// Stream tokens one at a time
for next_token in engine.stream(input, 100) {
let token_id = next_token.to_data().to_vec::<i64>().unwrap()[^0];
// Decode and display token
print!("{}", tokenizer.decode(&[token_id as u32]));
}
use nanochat::sampling::{extract_last_logits, sample_with_policy, SamplingPolicy};
let logits = model.forward(input, true); // true = use softcap
let last_logits = extract_last_logits(logits);
// Nucleus sampling with temperature
let next_token = sample_with_policy(
last_logits,
SamplingPolicy::TempTopP { t: 0.8, p: 0.9 }
);
use nanochat::checkpoint::{save_checkpoint, load_checkpoint};
// Save
save_checkpoint(&model, &cfg, "./checkpoints/model_v1")?;
// Load
let (loaded_model, loaded_cfg) = load_checkpoint::<AutoBackend>(
"./checkpoints/model_v1",
&device
)?;
Key hyperparameters in nanochatConfig:
pub struct nanochatConfig {
pub vocab_size: usize, // Tokenizer vocabulary size
pub n_layer: usize, // Number of transformer blocks
pub n_head: usize, // Number of query heads
pub n_kv_head: usize, // Number of KV heads (MQA)
pub n_embd: usize, // Embedding dimension
pub sequence_len: usize, // Maximum sequence length
pub block_size: usize, // Context window size
pub dropout: f64, // Dropout rate (0.0 for inference)
}
Automatically selects best available backend:
Override with environment:
export BURN_BACKEND=wgpu # or cuda, ndarray