PILoRA: Prototype Guided Incremental LoRA for Federated Class-Incremental Learning

Haiyang Guo,Fei Zhu,Wenzhuo Liu,Xu-Yao Zhang,Cheng-Lin Liu
2024-07-15
Abstract:Existing federated learning methods have effectively dealt with decentralized learning in scenarios involving data privacy and non-IID data. However, in real-world situations, each client dynamically learns new classes, requiring the global model to classify all seen classes. To effectively mitigate catastrophic forgetting and data heterogeneity under low communication costs, we propose a simple and effective method named PILoRA. On the one hand, we adopt prototype learning to learn better feature representations and leverage the heuristic information between prototypes and class features to design a prototype re-weight module to solve the classifier bias caused by data heterogeneity without retraining the classifier. On the other hand, we view incremental learning as the process of learning distinct task vectors and encoding them within different LoRA parameters. Accordingly, we propose Incremental LoRA to mitigate catastrophic forgetting. Experimental results on standard datasets indicate that our method outperforms the state-of-the-art approaches significantly. More importantly, our method exhibits strong robustness and superiority in different settings and degrees of data heterogeneity. The code is available at \url{<a class="link-external link-https" href="https://github.com/Ghy0501/PILoRA" rel="external noopener nofollow">this https URL</a>}.
Computer Vision and Pattern Recognition
What problem does this paper attempt to address?
The problem that this paper attempts to solve is the Class - Incremental Learning (CIL) problem in Federated Learning (FL), that is, in the case of data privacy protection and non - independent and identically distributed (non - IID) data, how to make the global model continuously learn new classes and maintain the recognition ability of existing classes. Specifically, the paper focuses on the following two key challenges: 1. **Catastrophic Forgetting**: When the model learns data of new classes, it is easy to forget the knowledge of old classes learned before. 2. **Data Heterogeneity**: The data distributions of different clients are quite different, resulting in Classifier Bias and affecting the performance of the global model. To solve these problems, the authors propose a method named PILoRA (Prototype Guided Incremental LoRA). This method combines Prototype Learning and Incremental LoRA, aiming to effectively alleviate the above challenges in the following ways: - **Prototype - guided feature representation learning**: Through prototype learning, the model can learn more discriminative feature representations, thereby reducing the occurrence of catastrophic forgetting. - **Incremental LoRA**: Use LoRA parameters for incremental learning efficiently, and ensure that the learning parameters of different tasks are in different sub - spaces through orthogonal constraints, so as to better retain old knowledge. - **Prototype Re - weight Module**: Dynamically adjust the weights according to the distance between the prototypes uploaded by each client and the global features to mitigate the classifier bias problem. ### Specific methods 1. **Prototype learning and prototype re - weighting**: - Introduce prototype learning to learn compact intra - class feature representations and separated inter - class feature representations. - Design a prototype re - weighting module. By calculating the distance between the prototypes uploaded by each client and the global average feature, the weights are dynamically adjusted to mitigate classifier bias. 2. **Incremental LoRA**: - Use LoRA parameters for incremental learning, and ensure that the learning parameters of different tasks are in different sub - spaces through orthogonal constraints, so as to better retain old knowledge. - In the inference stage, integrate the knowledge of all stages by simple parameter summation. 3. **Loss function**: - Define a comprehensive loss function \( l_{\text{total}}=l_{\text{dce}}+\lambda\cdot l_{\text{pl}}+\gamma\cdot l_{\text{ort}} \), where: - \( l_{\text{dce}} \) is the distance - based cross - entropy loss, which is used to reduce the distance between sample features and the correct prototype. - \( l_{\text{pl}} \) is the prototype learning loss, which is used to learn a more compact intra - class distribution. - \( l_{\text{ort}} \) is the orthogonal constraint loss, which is used to ensure that the learning parameters of different tasks are in different sub - spaces. ### Experimental results The experimental results show that PILoRA outperforms existing methods on standard datasets (such as CIFAR - 100 and TinyImageNet), and shows strong robustness under different non - IID settings and different degrees of data heterogeneity. The specific results are shown in the following table: | Dataset | Method | AN (↑) | Avg (↑) | | ------- | ------ | ------ | ------- | | CIFAR - 100 | Ours | 69.5 | 78.6 | | TinyImageNet | Ours | 69.6 | 78.5 | These results indicate that PILoRA has significant advantages and robustness in dealing with the federated class - incremental learning problem.