Three Mechanisms of Feature Learning in the Exact Solution of a Latent Variable Model

Yizhou Xu,Liu Ziyin
2024-05-04
Abstract:We identify and exactly solve the learning dynamics of a one-hidden-layer linear model at any finite width whose limits exhibit both the kernel phase and the feature learning phase. We analyze the phase diagram of this model in different limits of common hyperparameters including width, layer-wise learning rates, scale of output, and scale of initialization. Our solution identifies three novel prototype mechanisms of feature learning: (1) learning by alignment, (2) learning by disalignment, and (3) learning by rescaling. In sharp contrast, none of these mechanisms is present in the kernel regime of the model. We empirically demonstrate that these discoveries also appear in deep nonlinear networks in real tasks.
Machine Learning,Artificial Intelligence
What problem does this paper attempt to address?
The core problem that this paper attempts to solve is to understand the transition between feature learning and the kernel regime in neural networks and the mechanisms behind it. Specifically, the authors focus on how to accurately describe the training dynamics of neural networks in a one - hidden - layer linear model with a finite width, and reveal three new mechanisms of feature learning: learning by alignment, learning by disalignment, and learning by rescaling. These problems are crucial for understanding feature learning in infinite - width models and how these insights can be applied to finite - width models. ### Main contributions of the paper 1. **Accurately solving the finite - width model**: The authors provide an exact solution to describe the NTK (neural tangent kernel) evolution dynamics of a finite - width, one - hidden - layer linear model. This model can exhibit the NTK phase and the feature learning phase. 2. **Revealing three feature learning mechanisms**: - **Learning by alignment**: Feature learning that occurs when the weight vectors of the two layers of the model gradually become more aligned. - **Learning by disalignment**: Feature learning that occurs when the weight vectors of the two layers of the model gradually become less aligned. - **Learning by rescaling**: Feature learning is achieved by adjusting the output scale. 3. **Theoretical and experimental verification**: The authors not only provide theoretical analysis but also experimentally verify the existence of these mechanisms in deep nonlinear networks. ### Problem background In the study of neural networks, especially when the width tends to infinity, the learning dynamics of the model can be described by NTK. If NTK remains unchanged throughout the training process, the model is in the kernel regime; otherwise, the model is in the feature learning phase. Although there have been many studies on infinite - width models, the understanding of feature learning mechanisms in finite - width models is still limited. ### Solution By analyzing a simple but non - convex one - hidden - layer linear model, the authors first accurately solve its learning dynamics. They find that the learning dynamics of the model depend on the initialization and the choice of hyperparameters, and reveal the above three new feature learning mechanisms. These mechanisms do not exist in the kernel regime, thus providing a new perspective for understanding feature learning. ### Experimental results Through a series of experiments, the authors show that these three feature learning mechanisms exist not only in linear models but also in deep nonlinear networks in practical tasks. In particular, they experimentally verify the variation law of inter - layer alignment in the feature learning phase and the influence of different initialization scales on the model performance. ### Conclusion This research provides an important theoretical basis for understanding feature learning in finite - width neural networks and reveals three new feature learning mechanisms. This is of great significance for improving the design and optimization of deep - learning models.