Transformers Get Stable: An End-to-End Signal Propagation Theory for Language Models

Akhil Kedia,Mohd Abbas Zaidi,Sushil Khyalia,Jungho Jung,Harshith Goka,Haejun Lee
2024-07-19
Abstract:In spite of their huge success, transformer models remain difficult to scale in depth. In this work, we develop a unified signal propagation theory and provide formulae that govern the moments of the forward and backward signal through the transformer model. Our framework can be used to understand and mitigate vanishing/exploding gradients, rank collapse, and instability associated with high attention scores. We also propose DeepScaleLM, an initialization and scaling scheme that conserves unit output/gradient moments throughout the model, enabling the training of very deep models with 1000 layers. We find that transformer models could be much deeper - our deep models with fewer parameters outperform shallow models in Language Modeling, Speech Translation, and Image Classification, across encoder-only, decoder-only and encoder-decoder variants, for both Pre-LN and Post-LN transformers, for multiple datasets and model sizes. These improvements also translate into improved performance on downstream Question Answering tasks and improved robustness for Image Classification.
Computation and Language,Artificial Intelligence,Computer Vision and Pattern Recognition,Machine Learning
What problem does this paper attempt to address?
The paper primarily addresses several key issues present in deep Transformer models: 1. **Gradient Vanishing and Explosion**: As the model depth increases, Transformer models are prone to encountering gradient vanishing or gradient explosion problems during training. 2. **Rank Collapse**: The hidden state representations of Transformer models may become highly correlated in the early stages of training, causing the representations of different input tokens to be very similar. 3. **Instability due to High Attention Scores**: When the query and key values in the attention mechanism are too large, it can lead to instability in model training. To tackle these issues, the paper proposes a unified signal propagation theoretical framework and provides formulas for controlling the statistics (mean and variance) of signals (such as outputs and gradients) during forward and backward propagation. Based on this theoretical framework, the authors further propose a new initialization and scaling scheme called Deep-ScaleLM, which enables the model to maintain stable output and gradient statistics across the entire depth range, thereby facilitating the training of deeper models. By using Deep-ScaleLM, the paper demonstrates that deeper but less parameterized Transformer models can outperform shallower models on various tasks such as language modeling, speech translation, and image classification. This improvement is also reflected in the performance enhancement of downstream question-answering tasks and the robustness enhancement in image classification tasks. Additionally, the method explains the causes of phenomena such as gradient explosion and rank collapse in deep models and provides corresponding solutions.