Approximate attention with MLP: a pruning strategy for attention-based model in multivariate time series forecasting

Suhan Guo,Jiahong Deng,Yi Wei,Hui Dou,Furao Shen,Jian Zhao
2024-10-31
Abstract:Attention-based architectures have become ubiquitous in time series forecasting tasks, including spatio-temporal (STF) and long-term time series forecasting (LTSF). Yet, our understanding of the reasons for their effectiveness remains limited. This work proposes a new way to understand self-attention networks: we have shown empirically that the entire attention mechanism in the encoder can be reduced to an MLP formed by feedforward, skip-connection, and layer normalization operations for temporal and/or spatial modeling in multivariate time series forecasting. Specifically, the Q, K, and V projection, the attention score calculation, the dot-product between the attention score and the V, and the final projection can be removed from the attention-based networks without significantly degrading the performance that the given network remains the top-tier compared to other SOTA methods. For spatio-temporal networks, the MLP-replace-attention network achieves a reduction in FLOPS of $62.579\%$ with a loss in performance less than $2.5\%$; for LTSF, a reduction in FLOPs of $42.233\%$ with a loss in performance less than $2\%$.
Machine Learning
What problem does this paper attempt to address?
### Problems the paper attempts to solve This paper aims to explore the effectiveness and necessity of the attention mechanism in the Multivariate Time Series Forecasting (MTSF) task. Specifically, the author focuses on the following two aspects: 1. **Reducing computational complexity and model parameters**: Current attention - based models (such as Transformer) have high computational complexity and a large number of model parameters when dealing with Long - term Time Series Forecasting (LTSF) and Spatio - Temporal Forecasting (STF). The author hopes to reduce these costs by simplifying the attention mechanism. 2. **Evaluating the actual contribution of the attention mechanism**: Although the attention mechanism performs well in many tasks, its actual contribution in the MTSF task has not been fully verified. The author hopes to verify through experiments whether the model performance will decline significantly after removing or simplifying the attention mechanism. ### Main research contents To achieve the above goals, the author proposes the following methods: - **MLP approximating the attention mechanism**: The author finds that by removing key components in the attention mechanism (such as Q, K, V projections and dot - product calculations) and keeping the remaining operations (such as feed - forward, skip - connections and layer normalization), a simplified MLP structure can be constructed. Experiments show that this simplification does not significantly reduce the model performance while greatly reducing the computational complexity and model parameters. - **Empirical analysis**: The author uses multiple benchmark models (such as ASTGNN, STAEFormer, PatchTST, iTransformer) and multiple real - world datasets (such as METR - LA, PEMS - Bay, etc.) to conduct experiments and verifies the effectiveness of the simplified MLP structure in LTSF and STF tasks. ### Main contributions - **Systematically evaluating the effectiveness of the attention module for the first time**: This is the first study to systematically evaluate the effectiveness of the attention module in AMTSFM. - **Significantly reducing computational complexity and model parameters**: By replacing the attention mechanism with an MLP structure, the author achieves a significant reduction in computational complexity and model parameters, with an average reduction of 62.247% in FLOPs and 35.330% in parameters. - **Revealing the importance of skip - connections in the encoder**: The research shows that the skip - connections in the MLP structure are the core structure of AMTSFM, indicating that the current multivariate time series forecasting task can be regarded as a univariate forecasting task, ignoring the dependencies between time steps and nodes. ### Conclusion In summary, this paper proves through experiments that the attention mechanism is not indispensable in the multivariate time series forecasting task and can be replaced by a simpler MLP structure, thereby significantly reducing computational complexity and model parameters while maintaining good prediction performance. This finding provides new ideas for future time series forecasting model design.