Fast Training of Diffusion Models with Masked Transformers

Hongkai Zheng,Weili Nie,Arash Vahdat,Anima Anandkumar
2024-03-05
Abstract:We propose an efficient approach to train large diffusion models with masked transformers. While masked transformers have been extensively explored for representation learning, their application to generative learning is less explored in the vision domain. Our work is the first to exploit masked training to reduce the training cost of diffusion models significantly. Specifically, we randomly mask out a high proportion (e.g., 50%) of patches in diffused input images during training. For masked training, we introduce an asymmetric encoder-decoder architecture consisting of a transformer encoder that operates only on unmasked patches and a lightweight transformer decoder on full patches. To promote a long-range understanding of full patches, we add an auxiliary task of reconstructing masked patches to the denoising score matching objective that learns the score of unmasked patches. Experiments on ImageNet-256x256 and ImageNet-512x512 show that our approach achieves competitive and even better generative performance than the state-of-the-art Diffusion Transformer (DiT) model, using only around 30% of its original training time. Thus, our method shows a promising way of efficiently training large transformer-based diffusion models without sacrificing the generative performance.
Computer Vision and Pattern Recognition,Artificial Intelligence,Machine Learning
What problem does this paper attempt to address?
The main goal of this paper is to address the high computational cost associated with training large-scale diffusion models. Specifically, the paper proposes a new method that utilizes masked transformers to significantly reduce training time and improve training efficiency without sacrificing generative performance. The specific issues the paper aims to address are as follows: 1. **Reducing Training Costs**: Current diffusion models require a substantial amount of computational resources and time for training, which is a major bottleneck for most researchers and practitioners. The paper introduces a method of randomly masking input patches to reduce the computational cost of each training step. 2. **Improving Training Efficiency**: By introducing an asymmetric encoder-decoder architecture and new training objectives (including predicting the scores of unmasked patches and reconstructing masked patches), the paper demonstrates that its method can achieve faster training speeds and lower memory consumption while maintaining competitive generative performance. 3. **Validating Scalability**: Experiments on the ImageNet-256×256 and ImageNet-512×512 datasets show that the proposed method achieves comparable or even better generative performance than existing state-of-the-art diffusion models while using approximately 30% of the original training time, thereby validating the method's effectiveness and scalability. In summary, the paper aims to improve the training process of diffusion models by introducing masking techniques to enhance training efficiency and reduce costs while maintaining or improving generative performance.