Multi-Stage Balanced Distillation: Addressing Long-Tail Challenges in Sequence-Level Knowledge Distillation

Yuhang Zhou,Jing Zhu,Paiheng Xu,Xiaoyu Liu,Xiyao Wang,Danai Koutra,Wei Ai,Furong Huang
2024-10-19
Abstract:Large language models (LLMs) have significantly advanced various natural language processing tasks, but deploying them remains computationally expensive. Knowledge distillation (KD) is a promising solution, enabling the transfer of capabilities from larger teacher LLMs to more compact student models. Particularly, sequence-level KD, which distills rationale-based reasoning processes instead of merely final outcomes, shows great potential in enhancing students' reasoning capabilities. However, current methods struggle with sequence level KD under long-tailed data distributions, adversely affecting generalization on sparsely represented domains. We introduce the Multi-Stage Balanced Distillation (BalDistill) framework, which iteratively balances training data within a fixed computational budget. By dynamically selecting representative head domain examples and synthesizing tail domain examples, BalDistill achieves state-of-the-art performance across diverse long-tailed datasets, enhancing both the efficiency and efficacy of the distilled models.
Computation and Language,Artificial Intelligence
What problem does this paper attempt to address?
This paper attempts to address the challenges encountered when performing sequence - level knowledge distillation (SL - KD) under long - tailed data distributions. Specifically, the paper focuses on how to effectively transfer the knowledge of large - language models (LLMs) to smaller student models in the case of extremely unbalanced data distributions, so as to improve the generalization ability of student models in the field of sparse representation. The paper points out that current methods perform poorly when dealing with long - tailed data distributions, which affects the performance of models in the field of sparse representation. The paper proposes a framework named multi - stage balanced distillation (BalDistill) to meet these challenges. BalDistill iteratively balances the training data within a fixed computational budget by dynamically selecting representative head - domain examples and synthesizing tail - domain examples. This method not only improves the efficiency of the distilled model but also enhances its performance on a variety of long - tailed datasets, reaching the state - of - the - art level.