Why Batch Normalization Damage Federated Learning on Non-IID Data?

Yanmeng Wang,Qingjiang Shi,Tsung-Hui Chang
2023-11-09
Abstract:As a promising distributed learning paradigm, federated learning (FL) involves training deep neural network (DNN) models at the network edge while protecting the privacy of the edge clients. To train a large-scale DNN model, batch normalization (BN) has been regarded as a simple and effective means to accelerate the training and improve the generalization capability. However, recent findings indicate that BN can significantly impair the performance of FL in the presence of non-i.i.d. data. While several FL algorithms have been proposed to address this issue, their performance still falls significantly when compared to the centralized scheme. Furthermore, none of them have provided a theoretical explanation of how the BN damages the FL convergence. In this paper, we present the first convergence analysis to show that under the non-i.i.d. data, the mismatch between the local and global statistical parameters in BN causes the gradient deviation between the local and global models, which, as a result, slows down and biases the FL convergence. In view of this, we develop a new FL algorithm that is tailored to BN, called FedTAN, which is capable of achieving robust FL performance under a variety of data distributions via iterative layer-wise parameter aggregation. Comprehensive experimental results demonstrate the superiority of the proposed FedTAN over existing baselines for training BN-based DNN models.
Machine Learning,Distributed, Parallel, and Cluster Computing
What problem does this paper attempt to address?
### The Problem the Paper Attempts to Solve This paper aims to explore why Batch Normalization (BN) harms the performance of Federated Learning (FL) under Non-Independent and Identically Distributed (Non-IID) data. Specifically, the paper analyzes how BN leads to mismatches between local and global statistical parameters in a Non-IID data environment, causing gradient bias, which in turn slows down and shifts the convergence speed of FL. Additionally, the paper proposes a new FL algorithm—FedTAN, which eliminates these mismatches by aggregating statistical parameters layer by layer, thereby achieving robust FL performance under different data distributions. ### Main Contributions 1. **Convergence Analysis of FL and BN**: - The authors theoretically analyze for the first time how BN affects the convergence speed of FedAvg. The theoretical results show that if the statistical parameters of the local BN layer and their related gradients are inconsistent with the statistical parameters obtained from the global dataset, gradient bias will appear in the local model, which not only slows down the convergence speed of FL but also leads to biased convergence results. 2. **Analysis of the Relationship Between Local and Global Statistical Parameters and Their Gradients**: - Mathematical expressions under IID and Non-IID data conditions are defined, and the conclusion is drawn: IID data leads to consistency between local and global statistical parameters and their gradients, while Non-IID data causes mismatches. Moreover, merely eliminating the mismatch between local and global statistical parameters without ensuring the consistency of local and global gradients cannot guarantee the convergence of FL. 3. **FedTAN Algorithm**: - A new FL algorithm, FedTAN, is proposed, which eliminates the mismatch between local and global statistical parameters and their gradients by aggregating statistical parameters layer by layer, thereby achieving robust FL performance under different data distributions. 4. **Experimental Validation**: - Experiments were conducted on the CIFAR-10 dataset, and the results show that FedTAN has excellent performance, outperforming benchmark schemes. Although FedTAN requires additional message exchanges, it still achieves satisfactory training performance and faster convergence speed under different data distributions. ### Background and Motivation - **Federated Learning**: Federated Learning is a distributed learning paradigm that allows training deep neural network (DNN) models on edge devices while protecting client privacy. - **Batch Normalization**: BN is a simple and effective technique used to accelerate the training of DNNs and improve their generalization ability. However, recent studies have found that BN significantly harms the performance of FL under Non-IID data. - **Limitations of Existing Methods**: Although existing FL algorithms have addressed this issue to some extent, their performance is still far inferior to centralized schemes, and they lack a theoretical explanation of how BN harms FL convergence. ### Research Methods - **Theoretical Analysis**: Through mathematical derivation, the mismatches between local and global statistical parameters and gradients caused by BN under Non-IID data are analyzed, and how these mismatches affect the convergence of FL is explored. - **Algorithm Design**: The FedTAN algorithm is proposed, which eliminates mismatches by aggregating statistical parameters and gradients layer by layer, thereby improving the performance of FL. - **Experimental Validation**: Extensive experiments were conducted on the CIFAR-10 dataset to verify the effectiveness and superiority of FedTAN. ### Conclusion Through theoretical analysis and experimental validation, this paper reveals the mechanism by which BN harms FL performance under Non-IID data and proposes a new FL algorithm, FedTAN, which effectively solves this problem. By aggregating statistical parameters and gradients layer by layer, FedTAN eliminates the mismatches between local and global statistical parameters and gradients, thereby achieving robust FL performance under different data distributions.