ReFT: Reasoning with Reinforced Fine-Tuning

Trung Quoc Luong,Xinbo Zhang,Zhanming Jie,Peng Sun,Xiaoran Jin,Hang Li
2024-06-27
Abstract:One way to enhance the reasoning capability of Large Language Models (LLMs) is to conduct Supervised Fine-Tuning (SFT) using Chain-of-Thought (CoT) annotations. This approach does not show sufficiently strong generalization ability, however, because the training only relies on the given CoT data. In math problem-solving, for example, there is usually only one annotated reasoning path for each question in the training data. Intuitively, it would be better for the algorithm to learn from multiple annotated reasoning paths given a question. To address this issue, we propose a simple yet effective approach called Reinforced Fine-Tuning (ReFT) to enhance the generalizability of learning LLMs for reasoning, with math problem-solving as an example. ReFT first warmups the model with SFT, and then employs on-line reinforcement learning, specifically the PPO algorithm in this paper, to further fine-tune the model, where an abundance of reasoning paths are automatically sampled given the question and the rewards are naturally derived from the ground-truth answers. Extensive experiments on GSM8K, MathQA, and SVAMP datasets show that ReFT significantly outperforms SFT, and the performance can be potentially further boosted by combining inference-time strategies such as majority voting and re-ranking. Note that ReFT obtains the improvement by learning from the same training questions as SFT, without relying on extra or augmented training questions. This indicates a superior generalization ability for ReFT.
Computation and Language
What problem does this paper attempt to address?
The paper primarily aims to address the issue of enhancing the reasoning capabilities of large language models (LLMs) in the domain of mathematical problem-solving, particularly by improving existing supervised fine-tuning (SFT) methods to enhance the model's generalization ability. The paper proposes a method called "Reinforced Fine-Tuning" (ReFT), which is a simple yet effective new strategy to improve the generalization performance of LLMs in mathematical problem-solving tasks. Traditional SFT methods can train models to generate the chain of thought (CoT) required to solve problems, but their generalization ability is weak because they rely solely on the given CoT data during training. This results in usually only one annotated thought path for each problem. To address the above issues, ReFT adopts the following steps: 1. **Warming-up Phase**: First, the model is warmed up through SFT to equip it with a certain level of problem-solving ability. 2. **Reinforcement Learning Phase**: Next, an online reinforcement learning algorithm, particularly Proximal Policy Optimization (PPO), is used to further fine-tune the model. In this phase, ReFT can automatically sample and learn from different correct thought paths for the same problem, thereby enriching the model's learning signals. The main contributions of the paper include: - Proposing the ReFT method, which uses reinforcement learning to enhance the ability to solve mathematical problems and shows better generalization performance on the same datasets compared to traditional SFT methods. - Conducting extensive experimental validation on three standard datasets: GSM8K, MathQA, and SVAMP, using two base models, CodeLLAMA and Galactica, demonstrating that ReFT significantly outperforms SFT. - Exploring the combination of ReFT with inference-time strategies such as majority voting and reward model re-ranking, further improving performance. Additionally, the researchers analyzed the performance issues of ReFT on multiple-choice questions (e.g., MathQA) and found a phenomenon called "reward hacking," where the model attempts to gain positive rewards through partially correct intermediate steps. To mitigate this issue, the paper suggests using more refined reward mechanisms or directly predicting numerical answers. In summary, ReFT is an effective fine-tuning method that can significantly improve the generalization ability and problem-solving accuracy of models without adding extra training data.