Abstract:Knowledge distillation, the technique of transferring knowledge from large, complex models to smaller ones, marks a pivotal step towards efficient AI deployment. Distilling Step-by-Step~(DSS), a novel method utilizing chain-of-thought~(CoT) distillation, has demonstrated promise by imbuing smaller models with the superior reasoning capabilities of their larger counterparts. In DSS, the distilled model acquires the ability to generate rationales and predict labels concurrently through a multi-task learning framework. However, DSS overlooks the intrinsic relationship between the two training tasks, leading to ineffective integration of CoT knowledge with the task of label prediction. To this end, we investigate the mutual relationship of the two tasks from Information Bottleneck perspective and formulate it as maximizing the mutual information of the representation features of the two tasks. We propose a variational approach to solve this optimization problem using a learning-based method. Our experimental results across four datasets demonstrate that our method outperforms the state-of-the-art DSS. Our findings offer insightful guidance for future research on language model distillation as well as applications involving CoT. Codes are available at \url{<a class="link-external link-https" href="https://github.com/xinchen9/cot_distillation_ACL2024" rel="external noopener nofollow">this https URL</a>}.
What problem does this paper attempt to address?
### Problems the paper attempts to solve
This paper aims to improve the Chain - of - Thought (CoT) distillation method by maximizing mutual information. Specifically, the authors focus on how to more effectively integrate and transfer knowledge between the label prediction task and the reasoning generation task in the Multi - Task Learning (MTL) framework. Although the existing Distilling Step - by - Step (DSS) method can reduce the computational cost, it has difficulties in connecting these two tasks, resulting in insufficient knowledge transfer. To this end, the authors propose a method based on the Information Bottleneck (IB) to improve the reasoning ability of small models by maximizing the mutual information (MI) between the label prediction task and the reasoning generation task.
### Main contributions
1. **Redefine the MTL framework**: The authors redefine the MTL framework in DSS as a mutual information estimation problem, with the goal of maximizing the mutual information between the label prediction task and the reasoning generation task. For this purpose, they introduce a variational method based on the IB principle to effectively estimate the mutual information. To the best of the authors' knowledge, this is the first work to improve CoT distillation from the IB perspective.
2. **Practical mutual information estimation method**: In addition to establishing the theoretical basis, the authors also propose a simple and effective auxiliary loss function to quantify the shared information between the prediction task and the generation task, thereby enhancing the alignment between the two tasks and promoting the transfer of CoT knowledge.
3. **Experimental verification**: The authors conducted comprehensive experiments on four popular datasets, using two different - sized T5 models (T5 - base and T5 - small). The experimental results show that their method significantly outperforms the existing benchmark methods on multiple datasets, demonstrating its effectiveness in enhancing the reasoning ability of the distillation model.
4. **Systematic analysis**: The authors conducted a systematic analysis of the relationship between the prediction task and the explanation task under MTL training, providing qualitative and quantitative analysis results, which provide valuable references for further research.
### Experimental setup and results
#### Experimental setup
- **Datasets**: The authors conducted experiments on four widely - used benchmark datasets, which cover three different NLP tasks: natural language inference (e - SNLI and ANLI), common - sense question answering (CQA), and arithmetic math problems (SV AMP).
- **Models**: The student models adopted T5 - base (220 million parameters) and T5 - small (60 million parameters).
- **Baseline methods**: Compared with the state - of - the - art DSS method, and included two baseline methods: standard fine - tuning and single - task fine - tuning.
#### Experimental results
- **T5 - base model**: On all datasets, the method proposed by the authors outperforms standard fine - tuning and single - task methods. Compared with DSS, their method performs better on ANLI, CQA, and SV AMP, and the performance on e - SNLI is also close to DSS.
- **T5 - small model**: The experimental results are similar to those of T5 - base. The authors' method outperforms standard fine - tuning on all datasets and outperforms DSS on ANLI, CQA, and SV AMP.
- **Labels generated by LLM**: When distilling with labels generated by PaLM 540B, even if the label quality is poor, the authors' method still outperforms DSS.
- **Different dataset sizes**: When distilling on datasets of different sizes, the authors' method performs better than DSS in most cases.
### Discussion
#### Model calibration analysis
The authors use Expected Calibration Errors (ECE) and Average Confidence Scores to evaluate the calibration of the model. The results show that the ECE values of the authors' method are lower on some tasks (such as e - SNLI and ANLI), indicating its advantages in calibration. However, on other tasks (such as SV AMP and CQA), the ECE values are higher, which provides directions for future research.