Understanding and Improving Model Averaging in Federated Learning on Heterogeneous Data

Tailin Zhou,Zehong Lin,Jun Zhang,Danny H.K. Tsang
DOI: https://doi.org/10.1109/TMC.2024.3406554
2024-05-31
Abstract:Model averaging is a widely adopted technique in federated learning (FL) that aggregates multiple client models to obtain a global model. Remarkably, model averaging in FL yields a superior global model, even when client models are trained with non-convex objective functions and on heterogeneous local datasets. However, the rationale behind its success remains poorly understood. To shed light on this issue, we first visualize the loss landscape of FL over client and global models to illustrate their geometric properties. The visualization shows that the client models encompass the global model within a common basin, and interestingly, the global model may deviate from the basin's center while still outperforming the client models. To gain further insights into model averaging in FL, we decompose the expected loss of the global model into five factors related to the client models. Specifically, our analysis reveals that the global model loss after early training mainly arises from \textit{i)} the client model's loss on non-overlapping data between client datasets and the global dataset and \textit{ii)} the maximum distance between the global and client models. Based on the findings from our loss landscape visualization and loss decomposition, we propose utilizing iterative moving averaging (IMA) on the global model at the late training phase to reduce its deviation from the expected minimum, while constraining client exploration to limit the maximum distance between the global and client models. Our experiments demonstrate that incorporating IMA into existing FL methods significantly improves their accuracy and training speed on various heterogeneous data setups of benchmark datasets. Code is available at \url{<a class="link-external link-https" href="https://github.com/TailinZhou/FedIMA" rel="external noopener nofollow">this https URL</a>}.
Machine Learning,Artificial Intelligence
What problem does this paper attempt to address?
The problem that this paper attempts to solve is in Federated Learning (FL), how Model Averaging (MA) works on heterogeneous data and why it can effectively improve the performance of the global model. Specifically, the paper focuses on the following aspects: 1. **Understanding the successful mechanism of model averaging in FL**: - Although client data is highly heterogeneous, model averaging can still produce a well - performing global model in FL. The paper reveals the geometric characteristics behind this phenomenon by visualizing the Loss Landscape. - The visualization results show that client models are within a common basin around the global model, and even if the global model deviates from the center of the basin, it can still outperform client models. 2. **Analyzing the decomposition of the global model loss**: - The paper decomposes the expected loss of the global model into five factors: Training Bias, Heterogeneous Bias, Model - Prediction Variance, Covariance Between Client Models, and Locality. - Through this decomposition, the paper reveals that in the early training stage, the global model loss is mainly determined by the losses of client models on non - overlapping data and the maximum distance from the global model. 3. **Proposing improvement methods**: - Based on the above findings, the paper proposes the Iterative Moving Averaging (IMA) method to optimize the global model in the later training stage, so as to reduce the degree of its deviation from the expected minimum. - At the same time, the paper also suggests limiting client exploration to reduce the maximum distance between the global model and client models. 4. **Experimental verification**: - The experimental results show that integrating IMA into existing FL methods can significantly improve the accuracy and training speed of the model, especially on benchmark datasets in various heterogeneous data settings. In conclusion, this paper aims to deeply understand the mechanism of model averaging in federated learning through visualizing the loss landscape and loss decomposition, and proposes effective improvement methods to further improve the performance of FL on heterogeneous data.