Tackling Data Heterogeneity in Federated Learning with Class Prototypes

Yutong Dai,Zeyuan Chen,Junnan Li,Shelby Heinecke,Lichao Sun,Ran Xu
2023-12-26
Abstract:Data heterogeneity across clients in federated learning (FL) settings is a widely acknowledged challenge. In response, personalized federated learning (PFL) emerged as a framework to curate local models for clients' tasks. In PFL, a common strategy is to develop local and global models jointly - the global model (for generalization) informs the local models, and the local models (for personalization) are aggregated to update the global model. A key observation is that if we can improve the generalization ability of local models, then we can improve the generalization of global models, which in turn builds better personalized models. In this work, we consider class imbalance, an overlooked type of data heterogeneity, in the classification setting. We propose FedNH, a novel method that improves the local models' performance for both personalization and generalization by combining the uniformity and semantics of class prototypes. FedNH initially distributes class prototypes uniformly in the latent space and smoothly infuses the class semantics into class prototypes. We show that imposing uniformity helps to combat prototype collapse while infusing class semantics improves local models. Extensive experiments were conducted on popular classification datasets under the cross-device setting. Our results demonstrate the effectiveness and stability of our method over recent works.
Machine Learning,Artificial Intelligence
What problem does this paper attempt to address?
### Problems the Paper Aims to Solve This paper aims to address the issues of data heterogeneity and class imbalance in Federated Learning (FL). Specifically: 1. **Data Heterogeneity**: - In federated learning, the data distribution across different clients may vary, which violates the independent and identically distributed (i.i.d) assumption commonly used in centralized machine learning. This non-i.i.d phenomenon, known as data heterogeneity, can lead to a decline in the performance of the global model. 2. **Class Imbalance**: - Class imbalance refers to the uneven distribution of classes in the data across different clients, where some classes may have no samples at all. For example, the distribution of disease records among medical institutions may differ, requiring personalized models to detect all classes present in the local training dataset with equal accuracy. 3. **Limitations of Existing Methods**: - The current approach to evaluating the effectiveness of Personalized Federated Learning (PFL) methods may be biased. Specifically, when assessing accuracy, a balanced test dataset is typically split into multiple local test datasets that match the training data distribution of the clients. The personalized models are then tested on each local test dataset, and the average accuracy is reported. However, in the presence of class imbalance, this evaluation protocol may provide biased results due to overfitting to the dominant classes. ### Solution To address the above issues, the authors propose FedNH (Non-parametric Head), a novel method that leverages the uniformity and semantics of class prototypes to tackle data heterogeneity and class imbalance. The main contributions include: 1. **Proposing the FedNH Method**: - FedNH improves the quality of learned representations by uniformly distributing class prototypes in the latent space and fixing these prototypes throughout the training process. This reduces the likelihood of representations of minority classes being overshadowed by those of majority classes. 2. **Designing New Evaluation Metrics**: - The authors design a new evaluation metric that is insensitive to class imbalance and can reflect the generalization ability of personalized models on minority classes. 3. **Experimental Validation**: - Numerical experiments on the Cifar10, Cifar100, and TinyImageNet datasets demonstrate that FedNH can effectively improve the classification accuracy of both personalized and global models, with significantly lower computational costs compared to existing methods. ### Method Overview 1. **Initialization**: - The server is responsible for initializing the parameters of the network body and the classification head. For the parameters of the classification head \( W \), the authors generate uniformly distributed class prototypes by solving a constrained optimization problem. 2. **Client Update**: - In each communication round, clients receive the body parameters \( \theta_t \) and classification head parameters \( W_t \), and use their local training dataset \( D_k \) to learn a robust body for extracting representations needed for classification. 3. **Server Update**: - The server aggregates the body parameters returned by the clients and updates the global class prototypes using a new strategy that incorporates class semantic information. Through these methods, FedNH can improve model performance and generalization ability while addressing data heterogeneity and class imbalance.