PolySketchFormer: Fast Transformers via Sketching Polynomial Kernels

Praneeth Kacham,Vahab Mirrokni,Peilin Zhong
2024-03-18
Abstract:The quadratic time and memory complexity inherent to self-attention mechanisms, with respect to sequence length, presents a critical computational bottleneck in the training and deployment of large-scale Transformer-based language models. Recent theoretical results indicate the intractability of sub-quadratic softmax attention approximation under reasonable complexity assumptions. This paper addresses this challenge by first demonstrating that polynomial attention with high degree can effectively replace softmax without sacrificing model quality. Next, we develop polynomial sketching techniques from numerical linear algebra to achieve linear-time polynomial attention with approximation guarantees. Crucially, our approach achieves this speedup without requiring the sparsification of attention matrices. We also present a block-based algorithm to apply causal masking efficiently. Combining these techniques, we provide \emph{PolySketchFormer}, a practical linear-time Transformer architecture for language modeling that offers provable guarantees. We validate PolySketchFormer empirically by training language models capable of handling long contexts. These experiments utilize both synthetic and real-world datasets (PG19, Wikipedia and C4) on Google Cloud TPUs. For context lengths of 32k and GPT-2 style models, our model achieves a 2.5-4x speedup in training compared to FlashAttention, with no observed degradation in quality across our experiments.
Machine Learning
What problem does this paper attempt to address?
The paper aims to address the quadratic time and space complexity issues brought by the self-attention mechanism as the sequence length increases, which constitutes a critical computational bottleneck in the training and deployment of large-scale Transformer models. Specifically, the paper tackles this challenge through the following points: 1. **High-Order Polynomial Attention Mechanism**: The paper first demonstrates that high-order polynomial attention (degree-p polynomial attention) can effectively replace the softmax attention mechanism without sacrificing model quality. 2. **Polynomial Sketching Technique**: To achieve a polynomial attention mechanism with linear time complexity, the researchers developed a polynomial sketching technique from numerical linear algebra, ensuring that the approximate results have certain theoretical guarantees. This method accelerates the process without the need to sparsify the attention matrix. 3. **Block-Based Algorithm**: The paper also proposes a block-based algorithm to efficiently apply causal masking, thereby overcoming the bottlenecks encountered when applying attention linearization techniques on long sequences. Combining these techniques, the paper introduces PolySketchFormer, a practical linear time complexity Transformer architecture for language modeling, providing provable guarantees. Experiments show that when handling context lengths up to 32k, the model achieves a twofold increase in training speed compared to the fastest FlashAttention configuration, without observing any quality degradation in multiple experiments.