Optimizing Large Model Training through Overlapped Activation Recomputation

Ping Chen,Wenjie Zhang,Shuibing He,Yingjie Gu,Zhuwei Peng,Kexin Huang,Xuan Zhan,Weijian Chen,Yi Zheng,Zhefeng Wang,Yanlong Yin,Gang Chen
2024-06-27
Abstract:Large model training has been using recomputation to alleviate the memory pressure and pipelining to exploit the parallelism of data, tensor, and devices. The existing recomputation approaches may incur up to 40% overhead when training real-world models, e.g., the GPT model with 22B parameters. This is because they are executed on demand in the critical training path. In this paper, we design a new recomputation framework, Lynx, to reduce the overhead by overlapping the recomputation with communication occurring in training pipelines. It consists of an optimal scheduling algorithm (OPT) and a heuristic-based scheduling algorithm (HEU). OPT achieves a global optimum but suffers from a long search time. HEU was designed based on our observation that there are identical structures in large DNN models so that we can apply the same scheduling policy to all identical structures. HEU achieves a local optimum but reduces the search time by 99% compared to OPT. Our comprehensive evaluation using GPT models with 1.3B-20B parameters shows that both OPT and HEU outperform the state-of-the-art recomputation approaches (e.g., Megatron-LM and Checkmake) by 1.02-1.53x. HEU achieves a similar performance as OPT with a search time of 0.16s on average.
Distributed, Parallel, and Cluster Computing,Machine Learning
What problem does this paper attempt to address?
This paper attempts to solve the significant overhead problem caused by recomputation operations in large - scale model training. Specifically, existing recomputation methods may lead to up to 40% additional overhead when dealing with large - scale deep neural network (DNN) models with billions of parameters. This is because these methods are usually executed on - demand on the critical training path, thus increasing the overall training time. ### Main problems of the paper 1. **Memory pressure**: Large - scale model training requires a large amount of GPU memory to store activation data, which may lead to out - of - memory problems. 2. **Communication overhead**: When using tensor parallelism (TP) and pipeline parallelism (PP) for distributed training, communication overhead accounts for a considerable proportion, especially when communicating across GPUs. 3. **Recomputation overhead**: Existing recomputation methods are executed on the critical path, resulting in significant time overhead and affecting training efficiency. ### Solutions To solve the above problems, the authors propose a new recomputation framework, Lynx, whose main goals are: - **Overlap recomputation and communication**: By overlapping recomputation operations with the communication process, reduce the additional overhead caused by recomputation. - **Optimize GPU memory utilization**: Selectively store tensors to prevent unnecessary recomputation and improve memory utilization. - **Achieve load balancing between pipeline stages**: Ensure that the execution time of each pipeline stage is similar to maximize training throughput. ### Main contributions 1. **Lynx framework**: This is the first framework to fully explore the potential of overlapping recomputation and communication and use idle GPU memory to eliminate unnecessary tensor recomputation. 2. **Two scheduling algorithms**: - **OPT**: The global optimal solution, but with a long search time. - **HEU**: A heuristic - based local optimal solution, with a significantly reduced search time and performance close to OPT. 3. **Recomputation - aware model partitioning algorithm**: Ensure load balancing between pipeline stages, thereby maximizing training throughput. 4. **Comprehensive evaluation**: Experimental results show that Lynx improves the performance by 1.02 to 1.53 times compared with the existing state - of - the - art recomputation methods (such as Megatron - LM and Checkmake) on GPT models of different scales. ### Summary of mathematical formulas To minimize the end - to - end training time, including forward time, backward time, and recomputation overhead, the objective function of Lynx can be expressed as: \[ \text{minimize} \quad \sum_{t = 1}^{R_n} \sum_{i = 1}^{t} C_i\times R_{t,i}-\sum_{t\in COMM}\sum_{i = 1}^{t - 1} C_i\times R_{t,i} \] where: - \(C_i\) is the computation time of operation \(OP_i\). - \(R_{t,i}\) indicates whether operation \(OP_i\) is executed in stage \(t\). - \(S_{t,i}\) indicates whether the output of operation \(OP_i\) is retained in the GPU between stage \(t - 1\) and \(t\). The constraints include dependency relationships, communication limitations, and memory limitations, ensuring that recomputation operations can be effectively executed without affecting the correctness of training.