FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

Tri Dao,Daniel Y. Fu,Stefano Ermon,Atri Rudra,Christopher Ré
2022-06-24
Abstract:Transformers are slow and memory-hungry on long sequences, since the time and memory complexity of self-attention are quadratic in sequence length. Approximate attention methods have attempted to address this problem by trading off model quality to reduce the compute complexity, but often do not achieve wall-clock speedup. We argue that a missing principle is making attention algorithms IO-aware -- accounting for reads and writes between levels of GPU memory. We propose FlashAttention, an IO-aware exact attention algorithm that uses tiling to reduce the number of memory reads/writes between GPU high bandwidth memory (HBM) and GPU on-chip SRAM. We analyze the IO complexity of FlashAttention, showing that it requires fewer HBM accesses than standard attention, and is optimal for a range of SRAM sizes. We also extend FlashAttention to block-sparse attention, yielding an approximate attention algorithm that is faster than any existing approximate attention method. FlashAttention trains Transformers faster than existing baselines: 15% end-to-end wall-clock speedup on BERT-large (seq. length 512) compared to the MLPerf 1.1 training speed record, 3$\times$ speedup on GPT-2 (seq. length 1K), and 2.4$\times$ speedup on long-range arena (seq. length 1K-4K). FlashAttention and block-sparse FlashAttention enable longer context in Transformers, yielding higher quality models (0.7 better perplexity on GPT-2 and 6.4 points of lift on long-document classification) and entirely new capabilities: the first Transformers to achieve better-than-chance performance on the Path-X challenge (seq. length 16K, 61.4% accuracy) and Path-256 (seq. length 64K, 63.1% accuracy).
Machine Learning
What problem does this paper attempt to address?
The problem this paper attempts to address is: When training Transformer models on long sequences, the time and memory complexity of the self-attention mechanism increases significantly with the sequence length, leading to slow training speeds and high memory consumption. Although some approximate attention methods have tried to solve this problem by reducing computational complexity, these methods often fail to achieve actual acceleration and sacrifice model quality. Specifically, the paper points out that existing methods mainly focus on reducing floating-point operations (FLOP) while ignoring memory access (IO) overhead, which is the main reason for the lack of significant actual acceleration. Therefore, the paper proposes a new attention algorithm—FlashAttention, which optimizes memory access patterns and reduces the number of data reads and writes between high-bandwidth memory (HBM) and on-chip SRAM, thereby achieving faster training speeds and lower memory usage. ### Main Contributions 1. **FlashAttention Algorithm**: By using tiling and recomputation techniques, it reduces the number of HBM accesses, achieving faster training speeds and lower memory usage. 2. **Theoretical Analysis**: It proves that the I/O complexity of FlashAttention is much lower than that of standard attention algorithms and provides a lower bound proof, indicating that the number of HBM accesses cannot be further optimized within the range of all SRAM sizes. 3. **Extension to Sparse Attention**: It proposes block-sparse FlashAttention, further improving performance on long sequences. 4. **Experimental Validation**: It validates the superiority of FlashAttention on multiple tasks, including BERT, GPT-2, and Long Range Arena (LRA) benchmarks. ### Experimental Results - **Training Speed**: FlashAttention is 15% faster than the MLPerf 1.1 record on BERT-large, 3 times faster than HuggingFace and 1.7 times faster than Megatron-LM on GPT-2, and 2.4 times faster on LRA benchmarks. - **Model Quality**: FlashAttention enables the Transformer to handle longer sequences, thereby improving model quality. For example, on GPT-2, FlashAttention with a 4K context length is 30% faster than Megatron-LM with a 1K context length, and the perplexity improves by 0.7. - **New Capabilities**: FlashAttention enables the Transformer to achieve better-than-random performance for the first time on the Path-X challenge (sequence length 16K) and Path-256 (sequence length 64K). ### Summary This paper proposes a new attention algorithm, FlashAttention, by optimizing memory access patterns, effectively solving the problem of slow training speeds and high memory usage of Transformer models on long sequences. Experimental results show that FlashAttention not only accelerates model training speed but also improves model quality and the ability to handle long sequences.