Scalable Wasserstein Gradient Flow for Generative Modeling through Unbalanced Optimal Transport

Jaemoo Choi,Jaewoong Choi,Myungjoo Kang
2024-06-03
Abstract:Wasserstein Gradient Flow (WGF) describes the gradient dynamics of probability density within the Wasserstein space. WGF provides a promising approach for conducting optimization over the probability distributions. Numerically approximating the continuous WGF requires the time discretization method. The most well-known method for this is the JKO scheme. In this regard, previous WGF models employ the JKO scheme and parametrize transport map for each JKO step. However, this approach results in quadratic training complexity $O(K^2)$ with the number of JKO step $K$. This severely limits the scalability of WGF models. In this paper, we introduce a scalable WGF-based generative model, called Semi-dual JKO (S-JKO). Our model is based on the semi-dual form of the JKO step, derived from the equivalence between the JKO step and the Unbalanced Optimal Transport. Our approach reduces the training complexity to $O(K)$. We demonstrate that our model significantly outperforms existing WGF-based generative models, achieving FID scores of 2.62 on CIFAR-10 and 5.46 on CelebA-HQ-256, which are comparable to state-of-the-art image generative models.
Machine Learning,Computer Vision and Pattern Recognition
What problem does this paper attempt to address?
### Problems Addressed by the Paper This paper aims to address the scalability issue of the Wasserstein Gradient Flow (WGF) model when handling high-dimensional image datasets. Specifically, existing WGF models have a quadratic complexity \(O(K^2)\) during training, where \(K\) is the number of JKO steps. This complexity results in prolonged training times and limited model parameterization capabilities, severely restricting the scalability of WGF models. ### Solution To solve the aforementioned problem, the authors propose a new generative model based on the semi-dual form of the JKO step, called Semi-dual JKO (S-JKO). The main contributions of this model include: 1. **Semi-dual form of the JKO step**: By combining the JKO step with the equivalence of the Unbalanced Optimal Transport (UOT) problem, the semi-dual form of the JKO step is derived. This allows the model to train with linear complexity \(O(K)\) instead of the previous quadratic complexity \(O(K^2)\). 2. **Reparameterization technique**: A reparameterization technique is introduced, which generates intermediate distributions \(\mu_k\) through single-step inference directly from the initial distribution \(\mu\) to the target distribution \(\nu\), thus avoiding the need to simulate the entire JKO trajectory. This significantly reduces training complexity and improves training efficiency. 3. **Experimental validation**: Experiments on datasets such as CIFAR-10 and CelebA-HQ validate the effectiveness and scalability of the S-JKO model. Results show that the S-JKO model achieved an FID score of 2.62 on CIFAR-10 and 5.46 on CelebA-HQ, significantly outperforming existing WGF models and approaching the performance of state-of-the-art generative models. ### Main Contributions - Proposed a generative model S-JKO based on the semi-dual form of the JKO step. - Demonstrated the equivalence between the JKO step and the Unbalanced Optimal Transport problem, thereby deriving the semi-dual form of the JKO step. - Significantly improved the scalability of the WGF model through the reparameterization technique, especially when handling high-dimensional image datasets. - Experimental results show that the S-JKO model achieves excellent generative performance on multiple benchmark datasets, approaching or even surpassing existing state-of-the-art generative models. ### Conclusion By introducing the S-JKO model, this paper successfully addresses the scalability issue of WGF models when handling high-dimensional image datasets, providing new research directions and technical means for the field of generative modeling.