On Learning Gaussian Multi-index Models with Gradient Flow

Alberto Bietti,Joan Bruna,Loucas Pillaud-Vivien
2023-11-03
Abstract:We study gradient flow on the multi-index regression problem for high-dimensional Gaussian data. Multi-index functions consist of a composition of an unknown low-rank linear projection and an arbitrary unknown, low-dimensional link function. As such, they constitute a natural template for feature learning in neural networks. We consider a two-timescale algorithm, whereby the low-dimensional link function is learnt with a non-parametric model infinitely faster than the subspace parametrizing the low-rank projection. By appropriately exploiting the matrix semigroup structure arising over the subspace correlation matrices, we establish global convergence of the resulting Grassmannian population gradient flow dynamics, and provide a quantitative description of its associated `saddle-to-saddle' dynamics. Notably, the timescales associated with each saddle can be explicitly characterized in terms of an appropriate Hermite decomposition of the target link function. In contrast with these positive results, we also show that the related \emph{planted} problem, where the link function is known and fixed, in fact has a rough optimization landscape, in which gradient flow dynamics might get trapped with high probability.
Machine Learning,Optimization and Control
What problem does this paper attempt to address?
### What problem does this paper attempt to solve? This paper aims to study the gradient flow in the multi - index regression problem of high - dimensional Gaussian data. Specifically, it focuses on how to effectively learn and optimize the multi - index model through the gradient flow method. The multi - index function consists of an unknown low - rank linear projection and an arbitrary unknown low - dimensional link function, which makes them a natural template for feature learning in neural networks. #### Main research content 1. **Learning of multi - index models**: - The multi - index model contains a low - dimensional link function and a low - rank linear projection. These models can be regarded as a simplified form of neural networks for extracting low - dimensional features in internal layers. - The paper studies how to use the gradient flow to learn these models on high - dimensional Gaussian data. 2. **Two - time - scale algorithm**: - A two - time - scale algorithm is proposed, in which the learning speed of the low - dimensional link function is much faster than that of the parameterized low - rank projection. - By using the matrix semigroup structure on the subspace correlation matrix, the author establishes the global convergence of the Grassmannian overall gradient flow dynamics and provides a quantitative description of its "saddle - point - to - saddle - point" dynamics. 3. **Saddle - point dynamics**: - Research shows that the gradient flow dynamics will move between multiple saddle points, and these saddle points can be clearly characterized by the appropriate Hermite decomposition of the target link function. - The time scale of each saddle point can be described by the information index \( s \), which is related to the smoothness and complexity of the target function. 4. **Planted Model**: - By comparing the planted model with a known link function, it is found that even if the link function is known, the gradient flow dynamics may still be trapped in a local optimal solution and cannot recover the target subspace. 5. **Sample complexity guarantee**: - It explores how to transform the quantitative guarantee of time complexity into the guarantee of sample complexity, and proposes a method to make the empirical Stiefel gradient concentrate near the overall gradient by adjusting the appropriate regularization parameter. #### Formula display - **Representation of loss function**: \[ L(f, W)=\frac{1}{2}\|f\|_{\gamma_r}^2+\frac{1}{2}\|f^*\|_{\gamma_q}^2-\langle A_M f, f^*\rangle_{\gamma_q} \] where \( M = W^*W \), and \( A_M \) is the averaging operator. - **Definition of averaging operator**: \[ A_M:L^2_{\gamma_r}\to L^2_{\gamma_q},\quad A_M f(z):=\mathbb{E}_{y\sim\gamma_r}\left[f(M^T z+(I_r - M^T M)^{1/2}y)\right] \] - **Saddle - point escape time**: \[ t_k\simeq d^{s_k - 1} \] where \( s_k \) is the information index of each saddle point. ### Summary This paper explores the application of gradient flow in feature learning by studying multi - index models on high - dimensional Gaussian data, and reveals the behavior of gradient flow dynamics under different conditions. In particular, it shows that in some cases, the gradient flow may be trapped in a local optimal solution, while in other cases, global optimal solution convergence can be achieved through appropriate algorithm design.