Is Flash Attention Stable?

Alicia Golden,Samuel Hsia,Fei Sun,Bilge Acun,Basil Hosmer,Yejin Lee,Zachary DeVito,Jeff Johnson,Gu-Yeon Wei,David Brooks,Carole-Jean Wu
2024-05-05
Abstract:Training large-scale machine learning models poses distinct system challenges, given both the size and complexity of today's workloads. Recently, many organizations training state-of-the-art Generative AI models have reported cases of instability during training, often taking the form of loss spikes. Numeric deviation has emerged as a potential cause of this training instability, although quantifying this is especially challenging given the costly nature of training runs. In this work, we develop a principled approach to understanding the effects of numeric deviation, and construct proxies to put observations into context when downstream effects are difficult to quantify. As a case study, we apply this framework to analyze the widely-adopted Flash Attention optimization. We find that Flash Attention sees roughly an order of magnitude more numeric deviation as compared to Baseline Attention at BF16 when measured during an isolated forward pass. We then use a data-driven analysis based on the Wasserstein Distance to provide upper bounds on how this numeric deviation impacts model weights during training, finding that the numerical deviation present in Flash Attention is 2-5 times less significant than low-precision training.
Machine Learning,Distributed, Parallel, and Cluster Computing
What problem does this paper attempt to address?
The problem that this paper attempts to solve is the training instability issue that arises during the training process of large - scale machine - learning models, especially the instability related to numeric deviation. Specifically, the paper focuses on whether the use of Flash Attention optimization techniques results in instability during the training process, such as sudden increases in the loss function (loss spikes), due to numeric deviation. These instability phenomena may lead to training interruptions and require restarting the training, thereby increasing the cost and time of training. Therefore, the goal of the paper is to develop a principled method to quantify and understand the impact of numeric deviation and to evaluate the specific impact of this deviation on changes in model weights in order to determine whether Flash Attention can lead to training instability.