Provable Multi-Task Representation Learning by Two-Layer ReLU Neural Networks

Liam Collins,Hamed Hassani,Mahdi Soltanolkotabi,Aryan Mokhtari,Sanjay Shakkottai
2024-06-07
Abstract:An increasingly popular machine learning paradigm is to pretrain a neural network (NN) on many tasks offline, then adapt it to downstream tasks, often by re-training only the last linear layer of the network. This approach yields strong downstream performance in a variety of contexts, demonstrating that multitask pretraining leads to effective feature learning. Although several recent theoretical studies have shown that shallow NNs learn meaningful features when either (i) they are trained on a {\em single} task or (ii) they are {\em linear}, very little is known about the closer-to-practice case of {\em nonlinear} NNs trained on {\em multiple} tasks. In this work, we present the first results proving that feature learning occurs during training with a nonlinear model on multiple tasks. Our key insight is that multi-task pretraining induces a pseudo-contrastive loss that favors representations that align points that typically have the same label across tasks. Using this observation, we show that when the tasks are binary classification tasks with labels depending on the projection of the data onto an $r$-dimensional subspace within the $d\gg r$-dimensional input space, a simple gradient-based multitask learning algorithm on a two-layer ReLU NN recovers this projection, allowing for generalization to downstream tasks with sample and neuron complexity independent of $d$. In contrast, we show that with high probability over the draw of a single task, training on this single task cannot guarantee to learn all $r$ ground-truth features.
Machine Learning
What problem does this paper attempt to address?
The core problem that this paper attempts to solve is: **Can multi - task pre - training effectively learn feature representations, especially in the case of using non - linear models (such as a two - layer ReLU neural network)?** Specifically, the paper explores the following questions: 1. **Why can non - linear neural networks learn effective feature representations in multi - task pre - training?** 2. **How does multi - task pre - training affect the performance of downstream tasks?** 3. **Can multi - task pre - training help neural networks extract low - dimensional effective features from high - dimensional input data?** ### Main contributions of the paper 1. **Proving the possibility of multi - task feature learning**: - By analyzing the training dynamics of a two - layer ReLU neural network on multiple binary classification tasks, the author first proves that during multi - task pre - training, the neural network can learn effective feature representations. These feature representations can project the input data into a low - dimensional subspace, thereby simplifying the learning complexity of downstream tasks. 2. **Introducing the concept of pseudo - contrastive loss**: - Multi - task pre - training induces a pseudo - contrastive loss function, which encourages the learned feature representations to align data points with the same labels in different tasks. This mechanism helps the network learn more general and effective features. 3. **Providing theoretical guarantees**: - The author proves that under certain conditions, multi - task pre - training can significantly reduce the number of samples and neurons required for downstream tasks, and these complexities are independent of the dimension of the input data. This indicates that multi - task pre - training can effectively reduce the learning difficulty of downstream tasks. 4. **Verifying the limitations of single - task pre - training**: - The author also shows that when pre - training is only performed on one task or a random feature model is used, it cannot be guaranteed to learn all real features, thus emphasizing the importance of multi - task pre - training. ### Formula summary - **Projection matrix**: \[ \Pi_{\parallel}(W) = W M^{\top} M, \quad \Pi_{\perp}(W) = W M_{\perp}^{\top} M_{\perp} \] where \( M \in \mathbb{R}^{r \times d} \) is a row - orthogonal matrix, and \( M_{\perp} \) is its orthogonal complement. - **Error bound**: \[ \epsilon = O\left( \frac{d \log(d T n_2 / \delta)}{\sqrt{T n_2} \left( 1 + \sqrt{\frac{\log(T / \delta)}{n_1}} \right)} + \sqrt{\frac{dr \log(dm / \delta)}{T}} \right) \] - **Singular value ratio**: \[ \frac{\sigma_1(\Pi_{\perp}(W_1))}{\sigma_r(\Pi_{\parallel}(W_1))} = O\left( \frac{r^{3.5} + \log^{3.5}(m / \delta)}{d^{1.5} + 2^r \epsilon} \right) \] ### Conclusion Through strict theoretical analysis and experimental verification, this paper proves the effectiveness of multi - task pre - training in non - linear neural networks, especially in feature learning. These results provide an important theoretical basis for understanding the internal mechanism of multi - task learning and strong support for multi - task pre - training in practical applications.