Initialization of Large Language Models via Reparameterization to Mitigate Loss Spikes

Kosuke Nishida,Kyosuke Nishida,Kuniko Saito
2024-10-07
Abstract:Loss spikes, a phenomenon in which the loss value diverges suddenly, is a fundamental issue in the pre-training of large language models. This paper supposes that the non-uniformity of the norm of the parameters is one of the causes of loss spikes. Here, in training of neural networks, the scale of the gradients is required to be kept constant throughout the layers to avoid the vanishing and exploding gradients problem. However, to meet these requirements in the Transformer model, the norm of the model parameters must be non-uniform, and thus, parameters whose norm is smaller are more sensitive to the parameter update. To address this issue, we propose a novel technique, weight scaling as reparameterization (WeSaR). WeSaR introduces a gate parameter per parameter matrix and adjusts it to the value satisfying the requirements. Because of the gate parameter, WeSaR sets the norm of the original parameters uniformly, which results in stable training. Experimental results with the Transformer decoders consisting of 130 million, 1.3 billion, and 13 billion parameters showed that WeSaR stabilizes and accelerates training and that it outperformed compared methods including popular initialization methods.
Computation and Language
What problem does this paper attempt to address?
This paper attempts to address the issue of sudden divergence in loss values (i.e., loss spikes) during the pre-training of large-scale language models (LLMs). This phenomenon not only increases the final loss value but can also lead to pre-training failure. The authors believe that the unevenness of parameter norms is a significant cause of loss spikes and propose a new technique—Weight Scaling as Reparameterization (WeSaR) to address this issue. ### Main Contributions: 1. **Identifying the cause of loss spikes**: The authors point out that current initialization methods, which aim to avoid gradient vanishing and explosion problems, lead to uneven parameter norms, which is a major cause of loss spikes. 2. **Proposing the WeSaR method**: By introducing a gating parameter \(\alpha\), WeSaR can adjust the norms of parameter matrices to make them uniform, thereby achieving stable training. 3. **Experimental validation**: Through pre-training Transformer decoders with 130 million, 1.3 billion, and 13 billion parameters, experimental results show that WeSaR not only stabilizes the training process but also accelerates training speed and outperforms existing initialization methods. ### Solution: - **Reparameterization**: WeSaR reparameterizes each parameter matrix \(W\) by introducing a gating parameter \(\alpha\), i.e., using \(\alpha W\) instead of \(W\). This can set the norms of the original parameters to be uniform, thereby avoiding gradient vanishing and explosion problems. - **Uniform standard deviation**: WeSaR allows all parameters to have the same small standard deviation, which helps stabilize training and accelerate convergence. - **Experimental validation**: Experimental results show that WeSaR effectively stabilizes the training process across multiple model scales, reduces the occurrence of loss spikes, and outperforms existing initialization methods in terms of performance. ### Experimental Results: - **Stability**: WeSaR maintains the stability of parameter update ratios during training, avoiding the occurrence of loss spikes. - **Accelerated training**: WeSaR not only stabilizes training but also accelerates the training process, especially in large-scale models. - **Downstream task performance**: Evaluations on multiple downstream tasks indicate that models trained with WeSaR outperform those using other initialization methods. In summary, this paper effectively addresses the issue of loss spikes in the pre-training of large-scale language models by proposing the WeSaR method, providing new insights and technical support for future model training.