FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision

Jay Shah,Ganesh Bikshandi,Ying Zhang,Vijay Thakkar,Pradeep Ramani,Tri Dao
2024-07-13
Abstract:Attention, as a core layer of the ubiquitous Transformer architecture, is the bottleneck for large language models and long-context applications. FlashAttention elaborated an approach to speed up attention on GPUs through minimizing memory reads/writes. However, it has yet to take advantage of new capabilities present in recent hardware, with FlashAttention-2 achieving only 35% utilization on the H100 GPU. We develop three main techniques to speed up attention on Hopper GPUs: exploiting asynchrony of the Tensor Cores and TMA to (1) overlap overall computation and data movement via warp-specialization and (2) interleave block-wise matmul and softmax operations, and (3) block quantization and incoherent processing that leverages hardware support for FP8 low-precision. We demonstrate that our method, FlashAttention-3, achieves speedup on H100 GPUs by 1.5-2.0$\times$ with FP16 reaching up to 740 TFLOPs/s (75% utilization), and with FP8 reaching close to 1.2 PFLOPs/s. We validate that FP8 FlashAttention-3 achieves 2.6$\times$ lower numerical error than a baseline FP8 attention.
Machine Learning,Artificial Intelligence
What problem does this paper attempt to address?
### Problems the paper attempts to solve The paper aims to solve the computational bottleneck problem of the attention mechanism in the Transformer architecture. Specifically, the paper focuses on how to accelerate the computation of the attention mechanism on GPU, especially when dealing with long - context applications. The traditional attention mechanism has a quadratic complexity when calculating the self - attention scores of queries and keys, which has become a major bottleneck for large - scale language models and long - context applications. Although FlashAttention has already accelerated the attention computation by reducing memory read - write operations, its utilization rate on the latest hardware is still low, for example, only reaching a 35% utilization rate on the H100 GPU. In order to further improve the computational efficiency of the attention mechanism, the paper proposes three main techniques: 1. **Asynchronous execution**: Utilize the asynchronous characteristics of Tensor Cores and Tensor Memory Accelerator (TMA), and achieve the overlap of computation and data movement through warp specialization and the interleaved execution of block - matrix multiplication and softmax operations. 2. **Low - precision computation**: Through block quantization and non - contiguous processing techniques, utilize the FP8 low - precision computation supported by hardware to increase the computational speed and reduce numerical errors. 3. **Block quantization**: Divide the input tensor into blocks and perform individual quantization on each block to reduce numerical errors, especially when dealing with outliers in large - scale models. The combination of these techniques enables FlashAttention - 3 to achieve a significant performance improvement on the H100 GPU, reaching 740 TFLOPs/s (75% utilization) in FP16 precision and approaching 1.2 PFLOPs/s in FP8 precision. In addition, the numerical error of FP8 FlashAttention - 3 is 2.6 times lower than that of the baseline FP8 attention mechanism. ### Main contributions 1. **Asynchronous execution**: Through warp specialization and the interleaved execution of block - matrix multiplication and softmax operations, achieve the overlap of computation and data movement and improve computational efficiency. 2. **Low - precision computation**: Utilize FP8 low - precision computation, and through block quantization and non - contiguous processing techniques, reduce numerical errors and increase computational speed. 3. **Block quantization**: Divide the input tensor into blocks and perform individual quantization on each block to reduce numerical errors, especially when dealing with outliers in large - scale models. ### Experimental verification The paper verifies the effectiveness of FlashAttention - 3 through a series of experiments on the H100 SXM5 GPU. The experimental results show that: - In FP16 precision, FlashAttention - 3 is 1.5 - 2.0 times faster than FlashAttention - 2 in the forward pass, reaching 740 TFLOPs/s. - In FP8 precision, FlashAttention - 3 approaches 1.2 PFLOPs/s. - For long sequence lengths, FlashAttention - 3 in FP16 precision is superior to the implementation of the attention mechanism in the NVIDIA cuDNN library, and is also competitive in FP8 precision. - In terms of numerical error, FlashAttention - 3 in FP16 precision is the same as FlashAttention - 2 and is superior to the standard implementation of the attention mechanism. FlashAttention - 3 in FP8 precision, through block quantization and non - contiguous processing techniques, has a numerical error 2.6 times lower than that of the standard attention mechanism. ### Conclusion Through the above techniques, FlashAttention - 3 achieves a significant performance improvement on the latest GPU architectures, especially when dealing with long - context applications. These contributions of the paper are expected to promote the development of large - scale language models and long - context applications.