On Masked Pre-training and the Marginal Likelihood

Pablo Moreno-Muñoz,Pol G. Recasens,Søren Hauberg
DOI: https://doi.org/10.48550/arXiv.2306.00520
IF: 5.414
2023-06-01
Machine Learning
Abstract:Masked pre-training removes random input dimensions and learns a model that can predict the missing values. Empirical results indicate that this intuitive form of self-supervised learning yields models that generalize very well to new domains. A theoretical understanding is, however, lacking. This paper shows that masked pre-training with a suitable cumulative scoring function corresponds to maximizing the model's marginal likelihood, which is de facto the Bayesian model selection measure of generalization. Beyond shedding light on the success of masked pre-training, this insight also suggests that Bayesian models can be trained with appropriately designed self-supervision. Empirically, we confirm the developed theory and explore the main learning principles of masked pre-training in large language models.
What problem does this paper attempt to address?
The paper attempts to address the problem of understanding why Masked Pre-Training (MPT) enables models to have good generalization capabilities on new tasks. Specifically, the paper explores the relationship between MPT and marginal likelihood in Bayesian model selection and seeks to theoretically explain why MPT is so successful. ### Main Issues 1. **Reasons for MPT's Success**: Although MPT performs well in practice, the theoretical mechanisms behind it are not yet clear. The paper aims to explain from a theoretical perspective why MPT can produce models with good generalization capabilities. 2. **Relationship between MPT and Marginal Likelihood**: The paper explores whether MPT can improve the generalization ability of models by optimizing marginal likelihood. Marginal likelihood is an important indicator of generalization ability in Bayesian model selection. ### Research Background - **Definition of MPT**: MPT is a self-supervised learning method that randomly masks part of the input data and trains the model to predict these missing values. This method has achieved significant results in fields such as Natural Language Processing (NLP). - **Marginal Likelihood**: Marginal likelihood is a commonly used measure of generalization ability in Bayesian model selection, which evaluates the generalization ability of a model by averaging over hypotheses through probability integration. ### Main Contributions of the Paper 1. **Theoretical Connection**: The paper proves that under certain conditions, MPT is equivalent to maximizing the marginal likelihood of the model. Specifically, MPT optimizes the stochastic gradient of the log marginal likelihood through random masking patterns of different sizes. 2. **Empirical Validation**: The paper empirically validates the theoretical results, demonstrating the performance of MPT on different datasets and analyzing the impact of different masking rates on model performance. ### Experimental Results - **Traceable Models**: On the Probabilistic Principal Component Analysis (PPCA) model, the paper verifies that MPT can indeed maximize marginal likelihood. - **Large-Scale Language Models**: On large-scale language models (such as BERT), the paper observes that the behavior of the MPT curve is similar to that of the PPCA model, supporting the hypothesis that MPT performs implicit integration in the latent space and maximizes marginal likelihood. ### Conclusion The paper theoretically and empirically demonstrates the connection between MPT and marginal likelihood, explaining why MPT can produce models with good generalization capabilities. This finding not only deepens the understanding of MPT but also provides a new perspective for the training of Bayesian models.