candle-flash-attn-3

Crates.iocandle-flash-attn-3
lib.rscandle-flash-attn-3
version0.0.1
created_at2025-12-21 07:14:16.299541+00
updated_at2025-12-21 07:14:16.299541+00
descriptionFlash attention V3 layer for the candle ML framework.
homepage
repositoryhttps://github.com/michaelfeil/candle-flash-attn-v3
max_upload_size
id1997631
size24,696,146
Michael Feil (michaelfeil)

documentation

README

Candle Flash Attention v3 Layer

Flash Attention v3 Layer for Hopper (compatible nvidia sm90a arch) and the candle framework.

Usage

use candle_flash_attn_v3;
use anyhow::Result;
use candle::{DType, Device, IndexOp, Tensor, D};

fn flash_attn_acausal() -> Result<()> {
    let device = Device::new_cuda(0)?;
    let q = Tensor::arange(0u32, 3 * 2 * 64, &device)?
        .to_dtype(DType::F16)?
        .reshape((1, 3, 2, 64))?; // batch, head, seqlen, hidden_dim
    let k = (&q / 400.)?;
    let v = (&q / 500.)?;
    let q = (&q / 300.)?;

    let att = {
        let q = q.transpose(1, 2)?;
        let k = k.transpose(1, 2)?;
        let v = v.transpose(1, 2)?;
        candle_flash_attn_v3::flash_attn(&q, &k, &v, 0.5, false, false)?.transpose(1, 2)?
    };

Integration instructions

[dependencies]
candle = { version = "*", package = "candle-core", default-features = false }
candle-nn = { version = "*" }
candle-transformers = { version = "*" }
candle-flash-attn-3 = { git = "https://github.com/michaelfeil/candle-flash-attn-3", rev = "main", optional = true }

Install and test locally:

git submodule update --init --recursive
cargo test --release
Commit count: 0

cargo fmt