Abstract:A primary cost driver for training large models is wall-clock training time. We show that popular time estimates based on FLOPs are poor estimates, and construct a more accurate proxy based on memory copies. This allows us to accurately estimate the training speed of a transformer model from its hyperparameters. Combined with a scaling law curve like Chinchilla, this allows us to accurately predict the final loss of a model from a simple equation. We show that this expression is accurate across a wide range of model hyperparameter values, enabling us to analytically make architectural decisions and train models more efficiently. Crucially, this analysis predicts that in contrast to existing literature, models should be wider rather than deeper, as the benefits of speed outweigh the benefits of depth.
What problem does this paper attempt to address?
The main problem that this paper attempts to solve is: **How to maximize the final performance of the model by selecting the optimal hyper - parameters within a given training time**. Specifically, the author focuses on how to estimate the training speed and the final loss of the model more accurately, so as to optimize the model architecture to improve training efficiency.
### Analysis of the Main Problem
1. **Limitations of Existing Methods**:
- The traditional time - estimation method based on FLOPs (floating - point operations) is not accurate enough.
- There are disputes in the existing literature regarding the selection of model depth and width. It is usually considered that a deeper model is better, but the actual effect may not be so.
2. **The New Method Proposed**:
- The author introduced an estimation method based on memory copies and found that it can predict the running time of the model more accurately than FLOPs.
- By combining existing scaling laws such as Chinchilla, the author can directly predict the final loss of the model from its hyper - parameters without actually training the model.
3. **Core Contributions**:
- Proposed a new framework that can estimate the training speed and the final loss of the model only through the model's hyper - parameters (such as embedding dimension, number of layers, MLP width, etc.).
- Found that for a given training time, a wider rather than a deeper model is more advantageous because the speed advantage outweighs the benefits brought by depth.
### Mathematical Formulas
- **Number of Model Parameters (PARAMS)**:
\[
\text{PARAMS}(d, n, v, w)=vd + nd(8 + 2w + 4d)+nw
\]
where:
- \(d\) is the embedding dimension.
- \(n\) is the number of layers.
- \(v\) is the vocabulary size.
- \(w\) is the MLP width.
- **Number of Memory Copies (MEMCPYS)**:
\[
\text{MEMCPYS}(d, n, s, v, w)=2vd + 2sv+ns(w + 2hs)+2nd(w + 4s + 2d)
\]
where:
- \(s\) is the sequence length.
- \(h\) is the number of attention heads.
- **Number of Floating - Point Operations (FLOPS)**:
\[
\text{FLOPS}(d, n, s, v, w)=2svd + 2dns(w + 2d + s)+nh{s^2}
\]
- **Training Time per Step (TIME)**:
\[
\text{TIME}(d, n, s, v, w)=c_1\cdot\text{MEMCPYS}(d, n, s, v, w)+c_2\cdot\text{FLOPS}(d, n, s, v, w)+c_3
\]
where \(c_1, c_2, c_3\) are coefficients determined by linear regression.
- **Final Loss Prediction (L)**:
\[
\hat{L}(d, n, s, v, w)=E + A\left(\frac{\text{PARAMS}(d, n, v, w)}{N}\right)^\alpha+ B\left(\frac{\text{TIME}(d, n, s, v, w)}{T}\right)^\beta
\]
where \(N\) is the total number of parameters, \(T\) is the total training time, and \(\alpha\) and \(\beta\) are the exponents in the scaling laws.
Through these formulas, the author can accurately predict the model performance under different hyper - parameter configurations without actual training, thereby guiding model design and optimization.