Masked Thought: Simply Masking Partial Reasoning Steps Can Improve Mathematical Reasoning Learning of Language Models

Changyu Chen,Xiting Wang,Ting-En Lin,Ang Lv,Yuchuan Wu,Xin Gao,Ji-Rong Wen,Rui Yan,Yongbin Li
2024-07-11
Abstract:In reasoning tasks, even a minor error can cascade into inaccurate results, leading to suboptimal performance of large language models in such domains. Earlier fine-tuning approaches sought to mitigate this by leveraging more precise supervisory signals from human labeling, larger models, or self-sampling, although at a high cost. Conversely, we develop a method that avoids external resources, relying instead on introducing perturbations to the input. Our training approach randomly masks certain tokens within the chain of thought, a technique we found to be particularly effective for reasoning tasks. When applied to fine-tuning with GSM8K on Llama-2-7B, this method achieved a 5\% improvement in GSM8K accuracy and a 10\% improvement in GSM-IC accuracy over standard supervised fine-tuning with a few codes modified. Furthermore, it is complementary to existing methods. When integrated with related explicit data augmentation methods, it leads to improvements across five datasets of various augmentation methods, as well as two different base models. We further investigate the mechanisms behind this improvement through case studies and quantitative analysis, suggesting that our approach may provide superior support for the model in capturing long-distance dependencies, especially those related to questions. This enhancement could deepen understanding of the premises in questions and prior steps. Our code is available at Github.
Computation and Language,Artificial Intelligence,Machine Learning
What problem does this paper attempt to address?
### Problems Addressed by the Paper The paper primarily aims to address the performance issues of large language models (LLMs) in multi-step reasoning tasks. Specifically: 1. **Error Cascade Problem**: Even small errors can lead to issues in the entire solution process, thereby affecting the accuracy of the final result. 2. **Hallucination Problem**: State-of-the-art models are prone to hallucinations during reasoning, which can lead to errors. 3. **Cost of Supervision Signals**: Previous methods often rely on human annotations, larger models, or self-sampling to obtain more accurate supervision signals, but these methods are costly. ### Solution The paper proposes a simple and effective method—**Masked Thought Fine-Tuning (MFT)**. This method introduces noise by randomly masking certain tokens during the reasoning steps. Experiments show that this method can significantly improve the performance of models in reasoning tasks and has the following characteristics: 1. **Simplicity**: Easy to implement, requiring only the replacement of specific tokens in the reasoning chain. 2. **Effectiveness**: Achieves significant performance improvements across multiple datasets. 3. **Complementarity**: Complementary to existing data augmentation techniques, further enhancing model performance. ### Main Contributions 1. **Proposing the MFT Method**: Improves the reasoning ability of language models by randomly masking certain tokens during the reasoning steps. 2. **Analyzing the Method's Effectiveness**: Analyzes the MFT method from a regularization perspective and proposes two guiding principles. 3. **Enhancing Dependency**: Through quantitative analysis and case studies, it is found that the MFT method enhances dependency on the initial mathematical problem and early steps, thereby reducing the risk of misunderstanding and reasoning inconsistencies. ### Experimental Results The paper validates the effectiveness of the MFT method through various datasets and models, demonstrating its generalization ability and sample efficiency across different tasks. Particularly on smaller datasets, MFT shows significant performance improvements. Additionally, compared to other regularization techniques, MFT is more effective in introducing noise.