Towards a Theoretical Understanding of the 'Reversal Curse' via Training Dynamics

Hanlin Zhu,Baihe Huang,Shaolun Zhang,Michael Jordan,Jiantao Jiao,Yuandong Tian,Stuart Russell
2024-10-28
Abstract:Auto-regressive large language models (LLMs) show impressive capacities to solve many complex reasoning tasks while struggling with some simple logical reasoning tasks such as inverse search: when trained on '$A \to B$' (e.g., 'Tom is the parent of John'), LLM fails to directly conclude '$B \gets A$' (e.g., 'John is the child of Tom') during inference even if the two sentences are semantically identical, which is known as the 'reversal curse'. In this paper, we theoretically analyze the reversal curse via the training dynamics of (stochastic) gradient descent for two auto-regressive models: (1) a bilinear model that can be viewed as a simplification of a one-layer transformer; (2) one-layer transformers under certain assumptions. Our analysis reveals that for both models, the reversal curse is a consequence of the (effective) model weights 'asymmetry', i.e., the increase of weights from a token $A$ to token $B$ during training does not necessarily cause the increase of the weights from $B$ to $A$, which is caused by the training dynamics under certain choice of loss function and the optimization space of model parameters. Moreover, our analysis can be naturally applied to other logical reasoning tasks such as chain-of-thought (COT), which provides a new perspective different from previous work that focuses on expressivity. Finally, we conduct experiments to validate our theory on multi-layer transformers under different settings. Our code is available at <a class="link-external link-https" href="https://github.com/marlo-z/reversal_curse_analysis/" rel="external noopener nofollow">this https URL</a>.
Machine Learning,Computation and Language
What problem does this paper attempt to address?
### The Problem the Paper Attempts to Solve This paper attempts to understand the phenomenon of the "reversal curse" in autoregressive large language models (LLMs). Specifically, when the model learns the relationship from A to B during training (e.g., "Tom is John's father"), it fails to directly infer the reverse relationship from B to A (e.g., "John is Tom's child"), even though the two sentences are semantically equivalent. This phenomenon is referred to as the "reversal curse." ### Main Contributions 1. **Theoretical Analysis**: - The authors explain the "reversal curse" by analyzing the training dynamics of two autoregressive models: a bilinear model (which can be seen as a simplified version of a one-layer Transformer) and a one-layer Transformer (under certain assumptions). - The analysis shows that for these two models, the reversal curse is caused by the asymmetry in model weights, i.e., increasing the weight from A to B does not necessarily lead to an increase in the weight from B to A. This asymmetry is due to the training dynamics under the specific loss function and model parameter optimization space. 2. **Application to Logical Reasoning Tasks**: - The authors' method can naturally be applied to other logical reasoning tasks, such as Chain-of-Thought (COT). This provides a new perspective for understanding these tasks, different from previous research focusing on expressive capabilities. 3. **Experimental Validation**: - The authors conducted experiments under different settings of multi-layer Transformers to validate their theoretical results. ### Background and Motivation - **Performance of LLMs**: Although LLMs perform well in solving complex reasoning tasks, they perform poorly in some simple logical reasoning tasks, especially those requiring multiple reasoning steps. - **Reversal Curse**: This phenomenon is particularly evident in autoregressive LLMs, where the model learns the relationship in one direction during training but fails to infer the reverse relationship. ### Solutions - **Model Parameter Constraints**: One possible solution is to constrain the model parameters to satisfy high-level rules of specific relationships. For example, the reversal type of rule can be viewed as a pair of relationships (→, ←) and two sets of entities A and B, so that when the model is trained on "A→B," it also increases the probability of "B←A." However, manually hard-coding these constraints is very difficult in practice. - **Different Loss Functions**: Another approach is to use symmetric loss functions instead of the commonly used cross-entropy loss. However, symmetric loss functions may lead to the model learning meaningless sentences, so cross-entropy loss is still widely used in practice. ### Experimental Results - **Theoretical Validation**: The authors proved the existence of the reversal curse through the analysis of training dynamics, and the experimental results supported this theory. - **Importance of Chain-of-Thought**: The authors also analyzed the importance of Chain-of-Thought in logical reasoning tasks through training dynamics, showing that without Chain-of-Thought, the model cannot directly infer certain conclusions. ### Conclusion This paper delves into the phenomenon of the reversal curse in autoregressive LLMs and its underlying causes through theoretical analysis and experimental validation. The research results emphasize that under the current popular training paradigm, autoregressive LLMs may not automatically infer certain types of conclusions, highlighting the importance of contextual learning, data augmentation, or planning in solving complex reasoning tasks.