Smart Information Exchange for Unsupervised Federated Learning via Reinforcement Learning

Seohyun Lee,Anindya Bijoy Das,Satyavrat Wagle,Christopher G. Brinton
2024-02-15
Abstract:One of the main challenges of decentralized machine learning paradigms such as Federated Learning (FL) is the presence of local non-i.i.d. datasets. Device-to-device transfers (D2D) between distributed devices has been shown to be an effective tool for dealing with this problem and robust to stragglers. In an unsupervised case, however, it is not obvious how data exchanges should take place due to the absence of labels. In this paper, we propose an approach to create an optimal graph for data transfer using Reinforcement Learning. The goal is to form links that will provide the most benefit considering the environment's constraints and improve convergence speed in an unsupervised FL environment. Numerical analysis shows the advantages in terms of convergence speed and straggler resilience of the proposed method to different available FL schemes and benchmark datasets.
Machine Learning
What problem does this paper attempt to address?
The problem that this paper attempts to solve is how to optimize data exchange through device - to - device (D2D) communication in an unsupervised federated learning (FL) environment to improve the convergence speed and robustness of the model. Specifically, the paper focuses on the following points: 1. **Non - independent and identically distributed (Non - i.i.d.) data problem**: In federated learning, the data distribution on each device is often non - independent and identically distributed, which will lead to a slower convergence speed of the global model and may also cause deviation. The paper proposes to alleviate this problem through D2D communication. 2. **Data exchange challenges in an unsupervised environment**: In an unsupervised learning scenario, due to the lack of label information, how to effectively conduct data exchange is a difficult problem. The paper proposes to use reinforcement learning (RL) to discover the optimal data transmission graph, thereby guiding data exchange between devices. 3. **Consideration of communication cost and transmission failure**: In practical applications, communication between devices may fail or have a high cost. By introducing the communication failure probability \(P_D(i, j)\) and communication cost, the paper ensures that the proposed scheme can not only improve the model performance, but also reduce the communication overhead and increase the transmission success rate. 4. **Improving robustness to stragglers**: In federated learning, some devices may be unable to participate in the aggregation of the global model for various reasons (i.e., stragglers). The method proposed in the paper can still maintain good performance in the presence of stragglers. ### Overview of the solution To address the above problems, the paper proposes the following solutions: - **Feature extraction and clustering based on PCA and K - means++**: Use principal component analysis (PCA) to reduce the data dimension and retain important features; then use K - means++ for clustering to quantify the data differences between different devices. - **Reinforcement - learning - driven optimal graph discovery**: Consider each device as an RL agent, and encourage the establishment of links with significantly different data characteristics between devices through a reward function while considering the communication cost. The reward function is defined as: \[ r_{ij}=\alpha_1\cdot\lambda_{ij}-\alpha_2\cdot P_D(i, j) \] where \(\lambda_{ij}\) represents the degree of data difference between device \(i\) and \(j\), \(P_D(i, j)\) is the probability of communication failure between device \(i\) and \(j\), and \(\alpha_1\) and \(\alpha_2\) are user - parameterized weights. - **Autoencoder for unsupervised learning**: After D2D exchange, use an autoencoder as an unsupervised learning method to minimize the global reconstruction loss, thereby improving the generalization ability and convergence speed of the model. Through these methods, the paper shows that its scheme is superior to existing methods on multiple benchmark datasets (such as FashionMNIST and CIFAR - 10), especially in terms of convergence speed, classification accuracy, and robustness to stragglers.