Bayesian Uncertainty for Gradient Aggregation in Multi-Task Learning

Idan Achituve,Idit Diamant,Arnon Netzer,Gal Chechik,Ethan Fetaya
2024-05-13
Abstract:As machine learning becomes more prominent there is a growing demand to perform several inference tasks in parallel. Running a dedicated model for each task is computationally expensive and therefore there is a great interest in multi-task learning (MTL). MTL aims at learning a single model that solves several tasks efficiently. Optimizing MTL models is often achieved by computing a single gradient per task and aggregating them for obtaining a combined update direction. However, these approaches do not consider an important aspect, the sensitivity in the gradient dimensions. Here, we introduce a novel gradient aggregation approach using Bayesian inference. We place a probability distribution over the task-specific parameters, which in turn induce a distribution over the gradients of the tasks. This additional valuable information allows us to quantify the uncertainty in each of the gradients dimensions, which can then be factored in when aggregating them. We empirically demonstrate the benefits of our approach in a variety of datasets, achieving state-of-the-art performance.
Machine Learning
What problem does this paper attempt to address?
This paper proposes a new approach to address the gradient aggregation problem in Multi-Task Learning (MTL). In MTL, it is usually required to train a single model that can handle multiple tasks simultaneously in order to save computational resources and potentially improve generalization ability. However, existing optimization methods overlook the sensitivity of gradient dimensions when aggregating gradients from different tasks. The main contributions of this paper include: 1. Introducing Bayesian inference into gradient aggregation in MTL, by placing probability distributions on task-specific parameters to quantify the uncertainty of each gradient dimension. 2. Proposing a novel posterior approximation method based on second-order Taylor expansion. 3. Designing a new MTL optimization algorithm based on the above posterior estimation, considering the full distribution of gradients to determine the update direction. 4. Demonstrating state-of-the-art performance compared to existing methods on multiple datasets such as QM9, CIFAR-100, ChestX-ray14, and UTKFace. The paper points out that standard MTL optimization methods only consider individual parameter values, losing information during the aggregation step. In contrast, their approach provides a richer description of the gradient space by tracking all parameter configurations, enabling better consideration of the importance of each gradient dimension when searching for the update direction. By introducing Bayesian methods, they are able to assign specific weights to each gradient dimension for each task, improving the accuracy of gradient aggregation. Experimental results validate the effectiveness of this approach.