| Crates.io | burn_attention |
| lib.rs | burn_attention |
| version | 0.1.0 |
| created_at | 2025-11-13 01:53:18.910147+00 |
| updated_at | 2025-11-13 01:53:18.910147+00 |
| description | Flash Attention v3 implementation for Burn framework |
| homepage | |
| repository | https://github.com/mosure/burn_attention |
| max_upload_size | |
| id | 1930280 |
| size | 217,237 |
Flash Attention v3 implementation for the Burn deep learning framework.
This crate provides an efficient implementation of Flash Attention v3, a memory-efficient attention algorithm that reduces memory usage from quadratic to linear in sequence length. The implementation supports multiple backends including:
Add this to your Cargo.toml:
[dependencies]
burn_attention = "0.1"
wgpu (default): Enable WGPU backendcubecl: Enable CubeCL backendcuda: Enable CUDA backendExample with CUDA support:
[dependencies]
burn_attention = { version = "0.1", features = ["cuda"] }
use burn::backend::NdArray;
use burn::tensor::{Distribution, Tensor};
use burn_attention::FlashAttentionV3;
type Backend = NdArray;
fn main() {
let device = Default::default();
// Create input tensors
let batch_size = 2;
let num_heads = 8;
let seq_len = 128;
let head_dim = 64;
let query = Tensor::<Backend, 4>::random(
[batch_size, num_heads, seq_len, head_dim],
Distribution::Normal(0.0, 1.0),
&device,
);
let key = Tensor::<Backend, 4>::random(
[batch_size, num_heads, seq_len, head_dim],
Distribution::Normal(0.0, 1.0),
&device,
);
let value = Tensor::<Backend, 4>::random(
[batch_size, num_heads, seq_len, head_dim],
Distribution::Normal(0.0, 1.0),
&device,
);
// Compute attention
let output = FlashAttentionV3::forward(query, key, value, None, false);
println!("Output shape: {:?}", output.dims());
}
For autoregressive models, use causal masking:
let output = FlashAttentionV3::forward(query, key, value, None, true);
use burn_attention::FlashAttentionV3Config;
let config = FlashAttentionV3Config {
causal: true,
dropout_p: 0.1,
softmax_scale: Some(0.125),
block_size_q: 128,
block_size_k: 128,
};
let output = FlashAttentionV3::forward_with_config(
query,
key,
value,
None,
config,
);
Run benchmarks with:
cargo bench
This will run throughput benchmarks for various sequence lengths and batch sizes.
Run the test suite:
cargo test
The test suite includes:
This implementation follows the Flash Attention v3 algorithm with optimizations for:
[batch_size, num_heads, seq_len_q, head_dim][batch_size, num_heads, seq_len_k, head_dim][batch_size, num_heads, seq_len_k, head_dim][batch_size, num_heads, seq_len_q, head_dim]This project is licensed under either of:
at your option.
Contributions are welcome! Please feel free to submit a Pull Request.