f-Divergence Minimization for Sequence-Level Knowledge Distillation

Yuqiao Wen,Zichao Li,Wenyu Du,Lili Mou
2023-07-28
Abstract:Knowledge distillation (KD) is the process of transferring knowledge from a large model to a small one. It has gained increasing attention in the natural language processing community, driven by the demands of compressing ever-growing language models. In this work, we propose an f-DISTILL framework, which formulates sequence-level knowledge distillation as minimizing a generalized f-divergence function. We propose four distilling variants under our framework and show that existing SeqKD and ENGINE approaches are approximations of our f-DISTILL methods. We further derive step-wise decomposition for our f-DISTILL, reducing intractable sequence-level divergence to word-level losses that can be computed in a tractable manner. Experiments across four datasets show that our methods outperform existing KD approaches, and that our symmetric distilling losses can better force the student to learn from the teacher distribution.
Computation and Language,Machine Learning
What problem does this paper attempt to address?
The paper attempts to address the problem of how to improve the performance of small models in natural language generation tasks through sequence-level knowledge distillation. Specifically, the paper proposes a new framework, f-DISTILL, which defines sequence-level knowledge distillation as minimizing a generalized f-divergence function. Existing methods such as SeqKD and ENGINE can be seen as special cases under this framework. The authors point out that traditional knowledge distillation methods (such as KL divergence) have a mode-averaging problem when dealing with sequence generation tasks, meaning that the student model generates overly smooth probability distributions. On the other hand, reverse KL divergence (RKL) leads to mode-collapsing, where the student model only learns certain high-probability regions of the teacher model's distribution. To overcome these issues, the paper proposes methods based on symmetric f-divergence functions, including Jensen-Shannon divergence (JS) and Total Variation Distance (TVD). These symmetric methods can better force the student model to learn from the teacher model, thereby alleviating the problems of mode-averaging and mode-collapsing. Additionally, the paper proposes an efficient offline sampling method to compute sequence-level f-divergence and conducts experimental validation on four different datasets: DART (data-to-text generation), XSum (summarization), WMT16 EN-RO (machine translation), and Commonsense Dialogue (dialogue generation). The experimental results show that the methods under the f-DISTILL framework outperform existing knowledge distillation methods on all datasets.