Stick-breaking Attention

Shawn Tan,Yikang Shen,Songlin Yang,Aaron Courville,Rameswar Panda
2024-10-23
Abstract:The self-attention mechanism traditionally relies on the softmax operator, necessitating positional embeddings like RoPE, or position biases to account for token order. But current methods using still face length generalisation challenges. We propose an alternative attention mechanism based on the stick-breaking process: For each token before the current, we determine a break point $\beta_{i,j}$, which represents the proportion of the remaining stick to allocate to the current token. We repeat the process until the stick is fully allocated, resulting in a sequence of attention weights. This process naturally incorporates recency bias, which has linguistic motivations for grammar parsing (Shen et. al., 2017). We study the implications of replacing the conventional softmax-based attention mechanism with stick-breaking attention. We then discuss implementation of numerically stable stick-breaking attention and adapt Flash Attention to accommodate this mechanism. When used as a drop-in replacement for current softmax+RoPE attention systems, we find that stick-breaking attention performs competitively with current methods on length generalisation and downstream tasks. Stick-breaking also performs well at length generalisation, allowing a model trained with $2^{11}$ context window to perform well at $2^{14}$ with perplexity improvements.
Machine Learning
What problem does this paper attempt to address?
### What problems does this paper attempt to solve? This paper mainly attempts to solve the length generalization challenges faced by the existing self - attention mechanisms when dealing with long sequences. Specifically, the traditional softmax - based self - attention mechanisms rely on position embeddings (such as RoPE) to introduce information about the order of tokens, but these methods perform poorly when dealing with sequences that exceed the context length used during training. #### Main problems: 1. **Length generalization problem**: The performance of the existing self - attention mechanisms degrades when dealing with sequences longer than those used during training. 2. **Position information introduction problem**: Traditional methods need to use position embeddings or position biases to handle the order of tokens, which increases the model complexity and may affect the generalization ability. 3. **Attention dispersion problem**: In long sequences, high attention scores may be dispersed to irrelevant tokens, leading to a decline in model performance. #### Proposed solutions: To solve these problems, the author proposes a new attention mechanism based on the "stick - breaking process". This mechanism recursively assigns the remaining attention weights to the current token, thereby naturally introducing a recency bias, which helps the model better handle the order of tokens in long sequences. #### Specific improvement points: - **No need for position embeddings**: The stick - breaking attention mechanism can handle the order of tokens without additional position embeddings, simplifying the model structure. - **Better length generalization ability**: Experiments show that the stick - breaking attention mechanism performs well when dealing with sequences longer than those used during training, especially in language modeling tasks. - **Reduced attention dispersion**: Due to the existence of the recency bias, the model is more inclined to focus on the most recent relevant tokens, reducing the interference of irrelevant tokens. ### Conclusion By introducing the stick - breaking attention mechanism, the paper aims to improve the generalization ability and efficiency of the Transformer model when dealing with long sequences, while simplifying the model structure and reducing the dependence on position embeddings.