Towards an Understanding of Stepwise Inference in Transformers: A Synthetic Graph Navigation Model

Mikail Khona,Maya Okawa,Jan Hula,Rahul Ramesh,Kento Nishi,Robert Dick,Ekdeep Singh Lubana,Hidenori Tanaka
2024-02-13
Abstract:Stepwise inference protocols, such as scratchpads and chain-of-thought, help language models solve complex problems by decomposing them into a sequence of simpler subproblems. Despite the significant gain in performance achieved via these protocols, the underlying mechanisms of stepwise inference have remained elusive. To address this, we propose to study autoregressive Transformer models on a synthetic task that embodies the multi-step nature of problems where stepwise inference is generally most useful. Specifically, we define a graph navigation problem wherein a model is tasked with traversing a path from a start to a goal node on the graph. Despite is simplicity, we find we can empirically reproduce and analyze several phenomena observed at scale: (i) the stepwise inference reasoning gap, the cause of which we find in the structure of the training data; (ii) a diversity-accuracy tradeoff in model generations as sampling temperature varies; (iii) a simplicity bias in the model's output; and (iv) compositional generalization and a primacy bias with in-context exemplars. Overall, our work introduces a grounded, synthetic framework for studying stepwise inference and offers mechanistic hypotheses that can lay the foundation for a deeper understanding of this phenomenon.
Artificial Intelligence
What problem does this paper attempt to address?
The problem that this paper attempts to solve is to understand the internal mechanism of the Transformer model when performing stepwise inference. Specifically, the paper aims to study and explain how the stepwise inference protocol helps the Transformer model decompose complex problems and improve its problem - solving ability through a synthetic graph - navigation task. ### Problem Background Stepwise inference protocols, such as "scratchpads" and "chain - of - thought", significantly improve the performance of language models by decomposing complex problems into a series of simpler sub - problems. Although these protocols bring significant performance improvements, their underlying mechanisms remain unclear. To gain a deeper understanding of these mechanisms, the authors propose a research method based on a synthetic graph - navigation task. ### Research Objectives 1. **Establish a Synthetic Framework**: Design a synthetic graph - navigation task to simulate the multi - step reasoning process. 2. **Analyze Phenomena**: Reproduce and analyze several phenomena through experiments, such as the inference gap in stepwise inference, the trade - off between diversity and accuracy, the simplicity bias of model outputs, and the effects of compositional generalization and context examples. 3. **Provide Mechanism Hypotheses**: Propose mechanism hypotheses about stepwise inference phenomena, laying the foundation for further understanding of this phenomenon. ### Main Contributions 1. **Stepwise Inference Framework**: Propose a synthetic graph - navigation task as an abstraction, showing that the behavior when using stepwise inference can be replicated and explained. 2. **Simplicity Bias**: Demonstrate that when there are multiple solutions, the model tends to choose the shortest path connecting two nodes. 3. **Controllability of Context Examples**: Show that the preferred path of the model navigating between nodes can be controlled through context examples, and evaluate the model's generalization ability to longer paths and the impact of conflicting examples. ### Method Overview - **Graph - Navigation Task**: Define a graph - navigation task in which the Transformer model needs to predict whether two nodes can be connected by a path. Experiments are carried out in two scenarios: without context examples and with context examples. - **Data Generation Process**: Use randomly generated directed acyclic graphs (DAGs) to construct the training data set, and sample paths by selecting the starting node and the target node. - **Experimental Setup**: Use a two - layer Transformer model for experiments and evaluate the model's performance under different conditions. ### Key Findings 1. **Advantages of Stepwise Inference**: In the graph - navigation task, stepwise inference is significantly superior to direct inference, especially more obvious in hierarchical graphs. 2. **Diversity - Accuracy Trade - off**: As the sampling temperature increases, the path diversity generated by the model increases, but the accuracy decreases. 3. **Simplicity Bias**: The model tends to generate shorter paths, showing a preference for simplicity. 4. **Evolution of Failure Modes**: The model first learns to avoid missteps and then gradually learns to plan paths effectively. Through these studies, the authors hope to provide deeper insights into the understanding of the stepwise inference mechanism and provide a theoretical basis for future research.