Optimal Transport Tools (OTT): A JAX Toolbox for all things Wasserstein

Marco Cuturi,Laetitia Meng-Papaxanthos,Yingtao Tian,Charlotte Bunne,Geoff Davis,Olivier Teboul
DOI: https://doi.org/10.48550/arXiv.2201.12324
2022-01-29
Abstract:Optimal transport tools (OTT-JAX) is a Python toolbox that can solve optimal transport problems between point clouds and histograms. The toolbox builds on various JAX features, such as automatic and custom reverse mode differentiation, vectorization, just-in-time compilation and accelerators support. The toolbox covers elementary computations, such as the resolution of the regularized OT problem, and more advanced extensions, such as barycenters, Gromov-Wasserstein, low-rank solvers, estimation of convex maps, differentiable generalizations of quantiles and ranks, and approximate OT between Gaussian mixtures. The toolbox code is available at \texttt{<a class="link-external link-https" href="https://github.com/ott-jax/ott" rel="external noopener nofollow">this https URL</a>}
Machine Learning
What problem does this paper attempt to address?
The main problems that this paper attempts to solve are the computational complexity and differentiability issues faced by Optimal Transport (OT) when dealing with large - scale datasets and high - dimensional data. Specifically, the paper focuses on the following three main challenges: 1. **Scalability**: - The traditional linear optimal transport problem (formula (1)) can be solved by network flow algorithms, but the time complexity in the worst - case scenario is super - cubic (for example, \( O(nm(n + m)\log(n + m)) \)). For the more complex quadratic optimal transport problem (formula (2)), its computational complexity is even higher and usually needs to be approximately solved by iterative linearization methods, which will lead to repeated high computational costs. 2. **Curse of Dimensionality**: - In high - dimensional space, directly calculating the optimal transport distance between discrete measures (such as \( L_c(\hat{\mu}_n, \hat{\nu}_n) \)) will lead to a waste of computational resources and is prone to over - fitting samples. This problem may be more serious in the quadratic optimal transport problem (formula (5)). 3. **Argmin Differentiability**: - Many applications are not only concerned with the optimal transport values \( L_c \) or \( Q_{c_X, c_Y} \), but also with the changes of the optimal transport matrix \( P^* \). However, the optimal transport matrix \( P^* \) is usually not smooth with respect to input changes. For example, the optimal matching usually does not change due to small changes in input points, resulting in the Jacobian matrix \( J_{x_i}P^* \) being almost zero everywhere. To address these challenges, the paper proposes the use of the entropic regularization method, which can not only improve computational efficiency but also improve statistical performance and differentiability. In addition, the low - rank Sinkhorn method further addresses the computational complexity problem by restricting the low - rank structure of the optimal transport matrix \( P^* \) to accelerate the calculation. In summary, this paper aims to develop a Python toolbox (OTT - JAX) based on JAX for efficiently solving various optimal transport problems, especially for large - scale and high - dimensional data scenarios.