PrimePar: Efficient Spatial-temporal Tensor Partitioning for Large Transformer Model Training

Haobo Xu,Yinhe Han,Ying Wang,Yuming Li,Lei Wang,Haoran Wang
DOI: https://doi.org/10.1145/3620666.3651357
2024-04-27
Abstract:With the rapid up-scaling of transformer-based large language models (LLM), training these models is becoming increasingly demanding on novel parallel training techniques. Tensor partitioning is an extensively researched parallel technique, encompassing data and model parallelism, and has a significant influence on LLM training performance. However, existing state-of-the-art parallel training systems are based on incomplete tensor partitioning space, where the distribution of partitioned sub-operators is limited to the spatial dimension. We discover that introducing the temporal dimension into tensor partitioning of LLM training instance provides extra opportunities to avoid collective communication across devices, saving memory space and also overlapping device-to-device communication with computation. In this paper, we propose a new tensor partition primitive that distributes sub-operators along both the spatial and temporal dimensions to further explore communication and memory overhead reduction over current solutions. This new primitive creates a broader parallelization space and leads to parallel solutions that achieve better training throughput with lower peak memory occupancy compared to state-of-the-art techniques. To efficiently deploy optimized parallel transformer model training to multiple devices, we further present an optimization algorithm that can find optimal parallel solutions from our spatial-temporal tensor partition space with acceptable search time. Our evaluation shows that our optimized tensor partitioning achieves up to 1.68 × training throughput with 69% peak memory occupancy compared to state-of-the-art distributed training systems when training LLMs. Upon scaling to 32 GPUs, the geo-mean speedup across benchmarks is 1.30 ×. When applied in 3D parallelism, up to 1.46 × training throughput can be achieved.
Computer Science
What problem does this paper attempt to address?