On the Convergence of Continual Federated Learning Using Incrementally Aggregated Gradients

Satish Kumar Keshri,Nazreen Shah,Ranjitha Prasad
2024-11-13
Abstract:The holy grail of machine learning is to enable Continual Federated Learning (CFL) to enhance the efficiency, privacy, and scalability of AI systems while learning from streaming data. The primary challenge of a CFL system is to overcome global catastrophic forgetting, wherein the accuracy of the global model trained on new tasks declines on the old tasks. In this work, we propose Continual Federated Learning with Aggregated Gradients (C-FLAG), a novel replay-memory based federated strategy consisting of edge-based gradient updates on memory and aggregated gradients on the current data. We provide convergence analysis of the C-FLAG approach which addresses forgetting and bias while converging at a rate of $O(1/\sqrt{T})$ over $T$ communication rounds. We formulate an optimization sub-problem that minimizes catastrophic forgetting, translating CFL into an iterative algorithm with adaptive learning rates that ensure seamless learning across tasks. We empirically show that C-FLAG outperforms several state-of-the-art baselines on both task and class-incremental settings with respect to metrics such as accuracy and forgetting.
Machine Learning,Distributed, Parallel, and Cluster Computing
What problem does this paper attempt to address?
### What problems does this paper attempt to solve? This paper aims to solve several key challenges in **Continual Federated Learning (CFL)**, especially the problem of **global catastrophic forgetting**. Specifically, the authors propose a new method - **Continual Federated Learning with Aggregated Gradients (C - FLAG)** to enhance the efficiency, privacy, and scalability of AI systems while learning from streaming data. #### Main problems 1. **Global catastrophic forgetting**: - When the global model is trained on new tasks, its accuracy on old tasks will decline. This is one of the main challenges in CFL systems. - Formula representation: Let \( P \) denote the data set of past tasks and \( C \) denote the data set of current tasks. Then the goal is to ensure that the model does not forget the knowledge in \( P \) while learning \( C \). 2. **Bias and deviation in federated learning**: - Due to the memory limitations of edge devices, the replay buffer can only store a limited amount of historical data, resulting in sampling bias. - Formula representation: Let \( M_i \) be the replay buffer of the \( i \)-th client, then \( M_i\subset P_i \), where \( P_i \) is all the historical data of the \( i \)-th client. 3. **Non - stationary data streams**: - Streaming data is usually non - stationary, while traditional federated learning frameworks assume that the data is stationary. Therefore, theoretical proof is required to show that the proposed strategy can converge when dealing with both new and old tasks. #### Proposed methods To address the above challenges, the C - FLAG method combines the following features: - **Local learning steps and global aggregation**: - On each client, C - FLAG calculates an effective gradient, which is based on the combination of the local replay buffer \( M_i \) and the current task data set \( C_i \). - Formula representation: The effective gradient is \( \nabla g'(x_t) \), where \( \nabla g'(x_t)=\frac{1}{|C_i|}\sum_{j\in C_i}\nabla g_{i,j}(x_t) \). - **Incremental Aggregated Gradient (IAG)**: - Use the IAG method to reduce computational costs and reduce client drift by approximating gradients. - Formula representation: \( \nabla g'(x_t)\approx\nabla g(x_t)-\nabla g_i(x_t)+\nabla g'_i(x_t) \). - **Adaptive learning rate**: - Balance the relationship between learning new tasks and retaining old knowledge by adjusting the learning rate, thereby alleviating catastrophic forgetting. - Formula representation: For the \( t \)-th round of communication, the learning rates are \( \alpha_t \) and \( \beta_t \), where \( \alpha_t \) is used for memory data and \( \beta_t \) is used for current data. #### Theoretical analysis - **Convergence**: - The authors prove that C - FLAG converges to a stable point in a non - convex setting, with a convergence rate of \( O\left(\frac{1}{\sqrt{T}}\right) \), where \( T \) is the number of communication rounds. - Formula representation: \[ \min_t\mathbb{E}\left[\|\nabla f(x_t)\|^2\right]\leq O\left(\frac{1}{\sqrt{T}}\right) \] Through these methods, C - FLAG effectively solves the global catastrophic forgetting in CFL.