Training Deep Neural Networks with 8-bit Floating Point Numbers

Naigang Wang,Jungwook Choi,Daniel Brand,Chia-Yu Chen,Kailash Gopalakrishnan
DOI: https://doi.org/10.48550/arXiv.1812.08011
2018-12-19
Abstract:The state-of-the-art hardware platforms for training Deep Neural Networks (DNNs) are moving from traditional single precision (32-bit) computations towards 16 bits of precision -- in large part due to the high energy efficiency and smaller bit storage associated with using reduced-precision representations. However, unlike inference, training with numbers represented with less than 16 bits has been challenging due to the need to maintain fidelity of the gradient computations during back-propagation. Here we demonstrate, for the first time, the successful training of DNNs using 8-bit floating point numbers while fully maintaining the accuracy on a spectrum of Deep Learning models and datasets. In addition to reducing the data and computation precision to 8 bits, we also successfully reduce the arithmetic precision for additions (used in partial product accumulation and weight updates) from 32 bits to 16 bits through the introduction of a number of key ideas including chunk-based accumulation and floating point stochastic rounding. The use of these novel techniques lays the foundation for a new generation of hardware training platforms with the potential for 2-4x improved throughput over today's systems.
Machine Learning
What problem does this paper attempt to address?
The problem that this paper attempts to solve is to use 8 - bit floating point numbers (FP8) in the training process of deep neural networks (DNNs) to achieve the same model accuracy as the traditional 32 - bit single - precision (FP32), while significantly reducing the demand for computing resources. Specifically, the paper addresses the following main challenges: 1. **Data and computing precision reduced to 8 bits**: When all operands (such as weights, activation values, errors, and gradients) are reduced from 32 bits or 16 bits to 8 bits, most DNNs will suffer from a significant drop in accuracy. By designing a new FP8 floating - point format and combining the insights of DNN training, the paper enables GEMM calculations to work without losing model accuracy. 2. **Accumulation precision reduced from 32 bits to 16 bits**: In GEMM calculations, reducing the accumulation precision from 32 bits to 16 bits will have a significant impact on the convergence of DNN training. The paper introduces the chunk - based accumulation technique. By dividing the dot product of long vectors into small chunks for accumulation, it reduces the problem of information loss when adding large numbers to small numbers, thus maintaining the convergence of the model. 3. **Weight update precision reduced from 32 bits to 16 bits**: 32 - bit weight updates require additional storage of high - precision weights and gradients, which will increase memory overhead. The paper applies the floating - point stochastic rounding technique and uses 16 - bit precision in the weight update process to avoid loss of accuracy. Through these technological innovations, the paper demonstrates the feasibility of using 8 - bit floating - point numbers for DNN training on a series of deep - learning models and datasets while fully maintaining the model's accuracy. These techniques provide a potential for 2 - 4 times improvement in energy efficiency and throughput for future hardware platforms.