Custom Gradient Estimators are Straight-Through Estimators in Disguise

Matt Schoenbauer,Daniele Moro,Lukasz Lew,Andrew Howard
2024-05-23
Abstract:Quantization-aware training comes with a fundamental challenge: the derivative of quantization functions such as rounding are zero almost everywhere and nonexistent elsewhere. Various differentiable approximations of quantization functions have been proposed to address this issue. In this paper, we prove that when the learning rate is sufficiently small, a large class of weight gradient estimators is equivalent with the straight through estimator (STE). Specifically, after swapping in the STE and adjusting both the weight initialization and the learning rate in SGD, the model will train in almost exactly the same way as it did with the original gradient estimator. Moreover, we show that for adaptive learning rate algorithms like Adam, the same result can be seen without any modifications to the weight initialization and learning rate. We experimentally show that these results hold for both a small convolutional model trained on the MNIST dataset and for a ResNet50 model trained on ImageNet.
Machine Learning
What problem does this paper attempt to address?
### What problem does this paper attempt to solve? This paper aims to solve a core challenge in Quantization - Aware Training (QAT): the derivative of quantization functions (such as the rounding function) is almost everywhere zero or does not exist. Specifically, the authors studied the performance of different weight gradient estimators in QAT and proved that a large class of weight gradient estimators are approximately equivalent to the Straight - Through Estimator (STE) under certain conditions. #### Specific problem description: 1. **Derivative problem of quantization functions**: - In QAT, the derivative of quantization functions (such as the rounding function) is almost everywhere zero or does not exist, which makes the traditional back - propagation method unable to update model parameters. - For this reason, researchers have proposed various gradient estimators to approximate the derivative of quantization functions in order to achieve effective back - propagation. 2. **Choice problem of gradient estimators**: - The Straight - Through Estimator (STE) is one of the most commonly used gradient estimators, but many people think that more complex gradient estimators can better approximate quantization functions. - However, this paper points out that this view is misleading and proves through theory and experiments that STE is approximately equivalent to other complex gradient estimators under certain conditions. 3. **Burden of hyper - parameter tuning**: - When using different gradient estimators, it is usually necessary to adjust hyper - parameters such as the learning rate and weight initialization, which increases the burden on QAT practitioners. - The paper shows that under certain conditions, using STE can avoid these complex adjustments, thereby simplifying the QAT process. #### Main contributions: 1. **Theoretical proof**: - For non - adaptive learning rate optimizers (such as SGD), it is proved that all non - zero weight gradient estimators will lead to almost the same weight update after appropriately adjusting the learning rate and weight initialization. - For adaptive learning rate optimizers (such as Adam), it is proved that the result still holds without adjusting the learning rate and weight initialization. 2. **Experimental verification**: - Experiments were carried out on a small convolutional neural network on the MNIST dataset and the ResNet50 model on the ImageNet dataset to verify the above theoretical results. #### Practical significance: - **Simplify the QAT process**: Practitioners can safely choose STE as a gradient estimator without worrying about the hyper - parameter tuning problems brought by other complex estimators. - **Reduce the burden of hyper - parameter tuning**: Reduces the attention to the selection of learning rate, weight initialization and optimization methods, making QAT simpler and more efficient. In conclusion, this paper proves the effectiveness of STE in QAT through theoretical analysis and experiments, providing important guidance and simplification methods for the training of quantized deep - learning models.