Blockwise Parallel Transformer for Large Context Models

Hao Liu,Pieter Abbeel
2023-08-29
Abstract:Transformers have emerged as the cornerstone of state-of-the-art natural language processing models, showcasing exceptional performance across a wide range of AI applications. However, the memory demands posed by the self-attention mechanism and the large feedforward network in Transformers limit their ability to handle long sequences, thereby creating challenges for tasks involving multiple long sequences or long-term dependencies. We present a distinct approach, Blockwise Parallel Transformer (BPT), that leverages blockwise computation of self-attention and feedforward network fusion to minimize memory costs. By processing longer input sequences while maintaining memory efficiency, BPT enables training sequences 32 times longer than vanilla Transformers and up to 4 times longer than previous memory-efficient methods. Extensive experiments on language modeling and reinforcement learning tasks demonstrate the effectiveness of BPT in reducing memory requirements and improving performance.
Computation and Language,Machine Learning
What problem does this paper attempt to address?
### What problem does this paper attempt to solve? This paper aims to solve the memory bottleneck problem encountered by the Transformer model when processing long sequences. Specifically, due to the existence of its self - attention mechanism and large - sized feedforward network, the Transformer model requires a large amount of memory resources when processing long - sequence or multi - sequence tasks, which limits its application in tasks involving long - term dependencies or multiple long sequences. #### Main problems: 1. **Excessively high memory requirements**: The time complexity of the self - attention mechanism of the Transformer model is O(s²), where s is the sequence length. This means that as the sequence length increases, the memory requirements will rise sharply, making it impractical to process very long sequences. 2. **Memory consumption of the feedforward network**: In addition to the self - attention mechanism, the feedforward network in the Transformer also contains a large number of parameters and generates high - dimensional intermediate vectors, further increasing the memory consumption. #### Solutions: To solve these problems, the authors proposed the Blockwise Parallel Transformer (BPT) method. BPT reduces memory requirements and improves performance in the following ways: 1. **Block - parallel self - attention calculation**: Divide the input sequence into multiple small blocks and calculate self - attention independently on each block. This can avoid calculating the attention matrix of the entire sequence at once, thereby significantly reducing memory usage. 2. **Fusion calculation of the feedforward network and self - attention**: After completing the self - attention calculation on each query block, immediately perform the calculation of the feedforward network without waiting for the self - attention calculation of the entire sequence to be completed. This block - parallel calculation method reduces the need for feedforward network calculations of the entire sequence, thereby reducing memory overhead. Through these improvements, BPT can process longer input sequences while maintaining a lower memory budget. The experimental results show that BPT can train sequences 32 times longer than the traditional Transformer and 2 to 4 times longer than the most efficient existing memory - optimization methods (such as FlashAttention and Memory Efficient Attention). In addition, the paper also demonstrates the effectiveness of BPT in language modeling and reinforcement learning tasks, especially in handling complex and long - term - dependent tasks.