Mitigating Forgetting in LLM Supervised Fine-Tuning and Preference Learning

Heshan Fernando,Han Shen,Parikshit Ram,Yi Zhou,Horst Samulowitz,Nathalie Baracaldo,Tianyi Chen
2024-10-29
Abstract:Post-training of pre-trained LLMs, which typically consists of the supervised fine-tuning (SFT) stage and the preference learning (RLHF or DPO) stage, is crucial to effective and safe LLM applications. The widely adopted approach in post-training popular open-source LLMs is to sequentially perform SFT and RLHF/DPO. However, sequential training is sub-optimal in terms of SFT and RLHF/DPO trade-off: the LLM gradually forgets about the first stage's training when undergoing the second stage's training. We theoretically prove the sub-optimality of sequential post-training. Furthermore, we propose a practical joint post-training framework with theoretical convergence guarantees and empirically outperforms sequential post-training framework, while having similar computational cost. Our code is available at <a class="link-external link-https" href="https://github.com/heshandevaka/XRIGHT" rel="external noopener nofollow">this https URL</a>.
Machine Learning,Artificial Intelligence,Computation and Language
What problem does this paper attempt to address?
### Problems the paper attempts to solve This paper aims to solve the "forgetting" problem that occurs in the sequential training of large - scale language models (LLMs) during the supervised fine - tuning (SFT) and preference learning (such as reinforcement learning from human feedback RLHF or direct preference optimization DPO) stages. Specifically, when the model gradually forgets the training results of the first stage during the second - stage training, this sequential training method performs poorly in the trade - off between SFT and preference learning. Through theoretical analysis, the paper proves the sub - optimality of the sequential training method and proposes a joint training framework (including two variants, ALRIGHT and MAXRIGHT) to achieve a better balance between SFT and preference learning goals while maintaining a low computational cost. ### Main contributions 1. **Insight into the forgetting problem of two - stage sequential training**: - Provides theoretical results on the forgetting problem of the sequential training method and further supports them with experimental evidence. - Proves that sequentially performing DPO and SFT may lead to a non - decreasing optimality gap, which is the first theoretical proof of the sub - optimality of sequentially learning SFT and DPO goals. 2. **Principle - based post - training methods with almost no additional cost**: - Proposes post - training algorithms with theoretical guarantees, which outperform sequential methods in performance and have a lower computational cost than hybrid methods. - Specifically proposes two methods: - **ALRIGHT**: Alternately performs supervised fine - tuning and human - preference alignment and can converge to any desired trade - off between DPO and SFT goals. - **MAXRIGHT**: Adaptively switches between optimizing RLHF and SFT goals. 3. **Strong empirical performance on standard benchmarks**: - Using the LLAMA 3 - 8B model, the proposed methods improve by up to 3% over the sequential method on the MMLU (1 - shot) benchmark and increase the win rate by up to 31% on the RLHF dataset (evaluated by GPT - 4 - TURBO) with only a very small amount of additional computational resources. ### Technical challenges The main technical challenge lies in theoretically proving the forgetting problem when optimizing SFT and DPO losses (two log - likelihood functions). Existing continuous learning theoretical results usually rely on quadratic objectives, while this paper deals with negative log - likelihood objectives whose gradients are nonlinear, which increases the complexity of the analysis. The authors successfully overcame these challenges and provided a detailed analysis in Appendix A.2. ### Experimental results The paper experimentally compares the proposed ALRIGHT and MAXRIGHT methods with existing baseline methods in terms of Pareto frontier performance and resource consumption (such as computation time and memory usage). The experimental results show that the proposed methods can achieve a better balance between SFT and preference learning goals while maintaining a low computational cost.