Enhancing Multi-Step Reasoning Abilities of Language Models through Direct Q-Function Optimization

Guanlin Liu,Kaixuan Ji,Renjie Zheng,Zheng Wu,Chen Dun,Quanquan Gu,Lin Yan
2024-10-12
Abstract:Reinforcement Learning (RL) plays a crucial role in aligning large language models (LLMs) with human preferences and improving their ability to perform complex tasks. However, current approaches either require significant computational resources due to the use of multiple models and extensive online sampling for training (e.g., PPO) or are framed as bandit problems (e.g., DPO, DRO), which often struggle with multi-step reasoning tasks, such as math problem-solving and complex reasoning that involve long chains of thought. To overcome these limitations, we introduce Direct Q-function Optimization (DQO), which formulates the response generation process as a Markov Decision Process (MDP) and utilizes the soft actor-critic (SAC) framework to optimize a Q-function directly parameterized by the language model. The MDP formulation of DQO offers structural advantages over bandit-based methods, enabling more effective process supervision. Experimental results on two math problem-solving datasets, GSM8K and MATH, demonstrate that DQO outperforms previous methods, establishing it as a promising offline reinforcement learning approach for aligning language models.
Machine Learning,Artificial Intelligence,Computation and Language
What problem does this paper attempt to address?
### What problems does this paper attempt to solve? This paper aims to solve two main problems encountered when current Reinforcement Learning (RL) methods are used to align Large Language Models (LLMs): 1. **High demand for computational resources**: Many existing methods, such as Proximal Policy Optimization (PPO), require a large amount of online sampling and multi - model training, which leads to high computational costs. 2. **Difficulty in handling multi - step reasoning tasks**: Some existing methods simplify tasks into single - step decision - making problems (such as DPO, DRO), which makes them perform poorly when dealing with tasks involving long - chain thinking (such as solving math problems, complex reasoning, etc.). Specifically, these methods are usually difficult to deal with the problem of sparse reward signals and cannot effectively utilize process supervision, that is, the feedback information of intermediate steps. These problems limit the performance improvement of the model on complex tasks. To solve the above problems, the authors propose the **Direct Q - function Optimization (DQO)** algorithm. DQO improves existing methods in the following ways: - Model the response generation process as a Markov Decision Process (MDP), thus supporting multi - step reasoning tasks. - Use the Soft Actor - Critic (SAC) framework to directly optimize the Q - function parameterized by the language model, avoiding the dependence on additional reward models. - Train in an offline Reinforcement Learning manner, reducing the need for online sampling and improving training efficiency and stability. The experimental results show that DQO significantly outperforms other methods on the math problem - solving datasets GSM8K and MATH, especially when dealing with multi - step reasoning tasks. In addition, DQO can better utilize the process reward signals and further improve the performance of the model. ### Summary The main contributions of DQO are: - Proposing an offline RL algorithm suitable for LLMs, which can achieve better performance in multi - step reasoning tasks. - By introducing techniques such as λ - return and importance sampling, the training process is stabilized and good performance is ensured. - Experiments verify the superiority of DQO on multiple benchmark datasets, especially when there are process rewards, the performance is further improved. Through these improvements, DQO provides a new and effective method for improving the reasoning ability and alignment effect of language models in complex tasks.