Understanding Catastrophic Forgetting in Language Models via Implicit Inference

Suhas Kotha,Jacob Mitchell Springer,Aditi Raghunathan
2024-04-14
Abstract:We lack a systematic understanding of the effects of fine-tuning (via methods such as instruction-tuning or reinforcement learning from human feedback), particularly on tasks outside the narrow fine-tuning distribution. In a simplified scenario, we demonstrate that improving performance on tasks within the fine-tuning data distribution comes at the expense of capabilities on other tasks. We hypothesize that language models implicitly infer the task of the prompt and that fine-tuning skews this inference towards tasks in the fine-tuning distribution. To test this, we propose Conjugate Prompting, which artificially makes the task look farther from the fine-tuning distribution while requiring the same capability, and we find that this recovers some of the pretraining capabilities in our synthetic setup. Since real-world fine-tuning distributions are predominantly English, we apply conjugate prompting to recover pretrained capabilities in LLMs by simply translating the prompts to different languages. This allows us to recover in-context learning abilities lost via instruction tuning, natural reasoning capability lost during code fine-tuning, and, more concerningly, harmful content generation suppressed by safety fine-tuning in chatbots like ChatGPT.
Computation and Language,Machine Learning
What problem does this paper attempt to address?
The problem that this paper attempts to solve is the catastrophic forgetting phenomenon caused by fine - tuning (such as instruction fine - tuning or reinforcement learning fine - tuning based on human feedback) in language models. Specifically, the author is concerned with the significant decline in the performance of the model on other non - fine - tuned tasks when the language model is fine - tuned to improve the performance of specific tasks. This phenomenon is particularly evident in practical applications, especially when dealing with tasks outside the distribution range of fine - tuning data. The paper alleviates this problem by proposing the method of "Conjugate Prompting". The core idea of conjugate prompting is to change the way of input prompts, making the model more inclined to use the capabilities learned in its pre - training stage rather than simply relying on the task inference capabilities after fine - tuning. This helps to restore the original performance of the model on certain tasks, especially those capabilities that were suppressed during the fine - tuning process. For example, in the experimental part, the author shows how to reduce the probability that the prompt belongs to the fine - tuning distribution by translating the prompt into different languages, thereby restoring the performance of the model in context - learning tasks. In addition, the paper also explores the impact of code fine - tuning on natural language reasoning capabilities and successfully improves the performance of the model on such tasks through the conjugate prompting method. Overall, this paper aims to prove through theoretical analysis and experiments that task - inference shift during the fine - tuning process is a major cause of catastrophic forgetting and proposes effective technical means to alleviate this problem.