Enhancing Reasoning through Process Supervision with Monte Carlo Tree Search

Shuangtao Li,Shuaihao Dong,Kexin Luan,Xinhan Di,Chaofan Ding
2025-01-02
Abstract:Large language models (LLMs) have demonstrated their remarkable capacity across a variety of tasks. However, reasoning remains a challenge for LLMs. To improve LLMs' reasoning ability, process supervision has proven to be better than outcome supervision. In this work, we study using Monte Carlo Tree Search (MCTS) to generate process supervision data with LLMs themselves for training them. We sample reasoning steps with an LLM and assign each step a score that captures its "relative correctness," and the LLM is then trained by minimizing weighted log-likelihood of generating the reasoning steps. This generate-then-train process is repeated iteratively until <a class="link-external link-http" href="http://convergence.Our" rel="external noopener nofollow">this http URL</a> experimental results demonstrate that the proposed methods considerably improve the performance of LLMs on two mathematical reasoning datasets. Furthermore, models trained on one dataset also exhibit improved performance on the other, showing the transferability of the enhanced reasoning ability.
Artificial Intelligence,Computation and Language,Machine Learning
What problem does this paper attempt to address?
This paper attempts to address the problem of the insufficiency of large - language models (LLMs) in reasoning ability. Although LLMs perform well in various language tasks, even approaching the human level, there are still challenges in reasoning. Specifically, existing LLMs perform poorly when dealing with problems that require multi - step reasoning, especially in mathematical reasoning tasks. To enhance the reasoning ability of LLMs, the author proposes a method based on Monte Carlo Tree Search (MCTS) to generate process - supervised data and use this data to train LLMs. Different from the traditional result - supervised method, process supervision can provide more detailed and accurate feedback, thus better guiding LLMs to generate the correct reasoning path. ### Main Problems and Solutions 1. **Problem**: Existing LLMs have limitations in reasoning ability, especially in multi - step reasoning tasks. 2. **Solutions**: - Use MCTS to generate reasoning paths and assign a score reflecting its "relative correctness" to each step. - Train LLMs by minimizing the weighted log - likelihood loss function, where the weight is determined by the score of the step. - Iteratively generate training data and train LLMs until convergence. ### Method Overview - **Data Generation**: For each problem, use MCTS to explore possible reasoning paths and assign a score to each step. These scores reflect the relative correctness of the steps. - **Iterative Training**: Use the generated data to perform supervised fine - tuning on LLMs and evaluate the model performance after each iteration until the performance no longer improves significantly. ### Experimental Results The experimental results show that this method significantly improves the reasoning ability of LLMs on two mathematical reasoning datasets (MATH and GSM8K). In addition, the trained model also shows better performance on unseen datasets, proving the transferability of its reasoning ability. ### Summary This paper proposes an innovative method to enhance the reasoning ability of LLMs by combining MCTS and process supervision. The experimental results verify the effectiveness of this method and show its transferability on different datasets. However, the study also points out some limitations, such as the performance converges rapidly after multiple iterations and fails to improve continuously. ### Formula Example In the process of generating training data, the score calculation formula for each step is as follows: \[ r_{i,j,k} = \alpha \cdot N(v_{i,j,k}) \cdot \left( \frac{Q(v_{i,j,k})}{N(v_{i,j,k})} - \frac{\sum_m Q(v_{i,j,m})}{\sum_m N(v_{i,j,m})} \right) \] where: - \( Q(\cdot) \) is the cumulative reward of the node, - \( N(\cdot) \) is the number of visits to the node, - \( \alpha \) is a manually set constant used to control the scale of the score.