Dung Thuy Nguyen,Taylor T. Johnson,Kevin Leach
Abstract:Federated Learning (FL) shows promise in preserving privacy and enabling collaborative learning. However, most current solutions focus on private data collected from a single domain. A significant challenge arises when client data comes from diverse domains (i.e., domain shift), leading to poor performance on unseen domains. Existing Federated Domain Generalization approaches address this problem but assume each client holds data for an entire domain, limiting their practicality in real-world scenarios with domain-based heterogeneity and client sampling.
To overcome this, we introduce FISC, a novel FL domain generalization paradigm that handles more complex domain distributions across clients. FISC enables learning across domains by extracting an interpolative style from local styles and employing contrastive learning. This strategy gives clients multi-domain representations and unbiased convergent targets. Empirical results on multiple datasets, including PACS, Office-Home, and IWildCam, show FISC outperforms state-of-the-art (SOTA) methods. Our method achieves accuracy improvements ranging from 3.64% to 57.22% on unseen domains. Our code is available at <a class="link-external link-https" href="https://anonymous.4open.science/r/FISC-AAAI-16107" rel="external noopener nofollow">this https URL</a>.
Machine Learning,Computer Vision and Pattern Recognition,Distributed, Parallel, and Cluster Computing
What problem does this paper attempt to address?
### What problems does this paper attempt to solve?
This paper aims to solve the problem of domain generalization (DG) in federated learning (FL), especially when client data comes from different domains. Specifically, the paper explores the following issues:
1. **Domain Shift**:
- In existing federated learning methods, most solutions focus on private data collected from a single domain. However, when the data of clients comes from multiple different domains (i.e., domain shift), the performance of the trained model on unseen domains will decline significantly.
2. **Limitations of Existing Federated Domain Generalization Methods**:
- Existing federated domain generalization methods assume that each client has data of the entire domain, which is not realistic in practical applications. For example, in medical scenarios, data collected by sensors and cameras in different hospitals may have geographical distribution differences, resulting in domain heterogeneity.
3. **Domain Generalization in the Client Sampling Scenario**:
- The paper also considers the case of client sampling, that is, only some clients participate in each training round. In this case, how to ensure that the model still has good generalization ability on unseen domains is a challenge.
4. **Privacy Protection and Performance Improvement**:
- An important feature of federated learning is privacy protection. Therefore, while achieving domain generalization, it is necessary to ensure that the specific data information of clients will not be leaked. In addition, it is also necessary to improve the performance of the model under complex domain distributions.
To solve these problems, the paper proposes a new federated domain generalization method - FISC (Federated Interpolative Style Transfer and Contrastive Learning). FISC enables each client to obtain multi - domain representations by extracting interpolative styles and using contrastive learning, and avoids bias on local data, thereby improving the generalization ability of the model on unseen domains.
### Formula Summary
- Interpolative Style Extraction Formula:
\[
S_g=\text{median}(S_j|\forall S_j\in\Gamma_L)\in\mathbb{R}^{2d}
\]
where \(S_g\) is the global interpolative style, \(S_j\) is the style statistic of each cluster, and \(\Gamma_L\) is the result after FINCH clustering.
- Contrastive Learning Loss Function:
\[
L_T = \sum_{i = 1}^{|B|}[\|z_a^i - z_p^i\|_2^2-\|z_a^i - z_n^i\|_2^2+\alpha]
\]
where \(z_a^i\) is the feature representation of the current sample, \(z_p^i\) is its corresponding style - transferred embedding, \(z_n^i\) is the style - transferred embedding from other categories, and \(\alpha\) is the margin value.
- Overall Objective Function:
\[
L = L_{CE}+\gamma_1L_T+\gamma_2L_{reg}
\]
where \(L_{CE}\) is the cross - entropy loss, \(L_T\) is the triplet loss, and \(L_{reg}\) is the L2 regularization term.
Through these methods, FISC can effectively improve the generalization ability and privacy protection level of the model under complex domain distributions.