| Crates.io | mlx-rs |
| lib.rs | mlx-rs |
| version | 0.25.1 |
| created_at | 2024-07-13 07:35:07.449768+00 |
| updated_at | 2025-07-08 16:32:42.501266+00 |
| description | Unofficial rust wrapper for Apple's mlx machine learning library. |
| homepage | |
| repository | https://github.com/oxideai/mlx-rs |
| max_upload_size | |
| id | 1302107 |
| size | 1,107,472 |
Rust bindings for Apple's mlx machine learning library.
⚠️ Project is in active development - contributors welcome!
Blaze supports this project by providing ultra-fast Apple Silicon macOS Github Action Runners. Apply the discount code AI25 at checkout to enjoy 25% off your first year.
Due to known limitation of docsrs, we are hosting the documentation on github pages here.
MLX is an array framework for machine learning on Apple Silicon. mlx-rs provides Rust bindings for MLX, allowing you to use MLX in your Rust projects.
Some key features of MLX and mlx-rs include:
mlx-rs is designed to be a safe and idiomatic Rust interface to MLX, providing a seamless experience for Rust developers.
The examples directory contains sample projects demonstrating different uses cases of our library.
Add this to your Cargo.toml:
[dependencies]
mlx-rs = "0.21.0"
metal - enables metal (GPU) usage in MLXaccelerate - enables using the accelerate framework in MLXWhen using automatic differentiation in mlx-rs, there's an important difference in how closures work compared to Python's MLX. In Python, variables are implicitly captured and properly traced in the compute graph. However, in Rust, we need to be more explicit about which arrays should be traced.
❌ This approach may cause segfaults:
// Don't do this
let x = random::normal::<f32>(&[num_examples, num_features], None, None, None)?;
let y = x.matmul(&w_star)? + eps;
let loss_fn = |w: &Array| -> Result<Array, Exception> {
let y_pred = x.matmul(w)?; // x and y are captured from outer scope
let loss = Array::from_f32(0.5) * ops::mean(&ops::square(&(y_pred - &y))?, None, None)?;
Ok(loss)
};
let grad_fn = transforms::grad(loss_fn, &[0]);
✅ Instead, pass all required arrays as inputs to ensure proper tracing:
let loss_fn = |inputs: &[Array]| -> Result<Array, Exception> {
let w = &inputs[0];
let x = &inputs[1];
let y = &inputs[2];
let y_pred = x.matmul(w)?;
let loss = Array::from_f32(0.5) * ops::mean(&ops::square(y_pred - y)?, None, None)?;
Ok(loss)
};
let argnums = &[0]; // Specify which argument to differentiate with respect to
// Pass all required arrays in the inputs slice
let mut inputs = vec![w, x, y];
let grad = transforms::grad(loss_fn, argnums)(&inputs)?;
When using gradients in training loops, remember to update the appropriate array in your inputs:
let mut inputs = vec![w, x, y];
for _ in 0..num_iterations {
let grad = transforms::grad(loss_fn, argnums)(&inputs)?;
inputs[0] = &inputs[0] - Array::from_f32(learning_rate) * grad; // Update the weight array
inputs[0].eval()?;
}
We are actively working on improving this API to make it more ergonomic and closer to Python's behavior. For now, explicitly passing all required arrays as shown above is the recommended approach.
For simplicity, the main crate mls-rs follows MLX’s versioning, allowing you to easily see which MLX version you’re using under the hood. The mlx-sys crate follows the versioning of mlx-c, as that is the version from which the API is generated.
If you are excited about the project or want to contribute, don't hesitate to join our Discord! We try to be as welcoming as possible to everybody from any background. We're still building this out, but you can ask your questions there!
mlx-rs is currently in active development and can be used to run MLX models in Rust.
The minimum supported Rust version is 1.81.0.
The MSRV is the minimum Rust version that can be used to compile each crate.
mlx-rs is distributed under the terms of both the MIT license and the Apache License (Version 2.0). See LICENSE-APACHE and LICENSE-MIT for details. Opening a pull request is assumed to signal agreement with these licensing terms.