Improving Knowledge Distillation for BERT Models: Loss Functions, Mapping Methods, and Weight Tuning

Apoorv Dankar,Adeem Jassani,Kartikaeya Kumar
2023-08-27
Abstract:The use of large transformer-based models such as BERT, GPT, and T5 has led to significant advancements in natural language processing. However, these models are computationally expensive, necessitating model compression techniques that reduce their size and complexity while maintaining accuracy. This project investigates and applies knowledge distillation for BERT model compression, specifically focusing on the TinyBERT student model. We explore various techniques to improve knowledge distillation, including experimentation with loss functions, transformer layer mapping methods, and tuning the weights of attention and representation loss and evaluate our proposed techniques on a selection of downstream tasks from the GLUE benchmark. The goal of this work is to improve the efficiency and effectiveness of knowledge distillation, enabling the development of more efficient and accurate models for a range of natural language processing tasks.
Computation and Language,Artificial Intelligence
What problem does this paper attempt to address?
The main focus of this paper is on how to improve the knowledge distillation method based on the BERT model to achieve more efficient and accurate model compression. Specifically, the researchers explored three different techniques to enhance the effectiveness of knowledge distillation for TinyBERT, a compressed version of the BERT model: 1. **Loss Function**: The paper experimented with different loss functions, particularly Kullback-Leibler (KL) divergence, for comparing attention distributions and evaluated their performance on different tasks. 2. **Mapping Methods**: The researchers proposed several different mapping functions to determine which layer of the student model should learn from which layer of the teacher model. These methods include random mapping, mean mapping, and learnable mapping, among others. 3. **Weight Adjustment**: By adjusting the relative weights between representation loss and attention loss, the researchers aimed to find the optimal balance to further optimize the distillation process. The goal of the paper is to reduce the size and complexity of the model while maintaining its performance, making it easier to deploy on resource-constrained devices. Experimental results show that using KL divergence as the loss function can lead to significant performance improvements on certain tasks (such as CoLA), especially in data-limited scenarios. Additionally, although the learnable mapping method did not bring significant performance improvements, a trend of weight vectors converging to the same value was observed, which may suggest the existence of a global optimal solution. For the STS-B task, the paper found that directly performing prediction loss distillation might be more effective without going through an intermediate transformation layer distillation stage. These findings provide valuable insights for future research.