Direct Quantized Training of Language Models with Stochastic Rounding

Kaiyan Zhao,Tsuguchika Tabaru,Kenichi Kobayashi,Takumi Honda,Masafumi Yamazaki,Yoshimasa Tsuruoka
2024-12-06
Abstract:Although recent quantized Large Language Models (LLMs), such as BitNet, have paved the way for significant reduction in memory usage during deployment with binary or ternary weights, training these models still demands substantial memory footprints. This is partly because high-precision (i.e., unquantized) weight matrices required for straight-through estimation must be maintained throughout the whole training process. To address this, we explore the potential of directly updating the quantized low-precision weight matrices without relying on the straight-through estimator during backpropagation, thereby saving memory usage during training. Specifically, we employ a stochastic rounding technique to minimize information loss caused by the use of low-bit weights throughout training. Experimental results on our LLaMA-structured models indicate that (1) training with only low-precision weights is feasible even when they are constrained to ternary values, (2) extending the bit width to 8 bits results in only a 5% loss degradation compared to BitNet b1.58 while offering the potential for reduced memory usage during training, and (3) our models can also perform inference using ternary weights, showcasing their flexibility in deployment.
Machine Learning,Computation and Language
What problem does this paper attempt to address?
The main problem that this paper attempts to solve is how to reduce memory consumption by directly updating low - precision weight matrices during the training process of large - language models (LLMs) without relying on high - precision weight matrices. Specifically, existing quantization - aware training (QAT) methods need to use the straight - through estimator (STE) to update high - precision weights during the back - propagation process, which leads to a large memory overhead. To overcome this problem, the authors propose a direct quantization training (DQT) method, which uses the stochastic rounding technique to directly update low - precision weight matrices, thereby reducing memory usage during the training process. ### Key contributions of the paper: 1. **Reduction of memory consumption**: By directly updating low - precision weight matrices, the need to maintain high - precision weight matrices in traditional QAT methods is avoided, thereby significantly reducing memory consumption. 2. **Improvement of training efficiency**: The DQT method simplifies the training process and reduces the quantization operations in each training step, which can theoretically accelerate the training process. 3. **Maintenance of performance**: Experimental results show that when the model trained with DQT is quantized to 8 - bits, its performance is comparable to that of the unquantized FP32 model, and even in some cases is better than existing QAT methods (such as BitNet). 4. **Flexibility**: The DQT model can not only use low - precision weights during training, but also use ternary weights during inference, demonstrating its flexibility in deployment. ### Experimental results: - **Convergence**: Experiments show that even when the weights are restricted to ternary values, the DQT model can still converge. - **Influence of different bit widths**: As the bit width increases, the performance of the model gradually improves. The performance of the 8 - bit DQT model is close to that of the FP32 model, while the performance of the 1.58 - bit DQT model is slightly inferior. - **Inference performance**: When the DQT model uses ternary weights for inference, its performance decreases slightly, but still remains within an acceptable range, demonstrating its potential in practical applications. ### Conclusion: The DQT method provides an effective solution that can maintain or approach the performance level of traditional high - precision training while reducing memory consumption. This is of great significance for researchers and organizations with limited resources, especially when dealing with large - scale language models. Future work will further extend DQT to larger models and datasets and optimize hardware support to achieve more efficient low - precision operations.