A Bayesian Federated Learning Framework with Online Laplace Approximation

Liangxi Liu,Xi Jiang,Feng Zheng,Hong Chen,Guo-Jun Qi,Heng Huang,Ling Shao
DOI: https://doi.org/10.1109/TPAMI.2023.3322743
2023-12-02
Abstract:Federated learning (FL) allows multiple clients to collaboratively learn a globally shared model through cycles of model aggregation and local model training, without the need to share data. Most existing FL methods train local models separately on different clients, and then simply average their parameters to obtain a centralized model on the server side. However, these approaches generally suffer from large aggregation errors and severe local forgetting, which are particularly bad in heterogeneous data settings. To tackle these issues, in this paper, we propose a novel FL framework that uses online Laplace approximation to approximate posteriors on both the client and server side. On the server side, a multivariate Gaussian product mechanism is employed to construct and maximize a global posterior, largely reducing the aggregation errors induced by large discrepancies between local models. On the client side, a prior loss that uses the global posterior probabilistic parameters delivered from the server is designed to guide the local training. Binding such learning constraints from other clients enables our method to mitigate local forgetting. Finally, we achieve state-of-the-art results on several benchmarks, clearly demonstrating the advantages of the proposed method.
Machine Learning,Artificial Intelligence,Distributed, Parallel, and Cluster Computing
What problem does this paper attempt to address?
### What problems does this paper attempt to solve? This paper aims to solve two key problems in Federated Learning (FL): **Aggregation Error (AE)** and **Local Forgetting (LF)**, especially in the heterogeneous data setting. #### Aggregation Error (AE) In the standard Federated Learning framework, such as FedAvg, the server aggregates the global model by simply taking a weighted average of the local model parameters uploaded by the clients. However, in a heterogeneous data environment, the local model parameters of different clients have different posterior probability distributions, and directly averaging these parameters will lead to a large aggregation error. Specifically, this simple averaging method will make the aggregated posterior probability more uncertain, resulting in a lack of confidence in model prediction and a decline in generalization ability. #### Local Forgetting (LF) After aggregation on the server side, the global model will be distributed to multiple clients for further local training. When the training data is homogeneous, the local likelihood distribution is the same for all clients, so the locally optimized model can generalize well. But in a heterogeneous data environment, local training will cause the local parameters to gradually deviate from the global pattern, causing the local model to forget the knowledge learned from other clients. This not only affects the performance of the model on the local dataset but also leads to a large aggregation error. ### Solutions To solve the above problems, the authors propose a new Federated Learning method based on the Bayesian framework, called **Bayesian Federated Learning Framework with Online Laplace Approximation (FOLA)**. The main contributions of this method include: 1. **Global Posterior Construction**: On the server side, use the Gaussian product method to multiply the local posterior probabilities of multiple clients to construct and maximize the global posterior probability. This can effectively reduce the aggregation error caused by differences in local model parameters. 2. **Prior Iteration Strategy**: On the client side, introduce a Prior Iteration (PI) strategy, using the global posterior probability parameters issued by the server as prior to guide local training. By minimizing the prior loss, the global posterior probability can be maximized, thereby alleviating the local forgetting problem. 3. **Online Laplace Approximation**: Design a new Federated Online Laplace Approximation (FOLA) module to efficiently approximate the Gaussian posterior probability. FOLA can update the local posterior probability parameters in real - time without increasing the computational complexity. 4. **Experimental Verification**: Conduct experiments on multiple commonly - used Federated Learning benchmark datasets to verify the effectiveness of the proposed method and show its superior performance on various evaluation metrics. In summary, this paper successfully solves the aggregation error and local forgetting problems in Federated Learning by introducing the Bayesian perspective and online Laplace approximation, significantly improving the generalization ability and performance of the model.