burn_attention

Crates.ioburn_attention
lib.rsburn_attention
version0.1.0
created_at2025-11-13 01:53:18.910147+00
updated_at2025-11-13 01:53:18.910147+00
descriptionFlash Attention v3 implementation for Burn framework
homepage
repositoryhttps://github.com/mosure/burn_attention
max_upload_size
id1930280
size217,237
Mitchell Mosure (mosure)

documentation

README

burn_attention

Flash Attention v3 implementation for the Burn deep learning framework.

Overview

This crate provides an efficient implementation of Flash Attention v3, a memory-efficient attention algorithm that reduces memory usage from quadratic to linear in sequence length. The implementation supports multiple backends including:

  • WGPU (default): Cross-platform GPU support via WebGPU
  • CubeCL: High-performance compute kernels
  • CUDA: Direct CUDA support for NVIDIA GPUs

Features

  • ✅ Standard scaled dot-product attention
  • ✅ Causal masking for autoregressive models
  • ✅ Custom attention masks
  • ✅ Configurable softmax scaling
  • ✅ Multiple backend support (WGPU, CubeCL, CUDA)
  • ✅ Comprehensive test suite
  • ✅ Criterion benchmarks for performance testing

Installation

Add this to your Cargo.toml:

[dependencies]
burn_attention = "0.1"

Feature Flags

  • wgpu (default): Enable WGPU backend
  • cubecl: Enable CubeCL backend
  • cuda: Enable CUDA backend

Example with CUDA support:

[dependencies]
burn_attention = { version = "0.1", features = ["cuda"] }

Usage

Basic Example

use burn::backend::NdArray;
use burn::tensor::{Distribution, Tensor};
use burn_attention::FlashAttentionV3;

type Backend = NdArray;

fn main() {
    let device = Default::default();

    // Create input tensors
    let batch_size = 2;
    let num_heads = 8;
    let seq_len = 128;
    let head_dim = 64;

    let query = Tensor::<Backend, 4>::random(
        [batch_size, num_heads, seq_len, head_dim],
        Distribution::Normal(0.0, 1.0),
        &device,
    );
    let key = Tensor::<Backend, 4>::random(
        [batch_size, num_heads, seq_len, head_dim],
        Distribution::Normal(0.0, 1.0),
        &device,
    );
    let value = Tensor::<Backend, 4>::random(
        [batch_size, num_heads, seq_len, head_dim],
        Distribution::Normal(0.0, 1.0),
        &device,
    );

    // Compute attention
    let output = FlashAttentionV3::forward(query, key, value, None, false);

    println!("Output shape: {:?}", output.dims());
}

Causal Attention

For autoregressive models, use causal masking:

let output = FlashAttentionV3::forward(query, key, value, None, true);

Custom Configuration

use burn_attention::FlashAttentionV3Config;

let config = FlashAttentionV3Config {
    causal: true,
    dropout_p: 0.1,
    softmax_scale: Some(0.125),
    block_size_q: 128,
    block_size_k: 128,
};

let output = FlashAttentionV3::forward_with_config(
    query,
    key,
    value,
    None,
    config,
);

Benchmarks

Run benchmarks with:

cargo bench

This will run throughput benchmarks for various sequence lengths and batch sizes.

Testing

Run the test suite:

cargo test

The test suite includes:

  • Unit tests for basic functionality
  • Numerical correctness tests comparing against reference implementation
  • Property-based tests for attention output

Implementation Details

This implementation follows the Flash Attention v3 algorithm with optimizations for:

  1. Memory Efficiency: Tiled computation to reduce memory usage
  2. Numerical Stability: Online softmax computation
  3. Performance: Kernel fusion and optimized memory access patterns

Tensor Shapes

  • Query: [batch_size, num_heads, seq_len_q, head_dim]
  • Key: [batch_size, num_heads, seq_len_k, head_dim]
  • Value: [batch_size, num_heads, seq_len_k, head_dim]
  • Output: [batch_size, num_heads, seq_len_q, head_dim]

References

License

This project is licensed under either of:

at your option.

Contributing

Contributions are welcome! Please feel free to submit a Pull Request.

Commit count: 0

cargo fmt