Physics of Language Models: Part 2.2, How to Learn From Mistakes on Grade-School Math Problems

Tian Ye,Zicheng Xu,Yuanzhi Li,Zeyuan Allen-Zhu
2024-08-29
Abstract:Language models have demonstrated remarkable performance in solving reasoning tasks; however, even the strongest models still occasionally make reasoning mistakes. Recently, there has been active research aimed at improving reasoning accuracy, particularly by using pretrained language models to "self-correct" their mistakes via multi-round prompting. In this paper, we follow this line of work but focus on understanding the usefulness of incorporating "error-correction" data directly into the pretraining stage. This data consists of erroneous solution steps immediately followed by their corrections. Using a synthetic math dataset, we show promising results: this type of pretrain data can help language models achieve higher reasoning accuracy directly (i.e., through simple auto-regression, without multi-round prompting) compared to pretraining on the same amount of error-free data. We also delve into many details, such as (1) how this approach differs from beam search, (2) how such data can be prepared, (3) whether masking is needed on the erroneous tokens, (4) the amount of error required, (5) whether such data can be deferred to the fine-tuning stage, and many others.
Computation and Language,Artificial Intelligence,Machine Learning
What problem does this paper attempt to address?
The paper primarily explores how to improve the inference accuracy of language models by directly incorporating "error correction" data during the pre-training phase. Specifically, the paper focuses on the following questions: 1. **Why do language models make mistakes during generation?** - Although language models can self-correct errors through multiple rounds of prompts after generation, why can't they correct these errors immediately during the generation process? 2. **How to utilize "error correction" data for pre-training?** - The paper proposes a method of including data with errors and their immediate corrections (referred to as "retry data") in the pre-training dataset to train the language model to self-detect and correct errors during the generation process. 3. **Experimental validation and result analysis:** - Experiments using synthetic mathematical datasets (such as the iGSM dataset) demonstrate that this method can significantly improve the accuracy of language models in inference tasks. Even in high-error-rate scenarios, the model can still self-correct during generation, thereby improving overall accuracy. 4. **Comparison of different training strategies:** - The paper also compares the effects of training with completely error-free data, using error correction data only during the fine-tuning phase, and using error correction data throughout the entire pre-training phase. The results show that using error correction data during the pre-training phase is the most effective. 5. **Future work prospects:** - Finally, the paper discusses how to prepare and use this type of training data and provides suggestions for the future development of large-scale language models (LLMs), emphasizing the importance of incorporating error correction data during the pre-training phase.