Meta-learning with an Adaptive Task Scheduler

Huaxiu Yao,Yu Wang,Ying Wei,Peilin Zhao,Mehrdad Mahdavi,Defu Lian,Chelsea Finn
DOI: https://doi.org/10.48550/arXiv.2110.14057
2021-10-27
Abstract:To benefit the learning of a new task, meta-learning has been proposed to transfer a well-generalized meta-model learned from various meta-training tasks. Existing meta-learning algorithms randomly sample meta-training tasks with a uniform probability, under the assumption that tasks are of equal importance. However, it is likely that tasks are detrimental with noise or imbalanced given a limited number of meta-training tasks. To prevent the meta-model from being corrupted by such detrimental tasks or dominated by tasks in the majority, in this paper, we propose an adaptive task scheduler (ATS) for the meta-training process. In ATS, for the first time, we design a neural scheduler to decide which meta-training tasks to use next by predicting the probability being sampled for each candidate task, and train the scheduler to optimize the generalization capacity of the meta-model to unseen tasks. We identify two meta-model-related factors as the input of the neural scheduler, which characterize the difficulty of a candidate task to the meta-model. Theoretically, we show that a scheduler taking the two factors into account improves the meta-training loss and also the optimization landscape. Under the setting of meta-learning with noise and limited budgets, ATS improves the performance on both miniImageNet and a real-world drug discovery benchmark by up to 13% and 18%, respectively, compared to state-of-the-art task schedulers.
Machine Learning
What problem does this paper attempt to address?
The problem that this paper attempts to solve is in meta - learning, how to select meta - training tasks more effectively to improve the generalization ability of the model. Specifically, existing meta - learning algorithms usually sample randomly and uniformly from meta - training tasks, assuming that all tasks are of the same importance. However, in practical applications, tasks may be affected by noise or have unbalanced data distributions, which may cause the meta - model to be damaged by harmful tasks or dominated by the majority of tasks. To solve these problems, the authors propose an Adaptive Task Scheduler (ATS), which predicts the sampling probability of each candidate task through a neural scheduler and optimizes the generalization ability of the meta - model for unseen tasks. ATS considers two factors related to the meta - model: the loss on the query set of the task and the gradient similarity, in order to better characterize the difficulty of the task. ### Main contributions 1. **Propose an Adaptive Task Scheduler**: ATS can dynamically adjust the task selection strategy according to the state of the meta - model and avoid the influence of harmful tasks. 2. **Introduce two meta - model - related factors**: The loss on the task query set and the gradient similarity are used to guide task selection. 3. **Theoretical analysis**: It is proved that the scheduler considering these two factors can improve the meta - training loss and the optimization landscape. 4. **Empirical results**: On the image classification benchmark (miniImageNet) and the drug discovery dataset, ATS improves the performance by up to 13% and 18% respectively compared with the existing methods. ### Formula summary - Sampling probability formula for tasks: \[ w_i^{(k)} = g(T_i, \theta_0^{(k)}; \phi^{(k)}) \] where \(w_i^{(k)}\) is the sampling probability of the \(i\)-th task in the \(k\)-th iteration, \(T_i\) is the task, \(\theta_0^{(k)}\) is the current meta - model parameter, and \(\phi^{(k)}\) is the parameter of the neural scheduler. - Sampling probability formula after considering two factors: \[ w_i^{(k)} = g(L(D_q^i; \theta_i^{(k)}), \langle \nabla_{\theta_0^{(k)}} L(D_s^i; \theta_0^{(k)}), \nabla_{\theta_0^{(k)}} L(D_q^i; \theta_0^{(k)}) \rangle; \phi^{(k)}) \] - Meta - model update formula: \[ \theta_0^{(k + 1)} = \theta_0^{(k)} - \beta\frac{1}{B}\sum_{i = 1}^B L(D_q^i; \theta_i^{(k)}) \] - Two - level optimization process: \[ \min_\phi\frac{1}{N_v}\sum_{v = 1}^{N_v} L_{val}(T_v; \theta_0^*(\phi)), \quad \text{where} \quad \theta_0^*(\phi)=\arg\min_{\theta_0}\frac{1}{B}\sum_{i = 1}^B L_{tr}(T_i; \theta_0, \phi) \] Through these improvements, ATS can select tasks more intelligently in the meta - learning process, thereby enhancing the generalization ability and robustness of the meta - model.