Which Pretrain Samples to Rehearse when Finetuning Pretrained Models?

Andrew Bai, Chih-Kuan Yeh, Cho-Jui Hsieh, Ankur Taly
2024-02-13
Abstract:Fine-tuning pretrained foundational models on specific tasks is now the de facto approach for text and vision tasks. A known pitfall of this approach is the forgetting of pretraining knowledge that happens during finetuning. Rehearsing samples randomly from the pretrain dataset is a common approach to alleviate such forgetting. However, we find that random mixing unintentionally includes samples which are not (yet) forgotten or unlearnable by the model. We propose a novel sampling scheme, mix-cd, that identifies and prioritizes samples that actually face forgetting, which we call collateral damage. Since directly identifying collateral damage samples is computationally expensive, we propose a procedure to estimate the distribution of such samples by tracking the statistics of finetuned samples. Our approach is lightweight, easy to implement, and can be seamlessly integrated into existing models, offering an effective means to retain pretrain performance without additional computational costs.
Machine Learning
What problem does this paper attempt to address?
The paper primarily focuses on how to effectively prevent the forgetting of pre-trained knowledge during fine-tuning of pre-trained models. Specifically, the paper addresses the following issues: 1. **Knowledge Forgetting between Pre-training and Fine-tuning**: During the fine-tuning process, the model may forget the knowledge learned in the pre-training phase, which is known as "catastrophic forgetting". Randomly mixing pre-training samples is a common method to mitigate this forgetting, but the paper points out that this method is inefficient because not all pre-training samples are equally important or affected by forgetting. 2. **Optimizing Pre-training Sample Selection**: The paper proposes a new sampling scheme called mix-cd (Collateral Damage), which aims to identify and prioritize those pre-training samples that are truly at risk of being forgotten. These samples are referred to as "collateral damage", i.e., the samples where correct predictions are turned into incorrect ones due to the fine-tuning process. 3. **Efficient Estimation of Collateral Damage Distribution**: Directly identifying collateral damage samples is computationally expensive, so the paper introduces a process called mix-cd-sample, which estimates the distribution of such samples by tracking the statistical information of fine-tuning samples, thus avoiding additional computational costs. 4. **Maintaining Lightweight and Easy-to-Implement Characteristics**: The proposed scheme is versatile, easy to implement, and can be seamlessly integrated into existing models without increasing additional computational costs, effectively preserving pre-training performance. Through experiments, the paper demonstrates the superiority of the mix-cd scheme in various image and text tasks, proving that it balances the performance of pre-training and fine-tuning tasks better than random sampling and other baseline methods. Additionally, the paper explores the impact of different partitioning strategies on the sampling effect and how to effectively estimate the collateral damage ratio in each partition.