Quest: Query-Aware Sparsity for Efficient Long-Context LLM Inference

Jiaming Tang,Yilong Zhao,Kan Zhu,Guangxuan Xiao,Baris Kasikci,Song Han
2024-08-27
Abstract:As the demand for long-context large language models (LLMs) increases, models with context windows of up to 128K or 1M tokens are becoming increasingly prevalent. However, long-context LLM inference is challenging since the inference speed decreases significantly as the sequence length grows. This slowdown is primarily caused by loading a large KV cache during self-attention. Previous works have shown that a small portion of critical tokens will dominate the attention outcomes. However, we observe the criticality of a token highly depends on the query. To this end, we propose Quest, a query-aware KV cache selection algorithm. Quest keeps track of the minimal and maximal Key values in KV cache pages and estimates the criticality of a given page using Query vectors. By only loading the Top-K critical KV cache pages for attention, Quest significantly speeds up self-attention without sacrificing accuracy. We show that Quest can achieve up to 2.23x self-attention speedup, which reduces inference latency by 7.03x while performing well on tasks with long dependencies with negligible accuracy loss. Code is available at <a class="link-external link-http" href="http://github.com/mit-han-lab/Quest" rel="external noopener nofollow">this http URL</a> .
Computation and Language,Machine Learning
What problem does this paper attempt to address?
### Problems Addressed by the Paper The paper aims to address the speed bottleneck encountered during the inference process of large language models (LLMs) with long contexts. As the demand for long contexts increases, the context window length of models has expanded to 128K or 1M tokens. However, the main challenge faced by long-context inference is the significant decrease in inference speed as the sequence length grows. This slowdown is primarily due to the loading of a large number of key-value (KV) caches in the self-attention mechanism. ### Main Contributions 1. **Analysis of the Self-Attention Mechanism**: The paper highlights the importance of query-aware sparsity. 2. **Proposing the Quest Algorithm**: An efficient KV cache acceleration algorithm that utilizes query-aware sparsity to select critical KV cache pages for self-attention computation through specialized operator design and implementation. 3. **Comprehensive Evaluation**: Demonstrates the significant effects of Quest in reducing self-attention latency and improving end-to-end latency, achieving improvements of 7.03 times and 2.23 times, respectively. ### Method Overview 1. **High Cost of Long-Context Inference**: LLM inference includes a pre-fill phase and a decoding phase. The pre-fill phase converts input tokens into embeddings and generates key (K), query (Q), and value (V) vectors, which are stored in the KV cache. The decoding phase requires self-attention computation for each newly generated token, which is particularly time-consuming in long contexts. 2. **High Sparsity in Self-Attention Operations**: Research shows that only a small number of critical tokens can accumulate sufficient attention scores to capture the most important inter-token relationships in the self-attention mechanism. Therefore, if these critical tokens can be estimated, self-attention computation can be performed only on them, significantly reducing memory movement and improving efficiency. 3. **Critical Tokens Depend on Queries**: The selection of critical tokens is highly dependent on the current query vector Q. The paper demonstrates through experiments the variation of critical tokens under different queries, emphasizing the necessity of query-aware sparsity. 4. **Dynamic Estimation of Token Criticality**: Quest estimates the criticality of each page by approximately calculating attention weights. Specifically, Quest uses the maximum and minimum key-value vectors of each page to perform element-wise multiplication with the current query vector, obtaining an estimate of each page's criticality. Then, the Top-K critical pages are selected for self-attention computation. ### Experimental Results 1. **Language Modeling Task**: On the PG19 dataset, Quest's performance is close to the baseline model with the full KV cache. 2. **Long Text Key Retrieval Task**: In tasks dealing with long-distance dependencies, Quest performs excellently, achieving almost perfect accuracy with a smaller KV cache budget. 3. **LongBench Dataset**: Across multiple long-context datasets, Quest consistently outperforms other baseline methods, maintaining performance comparable to the full-cache model even with a smaller KV cache budget. ### Conclusion By leveraging query-aware sparsity, Quest effectively reduces memory movement and computational costs in long-context LLM inference, significantly improving inference speed while maintaining model accuracy.