Abstract:Federated Learning (FL) allows collaborative machine learning training without sharing private data. Numerous studies have shown that one significant factor affecting the performance of federated learning models is the heterogeneity of data across different clients, especially when the data is sampled from various domains. A recent paper introduces variance-aware dual-level prototype clustering and uses a novel $\alpha$-sparsity prototype loss, which increases intra-class similarity and reduces inter-class similarity. To ensure that the features converge within specific clusters, we introduce an improved algorithm, Federated Prototype Learning with Convergent Clusters, abbreviated as FedPLCC. To increase inter-class distances, we weight each prototype with the size of the cluster it represents. To reduce intra-class distances, considering that prototypes with larger distances might come from different domains, we select only a certain proportion of prototypes for the loss function calculation. Evaluations on the Digit-5, Office-10, and DomainNet datasets show that our method performs better than existing approaches.
What problem does this paper attempt to address?
### The Problem the Paper Attempts to Solve
The paper aims to address the issue of domain shift in Federated Learning (FL). Specifically, it focuses on how to improve the performance of federated learning models when the data distributions across different clients are inconsistent. In federated learning, the data from each client comes from different domains, leading to data heterogeneity, which can make it difficult for the model to converge during training, thereby affecting overall performance.
### Background and Challenges
1. **Data Heterogeneity**: The data in federated learning is usually non-IID (non-Independent and Identically Distributed), meaning that the data distributions across different clients can vary significantly. This data heterogeneity can cause each client to reach its own local optimum, deviating from the global objective, thus reducing the overall performance of the model.
2. **Limitations of Existing Methods**: Although existing Federated Prototype Learning methods have addressed the data heterogeneity issue to some extent, their handling of the clustering process is relatively coarse, which may lead to slow convergence or decreased final accuracy.
### Solution
To address the above challenges, the paper proposes an improved algorithm—FedPLCC (Federated Prototype Learning with Convergent Clusters). This method is based on a dual-level clustering framework and introduces innovations in the following aspects:
1. **Weighted Clustering**: During the clustering process, the number of samples represented by each prototype is considered, and these weights are incorporated into the loss function calculation. This approach better reflects the local data distribution and avoids the excessive influence of a few outlier samples on the loss function calculation.
2. **Selective Loss Calculation**: When calculating the intra-class loss function, a hyperparameter ϕ is introduced to determine the proportion of prototypes involved in the calculation. Specifically, only the top ϕ proportion of prototypes with high similarity to the current features are selected for loss calculation. This avoids forcing features to align with prototypes from completely different domains, thereby improving the model's generalization ability.
### Experimental Results
Experiments were conducted on three datasets: Digit-5, Office-10, and DomainNet. The results show that FedPLCC achieved higher accuracy than existing methods on multiple sub-datasets. Particularly on the Digit-5 and Office-10 datasets, the performance improvement was significant, indicating that the method is robust and effective in handling datasets of varying difficulty levels.
### Conclusion
By systematically comparing and summarizing existing federated prototype learning methods, the paper proposes the improved algorithm FedPLCC. Through weighted clustering and selective loss calculation, this algorithm effectively addresses the domain shift issue in federated learning, enhancing the model's performance and generalization ability.