DisCo-Diff: Enhancing Continuous Diffusion Models with Discrete Latents

Yilun Xu,Gabriele Corso,Tommi Jaakkola,Arash Vahdat,Karsten Kreis
2024-07-04
Abstract:Diffusion models (DMs) have revolutionized generative learning. They utilize a diffusion process to encode data into a simple Gaussian distribution. However, encoding a complex, potentially multimodal data distribution into a single continuous Gaussian distribution arguably represents an unnecessarily challenging learning problem. We propose Discrete-Continuous Latent Variable Diffusion Models (DisCo-Diff) to simplify this task by introducing complementary discrete latent variables. We augment DMs with learnable discrete latents, inferred with an encoder, and train DM and encoder end-to-end. DisCo-Diff does not rely on pre-trained networks, making the framework universally applicable. The discrete latents significantly simplify learning the DM's complex noise-to-data mapping by reducing the curvature of the DM's generative ODE. An additional autoregressive transformer models the distribution of the discrete latents, a simple step because DisCo-Diff requires only few discrete variables with small codebooks. We validate DisCo-Diff on toy data, several image synthesis tasks as well as molecular docking, and find that introducing discrete latents consistently improves model performance. For example, DisCo-Diff achieves state-of-the-art FID scores on class-conditioned ImageNet-64/128 datasets with ODE sampler.
Machine Learning,Artificial Intelligence,Computer Vision and Pattern Recognition
What problem does this paper attempt to address?
This paper proposes a new method called DisCo-Diff (Discrete-Continuous Latent Variable Diffusion Models) to address the problem of modeling complex and multimodal data distributions. Traditional diffusion models (DMs) encode data into a single continuous Gaussian distribution, but this approach is challenging for handling complex real-world data distributions. DisCo-Diff simplifies this task by introducing complementary discrete latent variables, which are learned by the encoder and trained end-to-end with DM. The paper mentions that the discrete latent variables significantly reduce the curvature of DM's generative ODE (ordinary differential equation), thereby simplifying the mapping learning from noise to data. In addition, DisCo-Diff does not require pre-training networks, making it universally applicable. The small number of discrete latent variables and codebooks make the learning of the distribution of discrete latent variables simple. The paper validates DisCo-Diff on tasks such as toy data, image synthesis, and molecular docking, and finds that introducing discrete latent variables consistently improves model performance. In image generation tasks, DisCo-Diff achieves state-of-the-art FID (Fréchet Inception Distance) scores on the class-conditional ImageNet-64/128 dataset. In molecular docking tasks, discrete latent variables improve performance by indicating key atoms, thus addressing the problem of multimodal uncertainty. In summary, the paper attempts to address the problem of more effectively modeling complex and multimodal data distributions by simplifying the learning task of diffusion models through the introduction of discrete latent variables, thereby improving the performance and generalization of generative models.