Deep Grokking: Would Deep Neural Networks Generalize Better?

Simin Fan,Razvan Pascanu,Martin Jaggi
2024-05-30
Abstract:Recent research on the grokking phenomenon has illuminated the intricacies of neural networks' training dynamics and their generalization behaviors. Grokking refers to a sharp rise of the network's generalization accuracy on the test set, which occurs long after an extended overfitting phase, during which the network perfectly fits the training set. While the existing research primarily focus on shallow networks such as 2-layer MLP and 1-layer Transformer, we explore grokking on deep networks (e.g. 12-layer MLP). We empirically replicate the phenomenon and find that deep neural networks can be more susceptible to grokking than its shallower counterparts. Meanwhile, we observe an intriguing multi-stage generalization phenomenon when increase the depth of the MLP model where the test accuracy exhibits a secondary surge, which is scarcely seen on shallow models. We further uncover compelling correspondences between the decreasing of feature ranks and the phase transition from overfitting to the generalization stage during grokking. Additionally, we find that the multi-stage generalization phenomenon often aligns with a double-descent pattern in feature ranks. These observations suggest that internal feature rank could serve as a more promising indicator of the model's generalization behavior compared to the weight-norm. We believe our work is the first one to dive into grokking in deep neural networks, and investigate the relationship of feature rank and generalization performance.
Machine Learning
What problem does this paper attempt to address?
This paper explores the phenomenon of "Deep Grokking" in deep neural networks (DNNs) during the training process and how it affects the model's generalization ability. Deep Grokking refers to a significant improvement in the test set generalization accuracy of the network after a long period of overfitting. The research mainly focuses on deep multi-layer perceptron (MLP) models and finds that they are more likely to experience Deep Grokking compared to shallow models, exhibiting a two-stage generalization phenomenon. The study also observes that feature rank plays a crucial role in the transition from overfitting to generalization, with a decrease in feature rank correlating with the stage change in generalization. Furthermore, the two-stage generalization aligns with the bimodal pattern of feature rank, suggesting that feature rank may be a more effective indicator of model generalization behavior than weight norm. The paper further discusses the impact of initialization scale and weight decay on Deep Grokking. Smaller weight decay may lead to delayed generalization (Deep Grokking), while larger weight decay may result in two-stage generalization. The paper also points out that the relationship between feature rank and model generalization performance has not been sufficiently studied, thus proposing it as a potential research direction. In conclusion, this paper provides an in-depth investigation of the phenomenon of Deep Grokking in deep neural networks and proposes feature rank as a novel indicator for evaluating model generalization performance, challenging the traditional notion that deeper networks always generalize better.