(FL)$^2$: Overcoming Few Labels in Federated Semi-Supervised Learning

Seungjoo Lee,Thanh-Long V. Le,Jaemin Shin,Sung-Ju Lee
2024-10-31
Abstract:Federated Learning (FL) is a distributed machine learning framework that trains accurate global models while preserving clients' privacy-sensitive data. However, most FL approaches assume that clients possess labeled data, which is often not the case in practice. Federated Semi-Supervised Learning (FSSL) addresses this label deficiency problem, targeting situations where only the server has a small amount of labeled data while clients do not. However, a significant performance gap exists between Centralized Semi-Supervised Learning (SSL) and FSSL. This gap arises from confirmation bias, which is more pronounced in FSSL due to multiple local training epochs and the separation of labeled and unlabeled data. We propose $(FL)^2$, a robust training method for unlabeled clients using sharpness-aware consistency regularization. We show that regularizing the original pseudo-labeling loss is suboptimal, and hence we carefully select unlabeled samples for regularization. We further introduce client-specific adaptive thresholding and learning status-aware aggregation to adjust the training process based on the learning progress of each client. Our experiments on three benchmark datasets demonstrate that our approach significantly improves performance and bridges the gap with SSL, particularly in scenarios with scarce labeled data.
Machine Learning
What problem does this paper attempt to address?
### The Problem the Paper Attempts to Solve The paper "(FL)²: Overcoming Few Labels in Federated Semi-Supervised Learning" aims to address the issue of scarce labeled data in Federated Semi-Supervised Learning (FSSL). Specifically, the paper focuses on improving model performance in the Federated Learning (FL) framework when only the server has a small amount of labeled data and the clients have no labeled data. ### Background and Challenges 1. **Federated Learning (FL)**: FL is a distributed machine learning framework that trains an accurate global model while preserving the privacy of client data. Each client trains its model using local data, and the server aggregates these local models to form a global model. 2. **Scarcity of Labeled Data**: Most existing FL methods assume that clients have labeled data, but this is often unrealistic in practical applications. Reasons include: - Clients are usually unwilling or lack the motivation to label data. - Some types of data require expert knowledge to label, such as medical data and multi-dimensional sensor data. 3. **Federated Semi-Supervised Learning (FSSL)**: FSSL aims to address the scarcity of labeled data, especially when only the server has a small amount of labeled data and the clients have no labeled data. However, the performance of FSSL significantly lags behind centralized semi-supervised learning (SSL). 4. **Confirmation Bias**: The issue of confirmation bias is more severe in FSSL because multiple local training rounds and the separation of labeled and unlabeled data make the model prone to overfitting to easy-to-learn samples or data with incorrect pseudo-labels. ### Solution To overcome the above issues, the paper proposes the (FL)² method, which includes the following three key components: 1. **Client-Specific Adaptive Thresholding (CAT)**: - Dynamically adjusts the threshold for pseudo-label generation based on the learning state of each client. - Uses a lower threshold at the beginning of training to utilize more data and gradually increases the threshold to obtain more accurate pseudo-labels as training progresses. 2. **Sharpness-Aware Consistency Regularization (SACR)**: - Selects high-confidence data samples for consistency regularization through Adversarial Weight Perturbation, avoiding the impact of incorrect pseudo-labels. - Ensures that the model's output is consistent under both original and perturbed weights, thereby improving generalization ability. 3. **Learning Status-Aware Aggregation (LSAA)**: - Adjusts aggregation weights based on the learning state of each client, giving higher weights to clients with higher learning difficulty. - Enables the global model to learn more effectively from the updates of these clients. ### Experimental Results The paper conducts experiments on three benchmark datasets (CIFAR10, CIFAR100, and SVHN), and the results show that (FL)² significantly improves model performance, especially in scenarios with extremely limited labeled data. Specifically: - (FL)² demonstrates the best or near-best performance in all settings. - In scenarios with extremely limited labeled data, the performance improvement of (FL)² is particularly significant. For example, in the IID setting, with only 10 labels for CIFAR10, the performance improved by 22.2%, and with only 40 labels for SVHN, the performance improved by 21.9%. - The synergistic effect of the components (CAT, SACR, and LSAA) effectively reduces the impact of confirmation bias and improves the overall performance of the model.