Monte Carlo Tree Search Boosts Reasoning via Iterative Preference Learning

Yuxi Xie,Anirudh Goyal,Wenyue Zheng,Min-Yen Kan,Timothy P. Lillicrap,Kenji Kawaguchi,Michael Shieh
2024-06-18
Abstract:We introduce an approach aimed at enhancing the reasoning capabilities of Large Language Models (LLMs) through an iterative preference learning process inspired by the successful strategy employed by AlphaZero. Our work leverages Monte Carlo Tree Search (MCTS) to iteratively collect preference data, utilizing its look-ahead ability to break down instance-level rewards into more granular step-level signals. To enhance consistency in intermediate steps, we combine outcome validation and stepwise self-evaluation, continually updating the quality assessment of newly generated data. The proposed algorithm employs Direct Preference Optimization (DPO) to update the LLM policy using this newly generated step-level preference data. Theoretical analysis reveals the importance of using on-policy sampled data for successful self-improving. Extensive evaluations on various arithmetic and commonsense reasoning tasks demonstrate remarkable performance improvements over existing models. For instance, our approach outperforms the Mistral-7B Supervised Fine-Tuning (SFT) baseline on GSM8K, MATH, and ARC-C, with substantial increases in accuracy to $81.8\%$ (+$5.9\%$), $34.7\%$ (+$5.8\%$), and $76.4\%$ (+$15.8\%$), respectively. Additionally, our research delves into the training and inference compute tradeoff, providing insights into how our method effectively maximizes performance gains. Our code is publicly available at <a class="link-external link-https" href="https://github.com/YuxiXie/MCTS-DPO" rel="external noopener nofollow">this https URL</a>.
Artificial Intelligence,Machine Learning
What problem does this paper attempt to address?
This paper proposes a method to enhance the reasoning ability of large-scale language models (LLMs) through an iterative preference learning process, inspired by the successful strategy of AlphaZero. The paper uses Monte Carlo Tree Search (MCTS) to iteratively collect preference data and decomposes instance-level rewards into finer-grained step-level signals using its lookahead ability. To improve the consistency of intermediate steps, a combination of result verification and stepwise self-evaluation is employed to continuously update the quality assessment of newly generated data. The proposed algorithm utilizes Direct Preference Optimization (DPO) to update the LLM policy. Theoretical analysis demonstrates the crucial importance of using online sampling data for successful self-improvement. Extensive evaluations on various arithmetic and commonsense reasoning tasks show significant performance improvements compared to existing models. For example, this method achieves accuracy improvements of 5.9%, 5.8%, and 15.8% on GSM8K, MATH, and ARC-C, respectively. The research also explores the trade-off between training and inference computation, providing insights into how to effectively maximize performance gains. Compared to traditional reinforcement learning methods based on human feedback, the iterative approach emphasizes the importance of continuous adaptation of LLMs, enabling them to better adapt to the complexity of human decision-making and reasoning. The experiments in the paper demonstrate the effectiveness of MCTS in improving model performance, particularly in reasoning tasks.