Implicit Chain of Thought Reasoning via Knowledge Distillation

Yuntian Deng,Kiran Prasad,Roland Fernandez,Paul Smolensky,Vishrav Chaudhary,Stuart Shieber
2023-11-03
Abstract:To augment language models with the ability to reason, researchers usually prompt or finetune them to produce chain of thought reasoning steps before producing the final answer. However, although people use natural language to reason effectively, it may be that LMs could reason more effectively with some intermediate computation that is not in natural language. In this work, we explore an alternative reasoning approach: instead of explicitly producing the chain of thought reasoning steps, we use the language model's internal hidden states to perform implicit reasoning. The implicit reasoning steps are distilled from a teacher model trained on explicit chain-of-thought reasoning, and instead of doing reasoning "horizontally" by producing intermediate words one-by-one, we distill it such that the reasoning happens "vertically" among the hidden states in different layers. We conduct experiments on a multi-digit multiplication task and a grade school math problem dataset and find that this approach enables solving tasks previously not solvable without explicit chain-of-thought, at a speed comparable to no chain-of-thought.
Computation and Language,Artificial Intelligence,Machine Learning
What problem does this paper attempt to address?
The problem this paper attempts to address is how to improve the performance of language models on tasks that require multi-step reasoning while maintaining high generation speed. Specifically, traditional Chain of Thought (CoT) methods, although effective in solving complex mathematical problems and other reasoning tasks, increase the time cost of generating the final answer because the model needs to generate intermediate steps word by word. This paper proposes a new method called Implicit Chain of Thought (Implicit CoT), which aims to perform reasoning by utilizing the internal hidden states of the language model instead of explicitly generating intermediate steps. The main contributions of the paper are as follows: 1. **From Teacher-Guided to Teacher-Teaching**: Training a student model to predict the internal hidden states of a teacher model, thereby achieving faster generation speed. 2. **Knowledge Distillation of Explicit Reasoning**: Distilling the explicit reasoning process in the teacher model into the implicit reasoning process in the student model. 3. **Combination and Optimization**: Combining the above two steps to form an end-to-end optimized system, enabling the student model to develop unique reasoning methods different from the teacher model. Experimental results show that the Implicit Chain of Thought method can solve some previously difficult tasks, such as 5-digit multiplication problems, without generating explicit intermediate steps. It also performs well in handling real-world tasks (e.g., elementary school math problems). Additionally, the Implicit Chain of Thought method significantly outperforms the explicit Chain of Thought method in reasoning speed.