Understanding Learning with Sliced-Wasserstein Requires Rethinking Informative Slices

Huy Tran,Yikun Bai,Ashkan Shahbazi,John R. Hershey,Soheil Kolouri
2024-11-16
Abstract:The practical applications of Wasserstein distances (WDs) are constrained by their sample and computational complexities. Sliced-Wasserstein distances (SWDs) provide a workaround by projecting distributions onto one-dimensional subspaces, leveraging the more efficient, closed-form WDs for one-dimensional distributions. However, in high dimensions, most random projections become uninformative due to the concentration of measure phenomenon. Although several SWD variants have been proposed to focus on \textit{informative} slices, they often introduce additional complexity, numerical instability, and compromise desirable theoretical (metric) properties of SWD. Amidst the growing literature that focuses on directly modifying the slicing distribution, which often face challenges, we revisit the classical Sliced-Wasserstein and propose instead to rescale the 1D Wasserstein to make all slices equally informative. Importantly, we show that with an appropriate data assumption and notion of \textit{slice informativeness}, rescaling for all individual slices simplifies to \textbf{a single global scaling factor} on the SWD. This, in turn, translates to the standard learning rate search for gradient-based learning in common machine learning workflows. We perform extensive experiments across various machine learning tasks showing that the classical SWD, when properly configured, can often match or surpass the performance of more complex variants. We then answer the following question: "Is Sliced-Wasserstein all you need for common learning tasks?"
Machine Learning,Artificial Intelligence,Computer Vision and Pattern Recognition,Applications,Computation
What problem does this paper attempt to address?
The main problems that this paper attempts to solve are the computational complexity and sample complexity of Wasserstein distance (WD) in high - dimensional data. Specifically, the standard Wasserstein distance calculation is very expensive in high - dimensional space, with a time complexity of \(O(N^3 \log N)\), a space complexity of \(O(N^2)\), and a sample complexity of \(O(N^{-1/d})\), which makes it impractical in many practical applications. To solve these problems, researchers proposed the Sliced - Wasserstein distance (SWD), which simplifies the calculation by projecting high - dimensional distributions onto one - dimensional subspaces. However, as the data dimension increases, most random projections become uninformative, resulting in these projections contributing little to the overall SWD. Therefore, although SWD has theoretical advantages, its performance may be limited in practice. In response to the above problems, this paper proposes a new perspective and method: 1. **Rethinking information projection**: The paper points out that when the data actually lies in a low - dimensional subspace, the weights of each one - dimensional Wasserstein distance can be adjusted to make all projections equally informative. This not only simplifies the calculation but also preserves the theoretical properties of the classical SWD. 2. **Introducing a global scaling factor**: The author proves that, under appropriate assumptions, the re - weighting of all slices can be simplified to a global scaling of the entire SWD. This means that a better gradient can be obtained by adjusting the learning rate, without the need to explicitly search for informative slices. 3. **Extensive experimental verification**: Through experiments on a variety of machine - learning tasks, including image generation, color transfer, etc., the paper shows that the classical SWD can match or even exceed the performance of more complex variants under appropriate configurations. In summary, this paper aims to answer the question: "For common learning tasks, is it only necessary to use the Sliced - Wasserstein distance?" The author believes that, with appropriate configuration and understanding, the classical SWD can achieve or exceed the effects of more complex methods in many cases while maintaining its simplicity and theoretical guarantees. ### Formula Explanation - **Wasserstein Distance (WD)**: \[ W_p^p(\mu, \nu)=\inf_{\pi \in \Pi(\mu, \nu)} \int_{\mathbb{R}^d \times \mathbb{R}^d}\|x - y\|^p \, d\pi(x, y) \] where \(\Pi(\mu, \nu)\) is all joint distributions from \(\mu\) to \(\nu\). - **Sliced - Wasserstein Distance (SWD)**: \[ SW_p(\mu, \nu; \sigma)=\left(\mathbb{E}_{\theta \sim \sigma}\left[W_p^p(\theta^\# \mu, \theta^\# \nu)\right]\right)^{1/p} \] where \(\sigma\) is the reference measure of the slice vector \(\theta\), and the default is the uniform distribution \(U(S^{d - 1})\). - **Re - weighted Sliced - Wasserstein Distance**: \[ gSW_p(\mu, \nu; \sigma, \rho_\phi)=\left(\int_{S^{d - 1}} \rho_\phi(\phi(\theta)) W_p^p(\theta^\# \mu, \theta^\# \nu) \, d\sigma(\theta)\right)^{1/p} \] where \(\rho_\phi\) is a weighting function based on the slice information amount. Through these formulas and methods, the paper provides a novel and effective way to deal with the Sliced - Wasserstein distance calculation problem in high - dimensional data.