Adam-mini: Use Fewer Learning Rates To Gain More

Yushun Zhang,Congliang Chen,Ziniu Li,Tian Ding,Chenwei Wu,Yinyu Ye,Zhi-Quan Luo,Ruoyu Sun
2024-07-04
Abstract:We propose Adam-mini, an optimizer that achieves on-par or better performance than AdamW with 45% to 50% less memory footprint. Adam-mini reduces memory by cutting down the learning rate resources in Adam (i.e., $1/\sqrt{v}$). We find that $\geq$ 90% of these learning rates in $v$ could be harmlessly removed if we (1) carefully partition the parameters into blocks following our proposed principle on Hessian structure; (2) assign a single but good learning rate to each parameter block. We further find that, for each of these parameter blocks, there exists a single high-quality learning rate that can outperform Adam, provided that sufficient resources are available to search it out. We then provide one cost-effective way to find good learning rates and propose Adam-mini. Empirically, we verify that Adam-mini performs on par or better than AdamW on various language models sized from 125M to 7B for pre-training, supervised fine-tuning, and RLHF. The reduced memory footprint of Adam-mini also alleviates communication overheads among GPUs and CPUs, thereby increasing throughput. For instance, Adam-mini achieves 49.6% higher throughput than AdamW when pre-training Llama2-7B on $2\times$ A800-80GB GPUs, which saves 33% wall-clock time for pre-training.
Machine Learning,Artificial Intelligence
What problem does this paper attempt to address?
The paper attempts to address the issue of excessive memory consumption by the Adam optimizer in large-scale language model (LLM) training. Specifically: 1. **Memory Consumption Issue**: Although the Adam optimizer performs well in training large language models, it requires a significant amount of memory to store its state variables, such as the first-order momentum \( m \) and the second-order momentum \( v \). The total memory usage of these state variables is at least twice the size of the model. For example, when training a model with 7B parameters, the Adam optimizer alone requires about 56GB of VRAM, and with the gradients' memory, a total of 86GB is needed. This is a substantial burden even for high-end GPUs like the A100-80GB. 2. **Scalability Issue**: As the model size increases, the memory consumption problem becomes more severe. For instance, when training the PaLM model with 540B parameters, the Adam optimizer alone occupies more than 50 GPUs and becomes the main cost of pre-training. 3. **Performance Optimization Issue**: Reducing memory consumption can not only alleviate the need for CPU offloading and parameter sharding but also improve training throughput and accelerate the training process. Additionally, reducing memory consumption allows researchers to train larger models with limited GPU resources, thereby saving costs and energy. To address these issues, the paper proposes a new optimizer—Adam-mini. Adam-mini significantly reduces memory consumption by reducing the use of learning rate resources while maintaining comparable or better performance than AdamW. Specifically, Adam-mini achieves this goal through the following methods: - **Parameter Partitioning**: The model parameters are divided into multiple blocks based on the approximate block diagonal structure of the Hessian matrix. Each block is assigned a single learning rate instead of assigning a learning rate to each parameter. - **Efficient Learning Rate Selection**: The appropriate learning rate is selected by calculating the average value of each block, thereby avoiding the high cost of grid search. Experimental results show that Adam-mini performs comparably or better than AdamW in various language model tasks (ranging from 125M to 7B parameters) including pre-training, supervised fine-tuning, and reinforcement learning based on human feedback, while reducing memory consumption by 45% to 50%. Additionally, Adam-mini improves training throughput; for example, when pre-training the Llama2-7B model, Adam-mini increases throughput by 49.6% compared to AdamW, saving 33% of training time.