Crates.io | rust-optimal-transport |
lib.rs | rust-optimal-transport |
version | 0.2.0 |
source | src |
created_at | 2022-01-12 17:13:40.961148 |
updated_at | 2022-02-16 03:38:28.825226 |
description | A library of optimal transport solvers for Rust |
homepage | |
repository | https://github.com/kachark/rust-optimal-transport |
max_upload_size | |
id | 512898 |
size | 180,672 |
This library provides solvers for performing regularized and unregularized Optimal Transport in Rust.
Inspired by Python Optimal Transport, this library provides the following solvers:
The library has been tested on macOS. It requires a C++ compiler for building the EMD solver and relies on the following Rust libraries:
Edit your Cargo.toml with the following to use rust-optimal-transport in your project.
[dependencies]
rust-optimal-transport = "0.1"
If you would like to enable LAPACK backend (currently supporting OpenBLAS):
[dependencies]
rust-optimal-transport = { version = "0.1", features = ["blas"] }
This will link against an installed instance of OpenBLAS on your system. For more details see the ndarray-linalg crate.
use rust_optimal_transport as ot;
use ot::prelude::*;
// Generate data
let n_samples = 100;
// Mean, Covariance of the source distribution
let mu_source = array![0., 0.];
let cov_source = array![[1., 0.], [0., 1.]];
// Mean, Covariance of the target distribution
let mu_target = array![4., 4.];
let cov_target = array![[1., -0.8], [-0.8, 1.]];
// Samples of a 2D gaussian distribution
let source = ot::utils::sample_2D_gauss(n_samples, &mu_source, &cov_source).unwrap();
let target = ot::utils::sample_2D_gauss(n_samples, &mu_target, &cov_target).unwrap();
// Uniform weights on the source and target distributions
let mut source_weights = Array1::<f64>::from_elem(n, 1. / (n as f64));
let mut target_weights = Array1::<f64>::from_elem(n, 1. / (n as f64));
// Compute ground cost matrix - Squared Euclidean distance
let mut cost = dist(&source, &target, SqEuclidean);
let max_cost = cost.max().unwrap();
// Normalize cost matrix for numerical stability
cost = &cost / *max_cost;
// Compute optimal transport matrix as the Earth Mover's Distance
let ot_matrix = match EarthMovers::new(
&mut source_weights,
&mut target_weights,
&mut ground_cost
).solve()?;
This library is inspired by Python Optimal Transport. The original authors and contributors of that project are listed at POT.