Promoting Exploration in Memory-Augmented Adam using Critical Momenta

Pranshu Malviya,Gonçalo Mordido,Aristide Baratin,Reza Babanezhad Harikandeh,Jerry Huang,Simon Lacoste-Julien,Razvan Pascanu,Sarath Chandar
2024-06-18
Abstract:Adaptive gradient-based optimizers, notably Adam, have left their mark in training large-scale deep learning models, offering fast convergence and robustness to hyperparameter settings. However, they often struggle with generalization, attributed to their tendency to converge to sharp minima in the loss landscape. To address this, we propose a new memory-augmented version of Adam that encourages exploration towards flatter minima by incorporating a buffer of critical momentum terms during training. This buffer prompts the optimizer to overshoot beyond narrow minima, promoting exploration. Through comprehensive analysis in simple settings, we illustrate the efficacy of our approach in increasing exploration and bias towards flatter minima. We empirically demonstrate that it can improve model performance for image classification on ImageNet and CIFAR10/100, language modelling on Penn Treebank, and online learning tasks on TinyImageNet and 5-dataset. Our code is available at \url{<a class="link-external link-https" href="https://github.com/chandar-lab/CMOptimizer" rel="external noopener nofollow">this https URL</a>}.
Machine Learning,Artificial Intelligence
What problem does this paper attempt to address?
### What problem does this paper attempt to solve? This paper aims to solve the problem of poor generalization performance encountered by adaptive gradient optimizers (especially Adam) when training large - scale deep learning models. Specifically, although adaptive optimizers such as Adam perform well in terms of convergence speed and robustness to hyper - parameter settings, they tend to converge to sharp minima in the loss landscape, resulting in poor generalization performance. To solve this problem, the author proposes a new enhanced Adam optimizer - by introducing a buffer that stores Critical Momenta (CM) to promote exploration, enabling the optimizer to escape from sharp minima and tend towards flatter minima regions. This improvement helps to improve the generalization ability and performance of the model. #### Main contributions: 1. **Propose a new enhanced Adam optimizer**: By storing and using Critical Momenta (CM) during the training process, it helps the optimizer escape from sharp minima. 2. **Provide theoretical convergence analysis**: Analyze the convergence of the new method under a simplified setting. 3. **Verify the effectiveness of the method through multiple experiments**: Demonstrate the superior performance of this method on a variety of benchmark tasks, including image classification, language modeling, and online learning tasks. 4. **Observe the improvement of model performance**: In supervised learning and online learning scenarios, the new method significantly improves model performance. #### Key technical points: - **Critical Momenta (CM)**: Different from previous work (such as the method using Critical Gradients (CG) proposed by McRae et al.), this paper uses momentum instead of gradient to construct the buffer. This helps to avoid the gradient cancellation problem and maintain a low variance, so that the optimizer can better escape from sharp minima. - **Buffer management**: Update the elements in the buffer through a priority mechanism to ensure that the buffer always contains the most influential momentum terms. #### Experimental results: - Extensive experiments were carried out on multiple benchmark datasets (such as ImageNet, CIFAR10/100, Penn Treebank), and the results show that Adam + CM is superior to other optimizers in terms of generalization performance and final model performance. - In particular, in the image classification task, Adam + CM achieved higher validation accuracy on CIFAR10, CIFAR100, and ImageNet; in the language modeling task, Adam + CM also significantly reduced the validation perplexity. In conclusion, by introducing the Critical Momenta buffer, this paper effectively solves the generalization performance problem faced by adaptive optimizers when training deep learning models, providing a new idea for the design of optimizers.