Self-Training with Direct Preference Optimization Improves Chain-of-Thought Reasoning

Tianduo Wang,Shichen Li,Wei Lu
2024-07-26
Abstract:Effective training of language models (LMs) for mathematical reasoning tasks demands high-quality supervised fine-tuning data. Besides obtaining annotations from human experts, a common alternative is sampling from larger and more powerful LMs. However, this knowledge distillation approach can be costly and unstable, particularly when relying on closed-source, proprietary LMs like GPT-4, whose behaviors are often unpredictable. In this work, we demonstrate that the reasoning abilities of small-scale LMs can be enhanced through self-training, a process where models learn from their own outputs. We also show that the conventional self-training can be further augmented by a preference learning algorithm called Direct Preference Optimization (DPO). By integrating DPO into self-training, we leverage preference data to guide LMs towards more accurate and diverse chain-of-thought reasoning. We evaluate our method across various mathematical reasoning tasks using different base models. Our experiments show that this approach not only improves LMs' reasoning performance but also offers a more cost-effective and scalable solution compared to relying on large proprietary LMs.
Computation and Language
What problem does this paper attempt to address?
The paper primarily aims to address the following issues: 1. **Enhancing the mathematical reasoning ability of small language models (LMs)**: While large language models perform well on mathematical reasoning tasks, effectively enhancing the reasoning ability of small language models remains an underexplored problem. 2. **Optimizing the self-training framework**: The paper proposes an improved self-training method by integrating Direct Preference Optimization (DPO) to enhance the performance of small language models on chain-of-thought (CoT) reasoning tasks. 3. **Reducing dependence on large proprietary models**: Traditional knowledge distillation methods can improve the performance of small models, but they rely on large, proprietary language models for data annotation, which is not only costly but also limited in sustainability and scalability. 4. **Improving inference efficiency**: To overcome the limitations of small models in basic arithmetic operations, the paper also proposes a method to integrate an external calculator and optimizes the decoding process to support larger batch sizes, thereby increasing inference speed. In summary, this research aims to develop an efficient and cost-effective method to enhance the mathematical reasoning ability of small language models, while reducing dependence on large proprietary models and improving inference efficiency in practical applications.