| Crates.io | burn-mlx |
| lib.rs | burn-mlx |
| version | 0.1.2 |
| created_at | 2025-12-31 08:42:02.565418+00 |
| updated_at | 2025-12-31 09:30:36.368117+00 |
| description | MLX backend for Burn deep learning framework - native Apple Silicon GPU acceleration |
| homepage | https://github.com/TuringWorks/burn-mlx |
| repository | https://github.com/TuringWorks/burn-mlx |
| max_upload_size | |
| id | 2014280 |
| size | 276,553 |
MLX backend for Burn — native Apple Silicon GPU acceleration for deep learning.
This crate provides a Burn backend using Apple's MLX framework, enabling high-performance machine learning on M1/M2/M3/M4 Macs.
Add to your Cargo.toml:
[dependencies]
burn-mlx = "0.1"
burn = "0.16"
use burn::tensor::Tensor;
use burn_mlx::{Mlx, MlxDevice};
// Create tensors on Apple Silicon GPU
let device = MlxDevice::Gpu;
let a: Tensor<Mlx, 2> = Tensor::ones([2, 3], &device);
let b: Tensor<Mlx, 2> = Tensor::ones([2, 3], &device);
let c = a + b;
println!("Result shape: {:?}", c.shape());
use burn::backend::Autodiff;
use burn_mlx::Mlx;
type TrainBackend = Autodiff<Mlx>;
// Now use TrainBackend for training with automatic differentiation
burn-mlx provides full support for pooling operations with both forward and backward passes, enabling their use in training workflows.
use burn::tensor::Tensor;
use burn::nn::pool::{AvgPool2d, AvgPool2dConfig};
use burn_mlx::{Mlx, MlxDevice};
let device = MlxDevice::Gpu;
// Create a 4D tensor: [batch, channels, height, width]
let input: Tensor<Mlx, 4> = Tensor::ones([1, 3, 32, 32], &device);
// Create avg pool layer with 2x2 kernel and stride 2
let config = AvgPool2dConfig::new([2, 2]).with_strides([2, 2]);
let pool = AvgPool2d::new(config);
let output = pool.forward(input);
// Output shape: [1, 3, 16, 16]
use burn::tensor::Tensor;
use burn::nn::pool::{MaxPool2d, MaxPool2dConfig};
use burn_mlx::{Mlx, MlxDevice};
let device = MlxDevice::Gpu;
let input: Tensor<Mlx, 4> = Tensor::ones([1, 3, 32, 32], &device);
// Create max pool layer with 2x2 kernel and stride 2
let config = MaxPool2dConfig::new([2, 2]).with_strides([2, 2]);
let pool = MaxPool2d::new(config);
let output = pool.forward(input);
// Output shape: [1, 3, 16, 16]
use burn::tensor::Tensor;
use burn::nn::pool::{AvgPool1d, AvgPool1dConfig, MaxPool1d, MaxPool1dConfig};
use burn_mlx::{Mlx, MlxDevice};
let device = MlxDevice::Gpu;
// Create a 3D tensor: [batch, channels, length]
let input: Tensor<Mlx, 3> = Tensor::ones([1, 64, 128], &device);
// Average pooling
let avg_config = AvgPool1dConfig::new(4).with_stride(4);
let avg_pool = AvgPool1d::new(avg_config);
let avg_output = avg_pool.forward(input.clone());
// Output shape: [1, 64, 32]
// Max pooling
let max_config = MaxPool1dConfig::new(4).with_stride(4);
let max_pool = MaxPool1d::new(max_config);
let max_output = max_pool.forward(input);
// Output shape: [1, 64, 32]
use burn::tensor::Tensor;
use burn::nn::pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig};
use burn_mlx::{Mlx, MlxDevice};
let device = MlxDevice::Gpu;
let input: Tensor<Mlx, 4> = Tensor::ones([1, 512, 14, 14], &device);
// Adaptive pool to fixed output size (common before FC layers)
let config = AdaptiveAvgPool2dConfig::new([1, 1]);
let pool = AdaptiveAvgPool2d::new(config);
let output = pool.forward(input);
// Output shape: [1, 512, 1, 1]
use burn_mlx::{MlxTensor, MlxDevice};
let device = MlxDevice::Gpu;
// Create tensors
let a = MlxTensor::<f32>::ones(&[1024, 1024], device);
let b = MlxTensor::<f32>::ones(&[1024, 1024], device);
// Operations
let c = a.matmul(&b);
let d = c.relu();
let e = d.softmax();
// Evaluate lazy computation
e.eval().expect("evaluation failed");
The pooling operations are implemented using MLX's as_strided function combined with reduction operations:
Forward Pass: Uses as_strided to create sliding window views over the input, then applies mean_axes (avg pool) or max_axes (max pool) for reduction.
Backward Pass:
scatter_addLayout Handling: Automatically converts between Burn's NCHW format and MLX's native NHWC format.
On Apple M-series chips, burn-mlx leverages:
Typical matmul performance (1024x1024):
Apache-2.0