Promises and Pitfalls of Generative Masked Language Modeling: Theoretical Framework and Practical Guidelines

Yuchen Li,Alexandre Kirchmeyer,Aashay Mehta,Yilong Qin,Boris Dadachev,Kishore Papineni,Sanjiv Kumar,Andrej Risteski
2024-07-23
Abstract:Autoregressive language models are the currently dominant paradigm for text generation, but they have some fundamental limitations that cannot be remedied by scale-for example inherently sequential and unidirectional generation. While alternate classes of models have been explored, we have limited mathematical understanding of their fundamental power and limitations. In this paper we focus on Generative Masked Language Models (GMLMs), a non-autoregressive paradigm in which we train a model to fit conditional probabilities of the data distribution via masking, which are subsequently used as inputs to a Markov Chain to draw samples from the model, These models empirically strike a promising speed-quality trade-off as each step can be typically parallelized by decoding the entire sequence in parallel. We develop a mathematical framework for analyzing and improving such models which sheds light on questions of sample complexity and inference speed and quality. Empirically, we adapt the T5 model for iteratively-refined parallel decoding, achieving 2-3x speedup in machine translation with minimal sacrifice in quality compared with autoregressive models. We run careful ablation experiments to give recommendations on key design choices, and make fine-grained observations on the common error modes in connection with our theory. Our mathematical analyses and empirical observations characterize both potentials and limitations of this approach, and can be applied to future works on improving understanding and performance of GMLMs. Our codes are released at <a class="link-external link-https" href="https://github.com/google-research/google-research/tree/master/padir" rel="external noopener nofollow">this https URL</a>
Computation and Language,Machine Learning
What problem does this paper attempt to address?
The paper attempts to address issues primarily focused on the theoretical foundations and practical guidance of Generative Masked Language Models (GMLMs) during training and inference. Specifically, the paper focuses on the following core issues: 1. **Training Objectives of GMLMs**: GMLMs are trained to learn the conditional probabilities of data distributions. However, does this training method effectively learn the joint probability? That is, can the joint probability of the entire sequence be inferred by learning the conditional probabilities of partial sequences? 2. **Impact of Data Distribution and Algorithm Characteristics on Model Performance**: Which characteristics of data distribution and training/inference algorithms affect the quality of the learned model and the quality of its generated samples? 3. **Best Training Practices for GMLMs**: How to design loss functions, training, and inference processes to optimize the performance of GMLMs? Is there theoretical guidance that can clarify the space of these design choices? ### Main Contributions of the Paper To answer the above questions, the paper introduces a theoretical framework to analyze the potential and limitations of GMLMs during training and inference. Specific contributions include: - **Asymptotic Sample Complexity**: The study investigates the asymptotic sample complexity of estimating distribution parameters through extensive masked prediction loss and relates it to the mixing time of the corresponding Markov chain. It is proven that using larger masks always improves statistical efficiency (Theorem 1). - **Theoretical Analysis of Finite Sample Bounds**: The study examines how well the conditional probabilities of data distributions are learned in finite sample scenarios and the effectiveness of learning the joint distribution. If there is some capacity control over the class of distributions being learned (e.g., covering number bounds), finite sample bounds can be theoretically derived (Section 2.3). - **Limitations of Transformers**: It is pointed out that Transformers can only represent decoding steps decomposed into coordinates, making it difficult for them to efficiently sample simple distributions with strong correlations (Section 2.4). ### Experimental Validation The paper also explores the key components and common error patterns of GMLMs through a series of detailed experiments. The experimental results show: - **Key Components**: High masking ratio, custom vocabulary, distillation from autoregressive models, and architectural improvements such as positional attention are key components of GMLMs. - **Machine Translation Tasks**: GMLMs perform well on machine translation tasks, producing reasonable translation results even with a single forward pass. This is consistent with the theoretical framework, as machine translation tasks typically involve low entropy and fewer multimodal outputs. - **Limitations of Parallel Decoding**: Common error patterns (such as "stuttering" phenomena) indicate that GMLMs have certain limitations in modeling strong dependencies. ### Combination of Theory and Empirical Findings The theoretical and empirical findings of the paper jointly suggest that in cases where the target has strong correlations, designing faster-mixing Markov chains and using loss functions that inherit good statistical behavior can synergistically enhance the performance of GMLMs. Overall, the paper provides important theoretical foundations and practical guidance for further development in this field through in-depth theoretical analysis and empirical research on GMLMs.