SpikingBERT: Distilling BERT to Train Spiking Language Models Using Implicit Differentiation

Malyaban Bal,Abhronil Sengupta
2024-02-19
Abstract:Large language Models (LLMs), though growing exceedingly powerful, comprises of orders of magnitude less neurons and synapses than the human brain. However, it requires significantly more power/energy to operate. In this work, we propose a novel bio-inspired spiking language model (LM) which aims to reduce the computational cost of conventional LMs by drawing motivation from the synaptic information flow in the brain. In this paper, we demonstrate a framework that leverages the average spiking rate of neurons at equilibrium to train a neuromorphic spiking LM using implicit differentiation technique, thereby overcoming the non-differentiability problem of spiking neural network (SNN) based algorithms without using any type of surrogate gradient. The steady-state convergence of the spiking neurons also allows us to design a spiking attention mechanism, which is critical in developing a scalable spiking LM. Moreover, the convergence of average spiking rate of neurons at equilibrium is utilized to develop a novel ANN-SNN knowledge distillation based technique wherein we use a pre-trained BERT model as "teacher" to train our "student" spiking architecture. While the primary architecture proposed in this paper is motivated by BERT, the technique can be potentially extended to different kinds of LLMs. Our work is the first one to demonstrate the performance of an operational spiking LM architecture on multiple different tasks in the GLUE benchmark.
Neural and Evolutionary Computing
What problem does this paper attempt to address?
The main problem that this paper attempts to solve is the high demand for computational cost and energy consumption in large - language models (LLMs). Specifically, although LLMs such as GPT - 3 have shown strong capabilities in natural - language - processing (NLP) tasks, they require a large amount of computational resources and energy to run, which is particularly prominent in both the training and inference stages. The paper proposes a new bio - inspired spiking language model (Spiking Language Model, Spiking LM), aiming to reduce the computational cost of traditional language models by drawing on the synaptic information - flow mechanism in the brain, thereby achieving a more efficient and energy - saving language - model design. ### Core Contributions of the Paper 1. **SpikingBERT and Spiking Attention Mechanism**: - A fully operational spiking language model, SpikingBERT, is proposed. Its architecture is based on BERT and has been evaluated for different tasks (classification and regression) on the GLUE benchmark. - An efficient spiking attention mechanism is designed. The average spiking rate (ASR) of this mechanism in the equilibrium state is approximately equal to that of the traditional non - spiking attention mechanism. 2. **Training Methods for Spiking Language Models**: - Theoretically and empirically, it is verified that the proposed spiking language model (including linear and nonlinear operations) converges to an equilibrium state, and the implicit differentiation method is used to overcome the non - differentiable problem in spiking neural network (SNN) training, reducing the memory usage during training. - This method enables the trained spiking language model to exceed the scale of existing spiking models, allowing the development of deeper models to handle complex tasks. 3. **ANN - SNN Knowledge Distillation Framework Based on the Equilibrium State**: - Using the equilibrium state of neurons after convergence, a novel ANN - SNN knowledge distillation framework is proposed. This framework uses the equilibrium - state ASR of specific intermediate layers and the target - layer activation values of a larger pre - trained "teacher" model for effective training. - Through this method, a larger BERT model can be used as a "teacher" to efficiently develop a smaller spiking "student" model, which can improve the model performance without a large number of parameters. ### Key Technologies of the Solution - **Implicit Modeling and Implicit Differentiation**: The paper adopts the method of implicit modeling. Instead of explicitly defining the specific computational process from input to output of the model, it ensures that specific constraints are satisfied by imposing them, so as to achieve the desired results. Especially in the training process, the implicit differentiation technique is used to calculate the gradient after the model reaches the equilibrium state, avoiding the non - differentiable problem of the spiking function. - **Spiking Attention Mechanism**: A computationally efficient spiking attention mechanism is designed, in which the input is processed as spikes from the previous layer. Through this mechanism, only the accumulation operation needs to be performed when performing matrix multiplication at each time step, greatly reducing the computational cost. - **Knowledge Distillation**: Through the knowledge distillation technique, knowledge is transferred from pre - trained large - language models (such as BERT) to spiking language models, especially the inter - layer knowledge transfer between internal layers (such as Transformer layers, embedding layers, and prediction layers), which improves the training efficiency and performance of the model. ### Summary By introducing the SpikingBERT model and its training methods, the paper aims to solve the problem of high demand for computational resources and energy consumption in large - language models. By drawing on the biological mechanisms of the brain and combining implicit differentiation and knowledge distillation techniques, the paper proposes an efficient, energy - saving, and competitive spiking - language - model solution.