Optimal Transport Tools (OTT): A JAX Toolbox for all things Wasserstein

Marco Cuturi,Laetitia Meng-Papaxanthos,Yingtao Tian,Charlotte Bunne,Geoff Davis,Olivier Teboul
DOI: https://doi.org/10.48550/arXiv.2201.12324
2022-01-29
Abstract:Optimal transport tools (OTT-JAX) is a Python toolbox that can solve optimal transport problems between point clouds and histograms. The toolbox builds on various JAX features, such as automatic and custom reverse mode differentiation, vectorization, just-in-time compilation and accelerators support. The toolbox covers elementary computations, such as the resolution of the regularized OT problem, and more advanced extensions, such as barycenters, Gromov-Wasserstein, low-rank solvers, estimation of convex maps, differentiable generalizations of quantiles and ranks, and approximate OT between Gaussian mixtures. The toolbox code is available at \texttt{<a class="link-external link-https" href="https://github.com/ott-jax/ott" rel="external noopener nofollow">this https URL</a>}
Machine Learning
What problem does this paper attempt to address?