Dual-Balancing for Multi-Task Learning

Baijiong Lin,Weisen Jiang,Feiyang Ye,Yu Zhang,Pengguang Chen,Ying-Cong Chen,Shu Liu,James T. Kwok
2023-09-29
Abstract:Multi-task learning (MTL), a learning paradigm to learn multiple related tasks simultaneously, has achieved great success in various fields. However, task balancing problem remains a significant challenge in MTL, with the disparity in loss/gradient scales often leading to performance compromises. In this paper, we propose a Dual-Balancing Multi-Task Learning (DB-MTL) method to alleviate the task balancing problem from both loss and gradient perspectives. Specifically, DB-MTL ensures loss-scale balancing by performing a logarithm transformation on each task loss, and guarantees gradient-magnitude balancing via normalizing all task gradients to the same magnitude as the maximum gradient norm. Extensive experiments conducted on several benchmark datasets consistently demonstrate the state-of-the-art performance of DB-MTL.
Machine Learning,Artificial Intelligence
What problem does this paper attempt to address?
This paper attempts to solve the task - balancing problem in multi - task learning (MTL). Specifically, the paper points out that in multi - task learning, the differences in loss/gradient scales between different tasks often lead to performance compromises, that is, some tasks perform well while others perform poorly. To alleviate this problem, the paper proposes a dual - balancing multi - task learning method (Dual - Balancing Multi - Task Learning, DB - MTL) to balance tasks from the perspectives of loss and gradient. ### Specific Problem Description 1. **Task - Balancing Problem**: In multi - task learning, there may be significant differences in the loss or gradient scales of different tasks. Such differences may cause the update direction of the model on some tasks to be dominated, thus affecting the performance of other tasks. For example, on the NYUv2 dataset, the performance of the surface normal prediction task is usually suppressed by the semantic segmentation and depth estimation tasks. 2. **Limitations of Existing Methods**: - **Equal Weighting Method (EW)**: Simply assigns the same weight to all tasks, but usually leads to task - balancing problems. - **Methods for Dynamically Adjusting Task Weights**: Although some methods alleviate the task - balancing problem by dynamically adjusting task weights, these methods still have deficiencies in practical applications, such as being unable to ensure that the loss or gradient scales of all tasks are completely consistent. ### Proposed Solution The paper proposes the DB - MTL method to solve the task - balancing problem from two aspects: 1. **Loss - Scale Balancing**: - **Logarithmic Transformation**: By performing a logarithmic transformation on the loss of each task, ensure that the losses of all tasks have the same scale. The specific formula is: \[ \log(\ell_t(D_t; \theta, \psi_t)) \] - This method is non - parametric and can recover the loss transformation in the IMTL - L method. 2. **Gradient - Magnitude Balancing**: - **Gradient Normalization**: Normalize the gradients of all tasks to the same magnitude as the maximum gradient norm. The specific formula is: \[ \tilde{g}_k = \alpha_k \sum_{t = 1}^T \frac{\hat{g}_{t,k}}{\|\hat{g}_{t,k}\|_2} \] where \(\alpha_k\) is a proportional factor that controls the update magnitude and is selected as the maximum of the gradient norms of all tasks: \[ \alpha_k=\max_{1\leq t\leq T}\|\hat{g}_{t,k}\|_2 \] ### Experimental Verification The paper conducted extensive experiments on multiple benchmark datasets, including NYUv2, Cityscapes, Office - 31, Office - Home and QM9. The experimental results show that the DB - MTL method has achieved state - of - the - art performance on these datasets, especially in terms of the task - balancing problem. ### Conclusion The DB - MTL method proposed in the paper effectively alleviates the task - balancing problem in multi - task learning by simultaneously balancing the loss scale and the gradient magnitude, and improves the overall performance of the model on multiple tasks.