| Crates.io | gllm-kernels |
| lib.rs | gllm-kernels |
| version | 0.1.3 |
| created_at | 2026-01-12 13:51:38.264647+00 |
| updated_at | 2026-01-12 14:10:23.150274+00 |
| description | Low-level attention kernels for gllm with CUDA/ROCm support |
| homepage | https://github.com/putao520/gllm-kernels |
| repository | https://github.com/putao520/gllm-kernels |
| max_upload_size | |
| id | 2037759 |
| size | 560,494 |
Low-level attention kernels for gllm with CUDA/ROCm support.
| Implementation | Time (seq=512) | vs burn_cuda |
|---|---|---|
| cuda_kernel | 21.27ms | 37% faster |
| burn_cuda | 33.83ms | baseline |
Add to your Cargo.toml:
[dependencies]
gllm-kernels = "0.1"
| Feature | Description | Default |
|---|---|---|
cpu |
CPU backend via burn-ndarray | Yes |
cuda |
CUDA backend via burn-cuda | No |
cuda-kernel |
Native CUDA kernels (requires CUDA toolkit) | No |
wgpu |
WebGPU backend | No |
rocm-kernel |
ROCm/HIP kernels (experimental) | No |
use gllm_kernels::ops::flash_attention::{
HierarchicalFlashAttention, AttentionConfig
};
// Create attention module
let attention = HierarchicalFlashAttention::new(
num_heads,
head_dim,
AttentionConfig::default(),
);
// Forward pass
let output = attention.forward(q, k, v, mask)?;
use gllm_kernels::cuda_kernels::FlashAttentionKernel;
use cudarc::driver::CudaContext;
use std::sync::Arc;
let ctx = Arc::new(CudaContext::new(0)?);
let kernel = FlashAttentionKernel::new(&ctx)?;
let output = kernel.forward(
&stream, &q, &k, &v,
batch_size, num_heads, seq_len, head_dim,
is_causal, scale, position_offset,
)?;
For reproducible results in ultra-long context scenarios:
use gllm_kernels::ops::flash_attention::DeterministicConfig;
let config = AttentionConfig {
determinism: DeterministicConfig::strict(),
..Default::default()
};
gllm-kernels
├── ops/
│ ├── flash_attention.rs # HierarchicalFlashAttention
│ ├── flash_attention_v3.rs # Advanced attention variants
│ ├── paged_attention.rs # KV cache paging
│ ├── ring_attention.rs # Distributed attention
│ ├── sparse_attention.rs # Sparse patterns
│ ├── mla.rs # Multi-head Latent Attention
│ ├── mamba.rs # State space models
│ └── kv_compression.rs # KV cache compression
├── cuda_kernels/
│ ├── flash_attn.rs # CUDA kernel bindings
│ └── kernels/
│ ├── tiled_attention.cu # CUDA source
│ └── tiled_attention.ptx # Compiled PTX (sm_61)
├── hip_kernels/ # ROCm/HIP (experimental)
└── comm/ # Distributed communication
If you need to recompile PTX for a different GPU architecture:
cd src/cuda_kernels/kernels
nvcc -ptx -arch=sm_XX tiled_attention.cu -o tiled_attention.ptx
Replace sm_XX with your GPU's compute capability (e.g., sm_61 for GTX 1060, sm_86 for RTX 3090).
Or set the environment variable to use a custom PTX:
export GLLM_FLASH_ATTN_PTX=/path/to/your/compiled.ptx
Apache-2.0