Interpretable Contrastive Monte Carlo Tree Search Reasoning

Zitian Gao,Boye Niu,Xuzheng He,Haotian Xu,Hongzhang Liu,Aiwei Liu,Xuming Hu,Lijie Wen
2024-10-12
Abstract:We propose SC-MCTS*: a novel Monte Carlo Tree Search (MCTS) reasoning algorithm for Large Language Models (LLMs), significantly improves both reasoning accuracy and speed. Our motivation comes from: 1. Previous MCTS LLM reasoning works often overlooked its biggest drawback--slower speed compared to CoT; 2. Previous research mainly used MCTS as a tool for LLM reasoning on various tasks with limited quantitative analysis or ablation studies of its components from reasoning interpretability perspective. 3. The reward model is the most crucial component in MCTS, however previous work has rarely conducted in-depth study or improvement of MCTS's reward models. Thus, we conducted extensive ablation studies and quantitative analysis on components of MCTS, revealing the impact of each component on the MCTS reasoning performance of LLMs. Building on this, (i) we designed a highly interpretable reward model based on the principle of contrastive decoding and (ii) achieved an average speed improvement of 51.9% per node using speculative decoding. Additionally, (iii) we improved UCT node selection strategy and backpropagation used in previous works, resulting in significant performance improvement. We outperformed o1-mini by an average of 17.4% on the Blocksworld multi-step reasoning dataset using Llama-3.1-70B with SC-MCTS*. Our code is available at \url{<a class="link-external link-https" href="https://github.com/zitian-gao/SC-MCTS" rel="external noopener nofollow">this https URL</a>}.
Computation and Language,Artificial Intelligence
What problem does this paper attempt to address?
### The Problem the Paper Attempts to Solve This paper aims to address two main challenges of large language models (LLM) in multi-step reasoning tasks: 1. **Slow reasoning speed**: Traditional Monte Carlo Tree Search (MCTS) algorithms, while performing well in reasoning accuracy, are significantly slower than Chain of Thought (CoT) methods. MCTS requires constructing a reasoning tree with multiple levels, where each node represents a round of dialogue with the LLM, leading to high computational costs and slower reasoning speed. 2. **Difficulty in designing reward models**: The reasoning ability of MCTS heavily depends on the performance of the reward model. However, designing a reward model that is dense, general, and efficient is very challenging. Existing works either require multiple LLMs or training cycles, increasing VRAM and computational demands, or rely on domain-specific tools or datasets, making it difficult to generalize to other tasks or datasets. To address these challenges, the authors propose a new MCTS reasoning algorithm—(S)peculative (C)ontrastive MCTS* (SC-MCTS*), which significantly improves reasoning accuracy and speed through the following improvements: - **Designed a highly interpretable reward model**: Based on the principle of contrastive decoding, combining multiple highly interpretable reward functions, and normalizing the rewards through clustering prior distributions. - **Introduced speculative decoding**: By using a smaller language model for speculative decoding, reducing reasoning time, and increasing the speed of each node by an average of 51.9%. - **Improved UCT node selection strategy and backpropagation**: Optimized the choice of exploration constant in the UCT strategy and improved the backpropagation algorithm to favor paths with stable improvements. Experimental results show that SC-MCTS* significantly outperforms existing methods on the Blocksworld multi-step reasoning dataset, with an average improvement of 17.4% in reasoning accuracy when using the Llama-3.1-70B model compared to the o1-mini model.