Geometric Dynamics of Signal Propagation Predict Trainability of Transformers

Aditya Cowsik,Tamra Nebabu,Xiao-Liang Qi,Surya Ganguli
2024-03-05
Abstract:We investigate forward signal propagation and gradient back propagation in deep, randomly initialized transformers, yielding simple necessary and sufficient conditions on initialization hyperparameters that ensure trainability of deep transformers. Our approach treats the evolution of the representations of $n$ tokens as they propagate through the transformer layers in terms of a discrete time dynamical system of $n$ interacting particles. We derive simple update equations for the evolving geometry of this particle system, starting from a permutation symmetric simplex. Our update equations show that without MLP layers, this system will collapse to a line, consistent with prior work on rank collapse in transformers. However, unlike prior work, our evolution equations can quantitatively track particle geometry in the additional presence of nonlinear MLP layers, and it reveals an order-chaos phase transition as a function of initialization hyperparameters, like the strength of attentional and MLP residual connections and weight variances. In the ordered phase the particles are attractive and collapse to a line, while in the chaotic phase the particles are repulsive and converge to a regular $n$-simplex. We analytically derive two Lyapunov exponents: an angle exponent that governs departures from the edge of chaos in this particle system, and a gradient exponent that governs the rate of exponential growth or decay of backpropagated gradients. We show through experiments that, remarkably, the final test loss at the end of training is well predicted just by these two exponents at the beginning of training, and that the simultaneous vanishing of these two exponents yields a simple necessary and sufficient condition to achieve minimal test loss.
Disordered Systems and Neural Networks,Machine Learning
What problem does this paper attempt to address?
This paper primarily explores the issues of signal propagation and gradient backpropagation in deeply randomly initialized Transformer models and proposes a theoretical framework to predict the trainability of these models. ### Main Research Questions The paper aims to address the following core questions: - How to quantitatively describe the forward propagation of signals (i.e., the transmission of input tokens through each layer) and the backpropagation of gradients in deep Transformers, and how this description depends on initialization hyperparameters. - How to use the quantitative description of signal propagation to reasonably select good initialization hyperparameters to ensure minimal final test loss. ### Research Methods and Findings 1. **Signal Propagation Theory**: - The authors view the propagation of n tokens between Transformer layers as a discrete-time dynamical system, where the representation of tokens is modeled as a system of n interacting particles. - They propose update equations to describe the geometric evolution of this particle system, starting from a permutation-symmetric simple shape. - In the absence of MLP layers, this system degenerates into a straight line, consistent with previous studies on rank collapse in Transformers. - When nonlinear MLP layers are added, the geometric changes of the particles can be quantitatively tracked, revealing a sequence-chaos phase transition as a function of initialization hyperparameters (such as attention residual connection strength, MLP residual connection strength, and weight variance). - In the sequence phase, particles attract and gather into a straight line; in the chaos phase, particles repel and converge to a regular n-simplex. 2. **Gradient Propagation Theory**: - Two Lyapunov exponents are derived: the angle exponent and the gradient exponent, which control the degree of chaos at the edges of the particle system and the growth or decay rate of gradient backpropagation, respectively. - Experiments show that these two exponents can well predict the final test loss at the beginning of training. - The condition of both exponents being zero is a necessary and sufficient condition for achieving minimal test loss. 3. **Theoretical Extension**: - The work extends previous research results on purely deep multilayer perceptron (MLP) networks and now applies to the Transformer architecture. - In Transformers, the situation is more complex due to the need to track the geometric changes of n different inputs (tokens) simultaneously propagating through the Transformer blocks. 4. **Experimental Validation**: - Experiments validate the accuracy of the theory's predictions for signal propagation, including the norm distribution of tokens and the angles between tokens. - The experimental results are highly consistent with the theoretical analysis, supporting the theory's validity and predictive power. In summary, this paper establishes a dynamical model of signal and gradient propagation, providing a new perspective for understanding the working principles of deep Transformers and proposing a quantitative method to optimize initialization hyperparameters, thereby improving the model's training performance.