| Crates.io | radix_mlp |
| lib.rs | radix_mlp |
| version | 0.0.2 |
| created_at | 2025-12-28 10:20:28.058113+00 |
| updated_at | 2025-12-28 17:52:22.011717+00 |
| description | RadixMLP: Prefix-based computation sharing for transformer models |
| homepage | https://github.com/michaelfeil/radix-mlp |
| repository | https://github.com/michaelfeil/radix-mlp |
| max_upload_size | |
| id | 2008556 |
| size | 65,752 |
Pure Rust library for prefix-based computation sharing in transformer models.
RadixMLP identifies shared prefixes among sequences in a batch and produces a compact representation containing only unique subsequences. This enables efficient computation sharing across sequences with shared prefixes.
Add to your Cargo.toml:
[dependencies]
radix_mlp = "0.1.0"
use radix_mlp::compute_fold_and_scatter;
let input_ids = vec![1, 2, 3, 1, 2, 4];
let position_ids = vec![0, 1, 2, 0, 1, 2];
let cu_seq_lengths = vec![0, 3, 6];
let (compact_input_ids, compact_position_ids, scatter_indices, fold_gather) =
compute_fold_and_scatter(&input_ids, &position_ids, &cu_seq_lengths, None);
println!("Original: {} -> Compact: {}", input_ids.len(), compact_input_ids.len());
compute_fold_and_scatterComputes indices for RadixMLP-style folding and scattering.
Parameters:
input_ids: Flattened vector of token IDsposition_ids: Flattened vector of position IDscu_seq_lengths: Cumulative sequence lengthspad_multiple_of: If Some(n), pad output to multiple of n for performance. If None, no padding.Returns:
compact_input_ids: Unique token IDscompact_position_ids: Corresponding position IDsscatter_indices: Unfold indices (compact -> original)fold_gather: Gather indices (original -> compact)Run tests with:
cargo test
MIT License - Copyright (c) 2025 michaelfeil