Non-Convex Optimization in Federated Learning via Variance Reduction and Adaptive Learning

Dipanwita Thakur,Antonella Guzzo,Giancarlo Fortino,Sajal K. Das
2024-12-16
Abstract:This paper proposes a novel federated algorithm that leverages momentum-based variance reduction with adaptive learning to address non-convex settings across heterogeneous data. We intend to minimize communication and computation overhead, thereby fostering a sustainable federated learning system. We aim to overcome challenges related to gradient variance, which hinders the model's efficiency, and the slow convergence resulting from learning rate adjustments with heterogeneous data. The experimental results on the image classification tasks with heterogeneous data reveal the effectiveness of our suggested algorithms in non-convex settings with an improved communication complexity of $\mathcal{O}(\epsilon^{-1})$ to converge to an $\epsilon$-stationary point - compared to the existing communication complexity $\mathcal{O}(\epsilon^{-2})$ of most prior works. The proposed federated version maintains the trade-off between the convergence rate, number of communication rounds, and test accuracy while mitigating the client drift in heterogeneous settings. The experimental results demonstrate the efficiency of our algorithms in image classification tasks (MNIST, CIFAR-10) with heterogeneous data.
Machine Learning
What problem does this paper attempt to address?
### What problems does this paper attempt to solve? This paper aims to solve several key problems encountered in Federated Learning (FL), especially those when dealing with heterogeneous data in a non - convex optimization environment. Specifically, the paper mainly focuses on the following aspects: 1. **Communication and Computation Overhead**: - **Communication Complexity**: Traditional federated learning methods require a large number of communication times in each iteration round, which leads to high communication costs. The method proposed in the paper reduces the communication complexity by decreasing the number of communication rounds. - **Computation Complexity**: Due to the inconsistent data distribution (non - IID) among different clients, traditional FL methods need more computational resources for model updates, increasing the computational overhead. 2. **Gradient Variance Problem**: - In federated learning, because each client has a different data distribution, the gradients generated during local model updates have a large variance. This high variance will hinder the convergence speed of the model and affect its performance. - The paper proposes to use momentum - based variance reduction to reduce the gradient variance, thereby increasing the convergence speed and stability of the model. 3. **Learning Rate Adjustment Problem**: - The choice of learning rate is crucial for model convergence. A fixed global learning rate cannot meet the needs of all parameters and may lead to slow or unstable convergence. - The paper introduces an adaptive learning rate mechanism, enabling each parameter to dynamically adjust the learning rate according to its historical gradients, thus accelerating convergence and improving model performance. 4. **Client Drift Problem**: - In a heterogeneous data environment, the local model of each client may gradually deviate from the global model, resulting in the "client drift" phenomenon. This will affect the convergence and accuracy of the global model. - The paper mitigates the client drift problem by combining momentum - based variance reduction and adaptive learning rate, ensuring that the global model can converge better. 5. **Non - convex Optimization Problem**: - Modern machine learning models such as deep neural networks are usually non - convex, so achieving effective non - convex optimization in the federated learning environment is a challenge. - The method proposed in the paper is especially suitable for non - convex settings and has verified its effectiveness in image classification tasks through experiments. ### Main Contributions - Proposed a new federated learning algorithm that combines momentum - based variance reduction and adaptive learning rate to accelerate the convergence speed and reduce the computational overhead. - Provided a convergence analysis in non - convex settings and proved the effectiveness of the proposed method in non - IID data environments. - Experimental results show that this method has higher communication efficiency in image classification tasks (such as MNIST and CIFAR - 10) and can effectively reduce client drift. In conclusion, this paper solves common problems in federated learning such as high communication complexity, large gradient variance, difficult learning rate adjustment, and client drift by introducing momentum - based variance reduction and adaptive learning rate, thereby improving the convergence speed and performance of the model.