ATM: Improving Model Merging by Alternating Tuning and Merging

Luca Zhou,Daniele Solombrino,Donato Crisostomi,Maria Sofia Bucarelli,Fabrizio Silvestri,Emanuele RodolĂ 
2024-11-05
Abstract:Model merging has recently emerged as a cost-efficient paradigm for multi-task learning. Among current approaches, task arithmetic stands out for its simplicity and effectiveness. In this paper, we motivate the effectiveness of task vectors by linking them to multi-task gradients. We show that in a single-epoch scenario, task vectors are mathematically equivalent to the gradients obtained via gradient descent in a multi-task setting, and still approximate these gradients in subsequent epochs. Furthermore, we show that task vectors perform optimally when equality is maintained, and their effectiveness is largely driven by the first epoch's gradient. Building on this insight, we propose viewing model merging as a single step in an iterative process that Alternates between Tuning and Merging (ATM). This method acts as a bridge between model merging and multi-task gradient descent, achieving state-of-the-art results with the same data and computational requirements. We extensively evaluate ATM across diverse settings, achieving up to 20% higher accuracy in computer vision and NLP tasks, compared to the best <a class="link-external link-http" href="http://baselines.Finally" rel="external noopener nofollow">this http URL</a>, we provide both empirical and theoretical support for its effectiveness, demonstrating increased orthogonality between task vectors and proving that ATM minimizes an upper bound on the loss obtained by jointly finetuning all tasks.
Machine Learning,Artificial Intelligence,Computer Vision and Pattern Recognition
What problem does this paper attempt to address?
The problem that this paper attempts to solve is the effectiveness and efficiency of model merging in multi - task learning. Specifically, the author focuses on how to improve the model merging technique through the Alternating Tuning and Merging (ATM) method to achieve more efficient and more accurate multi - task learning. The paper points out that although the existing task arithmetic methods are simple and effective, they may be over - optimized during the multi - task optimization process due to the large norm of the task vector, resulting in a decline in performance. To solve this problem, the paper proposes the ATM framework, which gradually adjusts and merges task - specific models iteratively, thereby achieving better multi - task learning results than existing methods without increasing additional computing resources. ### Main contributions of the paper: 1. **Relationship between task vectors and gradients**: The author proves that under specific conditions, the task vector can be equivalent to or approximate the gradient of the corresponding task loss. 2. **Limitations of existing one - time merging frameworks**: It is pointed out that the existing one - time merging frameworks are often over - optimized in multi - task optimization, especially when the task vector has a large norm. 3. **Alternating Tuning and Merging (ATM) framework**: A new model merging framework ATM is proposed. This framework alternates tuning and merging iteratively, allowing for a more gradual integration of task - specific knowledge. 4. **Experimental verification**: The superior performance of ATM in computer vision and natural language processing tasks has been verified through extensive experiments. In particular, under a limited computing budget, ATM can significantly improve the accuracy of multi - task models. ### Key findings: - **Relationship between task vectors and gradients**: In the scenario of a single epoch, the task vector is mathematically equivalent to the gradient obtained by gradient descent; in subsequent epochs, the task vector can still approximate these gradients. - **Importance of the initial epoch**: The effectiveness of the task vector is largely determined by the gradient of the first epoch. Even in subsequent epochs, the gradient direction also tends to be aligned with the gradient of the first epoch. - **Advantages of ATM**: ATM gradually adjusts and merges task - specific models iteratively, avoiding the over - optimization problem that may be caused by one - time merging, and thus achieving better performance in multi - task learning. ### Experimental results: - **Multi - task accuracy**: ATM performs well under different computing budgets. In particular, under a budget of 10 epochs, the average multi - task accuracy of ATM is 21% higher than that of the best baseline method. - **Orthogonality of task vectors**: As the number of ATM iterations increases, the cosine similarity between task vectors gradually decreases, indicating that the orthogonality of task vectors is enhanced, which helps to reduce task interference. - **Free setting of training data**: Even in the absence of task - specific training data, ATM using only validation data for tuning (valFT ATM) can also significantly outperform the baseline method. ### Conclusion: Through theoretical analysis and experimental verification, the paper demonstrates the effectiveness and superiority of the ATM framework in multi - task learning. ATM gradually adjusts and merges task - specific models iteratively, which not only improves the accuracy of multi - task models but also reduces task interference. It is a recommended multi - task learning method.