LaMDA: Large Model Fine-Tuning via Spectrally Decomposed Low-Dimensional Adaptation

Seyedarmin Azizi,Souvik Kundu,Massoud Pedram
2024-06-19
Abstract:Low-rank adaptation (LoRA) has become the default approach to fine-tune large language models (LLMs) due to its significant reduction in trainable parameters. However, trainable parameter demand for LoRA increases with increasing model embedding dimensions, leading to high compute costs. Additionally, its backward updates require storing high-dimensional intermediate activations and optimizer states, demanding high peak GPU memory. In this paper, we introduce large model fine-tuning via spectrally decomposed low-dimensional adaptation (LaMDA), a novel approach to fine-tuning large language models, which leverages low-dimensional adaptation to achieve significant reductions in trainable parameters and peak GPU memory footprint. LaMDA freezes a first projection matrix (PMA) in the adaptation path while introducing a low-dimensional trainable square matrix, resulting in substantial reductions in trainable parameters and peak GPU memory usage. LaMDA gradually freezes a second projection matrix (PMB) during the early fine-tuning stages, reducing the compute cost associated with weight updates to enhance parameter efficiency further. We also present an enhancement, LaMDA++, incorporating a ``lite-weight" adaptive rank allocation for the LoRA path via normalized spectrum analysis of pre-trained model weights. We evaluate LaMDA/LaMDA++ across various tasks, including natural language understanding with the GLUE benchmark, text summarization, natural language generation, and complex reasoning on different LLMs. Results show that LaMDA matches or surpasses the performance of existing alternatives while requiring up to 17.7x fewer parameter updates and up to 1.32x lower peak GPU memory usage during fine-tuning. Code will be publicly available.
Computation and Language,Artificial Intelligence,Machine Learning
What problem does this paper attempt to address?
The paper primarily aims to address several key issues in the fine-tuning process of large language models (LLMs): 1. **High cost due to large parameter size**: Directly fine-tuning large language models requires a significant amount of computational resources, which is almost infeasible on edge devices. 2. **High memory demand**: The fine-tuning process requires storing a large amount of intermediate activation data, which imposes high demands on GPU memory. 3. **Risk of overfitting**: Full model fine-tuning can easily lead to overfitting, especially when the number of parameters is large. 4. **Catastrophic forgetting**: Fine-tuning large pre-trained models on specific tasks may lead to forgetting the performance on the original tasks. To address the above issues, the researchers propose a new method—LaMDA (Large Model Fine-tuning via Spectrally Decomposed Low-Dimensional Adaptation), which is a new framework for fine-tuning large language models. It aims to significantly reduce the number of trainable parameters and the memory requirements for activations, thereby lowering computational costs and GPU memory usage. Additionally, the paper introduces an enhanced version, LaMDA++, which can adaptively allocate the rank of different layers based on the "energy score" of the pre-trained model weights, further optimizing parameter allocation. The key innovations of LaMDA include: - Using low-dimensional adapters (LDA) instead of the projection matrices in traditional LoRA methods, thereby significantly reducing the number of trainable parameters and downscaling the saved activation data to a lower-dimensional space, greatly saving memory requirements. - Proposing a progressive freezing strategy to gradually freeze the second projection matrix (PMB), further reducing the computational cost associated with weight updates. - LaMDA++ achieves layer-based adaptive rank allocation by analyzing the spectral decomposition of the pre-trained model weights, thereby more efficiently utilizing the limited fine-tuning budget. Experimental results show that LaMDA and its enhanced version LaMDA++ perform excellently on multiple natural language processing tasks, including natural language understanding, text summarization, and complex reasoning. At the same time, they significantly reduce the number of trainable parameters and the peak GPU memory usage compared to existing methods.