Automatic Channel Pruning for Multi-Head Attention

Eunho Lee,Youngbae Hwang
2024-05-31
Abstract:Despite the strong performance of Transformers, their quadratic computation complexity presents challenges in applying them to vision tasks. Automatic pruning is one of effective methods for reducing computation complexity without heuristic approaches. However, directly applying it to multi-head attention is not straightforward due to channel misalignment. In this paper, we propose an automatic channel pruning method to take into account the multi-head attention mechanism. First, we incorporate channel similarity-based weights into the pruning indicator to preserve more informative channels in each head. Then, we adjust pruning indicator to enforce removal of channels in equal proportions across all heads, preventing the channel misalignment. We also add a reweight module to compensate for information loss resulting from channel removal, and an effective initialization step for pruning indicator based on difference of attention between original structure and each channel. Our proposed method can be used to not only original attention, but also linear attention, which is more efficient as linear complexity with respect to the number of tokens. On ImageNet-1K, applying our pruning method to the FLattenTransformer, which includes both attention mechanisms, shows outperformed accuracy for several MACs compared with previous state-of-the-art efficient models and pruned methods. Code will be available soon.
Computer Vision and Pattern Recognition,Artificial Intelligence,Computational Complexity
What problem does this paper attempt to address?
The paper attempts to address the challenges encountered when applying automatic channel pruning methods in multi-head attention mechanisms. Specifically, due to the channel alignment issue in multi-head attention mechanisms, directly applying automatic pruning methods to multi-head attention mechanisms is not straightforward. The paper proposes a novel automatic channel pruning method aimed at reducing computational complexity while maintaining the model's expressive power and avoiding performance degradation caused by the channel alignment issue. The main contributions of the paper include: 1. **Channel Similarity Weights**: By introducing weights based on channel similarity to adjust the pruning criteria, more informative channels are retained in each head. 2. **Pruning Criteria Adjustment**: A pruning criteria adjustment process is proposed to ensure that all heads remove channels at the same proportion, preventing the channel alignment issue. 3. **Reweighting Module**: A reweighting module is added to compensate for the information loss caused by channel removal. 4. **Initialization Step**: An effective initialization step is proposed to initialize the pruning criteria based on the differences between the original structure and each channel. Through these methods, the paper validates its effectiveness on the ImageNet-1K dataset. The results show that under different MACs (Multiply-Accumulate Operations per Second), the proposed pruning method significantly reduces computational complexity while maintaining or even improving model performance. Specifically, compared to existing manually designed models and pruning models, the proposed method demonstrates outstanding performance in both efficiency and accuracy.