On the Generalization of Stochastic Gradient Descent with Momentum

Ali Ramezani-Kebrya,Kimon Antonakopoulos,Volkan Cevher,Ashish Khisti,Ben Liang
2024-01-15
Abstract:While momentum-based accelerated variants of stochastic gradient descent (SGD) are widely used when training machine learning models, there is little theoretical understanding on the generalization error of such methods. In this work, we first show that there exists a convex loss function for which the stability gap for multiple epochs of SGD with standard heavy-ball momentum (SGDM) becomes unbounded. Then, for smooth Lipschitz loss functions, we analyze a modified momentum-based update rule, i.e., SGD with early momentum (SGDEM) under a broad range of step-sizes, and show that it can train machine learning models for multiple epochs with a guarantee for generalization. Finally, for the special case of strongly convex loss functions, we find a range of momentum such that multiple epochs of standard SGDM, as a special form of SGDEM, also generalizes. Extending our results on generalization, we also develop an upper bound on the expected true risk, in terms of the number of training steps, sample size, and momentum. Our experimental evaluations verify the consistency between the numerical results and our theoretical bounds. SGDEM improves the generalization error of SGDM when training ResNet-18 on ImageNet in practical distributed settings.
Machine Learning
What problem does this paper attempt to address?
The paper attempts to address the following issues: 1. **Impact of Momentum Methods on Generalization Performance**: Although Stochastic Gradient Descent (SGD) with momentum is widely used in training machine learning models, there is limited theoretical understanding of its generalization error. This paper first demonstrates that for the standard Heavy Ball Momentum method (SGDM), there exists a convex loss function such that the stability gap over multiple epochs becomes unbounded. 2. **Improved Momentum Update Rule**: To address the above issue, the authors propose an improved momentum update rule—SGD with Early Momentum (SGDEM), and analyze this new method over a wide range of step sizes. The results show that even for non-convex problems, SGDEM can ensure good generalization performance while training over multiple epochs. 3. **Performance of SGDM under Strongly Convex Loss Functions**: For a special type of strongly convex loss function, the researchers also found a range of momentum values within which the standard SGDM, as a special case of SGDEM, can also achieve generalization. 4. **Experimental Validation**: Experiments conducted on training the ResNet-18 model in an actual distributed environment show that SGDEM significantly reduces generalization error compared to traditional SGD and SGDM. In summary, this paper aims to improve the generalization ability during the training process of deep learning models by introducing a new momentum mechanism, and provides theoretical analysis and experimental evidence to support this conclusion.