Efficient fine-tuning of 37-level GraphCast with the Canadian global deterministic analysis

Christopher Subich
2024-08-27
Abstract:This work describes a process for efficiently fine-tuning the GraphCast data-driven forecast model to simulate another analysis system, here the Global Deterministic Prediction System (GDPS) of Environment and Climate Change Canada (ECCC). Using two years of training data (July 2019 -- December 2021) and 37 GPU-days of computation to tune the 37-level, quarter-degree version of GraphCast, the resulting model significantly outperforms both the unmodified GraphCast and operational forecast, showing significant forecast skill in the troposphere over lead times from 1 to 10 days. This fine-tuning is accomplished through abbreviating DeepMind's original training curriculum for GraphCast, relying on a shorter single-step forecast stage to accomplish the bulk of the adaptation work and consolidating the autoregressive stages into separate 12hr, 1d, 2d, and 3d stages with larger learning rates. Additionally, training over 3d forecasts is split into two sub-steps to conserve host memory while maintaining a strong correlation with training over the full period.
Machine Learning,Atmospheric and Oceanic Physics
What problem does this paper attempt to address?
The problem this paper attempts to address is how to efficiently fine-tune the GraphCast data-driven prediction model to simulate the Global Deterministic Prediction System (GDPS) of Environment and Climate Change Canada (ECCC). Specifically, the paper fine-tunes the 37-layer, quarter-degree resolution GraphCast model using two years of training data (July 2019 to December 2021) and 37 days of GPU computation time. The fine-tuned model significantly outperforms the unmodified GraphCast model and operational forecasts within 1 to 10 days of forecast time, demonstrating significant forecasting capabilities in the troposphere. ### Background and Significance 1. **Background**: - The application of machine learning methods in medium- to long-term weather forecasting has started a revolution. Models like GraphCast, FourCastNet, and Pangu-Weather have surpassed the high-resolution model (HRES) of the European Centre for Medium-Range Weather Forecasts (ECMWF) in forecasting skills, especially within 10 days of forecast. - These data-driven models, due to the extensive use of accelerator cards (typically GPUs), have much lower computation times than traditional models while maintaining comparable accuracy. - However, data-driven models can experience performance degradation when the input data distribution is inconsistent with the training data distribution. For example, the GraphCast model is primarily trained on the ERA5 dataset, while different operational analysis systems (such as ECCC's GDPS) may have systematic differences from ERA5. 2. **Significance**: - This study provides the first publicly available guide for adapting the full 37-layer, quarter-degree resolution GraphCast model to other analysis systems. - The unmodified GraphCast model experiences performance degradation when using initial conditions from different analysis systems, making fine-tuning crucial for maximizing the operational application of GraphCast. ### Methods 1. **Datasets**: - The primary dataset is the "late" operational analysis data generated by ECCC's Global Deterministic Prediction System from July 2019 to December 2021. - The validation dataset is data from 2022, and the test dataset is data from 2023. - ERA5 data is used as a control group, fine-tuning with the same date range. 2. **Fine-tuning Process**: - The fine-tuning process includes several stages, each with different learning rate schedules. - The AdamW optimizer is used, with momentum parameters kept constant and a weight decay factor of 0.1. - The training batch size is reduced from 32 to 4 to accommodate the 4-GPU computing nodes of Environment and Climate Change Canada. - The loss function is a scalar combination of the mean squared error of the predicted variables, model layers, and forecast lead times. - Recalculate normalization factors to adapt to the distribution characteristics of GDPS data. - Determine the learning rate for each training stage through cosine annealing and a linear warm-up period. ### Results 1. **Performance Improvement**: - The fine-tuned GraphCast model significantly outperforms the unmodified GraphCast model and operational forecasts within 1 to 10 days of forecast time. - Demonstrates significant forecasting capabilities in the troposphere, particularly in medium- to short-term forecasts. 2. **Validation and Testing**: - Results on the validation and test datasets indicate that the fine-tuned model has higher forecasting skills. ### Conclusion This study demonstrates how to adapt the GraphCast model to specific operational analysis systems through fine-tuning, thereby improving its forecasting performance. This has important reference value for meteorological centers in other countries and regions.