Hierarchical Federated Learning with Multi-Timescale Gradient Correction

Wenzhi Fang,Dong-Jun Han,Evan Chen,Shiqiang Wang,Christopher G. Brinton
2024-09-27
Abstract:While traditional federated learning (FL) typically focuses on a star topology where clients are directly connected to a central server, real-world distributed systems often exhibit hierarchical architectures. Hierarchical FL (HFL) has emerged as a promising solution to bridge this gap, leveraging aggregation points at multiple levels of the system. However, existing algorithms for HFL encounter challenges in dealing with multi-timescale model drift, i.e., model drift occurring across hierarchical levels of data heterogeneity. In this paper, we propose a multi-timescale gradient correction (MTGC) methodology to resolve this issue. Our key idea is to introduce distinct control variables to (i) correct the client gradient towards the group gradient, i.e., to reduce client model drift caused by local updates based on individual datasets, and (ii) correct the group gradient towards the global gradient, i.e., to reduce group model drift caused by FL over clients within the group. We analytically characterize the convergence behavior of MTGC under general non-convex settings, overcoming challenges associated with couplings between correction terms. We show that our convergence bound is immune to the extent of data heterogeneity, confirming the stability of the proposed algorithm against multi-level non-i.i.d. data. Through extensive experiments on various datasets and models, we validate the effectiveness of MTGC in diverse HFL settings. The code for this project is available at \href{<a class="link-external link-https" href="https://github.com/wenzhifang/MTGC" rel="external noopener nofollow">this https URL</a>}{<a class="link-external link-https" href="https://github.com/wenzhifang/MTGC" rel="external noopener nofollow">this https URL</a>}.
Machine Learning
What problem does this paper attempt to address?
The problem that this paper attempts to solve is: **How to effectively deal with multi - timescale model drift in non - independent and identically distributed (non - i.i.d.) Hierarchical Federated Learning (HFL) to improve the model convergence performance without introducing frequent model aggregations**. ### Problem Background Traditional Federated Learning (FL) usually assumes that clients communicate directly with the central server, forming a star - shaped topology. However, in real - world distributed systems, communication networks often present a hierarchical architecture, such as intermediate edge servers in edge computing and software - defined networks. This hierarchical architecture brings new challenges, especially when dealing with data heterogeneity. The differences in data distribution at different levels will lead to model drift, especially in the case of multiple time - scales. ### Specific Problems 1. **Multi - timescale model drift**: In HFL, due to the existence of aggregations at multiple levels (such as from clients to group aggregators and then to the central server), the aggregation frequencies at different levels are different, resulting in model drift occurring on different time - scales. 2. **Data heterogeneity**: Data heterogeneity in HFL is divided into two categories: - Intra - group non - i.i.d.: The data distributions of clients within the same group are different. - Inter - group non - i.i.d.: The data distributions between different groups are also different. 3. **Limitations of existing methods**: Existing FL algorithms (such as FedProx, SCAFFOLD, FedDyn) can correct model drift in traditional FL, but in the HFL scenario, these methods are difficult to be directly applied because they do not consider the influence of the multi - timescale communication architecture. ### Paper's Solution To solve the above problems, the authors propose the **Multi - Timescale Gradient Correction (MTGC)** method. MTGC corrects the gradient by introducing two control variables: - **Client - group correction term**: Correct the client gradient to the group gradient to reduce the client model drift caused by local updates. - **Group - global correction term**: Correct the group gradient to the global gradient to reduce the group model drift caused by intra - group federated learning. ### Main Contributions 1. **Theoretical guarantee**: MTGC establishes a convergence bound under non - convex learning models, and this bound is robust to the degree of data heterogeneity. 2. **Linear acceleration**: MTGC achieves linear acceleration in terms of the number of local iterations, the number of group aggregations, and the number of clients. 3. **Experimental verification**: Through extensive experiments on multiple datasets and models, the superiority of MTGC in different non - i.i.d. environments has been verified. ### Summary This paper successfully solves the problem of multi - timescale model drift in HFL by proposing the MTGC method and provides theoretical and experimental evidence to prove its effectiveness in dealing with data heterogeneity.