Transformers to SSMs: Distilling Quadratic Knowledge to Subquadratic Models

Aviv Bick,Kevin Y. Li,Eric P. Xing,J. Zico Kolter,Albert Gu
2024-08-20
Abstract:Transformer architectures have become a dominant paradigm for domains like language modeling but suffer in many inference settings due to their quadratic-time self-attention. Recently proposed subquadratic architectures, such as Mamba, have shown promise, but have been pretrained with substantially less computational resources than the strongest Transformer models. In this work, we present a method that is able to distill a pretrained Transformer architecture into alternative architectures such as state space models (SSMs). The key idea to our approach is that we can view both Transformers and SSMs as applying different forms of mixing matrices over the token sequences. We can thus progressively distill the Transformer architecture by matching different degrees of granularity in the SSM: first matching the mixing matrices themselves, then the hidden units at each block, and finally the end-to-end predictions. Our method, called MOHAWK, is able to distill a Mamba-2 variant based on the Phi-1.5 architecture (Phi-Mamba) using only 3B tokens and a hybrid version (Hybrid Phi-Mamba) using 5B tokens. Despite using less than 1% of the training data typically used to train models from scratch, Phi-Mamba boasts substantially stronger performance compared to all past open-source non-Transformer models. MOHAWK allows models like SSMs to leverage computational resources invested in training Transformer-based architectures, highlighting a new avenue for building such models.
Machine Learning,Artificial Intelligence
What problem does this paper attempt to address?
The paper aims to address a problem in the Transformer architecture of large language models: its self-attention mechanism has a quadratic time complexity, which is a bottleneck in many inference scenarios. To solve this issue, the researchers propose a method to distill pre-trained Transformer architectures into alternative architectures, specifically State Space Models (SSMs), which have lower time complexity. Specifically, the paper proposes a method called MOHAWK, which achieves this goal through three stages: 1. **Matrix Orientation**: First, align the sequence transformation matrices in both the teacher model (Transformer) and the student model (SSMs). 2. **Hidden-State Alignment**: Further align the hidden state representations of each block in both models. 3. **Weight-Transfer and Knowledge Distillation**: Finally, conduct end-to-end model training and use the supervision from the teacher model to optimize the performance of the student model. Using this method, the researchers successfully distilled a Transformer model based on the Phi-1.5 architecture into a Mamba-2 variant called Phi-Mamba, as well as another hybrid version called Hybrid Phi-Mamba. Despite using a training data volume much smaller than what is required to train a model from scratch (less than 1%), the distilled Phi-Mamba model significantly outperformed all previous non-Transformer models on multiple benchmarks and approached the performance of the original Phi-1.5 model. The paper also explores the importance of different stages for the final model performance and demonstrates the effectiveness and expressiveness of the Mamba-2 architecture as a sequence transformer. Experimental results show that this three-stage distillation method effectively enhances the performance of the student model and proves that SSMs can achieve good performance with less data.