Multi-Head Attention: Collaborate Instead of Concatenate

Jean-Baptiste Cordonnier,Andreas Loukas,Martin Jaggi
DOI: https://doi.org/10.48550/arXiv.2006.16362
2021-05-20
Abstract:Attention layers are widely used in natural language processing (NLP) and are beginning to influence computer vision architectures. Training very large transformer models allowed significant improvement in both fields, but once trained, these networks show symptoms of over-parameterization. For instance, it is known that many attention heads can be pruned without impacting accuracy. This work aims to enhance current understanding on how multiple heads interact. Motivated by the observation that attention heads learn redundant key/query projections, we propose a collaborative multi-head attention layer that enables heads to learn shared projections. Our scheme decreases the number of parameters in an attention layer and can be used as a drop-in replacement in any transformer architecture. Our experiments confirm that sharing key/query dimensions can be exploited in language understanding, machine translation and vision. We also show that it is possible to re-parametrize a pre-trained multi-head attention layer into our collaborative attention layer. Collaborative multi-head attention reduces the size of the key and query projections by 4 for same accuracy and speed. Our code is public.
Machine Learning,Computation and Language
What problem does this paper attempt to address?
The main problems that this paper attempts to solve are the parameter redundancy and efficiency issues in the Multi - Head Attention (MHA). Specifically, the authors observe that in the existing multi - head attention mechanism, there is a large amount of redundancy in the key/query projections between different heads, that is, multiple heads often learn similar feature representations. This redundancy not only increases the number of parameters in the model but may also lead to a waste of computing resources. To solve these problems, the paper proposes a new Collaborative Multi - Head Attention mechanism. By allowing different heads to share key/query projections, the number of parameters is reduced and the computational efficiency is improved. The following are the main contributions of the paper: 1. **Redundancy Analysis**: Use Principal Component Analysis (PCA) to quantify the redundancy degree of the key/query matrices between different heads. 2. **Collaborative Multi - Head Attention Mechanism**: Propose a new multi - head attention mechanism that allows heads to share key/query projections, reducing the number of parameters. 3. **Tensor Decomposition**: Utilize tensor decomposition techniques to re - parameterize the pre - trained multi - head attention layer so that it can use the Collaborative Multi - Head Attention mechanism. 4. **Experimental Verification**: Verify the effectiveness of the Collaborative Multi - Head Attention mechanism through a series of experiments, including Neural Machine Translation (NMT), Natural Language Understanding (NLU) and image classification tasks. ### Formula Summary - **Standard Multi - Head Attention Mechanism**: \[ \text{MultiHead}(X, Y)=\text{concat}_{i\in [N_h]}[H^{(i)}]W_O \] where \[ H^{(i)}=\text{Attention}(XW_Q^{(i)}, YW_K^{(i)}, YW_V^{(i)}) \] - **Collaborative Multi - Head Attention Mechanism**: \[ \text{CollabHead}(X, Y)=\text{concat}_{i\in [N_h]}[H^{(i)}]W_O \] where \[ H^{(i)}=\text{Attention}(X\tilde{W}_Q \text{diag}(m_i), Y\tilde{W}_K, YW_V^{(i)}) \] Here, each head learns a mixing vector \(m_i\), which is used to weight the shared key/query projections. - **Tensor Decomposition**: \[ W_{QK}:=\text{stack}_{i\in [N_h]}[W_Q^{(i)}W_K^{(i)\top}]\in \mathbb{R}^{N_h\times D_{\text{in}}\times D_{\text{in}}} \] Use Tucker decomposition: \[ T\approx G\times_1 A\times_2 B\times_3 C = \sum_{p = 1}^P\sum_{q = 1}^Q\sum_{r = 1}^R g_{pqr}a_p\circ b_q\circ c_r \] ### Experimental Results - In the neural machine translation task, the Collaborative Multi - Head Attention mechanism can maintain or even improve performance while reducing the number of parameters. - In the image classification task, the Collaborative Multi - Head Attention mechanism also performs well, especially in significantly reducing the number of parameters. - For the natural language understanding task, by re - parameterizing the pre - trained model, the Collaborative Multi - Head Attention mechanism can effectively compress the model without significantly degrading performance. In conclusion, this paper effectively solves the parameter redundancy problem in the multi - head attention mechanism by introducing the Collaborative Multi - Head Attention mechanism and verifies its superiority in multiple tasks.