GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints

Joshua Ainslie,James Lee-Thorp,Michiel de Jong,Yury Zemlyanskiy,Federico Lebrón,Sumit Sanghai
2023-12-24
Abstract:Multi-query attention (MQA), which only uses a single key-value head, drastically speeds up decoder inference. However, MQA can lead to quality degradation, and moreover it may not be desirable to train a separate model just for faster inference. We (1) propose a recipe for uptraining existing multi-head language model checkpoints into models with MQA using 5% of original pre-training compute, and (2) introduce grouped-query attention (GQA), a generalization of multi-query attention which uses an intermediate (more than one, less than number of query heads) number of key-value heads. We show that uptrained GQA achieves quality close to multi-head attention with comparable speed to MQA.
Computation and Language,Machine Learning
What problem does this paper attempt to address?
The main problem that this paper attempts to solve is the memory bandwidth bottleneck problem encountered by the Transformer model in the autoregressive decoding inference process, especially the overhead brought by loading decoder weights and all attention keys and values. Specifically, the paper focuses on how to accelerate the inference process of large - language models without significantly sacrificing model quality. To solve this problem, the authors make two main contributions: 1. **Converting from a multi - head model to a multi - query model (MQA)**: - A method is proposed to convert the existing multi - head attention (MHA) model checkpoints into models using multi - query attention (MQA) through a small amount of additional pre - training (uptraining). This method only requires 5% of the original pre - training computation, thus providing a cost - effective way to obtain an MQA model for fast inference while maintaining high model quality. 2. **Introducing grouped - query attention (GQA)**: - GQA is a new attention mechanism between multi - head attention and multi - query attention. It uses an intermediate number of key and value heads (fewer than the number of query heads, but more than a single key and value head). Experimental results show that the fine - tuned GQA model can achieve a quality close to that of multi - head attention at a speed close to that of MQA. Through these two methods, the paper aims to find a better compromise to reduce memory bandwidth overhead while maintaining or approaching the quality of the multi - head attention model and significantly improving the inference speed. ### Formula Summary - **Multi - head attention (MHA)**: Each query head corresponds to an independent key and value head. \[ \text{MHA}(Q, K, V)=\text{softmax}\left(\frac{QK^{T}}{\sqrt{d_{k}}}\right)V \] - **Multi - query attention (MQA)**: All query heads share the same key and value head. \[ \text{MQA}(Q, K, V)=\text{softmax}\left(\frac{QK^{T}}{\sqrt{d_{k}}}\right)V \] where \(K\) and \(V\) are shared. - **Grouped - query attention (GQA)**: Divide the query heads into several groups, and each group shares a key and value head. \[ \text{GQA}(Q, K, V)=\text{softmax}\left(\frac{QK^{T}}{\sqrt{d_{k}}}\right)V \] where \(K\) and \(V\) are shared within each group. These methods work together to achieve a better balance between inference speed and quality.