Unmasking Efficiency: Learning Salient Sparse Models in Non-IID Federated Learning

Riyasat Ohib,Bishal Thapaliya,Gintare Karolina Dziugaite,Jingyu Liu,Vince Calhoun,Sergey Plis
2024-05-15
Abstract:In this work, we propose Salient Sparse Federated Learning (SSFL), a streamlined approach for sparse federated learning with efficient communication. SSFL identifies a sparse subnetwork prior to training, leveraging parameter saliency scores computed separately on local client data in non-IID scenarios, and then aggregated, to determine a global mask. Only the sparse model weights are communicated each round between the clients and the server. We validate SSFL's effectiveness using standard non-IID benchmarks, noting marked improvements in the sparsity--accuracy trade-offs. Finally, we deploy our method in a real-world federated learning framework and report improvement in communication time.
Machine Learning,Artificial Intelligence,Distributed, Parallel, and Cluster Computing
What problem does this paper attempt to address?
The problem that this paper attempts to solve is: in the federated learning environment of non - independent and identically distributed (non - IID), how to improve communication efficiency and computational cost. Specifically, the author proposes a new framework named Salient Sparse Federated Learning (SSFL), aiming to reduce the communication volume between clients and servers by sparsifying model parameters and train efficient sparse models on resource - constrained edge devices. ### Main problems and solutions in the paper #### 1. **Problem description** - **Data heterogeneity**: In federated learning, the data distributions of different clients may be non - independent and identically distributed (non - IID), which leads to difficulties in model training. - **Communication bandwidth limitation**: Due to the large number of clients participating in federated learning and their scattered geographical locations, the communication bandwidth is limited, and frequent transmission of a large number of model parameters will bring high communication costs. - **Computation resource limitation**: Many edge devices (such as IoT devices, mobile devices, etc.) have limited computation resources and cannot efficiently process large - scale deep - learning models. #### 2. **Solutions** To solve the above problems, the author proposes the SSFL method, whose core ideas include: - **Sparse sub - network identification**: Before the start of training, calculate the parameter saliency scores based on the local data of each client, and aggregate these scores to determine a global mask. This mask is used to select a sparse sub - network. - **Sparse model communication**: Throughout the training process, only the weights of the sparse sub - network are transmitted, rather than the complete model parameters, thus greatly reducing the communication volume. - **Initialization consistency**: All client models start from the same initial weights and masks, avoiding the need to share dense model parameters during the training process. ### Specific implementation steps 1. **Calculate parameter saliency scores**: Each client calculates the saliency scores of parameters based on local data, and these scores reflect the importance of parameters to the loss function. 2. **Generate global mask**: Weighted - average the saliency scores of all clients to obtain a global saliency score, and then select the most important connections according to this score to form a global mask. 3. **Sparse model training**: Apply the global mask to the model of each client, and only train and transmit the weights of the sparse sub - network. ### Experimental verification The author verifies the effectiveness of SSFL through multiple standard non - IID benchmark datasets (such as CIFAR - 10, CIFAR - 100 and TinyImageNet). The results show significant improvement in the sparsity - accuracy trade - off and a reduction in communication time in the actual federated learning framework. ### Summary The main contribution of this paper is to provide a novel sparse federated learning method SSFL, which can effectively reduce communication overhead in the case of non - IID data distribution while maintaining high model performance, and is especially suitable for resource - constrained edge devices.