Federated Class-Incremental Learning with Hierarchical Generative Prototypes

Riccardo Salami,Pietro Buzzega,Matteo Mosconi,Mattia Verasani,Simone Calderara
2024-10-23
Abstract:Federated Learning (FL) aims at unburdening the training of deep models by distributing computation across multiple devices (clients) while safeguarding data privacy. On top of that, Federated Continual Learning (FCL) also accounts for data distribution evolving over time, mirroring the dynamic nature of real-world environments. While previous studies have identified Catastrophic Forgetting and Client Drift as primary causes of performance degradation in FCL, we shed light on the importance of Incremental Bias and Federated Bias, which cause models to prioritize classes that are recently introduced or locally predominant, respectively. Our proposal constrains both biases in the last layer by efficiently finetuning a pre-trained backbone using learnable prompts, resulting in clients that produce less biased representations and more biased classifiers. Therefore, instead of solely relying on parameter aggregation, we leverage generative prototypes to effectively balance the predictions of the global model. Our method significantly improves the current State Of The Art, providing an average increase of +7.8% in accuracy. Code to reproduce the results is provided in the suppl. material.
Machine Learning
What problem does this paper attempt to address?
The problems that this paper attempts to solve are two main bias problems in Federated Class - Incremental Learning (FCIL): Incremental Bias (IB) and Federated Bias (FB). These problems lead to a decline in the performance of the model during the distributed training process. ### Specific problem description: 1. **Incremental Bias (IB)**: - When the model gradually learns new classes, it tends to give priority to the most recently introduced classes. This bias makes it difficult for the model to distinguish between classes from different tasks, especially when these classes have not been seen simultaneously. - Mathematically, assume that the model learns new classes \(C_t\) in the \(t\)-th task. Then, for the old classes \(C_{t - 1}\), the probability distribution of the model's prediction may shift: \[ P(y = c|x,\theta_t)\neq P(y = c|x,\theta_{t - 1}),\quad\text{for }c\in C_{t - 1} \] 2. **Federated Bias (FB)**: - Since each client is only trained on its local dataset, the model output of the client is biased towards its local label distribution. This makes the model outputs of different clients vary greatly, further affecting the performance of the global model. - Mathematically, assume that the local data distribution of client \(m\) is \(D_m\), then its output distribution \(P(y|x,\theta_m)\) may be significantly different from the global distribution \(P(y|x,\theta)\): \[ P(y|x,\theta_m)\neq P(y|x,\theta) \] ### Solution: To alleviate the above problems, this paper proposes a new method - **Hierarchical Generative Prototypes (HGP)**, which specifically includes the following aspects: 1. **Prompting**: - Use prompting to constrain the bias, so that the parameter update of the model is concentrated in the classification layer, thereby reducing the impact on the entire network. In this way, the incremental bias and the federated bias can be limited to the last layer. - The specific implementation of prompting is to introduce learnable vectors \(P_k\) and \(P_v\), and optimize these vectors in the multi - head self - attention mechanism to minimize the loss function: \[ \min_{P_k, P_v}\mathbb{E}_{(x,y)\sim D_m}\mathcal{L}(f_{\theta_m}(x),y) \] 2. **Classifier Rebalancing**: - Retrain the classifier on the server side using the generated prototype data to alleviate the bias in the classifier. The specific steps include: - Each client \(m\) calculates the feature distribution \(N_{m,c}(\mu_{m,c},\Sigma_{m,c})\) for each class \(c\). - The server side finds the global distribution \(\tilde{Q}_c\) closest to all client distributions by minimizing the Jensen - Shannon divergence (JSD): \[ \min_{\tilde{Q}_c}\sum_{m = 1}^M\pi_{m,c}D_{KL}(\tilde{Q}_c||G) \] - Finally, retrain the global classifier by sampling from the global generative model \(\tilde{Q}\): \[