Hiding Communication Cost in Distributed LLM Training via Micro-batch Co-execution

Haiquan Wang,Chaoyi Ruan,Jia He,Jiaqi Ruan,Chengjie Tang,Xiaosong Ma,Cheng Li
2024-11-24
Abstract:The growth of Large Language Models (LLMs) has necessitated large-scale distributed training. Highly optimized frameworks, however, still suffer significant losses in Model FLOPS utilization (often below 50%) due to large communication volumes. Meanwhile, our comprehensive profiling shows that the computation- and communication-intensive operators overlap well. This paper introduces DHelix, a novel micro-structure that dramatically improves the efficiency of LLM training inspired by the DNA structure. Central to DHelix's design is Strand Interleaving (SI), which views the continuous stream of training micro-batches through a GPU as two strands. DHelix juxtaposes the forward and backward passes of the two strands and performs a systematic optimization for an SI plan that co-schedules the operators from the opposite strands, enabled by operator-level overlap profiling results and a dynamic-programming based search algorithm. Meanwhile, DHelix enables the two strands to share model states and space for activation data, effectively accommodating two micro-batches with under 3% extra memory space. Dhelix seamlessly integrates with all forms of existing data/model parallelism, the most challenging being pipeline parallelism, thanks to its unique model folding design that results in a W-shaped pipeline. We evaluate DHelix training with the popular Llama and GPT dense models, plus the Phi Mixture of Expert (MoE) model, across 3 GPU clusters (A40, A800, and H100). Results show that it achieves 12-40% (up to 58% MFU) and 2-29% (up to 71% MFU) improvement on the 64-A40 and 64-A800 clusters, respectively, significantly outperforming state-of-the-art methods. On the H100 cluster, though the faster network reduces DHelix's profit margin, it makes cross-node tensor parallelism promising, a practice currently prohibitive due to communication costs.
Distributed, Parallel, and Cluster Computing
What problem does this paper attempt to address?
The problem that this paper attempts to solve is: **In the training of distributed large language models (LLMs), the problem of low model floating - point operation utilization (MFU) due to communication overhead**. Specifically: 1. **Communication Bottleneck Problem**: - As the scale of LLMs grows, distributed training becomes indispensable. However, current optimization frameworks still have significant communication bottlenecks, resulting in the model floating - point operation utilization (MFU) of GPUs often being less than 50%. - These communication bottlenecks mainly come from a large number of communication operations (such as AllGather and ReduceScatter) introduced by various parallel strategies (such as tensor parallelism, sequence parallelism, context parallelism, and expert parallelism, etc.), which occupy a large amount of execution time. 2. **Limitations of Existing Methods**: - The two existing main methods - "intra - batch" and "inter - batch" - both have their own limitations: - "Intra - batch" achieves overlap by decomposing communication and computation operations into smaller units, but is limited by data dependencies within micro - batches and a decline in computational efficiency. - "Inter - batch" overlaps forward computation and backward propagation by concurrently executing two batches and using their complementary GPU memory usage patterns, but has fundamental limitations when applied in the pipeline parallelism (PP) framework and requires a large memory overhead. To solve these problems, the paper proposes **DHelix**, a novel microstructure design, aiming to significantly improve the efficiency of LLM training through Strand Interleaving (SI). The core idea of DHelix is to regard consecutive training micro - batches as two interleaved chains (α - chain and β - chain) and jointly schedule operations from different chains through systematic optimization, thereby effectively hiding communication costs and increasing GPU utilization.