BurstAttention: An Efficient Distributed Attention Framework for Extremely Long Sequences

Ao Sun,Weilin Zhao,Xu Han,Cheng Yang,Zhiyuan Liu,Chuan Shi,Maosong Sun
2024-06-06
Abstract:Effective attention modules have played a crucial role in the success of Transformer-based large language models (LLMs), but the quadratic time and memory complexities of these attention modules also pose a challenge when processing long sequences. One potential solution for the long sequence problem is to utilize distributed clusters to parallelize the computation of attention modules across multiple devices (e.g., GPUs). However, adopting a distributed approach inevitably introduces extra memory overheads to store local attention results and incurs additional communication costs to aggregate local results into global ones. In this paper, we propose a distributed attention framework named ``BurstAttention'' to optimize memory access and communication operations at both the global cluster and local device levels. In our experiments, we compare BurstAttention with other competitive distributed attention solutions for long sequence processing. The experimental results under different length settings demonstrate that BurstAttention offers significant advantages for processing long sequences compared with these competitive baselines, reducing 40% communication overheads and achieving 1.37 X speedup during training 128K sequence length on 32 X A100.
Distributed, Parallel, and Cluster Computing,Machine Learning
What problem does this paper attempt to address?
The problem that this paper attempts to solve is the time and memory complexity issues of the attention mechanism in the Transformer model when dealing with extremely long sequences. Specifically: - **Problem Background**: The Transformer model has dominated in large - language models (LLMs) due to its powerful semantic understanding ability and controllable response generation ability. However, the multi - head attention module in these models exhibits quadratic time complexity and memory complexity when processing long sequences, which poses challenges to computational time and memory consumption. - **Existing Solutions and Their Shortcomings**: - Single - device optimization methods (such as FlashAttention) accelerate the calculation of the attention module by using more efficient static random - access memory (SRAM), but mainly focus on optimization on a single device. - Distributed cluster methods (such as RingAttention) handle long sequences by dividing the sequence into multiple sub - sequences and processing them separately on different devices. However, this method introduces additional memory overhead for storing local attention results and requires additional communication costs to aggregate local results into global results. - **Solution Proposed in the Paper**: This paper proposes a distributed attention framework named "BurstAttention", aiming to solve the above problems by optimizing memory access and communication operations at the global cluster and local device levels. The main features of BurstAttention include: - **Global Attention Optimization (GAO)**: Dynamically accumulate local attention results into global results through the online softmax technique, avoid storing the intermediate result \( QK^T \) with quadratic memory complexity, and recalculate these intermediate results during the back - propagation process. - **Local Attention Optimization (LAO)**: Further divide the local attention operation into smaller blocks, utilize the high bandwidth of SRAM while minimizing access to the low - bandwidth HBM. - **Overlapping of Communication and Computation**: Adopt the double - buffering technique so that communication can be carried out simultaneously with computation, improving the overall efficiency. - **Integration with Sparse Attention Methods**: BurstAttention can be easily combined with various sparse attention methods to further reduce time and memory overhead. Through experiments, the authors compared BurstAttention with other competitive long - sequence distributed attention solutions and demonstrated the significant advantages of BurstAttention when dealing with long sequences. In particular, when training a sequence length of 128K, BurstAttention reduced the communication overhead by 40% compared to other baseline methods and achieved a 1.37 - fold acceleration.