Variational Federated Multi-Task Learning

Luca Corinzia,Ami Beuret,Joachim M. Buhmann
DOI: https://doi.org/10.48550/arXiv.1906.06268
2021-02-04
Abstract:In federated learning, a central server coordinates the training of a single model on a massively distributed network of devices. This setting can be naturally extended to a multi-task learning framework, to handle real-world federated datasets that typically show strong statistical heterogeneity among devices. Despite federated multi-task learning being shown to be an effective paradigm for real-world datasets, it has been applied only on convex models. In this work, we introduce VIRTUAL, an algorithm for federated multi-task learning for general non-convex models. In VIRTUAL the federated network of the server and the clients is treated as a star-shaped Bayesian network, and learning is performed on the network using approximated variational inference. We show that this method is effective on real-world federated datasets, outperforming the current state-of-the-art for federated learning, and concurrently allowing sparser gradient updates.
Machine Learning
What problem does this paper attempt to address?
The problem that this paper attempts to solve is how to effectively conduct multi - task learning (MTL) in the Federated Learning (FL) framework, especially for non - convex model scenarios. Specifically, the paper addresses the following issues: 1. **Data heterogeneity problem**: In federated learning, the data distributions of different clients usually have strong statistical heterogeneity (non - IID). Traditional federated learning methods such as FedAvg perform poorly when dealing with such heterogeneous data. Therefore, the paper proposes a new algorithm VIRTUAL to meet this challenge. 2. **Privacy protection and communication cost**: An important goal of federated learning is to protect user data privacy and avoid centralizing data on a central server. In addition, communication cost is also a key issue. By introducing the variational inference method, the paper reduces communication costs and improves model performance while ensuring privacy. 3. **Support for non - convex models**: Most existing federated multi - task learning methods are only applicable to convex models. The paper proposes a framework that can be applied to non - convex models, expanding the application scope of federated multi - task learning. ### Main contributions of the paper 1. **First proposed a non - convex model solution for federated multi - task learning**: A new metric is designed, and the VIRTUAL algorithm is proposed for joint training in the case of strongly non - IID client data distributions. 2. **Extensive experimental verification**: Extensive experimental evaluations are carried out on multiple real - world federated datasets, proving that VIRTUAL not only surpasses the current state - of - the - art methods in federated learning, but also can maintain high performance while reducing communication costs. ### Core idea of the VIRTUAL algorithm The VIRTUAL algorithm regards the server and clients in the federated network as a star - shaped Bayesian network and uses approximate variational inference for learning. Each client has a task - specific model, and these models benefit from the server model, similar to the way of transfer learning. Some parameters are shared among all clients, and the other part is private and adjusted separately. The server maintains a posterior distribution representing the rationality of the shared parameters. ### Mathematical formulas In the Bayesian network, assume that the model parameter of the server is \(\theta\), and the model parameters of the clients are \(\{\varphi_i\}_{i = 1}^K\). The data set \(D_i\) of each client is generated by a client - dependent probability distribution function. Given all data sets \(D_{1:K}\), the posterior distribution of the parameters is: \[p(\theta, \varphi_1,\ldots,\varphi_K|D_{1:K})\propto\prod_{i = 1}^K p(\theta,\varphi_i|D_i)p(\theta)^{K - 1}\] where we assume that the client data are conditionally independent given the server and client parameters. The variational free energy function \(L_i\) is defined as: \[L_i=D_{KL}\left(\frac{s_i^{(t)}(\theta)s^{(t - 1)}(\theta)}{s_i^{(t - 1)}(\theta)}\middle\|\frac{p(\theta)}{K}\frac{s^{(t - 1)}(\theta)}{s_i^{(t - 1)}(\theta)}\right)+D_{KL}(c_i^{(t)}(\varphi_i)\|p(\varphi_i))-\mathbb{E}_{s^{(t)}(\theta)c_i^{(t)}(\varphi_i)}[\log p(D_i|\theta,\varphi_i)]\] By minimizing this variational free energy function, the model parameters can be optimized while maintaining privacy. ### Summary By introducing the variational inference method, the paper proposes a new federated multi - task learning algorithm VIRTUAL, which effectively solves the problems of data heterogeneity and non - convex model support in federated learning.