PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel

Yanli Zhao,Andrew Gu,Rohan Varma,Liang Luo,Chien-Chin Huang,Min Xu,Less Wright,Hamid Shojanazeri,Myle Ott,Sam Shleifer,Alban Desmaison,Can Balioglu,Pritam Damania,Bernard Nguyen,Geeta Chauhan,Yuchen Hao,Ajit Mathews,Shen Li
2023-09-13
Abstract:It is widely acknowledged that large models have the potential to deliver superior performance across a broad range of domains. Despite the remarkable progress made in the field of machine learning systems research, which has enabled the development and exploration of large models, such abilities remain confined to a small group of advanced users and industry leaders, resulting in an implicit technical barrier for the wider community to access and leverage these technologies. In this paper, we introduce PyTorch Fully Sharded Data Parallel (FSDP) as an industry-grade solution for large model training. FSDP has been closely co-designed with several key PyTorch core components including Tensor implementation, dispatcher system, and CUDA memory caching allocator, to provide non-intrusive user experiences and high training efficiency. Additionally, FSDP natively incorporates a range of techniques and settings to optimize resource utilization across a variety of hardware configurations. The experimental results demonstrate that FSDP is capable of achieving comparable performance to Distributed Data Parallel while providing support for significantly larger models with near-linear scalability in terms of TFLOPS.
Distributed, Parallel, and Cluster Computing,Artificial Intelligence,Machine Learning,Performance
What problem does this paper attempt to address?
The paper aims to address several key challenges in training large-scale neural network models. Specifically: 1. **User Friendliness**: Existing distributed training methods (such as DistributedDataParallel) require the entire model and its parameters, gradients, and optimizer states to fit into the memory of a single GPU device, which is a limitation for large models. Therefore, the paper proposes a new method—PyTorch Fully Sharded Data Parallel (FSDP), to provide a user experience similar to local training while supporting the training of large-scale models. 2. **Hardware Heterogeneity**: Modern GPU clusters have different hardware configurations, such as high-bandwidth internal connections and low-bandwidth cross-machine connections. FSDP's design takes these hardware heterogeneities into account and optimizes the corresponding resource utilization. 3. **Resource Utilization**: To ensure that GPU devices remain efficient during distributed training, FSDP minimizes downtime caused by non-computational operations through various techniques. 4. **Memory Management**: Memory planning is crucial for large-scale model training. FSDP optimizes memory usage by limiting the number of memory blocks allocated to unsharded parameters and pausing CPU execution when necessary. In summary, FSDP aims to provide an industry-grade solution that makes the training of large-scale models more efficient and easier to implement, while reducing the learning curve and technical barriers. Experimental results show that FSDP performs comparably to DistributedDataParallel on small models and exhibits near-linear scalability on large-scale models.