SimpleFSDP: Simpler Fully Sharded Data Parallel with torch.compile

Ruisi Zhang,Tianyu Liu,Will Feng,Andrew Gu,Sanket Purandare,Wanchao Liang,Francisco Massa
2024-11-01
Abstract:Distributed training of large models consumes enormous computation resources and requires substantial engineering efforts to compose various training techniques. This paper presents SimpleFSDP, a PyTorch-native compiler-based Fully Sharded Data Parallel (FSDP) framework, which has a simple implementation for maintenance and composability, allows full computation-communication graph tracing, and brings performance enhancement via compiler backend optimizations. SimpleFSDP's novelty lies in its unique <a class="link-external link-http" href="http://torch.compile" rel="external noopener nofollow">this http URL</a>-friendly implementation of collective communications using existing PyTorch primitives, namely parametrizations, selective activation checkpointing, and DTensor. It also features the first-of-its-kind intermediate representation (IR) nodes bucketing and reordering in the TorchInductor backend for effective computation-communication overlapping. As a result, users can employ the aforementioned optimizations to automatically or manually wrap model components for minimal communication exposure. Extensive evaluations of SimpleFSDP on Llama 3 models (including the ultra-large 405B) using TorchTitan demonstrate up to 28.54% memory reduction and 68.67% throughput improvement compared to the most widely adopted FSDP2 eager framework, when composed with other distributed training techniques.
Distributed, Parallel, and Cluster Computing,Artificial Intelligence
What problem does this paper attempt to address?
This paper attempts to solve the problems of huge computational resource consumption and high engineering complexity encountered in distributed training of large - scale models. Specifically, the paper introduces SimpleFSDP, a compiler - driven Fully Sharded Data Parallel (FSDP) framework based on PyTorch. It aims to simplify implementation, improve performance, and enhance compatibility with other distributed training techniques. ### Main problems: 1. **Computational resource consumption**: Training large - scale models requires a large amount of computational resources. For example, the training of the Llama 3.1 405B model requires 30.84 million H100 GPU hours. 2. **Engineering complexity**: In order to optimize training performance, it is necessary to use a combination of multiple training techniques, such as various parallel strategies (data parallel, tensor parallel, pipeline parallel), memory - optimization techniques, and communication - optimization techniques. This makes the code base very complex and difficult to maintain and debug. ### Main contributions of SimpleFSDP: 1. **Simplified implementation**: Users can enjoy the performance improvement brought by full - model compilation without modifying the existing distributed training code base. 2. **Composability**: SimpleFSDP can be seamlessly integrated with other distributed training techniques, such as tensor parallel, pipeline parallel, meta - initialization, mixed - precision training, and activation checkpointing, etc. 3. **Performance improvement**: Through full - graph tracking and compiler - backend optimization, SimpleFSDP achieves significant performance improvement. Experimental results show that on the Llama 3 model, compared with the most widely - adopted FSDP2 eager - mode framework, SimpleFSDP reduces the peak memory by 28.54% and increases the throughput by 68.67%. 4. **Debuggability**: SimpleFSDP maintains good debuggability and flexibility in the PyTorch eager - mode, allowing users to carry out agile development. ### Technical details: - **Collective communication implementation**: SimpleFSDP utilizes existing PyTorch primitives (parameterization, selective activation checkpointing, and DTensor API) to implement FSDP semantics. Through these primitives, SimpleFSDP can perform all - gather operations on parameters in the forward propagation and automatically handle reduce - scatter operations on gradients in the reverse propagation. - **Optimization components**: SimpleFSDP introduces two optimization components - bucketing and reordering - to reduce the communication frequency and minimize communication exposure. Bucketing combines multiple communication operations into one larger communication operation, while reordering prefetches the parameters required in subsequent stages so as to overlap with the calculations in the current stage. - **Model - wrapping interface**: SimpleFSDP provides two interfaces, manual - wrapping and auto - wrapping. Users can choose to customize the bucketing and reordering of communication IR nodes according to the module list, or let the system perform fine - grained optimization automatically. ### Summary: SimpleFSDP solves the problems of high computational resource consumption and engineering complexity faced in distributed training of large - scale models by simplifying implementation, improving performance, and enhancing compatibility with other techniques. It not only improves training efficiency but also maintains good debuggability and flexibility, and is suitable for distributed training of large - scale models.