Weight Prediction Boosts the Convergence of AdamW

Lei Guan
2023-08-08
Abstract:In this paper, we introduce weight prediction into the AdamW optimizer to boost its convergence when training the deep neural network (DNN) models. In particular, ahead of each mini-batch training, we predict the future weights according to the update rule of AdamW and then apply the predicted future weights to do both forward pass and backward propagation. In this way, the AdamW optimizer always utilizes the gradients w.r.t. the future weights instead of current weights to update the DNN parameters, making the AdamW optimizer achieve better convergence. Our proposal is simple and straightforward to implement but effective in boosting the convergence of DNN training. We performed extensive experimental evaluations on image classification and language modeling tasks to verify the effectiveness of our proposal. The experimental results validate that our proposal can boost the convergence of AdamW and achieve better accuracy than AdamW when training the DNN models.
Machine Learning,Optimization and Control
What problem does this paper attempt to address?
The problem that this paper attempts to solve is to improve the convergence speed and performance of the AdamW optimizer when training deep neural network (DNN) models. Specifically, the author introduced the weight prediction technique, hoping to improve the update rule of the AdamW optimizer by predicting future weights, thereby achieving faster convergence and higher accuracy. ### Specific description of the problem 1. **Existing problems**: - AdamW is a commonly used optimizer, widely applied in deep - learning tasks such as image classification and language modeling. - Although AdamW performs well on many tasks, in some cases, there is still room for improvement in its convergence speed and final performance. 2. **Proposed method**: - The author proposed a new method, that is, before each mini - batch training, predict future weights according to the update rule of AdamW, and use these predicted future weights for forward propagation and backward propagation. - The purpose of doing this is to make the AdamW optimizer always use the gradients of future weights to update DNN parameters instead of current weights, thereby accelerating convergence and improving model performance. ### Formula explanation To better understand this method, we can review the update formula of AdamW: \[ \theta_t=(1 - \gamma\lambda)\theta_{t - 1}-\frac{\gamma\hat{m}_t}{\sqrt{\hat{v}_t}+\epsilon} \] where: - \(\theta_t\) is the weight after the \(t\) - th iteration. - \(\gamma\) is the learning rate. - \(\lambda\) is the weight decay coefficient. - \(\hat{m}_t\) and \(\hat{v}_t\) are the first and second moment estimates after bias correction respectively. - \(\epsilon\) is a smoothing term to prevent division by zero. By introducing weight prediction, the author derived the following formula to predict future weights: \[ \hat{\theta}_{t + s}=\theta_t - s\frac{\gamma\hat{m}_{t + 1}}{\sqrt{\hat{v}_{t + 1}}+\epsilon} \] where: - \(\hat{\theta}_{t + s}\) is the predicted future weight. - \(s\) is the number of prediction steps. ### Experimental verification The author carried out experimental verification on multiple tasks, including image classification and language modeling tasks. The experimental results show that this method is superior to the traditional AdamW optimizer in both convergence speed and final performance. ### Conclusion By introducing the weight prediction technique, this paper successfully improved the convergence speed and performance of the AdamW optimizer when training deep neural networks. This method is not only simple and easy to implement, but also shows significant improvement effects on multiple tasks.