Bifurcated Attention: Accelerating Massively Parallel Decoding with Shared Prefixes in LLMs

Ben Athiwaratkun,Sujan Kumar Gonugondla,Sanjay Krishna Gouda,Haifeng Qian,Hantian Ding,Qing Sun,Jun Wang,Jiacheng Guo,Liangfu Chen,Parminder Bhatia,Ramesh Nallapati,Sudipta Sengupta,Bing Xiang
2024-07-12
Abstract:This study introduces bifurcated attention, a method designed to enhance language model inference in shared-context batch decoding scenarios. Our approach addresses the challenge of redundant memory IO costs, a critical factor contributing to latency in high batch sizes and extended context lengths. Bifurcated attention achieves this by strategically dividing the attention mechanism during incremental decoding into two separate GEMM operations: one focusing on the KV cache from prefill, and another on the decoding process itself. While maintaining the computational load (FLOPs) of standard attention mechanisms, bifurcated attention ensures precise computation with significantly reduced memory IO. Our empirical results show over 2.1$\times$ speedup when sampling 16 output sequences and more than 6.2$\times$ speedup when sampling 32 sequences at context lengths exceeding 8k tokens on a 7B model that uses multi-head attention. The efficiency gains from bifurcated attention translate into lower latency, making it particularly suitable for real-time applications. For instance, it enables massively parallel answer generation without substantially increasing latency, thus enhancing performance when integrated with post-processing techniques such as re-ranking.
Machine Learning,Artificial Intelligence
What problem does this paper attempt to address?
### Problems the paper attempts to solve The paper "Bifurcated Attention: Accelerating Massively Parallel Decoding with Shared Prefixes in LLMs" aims to solve the problems of high memory I/O cost and large latency encountered by large language models (LLMs) in the scenario of shared - context batch decoding. Specifically, the paper proposes the Bifurcated Attention method, which divides the attention mechanism into two parts during the incremental decoding process to reduce the memory I/O cost, thereby improving the efficiency and speed of model inference. ### Background and motivation 1. **Problem background**: - Large language models (LLMs) perform well on many tasks, but their practical applications face challenges in inference latency and efficiency. - Especially in the single - context batch sampling scenario, generating multiple completion sequences requires a large number of memory I/O operations, which becomes the main bottleneck under high batch sizes and long context lengths. 2. **Limitations of existing methods**: - **Quantization**: Although it can reduce memory usage, its effect is limited under long sequence lengths and large batch sizes. - **Sparse attention**: It can reduce the complexity of the attention mechanism, but has limited improvement for long contexts and fast inference. - **Multi - query attention**: It can reduce the memory I/O of the KV cache, but will reduce the expressive power of the model. - **Paged attention**: It manages the KV cache through block tables to reduce the memory storage requirement, but does not reduce the memory reading of the KV cache. ### Bifurcated Attention 1. **Method overview**: - **Bifurcated Attention** divides the attention mechanism into two parts during the incremental decoding process: - **Context - related attention** (⟨q, Kc⟩): Only needs to load the KV cache once. - **Incremental - decoding - step - related attention** (⟨q, Kd⟩): Loads on demand. - This method significantly reduces the memory I/O cost while keeping the computational load (FLOPs) unchanged, thereby reducing the latency. 2. **Technical details**: - **Formula representation**: - Attention calculation formula: \[ \text{logits} = \langle q, K \rangle : \text{einsum} (bgpnk, bgmk) \rightarrow bgpnm \quad (1) \] \[ o = \langle w, V \rangle : \text{einsum} (bgpmn, bgmv) \rightarrow bgpnv \quad (2) \] - Bifurcated attention calculation formula: \[ \langle q, K \rangle = \langle q, Kc \rangle \oplus \langle q, Kd \rangle \quad (3) \] \[ \langle q, Kc \rangle : \text{einsum} (bgpnk, gmck) \rightarrow bgpnm_c \] \[ \langle q, Kd \rangle : \text{einsum} (bgpnk, bgmdk) \rightarrow bgpnm_d \] \[ \langle w, V \rangle = \langle wc, Vc \rangle