SparQ Attention: Bandwidth-Efficient LLM Inference

Luka Ribar,Ivan Chelombiev,Luke Hudlass-Galley,Charlie Blake,Carlo Luschi,Douglas Orr
2024-07-19
Abstract:The computational difficulties of large language model (LLM) inference remain a significant obstacle to their widespread deployment. The need for many applications to support long input sequences and process them in large batches typically causes token-generation to be bottlenecked by data transfer. For this reason, we introduce SparQ Attention, a technique for increasing the inference throughput of LLMs by utilising memory bandwidth more efficiently within the attention layers, through selective fetching of the cached history. Our proposed technique can be applied directly to off-the-shelf LLMs during inference, without requiring any modification to the pre-training setup or additional fine-tuning. We show that SparQ Attention brings up to 8x savings in attention data transfers without substantial drops in accuracy, by evaluating Llama 2 and 3, Mistral, Gemma and Pythia models on a wide range of downstream tasks.
Machine Learning
What problem does this paper attempt to address?
This paper attempts to address the issue of high computational resource consumption during the inference process of large language models (LLMs), particularly when handling long input sequences and large batch processing, where data transfer becomes a performance bottleneck. Specifically, the paper proposes a technique called SparQ Attention, which improves memory bandwidth utilization by selectively fetching cached historical records in the attention layer, thereby increasing the inference throughput of LLMs. ### Main Issues 1. **High computational resource consumption**: The inference of large language models requires a significant amount of computational resources, especially when processing long input sequences. 2. **Data transfer bottleneck**: When handling long input sequences and large batches of data, data transfer becomes the main bottleneck, limiting the inference speed. 3. **Limitations of existing optimization methods**: Existing optimization methods such as KV caching, although helpful in supporting context learning, are still limited by the large amount of data read from memory when processing long sequences. ### Solution The paper proposes the SparQ Attention technique to address the above issues through the following methods: 1. **Selective fetching of cached history**: By predicting which tokens will have high attention scores and only transferring the key-value pairs of these high-score tokens, the data transfer volume is significantly reduced. 2. **No need for retraining or fine-tuning**: SparQ Attention can be directly applied to existing large language models without modifying the pre-training setup or performing additional fine-tuning. 3. **Performance validation**: By evaluating models such as Llama 2, Mistral, Gemma, and Pythia on multiple downstream tasks, it is demonstrated that SparQ Attention can achieve up to 8 times data transfer compression without significantly reducing accuracy. ### Experimental Results - **Performance comparison**: SparQ Attention performs well on multiple tasks, including SQuAD, TriviaQA, CNN/DailyMail, WikiText, and text repetition tasks. - **Long sequence processing capability**: SparQ Attention maintains performance when processing sequences up to 128k in length, demonstrating its scalability on large-scale sequences. - **Performance under sparse settings**: Even under high sparsity settings, the performance of SparQ Attention is very close to the dense attention baseline. ### Conclusion SparQ Attention significantly reduces the data transfer volume during the inference process of large language models by improving memory bandwidth efficiency, thereby increasing inference throughput. This technique performs excellently across multiple tasks and models, showing broad application prospects.