Attention is Not All You Need: Pure Attention Loses Rank Doubly Exponentially with Depth

Yihe Dong,Jean-Baptiste Cordonnier,Andreas Loukas
2023-08-01
Abstract:Attention-based architectures have become ubiquitous in machine learning, yet our understanding of the reasons for their effectiveness remains limited. This work proposes a new way to understand self-attention networks: we show that their output can be decomposed into a sum of smaller terms, each involving the operation of a sequence of attention heads across layers. Using this decomposition, we prove that self-attention possesses a strong inductive bias towards "token uniformity". Specifically, without skip connections or multi-layer perceptrons (MLPs), the output converges doubly exponentially to a rank-1 matrix. On the other hand, skip connections and MLPs stop the output from degeneration. Our experiments verify the identified convergence phenomena on different variants of standard transformer architectures.
Machine Learning
What problem does this paper attempt to address?
The problem this paper attempts to address is understanding the behavior of the self-attention mechanism in deep neural networks and its potential flaws. Specifically, the authors found that pure self-attention networks (SANs, i.e., Transformers without skip connections and multi-layer perceptrons) converge to a rank-1 matrix at a double exponential rate as the network depth increases, leading to information loss and reduced expressive power. The main contributions of the paper include: 1. **Systematic study of Transformer building blocks**: Revealing the interactions between the self-attention mechanism, skip connections, and multi-layer perceptrons (MLPs), and how they affect the rank collapse of the network. 2. **Proposing a path decomposition method**: Decomposing multi-head self-attention networks into multiple simple single-head network paths, revealing that SANs can be viewed as a collection of shallow networks. 3. **Experimental validation of the theory**: Experimentally validating the theoretical results and demonstrating the crucial role of skip connections and MLPs in preventing rank collapse. The core question of the paper is: Why do Transformers perform well in practice, while pure self-attention networks suffer from severe rank collapse? By analyzing the roles of skip connections and MLPs, the authors explain how these components mitigate rank collapse, thereby maintaining the network's effectiveness and expressive power.