CAST: Clustering Self-Attention using Surrogate Tokens for Efficient Transformers

Adjorn van Engelenhoven, Nicola Strisciuglio, Estefanía Talavera
2024-02-07
Abstract:The Transformer architecture has shown to be a powerful tool for a wide range of tasks. It is based on the self-attention mechanism, which is an inherently computationally expensive operation with quadratic computational complexity: memory usage and compute time increase quadratically with the length of the input sequences, thus limiting the application of Transformers. In this work, we propose a novel Clustering self-Attention mechanism using Surrogate Tokens (CAST), to optimize the attention computation and achieve efficient transformers. CAST utilizes learnable surrogate tokens to construct a cluster affinity matrix, used to cluster the input sequence and generate novel cluster summaries. The self-attention from within each cluster is then combined with the cluster summaries of other clusters, enabling information flow across the entire input sequence. CAST improves efficiency by reducing the complexity from $O(N^2)$ to $O(\alpha N)$ where N is the sequence length, and {\alpha} is constant according to the number of clusters and samples per cluster. We show that CAST performs better than or comparable to the baseline Transformers on long-range sequence modeling tasks, while also achieving higher results on time and memory efficiency than other efficient transformers.
Machine Learning
What problem does this paper attempt to address?
### The problems the paper attempts to solve This paper aims to solve the problem of low computational efficiency in the self - attention mechanism of the Transformer architecture. Specifically, the self - attention mechanism has a quadratic computational complexity \(O(N^2)\), that is, the memory usage and computation time increase quadratically as the length of the input sequence increases. This limits the application of the Transformer in handling long - sequence tasks. To this end, the authors propose a new clustering - based self - attention mechanism - CAST (Clustering self - Attention using Surrogate Tokens), which optimizes the attention computation by introducing learnable surrogate tokens, thereby improving the efficiency of the Transformer. ### Main contributions 1. **CAST mechanism**: CAST constructs a clustering affinity matrix using learnable surrogate tokens, clusters the input sequence, and generates new clustering summaries. The self - attention within each cluster is combined with the summaries from other clusters to realize the information flow of the entire input sequence. 2. **Reduction in computational complexity**: CAST reduces the computational complexity from \(O(N^2)\) to \(O(\alpha N)\), where \(N\) is the sequence length and \(\alpha\) is a constant determined according to the number of clusters and the number of samples in each cluster. 3. **Performance improvement**: Experimental results show that CAST performs better than or is comparable to the baseline Transformer on long - sequence modeling tasks, and also outperforms other efficient Transformer variants in terms of time and memory efficiency. ### Core ideas of the solution - **Learnability in the clustering direction**: CAST ensures that each token can receive information from all clusters by learning the clustering direction, thus avoiding the problems caused by random initialization. - **Clustering summaries**: Each cluster generates a summary, and these summaries transfer information between different clusters, maintaining the original information - flow characteristics of the Transformer. - **Multi - head attention**: CAST supports single - head and multi - head attention mechanisms, and adapts to different application scenarios by dividing surrogate tokens into multiple heads. ### Experimental verification - **Efficiency evaluation**: Through experiments on the Long Range Arena (LRA) benchmark dataset, the speed and memory efficiency of CAST under different sequence lengths are verified. - **Performance evaluation**: A small - scale hyperparameter search is carried out on the LRA benchmark dataset to evaluate the performance of CAST on classification tasks. - **Comparison of clustering mechanisms**: Through ablation experiments, the performance of two clustering mechanisms, Top - K and Single Assignment Top - K, on different tasks is compared, and the influence of cluster size on performance, peak memory usage, and the number of training steps is analyzed. ### Conclusion CAST significantly improves the computational efficiency of the Transformer while maintaining its performance on long - sequence modeling tasks by introducing learnable surrogate tokens and clustering summaries. Experimental results show that CAST outperforms existing efficient Transformer variants in both efficiency and performance.