| Crates.io | axonml-jit |
| lib.rs | axonml-jit |
| version | 0.2.4 |
| created_at | 2026-01-19 22:07:22.242881+00 |
| updated_at | 2026-01-25 22:38:34.558774+00 |
| description | JIT compilation for Axonml tensor operations |
| homepage | |
| repository | |
| max_upload_size | |
| id | 2055422 |
| size | 119,181 |
axonml-jit provides Just-In-Time compilation for tensor operations, enabling significant performance improvements through operation tracing, graph optimization, and compiled function caching. It builds computation graphs from traced operations and optimizes them before execution.
| Module | Description |
|---|---|
ir |
Graph-based intermediate representation with Node, Op, Shape, and DataType definitions |
trace |
Operation tracing functionality with TracedValue and Tracer for graph construction |
optimize |
Optimization passes including constant folding, DCE, CSE, and algebraic simplification |
codegen |
JIT compiler and compiled function execution with interpreter fallback |
cache |
Function cache with LRU eviction and graph hashing |
error |
Error types and Result alias for JIT operations |
Add this to your Cargo.toml:
[dependencies]
axonml-jit = "0.1.0"
use axonml_jit::{trace, JitCompiler};
// Trace operations to build a computation graph
let graph = trace(|tracer| {
let a = tracer.input("a", &[2, 3]);
let b = tracer.input("b", &[2, 3]);
let c = a.add(&b);
let d = c.mul_scalar(2.0);
tracer.output("result", d)
});
// Compile the graph
let compiler = JitCompiler::new();
let compiled = compiler.compile(&graph)?;
// Execute with real data
let a_data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let b_data = [0.5, 0.5, 0.5, 0.5, 0.5, 0.5];
let result = compiled.run(&[("a", &a_data), ("b", &b_data)])?;
use axonml_jit::trace;
let graph = trace(|tracer| {
let x = tracer.input("x", &[4, 4]);
// Unary operations
let y = x.relu()
.mul_scalar(2.0)
.add_scalar(1.0);
// Activation functions
let z = y.sigmoid().tanh().gelu();
// Reductions
let mean = z.mean_axis(1, true);
// Shape operations
let reshaped = mean.reshape(&[-1]);
tracer.output("output", reshaped)
});
use axonml_jit::{Optimizer, OptimizationPass, JitCompiler};
// Create optimizer with custom passes
let mut optimizer = Optimizer::new();
optimizer.add_pass(OptimizationPass::ConstantFolding);
optimizer.add_pass(OptimizationPass::AlgebraicSimplification);
optimizer.add_pass(OptimizationPass::DeadCodeElimination);
optimizer.add_pass(OptimizationPass::CommonSubexpressionElimination);
// Apply optimizations
let optimized_graph = optimizer.optimize(graph);
// Compile optimized graph
let compiler = JitCompiler::with_optimizer(optimizer);
let compiled = compiler.compile(&graph)?;
use axonml_jit::JitCompiler;
let compiler = JitCompiler::new();
// Compile multiple graphs
let _ = compiler.compile(&graph1)?;
let _ = compiler.compile(&graph2)?;
// Check cache statistics
let stats = compiler.cache_stats();
println!("Cached functions: {}", stats.entries);
println!("Cache utilization: {:.1}%", stats.utilization());
// Clear cache if needed
compiler.clear_cache();
add, sub, mul, div, pow, max, minneg, abs, sqrt, exp, log, sin, cos, tanhrelu, sigmoid, gelu, siluadd_scalar, mul_scalarsum, mean, sum_axis, mean_axisreshape, transpose, squeeze, unsqueezematmulgt, lt, eq, where| Pass | Description |
|---|---|
ConstantFolding |
Evaluate constant expressions at compile time |
DeadCodeElimination |
Remove nodes that do not contribute to outputs |
AlgebraicSimplification |
Simplify expressions (x * 1 = x, x + 0 = x, etc.) |
CommonSubexpressionElimination |
Reuse identical subexpressions |
ElementwiseFusion |
Fuse consecutive elementwise operations |
StrengthReduction |
Replace expensive ops with cheaper equivalents |
Run the test suite:
cargo test -p axonml-jit
Licensed under either of:
at your option.