Balancing Speed and Stability: The Trade-offs of FP8 vs. BF16 Training in LLMs

Kazuki Fujii,Taishi Nakamura,Rio Yokota
2024-11-10
Abstract:Large Language Models (LLMs) have attracted significant attention due to their human-like language understanding and generation capabilities, as well as their applicability across various domains. These models, characterized by their massive scale and extensive training data, continue to push the boundaries of what is possible in natural language processing. The Llama 3 series, for instance, exemplifies this trend with its flagship model boasting 405 billion parameters trained on 15.6 trillion tokens. The immense computational demands associated with training such models have spurred ongoing research into optimizing the efficiency of the training process, particularly through the use of lower-precision formats. NVIDIA's H100 GPU, which introduces support for FP8 in addition to the more conventional FP16 and BF16 formats, has emerged as a focal point in this optimization effort. Preliminary studies suggest that FP8 could offer substantial reductions in training time without sacrificing model performance when compared to BF16, making it a promising candidate for large-scale model training. However, the broader implications of adopting FP8, particularly in terms of training stability and downstream task performance, have yet to be fully understood. In this study, we delve into the practical trade-offs involved in adopting FP8 over BF16 for training LLMs.
Machine Learning
What problem does this paper attempt to address?
This paper aims to explore the trade - offs between using FP8 (8 - bit floating - point numbers) and BF16 (16 - bit brain floating - point numbers) in the training of large - scale language models (LLMs). Specifically, the researchers used the Megatron - LM framework to conduct large - scale pre - training on the Llama - 3 - 70B model with multilingual and code corpora, and compared the performance of BF16 and FP8 in terms of training throughput, training loss behavior, and performance across multiple downstream tasks. ### Problems the paper attempts to solve: 1. **Balance between training efficiency and stability**: With the development of large - scale language models, the computational resources required to train these models have become extremely large. Therefore, optimizing the training process, especially by reducing the precision format to increase the training speed, has become an important research direction. However, reducing the precision may introduce the risks of training instability and performance degradation. The paper attempts to evaluate whether FP8 can maintain the stability of model training and the performance of downstream tasks while increasing the training speed compared to BF16. 2. **Sensitivity of different task types to precision formats**: The research also focuses on the responses of different types of downstream tasks (such as question - answering in natural language processing, code generation, mathematical reasoning, etc.) to the changes in FP8 and BF16 precision formats. Through this research, we can better understand which task types are more vulnerable to the reduction in precision, thus providing guidance for choosing the appropriate precision format. ### Research methods: - **Experimental setup**: Use the Megatron - LM framework to continue the pre - training of the Llama - 3 - 70B model. The dataset contains a multilingual and code corpus of approximately 100 billion tokens. - **Comparison metrics**: Training throughput (measured in TFLOPS), training loss curves, and performance in Japanese and English downstream tasks. - **Hardware environment**: The BF16 experiment was carried out on the AI Bridging Cloud Infrastructure (ABCI) equipped with NVIDIA A100 GPUs, while the FP8 experiment was carried out on the TSUBAME 4.0 supercomputer equipped with NVIDIA H100 GPUs. ### Main findings: - **Training speed**: FP8 training significantly improves the training throughput, increasing from 415 TFLOPS in BF16 to a maximum of 570 TFLOPS. - **Training stability**: FP8 training leads to unstable training losses and frequent loss peaks. - **Downstream task performance**: The impact of FP8 on different task types varies. For example, in Japanese, the question - answering task is relatively insensitive to changes in the precision format, while the code generation and mathematical reasoning tasks show more obvious performance degradation. A similar trend was also observed in English tasks, but the differences between task categories were smaller. In summary, through the comparative analysis of FP8 and BF16 in the training of large - scale language models, this paper reveals the importance of weighing training stability and specific task performance while pursuing higher training efficiency.