Mechanics of Next Token Prediction with Self-Attention

Yingcong Li,Yixiao Huang,M. Emrullah Ildiz,Ankit Singh Rawat,Samet Oymak
2024-03-13
Abstract:Transformer-based language models are trained on large datasets to predict the next token given an input sequence. Despite this simple training objective, they have led to revolutionary advances in natural language processing. Underlying this success is the self-attention mechanism. In this work, we ask: $\textit{What}$ $\textit{does}$ $\textit{a}$ $\textit{single}$ $\textit{self-attention}$ $\textit{layer}$ $\textit{learn}$ $\textit{from}$ $\textit{next-token}$ $\textit{prediction?}$ We show that training self-attention with gradient descent learns an automaton which generates the next token in two distinct steps: $\textbf{(1)}$ $\textbf{Hard}$ $\textbf{retrieval:}$ Given input sequence, self-attention precisely selects the $\textit{high-priority}$ $\textit{input}$ $\textit{tokens}$ associated with the last input token. $\textbf{(2)}$ $\textbf{Soft}$ $\textbf{composition:}$ It then creates a convex combination of the high-priority tokens from which the next token can be sampled. Under suitable conditions, we rigorously characterize these mechanics through a directed graph over tokens extracted from the training data. We prove that gradient descent implicitly discovers the strongly-connected components (SCC) of this graph and self-attention learns to retrieve the tokens that belong to the highest-priority SCC available in the context window. Our theory relies on decomposing the model weights into a directional component and a finite component that correspond to hard retrieval and soft composition steps respectively. This also formalizes a related implicit bias formula conjectured in [Tarzanagh et al. 2023]. We hope that these findings shed light on how self-attention processes sequential data and pave the path toward demystifying more complex architectures.
Machine Learning,Artificial Intelligence,Computation and Language,Optimization and Control
What problem does this paper attempt to address?
This paper investigates the learning behavior of self-attention mechanism in Transformer for predicting the next word task. It found that a single layer of self-attention learns a two-step mechanism through gradient descent: hard retrieval (precisely selecting words highly relevant to the previous input word) and soft composition (creating a convex combination of these high-priority words to output the next word). The paper characterizes this mechanism with a directed graph and proves that gradient descent implicitly discovers the strongly connected components in the graph, and self-attention learns to retrieve words from the highest priority strongly connected components. These findings contribute to understanding how self-attention handles sequential data and lay a foundation for analyzing more complex architectures.