TidalDecode: Fast and Accurate LLM Decoding with Position Persistent Sparse Attention

Lijie Yang,Zhihao Zhang,Zhuofu Chen,Zikun Li,Zhihao Jia
2024-10-07
Abstract:Large language models (LLMs) have driven significant advancements across diverse NLP tasks, with long-context models gaining prominence for handling extended inputs. However, the expanding key-value (KV) cache size required by Transformer architectures intensifies the memory constraints, particularly during the decoding phase, creating a significant bottleneck. Existing sparse attention mechanisms designed to address this bottleneck have two limitations: (1) they often fail to reliably identify the most relevant tokens for attention, and (2) they overlook the spatial coherence of token selection across consecutive Transformer layers, which can lead to performance degradation and substantial overhead in token selection. This paper introduces TidalDecode, a simple yet effective algorithm and system for fast and accurate LLM decoding through position persistent sparse attention. TidalDecode leverages the spatial coherence of tokens selected by existing sparse attention methods and introduces a few token selection layers that perform full attention to identify the tokens with the highest attention scores, while all other layers perform sparse attention with the pre-selected tokens. This design enables TidalDecode to substantially reduce the overhead of token selection for sparse attention without sacrificing the quality of the generated results. Evaluation on a diverse set of LLMs and tasks shows that TidalDecode closely matches the generative performance of full attention methods while reducing the LLM decoding latency by up to 2.1x.
Machine Learning,Artificial Intelligence,Computation and Language
What problem does this paper attempt to address?
### Problems the paper attempts to solve This paper aims to solve the memory bottleneck problem encountered by large - language models (LLMs) when handling long - context tasks. Specifically, the Transformer architecture requires an ever - expanding key - value (KV) cache size during the decoding stage, which exacerbates the memory limitations and becomes a significant bottleneck. Although the existing sparse attention mechanisms attempt to solve this bottleneck, they have two main limitations: 1. **Unable to reliably identify the most relevant attention tokens**: The existing sparse attention mechanisms often fail to accurately select the most important tokens for attention calculation. 2. **Ignoring the spatial consistency of token selection**: The existing methods overlook the spatial consistency of token selection between consecutive Transformer layers, leading to performance degradation and significant overhead in token selection. To solve these problems, the paper introduces TidalDecode, a fast and accurate LLM decoding algorithm and system implemented through Position Persistent Sparse Attention (PPSA). TidalDecode takes advantage of the spatial consistency of tokens selected by existing sparse attention methods and introduces several token - selection layers. These layers perform full - attention to identify tokens with the highest attention scores, while other layers perform sparse attention using pre - selected tokens. This design enables TidalDecode to significantly reduce the token - selection overhead of sparse attention without sacrificing the quality of the generated results. ### Main contributions 1. **Propose TidalDecode**: An efficient and high - quality LLM decoding algorithm and system that utilizes position - persistent sparse attention. 2. **Introduce a cache - correction mechanism**: Regularly refill the KV cache using full - attention to alleviate the problem of KV cache distribution shift. 3. **Experimentally prove effectiveness**: Through extensive experimental evaluations, demonstrate the performance and efficiency advantages of TidalDecode on multiple LLMs and tasks, especially in long - context tasks. ### Experimental results - **Needle - in - the - Haystack**: TidalDecode performs excellently in tasks with context lengths of 10K, 32K, and 100K, achieving 100% accuracy with only a very low sparsity (about 0.5%). - **Language Modeling**: On the PG - 19 dataset, the perplexity of TidalDecode is always lower than that of Quest, indicating that it can still effectively retain key information under long - context inputs. - **LongBench**: In multiple document - answering, summarization, and retrieval tasks, TidalDecode outperforms Quest under a 4096 - token budget and has an average score higher than the full - weight - attention method. In conclusion, TidalDecode significantly improves the performance and efficiency of long - context LLM tasks through its innovative sparse - attention mechanism and cache - correction mechanism.