Understanding Factual Recall in Transformers via Associative Memories

Eshaan Nichani,Jason D. Lee,Alberto Bietti
2024-12-09
Abstract:Large language models have demonstrated an impressive ability to perform factual recall. Prior work has found that transformers trained on factual recall tasks can store information at a rate proportional to their parameter count. In our work, we show that shallow transformers can use a combination of associative memories to obtain such near optimal storage capacity. We begin by proving that the storage capacities of both linear and MLP associative memories scale linearly with parameter count. We next introduce a synthetic factual recall task, and prove that a transformer with a single layer of self-attention followed by an MLP can obtain 100% accuracy on the task whenever either the total number of self-attention parameters or MLP parameters scales (up to log factors) linearly with the number of facts. In particular, the transformer can trade off between using the value matrices or the MLP as an associative memory to store the dataset of facts. We complement these expressivity results with an analysis of the gradient flow trajectory of a simplified linear attention model trained on our factual recall task, where we show that the model exhibits sequential learning behavior.
Machine Learning,Computation and Language,Information Theory
What problem does this paper attempt to address?
The problem that this paper attempts to solve is: how to make the shallow Transformer model achieve near - optimal storage capacity in fact - recall tasks through associative memories. Specifically, the authors hope to understand how the Transformer model encodes and stores a large amount of factual information in its weights, and show that the shallow Transformer can achieve this by combining the self - attention mechanism and MLP (Multi - Layer Perceptron) as associative memories. ### Analysis of the Core Problems in the Paper 1. **Storage Capacity Problem**: - The authors observe that Transformer models trained on fact - recall tasks can store information at a rate proportional to the number of parameters. However, it is not yet clear how these models optimally encode this factual information. - This research aims to prove that the shallow Transformer can use the associative memory method to reach an approximately optimal storage capacity. 2. **Mechanism Understanding**: - Researchers attempt to understand the working principle of the Transformer model in fact - recall tasks, especially which parts (such as the self - attention layer or MLP) are responsible for storing specific factual information. - By introducing a synthetic fact - recall task, they show that a single - layer Transformer can achieve 100% accuracy when the number of parameters grows linearly. 3. **Optimization Dynamics**: - The study also explores the gradient - descent dynamics of the simplified linear - attention model during the training process, revealing the phased behavior exhibited by the model during the learning process, including the "hallucination" phase, in which the model makes predictions based only on relational words. ### Key Contributions - **Storage Capacity of Associative Memories**: The authors prove that the storage capacities of the linear and MLP associative memory models are linearly related to the number of parameters (considering the logarithmic factor), which is significantly better than the case of orthogonal embedding. - **Synthetic Fact - Recall Task**: A synthetic task is introduced, showing that a single - layer Transformer can achieve 100% accuracy when the self - attention parameters or MLP parameters grow linearly. - **Analysis of Gradient - Descent Dynamics**: The gradient - descent dynamics of the linear - attention model during the training process are studied, and it is found that the model experiences a "hallucination" phase during the learning process and finally converges to the correct prediction. ### Formula Summary - **Linear Associative Memory**: \[ \text{Theorem 1: } \arg \max_{y \in [M]} u_y^\top W e_x = f^*(x) \quad \text{for all } x \in [N] \] where \( W = \sum_{x \in [N]} u_{f^*(x)} e_x^\top \). - **MLP Associative Memory**: \[ \text{Theorem 2 (Informal): } \arg \max_{y \in [M]} u_y^\top V^\top \sigma(W e_x) = f^*(x) \quad \text{for all } x \in [N] \] where \( F(e_x) = V^\top \sigma(W e_x) \) and \( V, W \in \mathbb{R}^{m \times d} \). - **Gradient - Descent Dynamics**: \[ L(\theta) = \mathbb{E}_{z_1^{T+1}}[-\log \hat{p}(z_{T+1} | z_1^T)] \] where \( \hat{p}(a | z_1^T) = \frac{\exp(\langle \phi(a), F_{\text{lin}}(X; \theta) \rangle)}{\sum_{a' \in A} \exp(\langle \phi(a'), F_{\text{lin}}(X; \theta) \rangle)} \).