SpikingSSMs: Learning Long Sequences with Sparse and Parallel Spiking State Space Models

Shuaijie Shen,Chao Wang,Renzhuo Huang,Yan Zhong,Qinghai Guo,Zhichao Lu,Jianguo Zhang,Luziwei Leng
2024-08-27
Abstract:Known as low energy consumption networks, spiking neural networks (SNNs) have gained a lot of attention within the past decades. While SNNs are increasing competitive with artificial neural networks (ANNs) for vision tasks, they are rarely used for long sequence tasks, despite their intrinsic temporal dynamics. In this work, we develop spiking state space models (SpikingSSMs) for long sequence learning by leveraging on the sequence learning abilities of state space models (SSMs). Inspired by dendritic neuron structure, we hierarchically integrate neuronal dynamics with the original SSM block, meanwhile realizing sparse synaptic computation. Furthermore, to solve the conflict of event-driven neuronal dynamics with parallel computing, we propose a light-weight surrogate dynamic network which accurately predicts the after-reset membrane potential and compatible to learnable thresholds, enabling orders of acceleration in training speed compared with conventional iterative methods. On the long range arena benchmark task, SpikingSSM achieves competitive performance to state-of-the-art SSMs meanwhile realizing on average 90\% of network sparsity. On language modeling, our network significantly surpasses existing spiking large language models (spikingLLMs) on the WikiText-103 dataset with only a third of the model size, demonstrating its potential as backbone architecture for low computation cost LLMs.
Computation and Language,Machine Learning,Neural and Evolutionary Computing
What problem does this paper attempt to address?
The main problem that this paper attempts to solve is: how to use spiking neural networks (SNNs) to achieve efficient and low - power - consumption computation in long - time - series tasks. Specifically, the author has developed a new model - Spiking State - Space Models (SpikingSSMs), aiming to combine the parallel computing ability and long - sequence modeling ability of State - Space Models (SSMs) with the sparse computing characteristics of SNNs, in order to solve the efficiency and performance problems of existing methods when dealing with long - time - series tasks. ### Main problem decomposition: 1. **Efficient computation in long - sequence tasks**: - Existing deep - learning models such as Transformer face the problem of high computational complexity when dealing with long - sequence tasks (for example, the time complexity of the self - attention mechanism is O(L^2)), which makes them very time - consuming during training and inference. - Although RNN and its variants have lower time complexity, they are limited by the hidden state space and the vanishing - gradient problem, and it is difficult for them to effectively handle long sequences. 2. **Low - power - consumption computation**: - SNNs have attracted attention because of their low - power - consumption characteristics, but they are less used in long - time - series tasks, mainly due to the conflict between their event - driven neuron dynamics and parallel computing. 3. **Sparse computing and parallelism**: - In order to achieve sparse computing while maintaining efficient parallel computing, it is necessary to solve the contradiction between the asynchrony of neuron dynamics and parallel computing in SNNs. ### Solution overview: - **Introducing SpikingSSMs**: By combining SNNs with SSMs, the author proposes a new architecture that can effectively handle long - sequence tasks and is more computationally efficient and energy - saving. - **Proposing a lightweight Surrogate Dynamic Network (SDN)**: In order to solve the conflict between the event - driven neuron dynamics and parallel computing in SNNs, the author has designed a lightweight Surrogate Dynamic Network (SDN). This network can predict the membrane potential and be compatible with learnable thresholds, thus significantly accelerating the training speed. - **Optimizing neuron thresholds**: By setting neuron thresholds as learnable parameters, the network performance is further improved. ### Experimental verification: - The experimental results on multiple benchmark datasets show that SpikingSSMs not only perform well in long - sequence tasks, but also achieve a network sparsity of up to 90%. Especially in the language - modeling task, when using the WikiText - 103 dataset, the model significantly outperforms the existing spiking large - language models (spikingLLMs) with only one - third of the model size. In conclusion, this paper successfully solves the problems of efficient and low - power - consumption computation in long - time - series tasks by introducing SpikingSSMs and SDN, and shows its potential in practical applications.