MDTv2: Masked Diffusion Transformer is a Strong Image Synthesizer

Shanghua Gao,Pan Zhou,Ming-Ming Cheng,Shuicheng Yan
2024-02-21
Abstract:Despite its success in image synthesis, we observe that diffusion probabilistic models (DPMs) often lack contextual reasoning ability to learn the relations among object parts in an image, leading to a slow learning process. To solve this issue, we propose a Masked Diffusion Transformer (MDT) that introduces a mask latent modeling scheme to explicitly enhance the DPMs' ability to contextual relation learning among object semantic parts in an image. During training, MDT operates in the latent space to mask certain tokens. Then, an asymmetric diffusion transformer is designed to predict masked tokens from unmasked ones while maintaining the diffusion generation process. Our MDT can reconstruct the full information of an image from its incomplete contextual input, thus enabling it to learn the associated relations among image tokens. We further improve MDT with a more efficient macro network structure and training strategy, named MDTv2. Experimental results show that MDTv2 achieves superior image synthesis performance, e.g., a new SOTA FID score of 1.58 on the ImageNet dataset, and has more than 10x faster learning speed than the previous SOTA DiT. The source code is released at <a class="link-external link-https" href="https://github.com/sail-sg/MDT" rel="external noopener nofollow">this https URL</a>.
Computer Vision and Pattern Recognition
What problem does this paper attempt to address?
The paper aims to address two main issues of Diffusion Probabilistic Models (DPMs) in image synthesis: 1. **Insufficient Contextual Reasoning Ability**: The paper observes that existing DPMs have difficulty in learning the associations between parts of objects in images, leading to a slow training process. Specifically, the model learns each semantic part independently during training, failing to establish the relationships between these parts. 2. **Slow Training Speed**: DPMs require a large number of time steps to ensure the convergence of the Stochastic Differential Equation (SDE), making the model training very time-consuming and computationally expensive. To solve these problems, the authors propose a method called "Masked Diffusion Transformer (MDT)" which enhances the contextual learning ability of DPMs between semantic parts of images by introducing a masked latent modeling scheme. Additionally, the paper presents an improved version, MDTv2, which makes advancements in the macro network structure and training strategies, further improving the model's learning speed and image generation quality. Experimental results show that MDTv2 not only outperforms existing methods in image synthesis performance on the ImageNet dataset but also significantly accelerates the training speed.