On Mesa-Optimization in Autoregressively Trained Transformers: Emergence and Capability

Chenyu Zheng,Wei Huang,Rongzhen Wang,Guoqiang Wu,Jun Zhu,Chongxuan Li
2024-10-26
Abstract:Autoregressively trained transformers have brought a profound revolution to the world, especially with their in-context learning (ICL) ability to address downstream tasks. Recently, several studies suggest that transformers learn a mesa-optimizer during autoregressive (AR) pretraining to implement ICL. Namely, the forward pass of the trained transformer is equivalent to optimizing an inner objective function in-context. However, whether the practical non-convex training dynamics will converge to the ideal mesa-optimizer is still unclear. Towards filling this gap, we investigate the non-convex dynamics of a one-layer linear causal self-attention model autoregressively trained by gradient flow, where the sequences are generated by an AR process $x_{t+1} = W x_t$. First, under a certain condition of data distribution, we prove that an autoregressively trained transformer learns $W$ by implementing one step of gradient descent to minimize an ordinary least squares (OLS) problem in-context. It then applies the learned $\widehat{W}$ for next-token prediction, thereby verifying the mesa-optimization hypothesis. Next, under the same data conditions, we explore the capability limitations of the obtained mesa-optimizer. We show that a stronger assumption related to the moments of data is the sufficient and necessary condition that the learned mesa-optimizer recovers the distribution. Besides, we conduct exploratory analyses beyond the first data condition and prove that generally, the trained transformer will not perform vanilla gradient descent for the OLS problem. Finally, our simulation results verify the theoretical results.
Machine Learning,Computation and Language
What problem does this paper attempt to address?
The core problem that this paper attempts to solve is to understand the mechanism of the in - context learning (ICL) ability of autoregressive - trained transformers, especially whether and how they achieve this ability by learning a "mesa - optimizer". Specifically, the paper explores the following aspects: 1. **When does the intermediate optimization algorithm occur?** - The paper attempts to determine the specific conditions for the emergence of the intermediate optimizer in autoregressive - trained transformers. By analyzing non - convex dynamics, the author proves that when the initial data distribution satisfies certain conditions, the autoregressive - trained linear transformer will converge to an intermediate optimizer, which can perform one - step gradient descent to minimize the ordinary least squares (OLS) problem. 2. **If the intermediate optimizer does emerge, what are its capacity limitations?** - The author further explores the capacity limitations of the obtained intermediate optimizer and proposes stronger data distribution assumptions as the necessary and sufficient conditions for the intermediate optimizer to be able to recover the true data distribution. The study finds that in some cases (such as when the initial data follows a standard normal distribution), the intermediate optimizer cannot fully recover the data distribution. 3. **In more general cases, how does the trained transformer behave?** - The paper also studies whether the trained transformer will perform ordinary gradient descent when specific assumptions are not met. The results show that in the absence of structural assumptions, the trained transformer generally does not perform ordinary gradient descent to minimize the OLS problem. ### Specific problem solutions #### 1. Conditions for the emergence of the intermediate optimizer According to **Theorem 4.1** in the paper, when the initial data distribution \( D_{x_1} \) satisfies the following assumptions, the trained transformer will converge to an intermediate optimizer: \[ \text{Assumption 4.1:} \mathbb{E}_{x_1 \sim D_{x_1}}[x_{1i_1}^{r_1} x_{1i_2}^{r_2} \cdots x_{1i_n}^{r_n}] = 0 \] for any subset \(\{i_1, i_2, \ldots, i_n | n \leq 4\}\), and \( r_2, \ldots, r_n \in \mathbb{N} \). In addition, assume that \(\kappa_1=\mathbb{E}[x_{1j}^4]\), \(\kappa_2 = \mathbb{E}[x_{1j}^6]\), and \(\kappa_3=\sum_{r \neq j} \mathbb{E}[x_{1j}^2 x_{1r}^4]\) are finite constants. #### 2. Capacity limitations of the intermediate optimizer According to **Proposition 4.1**, even in the ideal case (long training context), the trained linear transformer cannot recover a simple AR process, especially when the initial data follows a multivariate normal distribution \( N(0_d, \sigma^2 I_d) \). This shows that autoregressive pre - training ICL is different from few - shot pre - training ICL. #### 3. Behavior in more general cases According to **Proposition 4.2**, in the absence of additional off - diagonal gradient masking, the trained transformer will not perform one - step ordinary gradient descent to minimize the OLS problem. This means that in more complex data distributions or model structures, its behavior needs further study. ### Summary Through theoretical analysis and experimental verification, this paper deeply explores the mechanism of the in - context learning ability of autoregressive - trained transformers, especially in terms of the emergence of the intermediate optimizer and its capacity limitations. These findings provide important insights for understanding the internal working principles of large - language models and propose new directions for further research.