MindStar: Enhancing Math Reasoning in Pre-trained LLMs at Inference Time

Jikun Kang,Xin Zhe Li,Xi Chen,Amirreza Kazemi,Qianyi Sun,Boxing Chen,Dong Li,Xu He,Quan He,Feng Wen,Jianye Hao,Jun Yao
2024-06-26
Abstract:Although Large Language Models (LLMs) achieve remarkable performance across various tasks, they often struggle with complex reasoning tasks, such as answering mathematical questions. Recent efforts to address this issue have primarily focused on leveraging mathematical datasets through supervised fine-tuning or self-improvement techniques. However, these methods often depend on high-quality datasets that are difficult to prepare, or they require substantial computational resources for fine-tuning. Inspired by findings that LLMs know how to produce the right answer but struggle to select the correct reasoning path, we propose a purely inference-based searching method -- MindStar (M*). This method formulates reasoning tasks as searching problems and proposes two search ideas to identify the optimal reasoning paths. We evaluate the M* framework on both the GSM8K and MATH datasets, comparing its performance with existing open and closed-source LLMs. Our results demonstrate that M* significantly enhances the reasoning abilities of open-source models, such as Llama-2-13B and Mistral-7B, and achieves comparable performance to GPT-3.5 and Grok-1, but with substantially reduced model size and computational costs.
Machine Learning
What problem does this paper attempt to address?
This paper proposes a solution to the problem of insufficient capability of large language models (LLMs) in handling complex reasoning tasks during inference. Current methods mainly rely on supervised fine-tuning or self-improvement techniques, which require high-quality datasets or a large amount of computational resources. The study found that although the models are able to generate correct answers, they struggle with selecting the correct reasoning path. Therefore, the paper presents a pure reasoning search method called MindStar (M*), which transforms the reasoning task into a search problem and identifies the optimal reasoning path through two search strategies. Evaluation on the GSM8K and MATH datasets shows that M* significantly improves the reasoning capability of open-source models such as Llama-2-13B and Mistral-7B, with performance comparable to GPT-3.5, but with significantly reduced model size and computational cost. In addition, M* guarantees search efficiency through tree-based search algorithms such as beam search and A* tree search. Experimental results demonstrate that M* enhances the performance of models in mathematical reasoning tasks and narrows the gap between open-source and closed-source models.