Why Does Sharpness-Aware Minimization Generalize Better Than SGD?

Zixiang Chen,Junkai Zhang,Yiwen Kou,Xiangning Chen,Cho-Jui Hsieh,Quanquan Gu
2023-10-11
Abstract:The challenge of overfitting, in which the model memorizes the training data and fails to generalize to test data, has become increasingly significant in the training of large neural networks. To tackle this challenge, Sharpness-Aware Minimization (SAM) has emerged as a promising training method, which can improve the generalization of neural networks even in the presence of label noise. However, a deep understanding of how SAM works, especially in the setting of nonlinear neural networks and classification tasks, remains largely missing. This paper fills this gap by demonstrating why SAM generalizes better than Stochastic Gradient Descent (SGD) for a certain data model and two-layer convolutional ReLU networks. The loss landscape of our studied problem is nonsmooth, thus current explanations for the success of SAM based on the Hessian information are insufficient. Our result explains the benefits of SAM, particularly its ability to prevent noise learning in the early stages, thereby facilitating more effective learning of features. Experiments on both synthetic and real data corroborate our theory.
Machine Learning,Optimization and Control
What problem does this paper attempt to address?
### What problem does this paper attempt to solve? This paper aims to explore why Sharpness - Aware Minimization (SAM) has better generalization ability than Stochastic Gradient Descent (SGD) in some cases. Specifically, the authors study how SAM can more effectively prevent the model from over - fitting the training data and generalize better to the test data in the presence of label noise. #### Background and Motivation As the scale of deep neural networks increases, the over - fitting problem becomes more and more prominent. Over - fitting means that the model memorizes the details of the training data but performs poorly on the test data. Traditional methods such as SGD can achieve good generalization under certain conditions, but in high - dimensional data and complex models, they are still prone to unstable training and harmful over - fitting. To overcome this challenge, Sharpness - Aware Minimization (SAM) was proposed as a new training method. SAM not only focuses on finding the point with the smallest loss function gradient, but also pursues reducing the local change rate (i.e., sharpness) of the loss function, thereby improving the model's generalization ability. #### Main Contributions of the Paper 1. **Theoretical Explanation**: The paper explains why SAM can outperform SGD even in non - smooth loss landscapes by analyzing the loss landscapes of two - layer convolutional ReLU networks. Traditional explanations based on the Hessian matrix are insufficient in this case. 2. **Benign Over - Fitting Conditions**: The authors precisely characterize the conditions for benign over - fitting when training two - layer convolutional ReLU networks with SGD. This is the first study on benign over - fitting of neural networks trained with mini - batch SGD. 3. **Phase - Transition Phenomenon**: The paper proves that SGD will lead to harmful over - fitting under certain conditions, while SAM can achieve benign over - fitting under these conditions. This difference indicates that SAM is strictly superior to SGD in terms of generalization error. 4. **Early Noise Learning Suppression**: SAM prevents the model from memorizing noise in the early training stage by introducing perturbations, making the neural network learn features more effectively. #### Theoretical Results The authors show the separation of SGD and SAM in test error through the following theorems: - **For SGD**: - When the signal strength \(\|\mu\|_2\geq\Omega(d^{1/4})\), the test error \(L_{0 - 1}^D(W(t))\leq p+\epsilon\). - When the signal strength \(\|\mu\|_2\leq O(d^{1/4})\), the test error \(L_{0 - 1}^D(W(t))\geq p + 0.1\). - **For SAM**: - As long as the signal strength \(\|\mu\|_2\geq e^{\Omega(1)}\), the test error \(L_{0 - 1}^D(W(t))\leq p+\epsilon\). These results indicate that under the same signal strength and data dimension conditions, SAM can achieve better generalization performance in a wider range. #### Experimental Verification The authors verify the above theoretical results through synthetic data and real - data experiments. The experimental results show that SAM can indeed better prevent over - fitting under high - dimensional data and low - signal - strength conditions, and has a significant advantage in test error. ### Summary This paper proves through theoretical analysis and experiments that SAM has better generalization ability than SGD in some cases, especially in the presence of label noise. SAM improves the model's learning efficiency and generalization performance by introducing perturbations to suppress early noise learning.