MagicPIG: LSH Sampling for Efficient LLM Generation

Zhuoming Chen,Ranajoy Sadhukhan,Zihao Ye,Yang Zhou,Jianyu Zhang,Niklas Nolte,Yuandong Tian,Matthijs Douze,Leon Bottou,Zhihao Jia,Beidi Chen
2024-10-28
Abstract:Large language models (LLMs) with long context windows have gained significant attention. However, the KV cache, stored to avoid re-computation, becomes a bottleneck. Various dynamic sparse or TopK-based attention approximation methods have been proposed to leverage the common insight that attention is sparse. In this paper, we first show that TopK attention itself suffers from quality degradation in certain downstream tasks because attention is not always as sparse as expected. Rather than selecting the keys and values with the highest attention scores, sampling with theoretical guarantees can provide a better estimation for attention output. To make the sampling-based approximation practical in LLM generation, we propose MagicPIG, a heterogeneous system based on Locality Sensitive Hashing (LSH). MagicPIG significantly reduces the workload of attention computation while preserving high accuracy for diverse tasks. MagicPIG stores the LSH hash tables and runs the attention computation on the CPU, which allows it to serve longer contexts and larger batch sizes with high approximation accuracy. MagicPIG can improve decoding throughput by $1.9\sim3.9\times$ across various GPU hardware and achieve 110ms decoding latency on a single RTX 4090 for Llama-3.1-8B-Instruct model with a context of 96k tokens. The code is available at \url{<a class="link-external link-https" href="https://github.com/Infini-AI-Lab/MagicPIG" rel="external noopener nofollow">this https URL</a>}.
Computation and Language,Machine Learning
What problem does this paper attempt to address?
The problem that this paper attempts to solve is the problem that the KV cache becomes a bottleneck due to the long - context window in large - language models (LLMs). Specifically, the paper points out: 1. **KV Cache Problem**: During the autoregressive generation process, in order to avoid repeated calculations, the intermediate attention key - value pairs (KV cache) will be stored. However, as the batch size and sequence length increase, the KV cache grows linearly, occupying a large amount of GPU memory and increasing the decoding time. This makes LLM generation extremely memory - dependent, resulting in low utilization of GPU computing power. 2. **Limitations of TopK Attention**: Although many studies have attempted to solve the KV cache problem through dynamic sparsity or TopK attention approximation methods, these methods are based on the assumption that attention is naturally sparse. However, the paper points out that TopK attention has the following problems: - **Quality Degradation**: In some downstream tasks, even using accurate TopK attention will significantly reduce accuracy. - **High Overhead**: Identifying TopK attention itself requires a large overhead and sometimes even becomes a bottleneck. - **No Memory Savings**: Although these methods can reduce the loading time of the KV cache, they cannot reduce the total memory occupied by the KV cache, limiting the maximum context and batch size. 3. **Rethinking Attention Sparsity**: The paper observes that attention is not always sparse, especially in tasks that utilize the full context. In addition, the "attention pool" phenomenon in the attention distribution makes the attention scores appear sparser, but in fact the distribution is more uniform. These findings indicate that relying solely on TopK attention is not enough. To solve these problems, the paper proposes a new method - MagicPIG, which uses locality - sensitive hashing (LSH) sampling to efficiently generate LLMs. MagicPIG solves the above problems in the following ways: - **Sampling Estimation**: Unlike TopK attention, MagicPIG uses a sampling method to estimate the attention output, thereby providing better estimation accuracy. - **System Design**: By transferring some computational tasks (such as hash table queries and attention calculations) to the CPU, MagicPIG can support larger batches or longer contexts while maintaining high accuracy. - **Performance Improvement**: Experimental results show that MagicPIG can increase the decoding throughput by 1.9 to 3.9 times on various GPU hardware and achieve a decoding delay of 110 milliseconds on a single RTX 4090 for the Llama - 3.1 - 8B - Instruct model with a 96k context. In summary, this paper aims to solve the KV cache bottleneck problem in long - context LLM generation by introducing MagicPIG, while improving decoding efficiency and accuracy.