Improve Vision Language Model Chain-of-thought Reasoning

Ruohong Zhang,Bowen Zhang,Yanghao Li,Haotian Zhang,Zhiqing Sun,Zhe Gan,Yinfei Yang,Ruoming Pang,Yiming Yang
2024-10-22
Abstract:Chain-of-thought (CoT) reasoning in vision language models (VLMs) is crucial for improving interpretability and trustworthiness. However, current training recipes lack robust CoT reasoning data, relying on datasets dominated by short annotations with minimal rationales. In this work, we show that training VLM on short answers does not generalize well to reasoning tasks that require more detailed responses. To address this, we propose a two-fold approach. First, we distill rationales from GPT-4o model to enrich the training data and fine-tune VLMs, boosting their CoT performance. Second, we apply reinforcement learning to further calibrate reasoning quality. Specifically, we construct positive (correct) and negative (incorrect) pairs of model-generated reasoning chains, by comparing their predictions with annotated short answers. Using this pairwise data, we apply the Direct Preference Optimization algorithm to refine the model's reasoning abilities. Our experiments demonstrate significant improvements in CoT reasoning on benchmark datasets and better generalization to direct answer prediction as well. This work emphasizes the importance of incorporating detailed rationales in training and leveraging reinforcement learning to strengthen the reasoning capabilities of VLMs.
Artificial Intelligence,Computer Vision and Pattern Recognition
What problem does this paper attempt to address?
This paper attempts to address the challenges faced by Vision Language Models (VLMs) during Chain - of - Thought (CoT) reasoning. Specifically, the paper mainly focuses on the following issues: 1. **Limitations of Current Training Data**: Most of the existing VLM training datasets contain short answers and lack detailed reasoning processes and justifications. This data structure restricts the performance of the model in complex reasoning tasks. 2. **Relationship between Direct Prediction and Chain - of - Thought Reasoning**: The paper explores whether training only on short answers can implicitly teach the model to perform Chain - of - Thought reasoning. The research results show that training relying solely on short answers has limited effectiveness in enhancing CoT reasoning ability. 3. **How to Enhance the CoT Reasoning Ability of VLM**: To solve the above problems, the paper proposes a two - step method to enhance the CoT reasoning ability of VLM: - **Step 1**: Enrich the training data by extracting detailed reasoning paths from the GPT - 4o model and fine - tune the VLM to improve its CoT reasoning performance. - **Step 2**: Apply Reinforcement Learning (RL), specifically using the Direct Preference Optimization (DPO) algorithm, to further calibrate the reasoning quality. By constructing positive and negative reasoning chain pairs and optimizing these pairs, the reasoning ability of the model is further enhanced. ### Formula Representation When describing the objective function of the DPO algorithm, the paper uses the following formula: \[ L_{\text{DPO}}(\pi_\theta; \pi_{\text{ref}}) = -\mathbb{E}_{(V, x, y_w, y_l) \sim D_{\text{DPO}}} \left[ \log \sigma \left( \beta \log \frac{\pi_\theta(y_w | x, V)}{\pi_{\text{ref}}(y_w | x, V)} - \beta \log \frac{\pi_\theta(y_l | x, V)}{\pi_{\text{ref}}(y_l | x, V)} \right) \right] \] where: - \( \pi_\theta \) is the policy model to be optimized, - \( \pi_{\text{ref}} \) is the baseline reference model, - \( \sigma \) is the logistic function, - \( \beta \) is a constant parameter (set to 0.1), - \( D_{\text{DPO}} \) is a dataset containing images, questions, positive answers and negative answers. ### Summary The core issue of the paper is to solve the problem of the lack of detailed reasoning paths in the existing VLM training data, and significantly improve the performance of VLM in complex reasoning tasks by introducing high - quality CoT data and reinforcement learning techniques. This not only improves the interpretability and credibility of the model, but also enhances its generalization ability on various benchmark datasets.