Continual Learning with Weight Interpolation

Jędrzej Kozal,Jan Wasilewski,Bartosz Krawczyk,Michał Woźniak
2024-04-09
Abstract:Continual learning poses a fundamental challenge for modern machine learning systems, requiring models to adapt to new tasks while retaining knowledge from previous ones. Addressing this challenge necessitates the development of efficient algorithms capable of learning from data streams and accumulating knowledge over time. This paper proposes a novel approach to continual learning utilizing the weight consolidation method. Our method, a simple yet powerful technique, enhances robustness against catastrophic forgetting by interpolating between old and new model weights after each novel task, effectively merging two models to facilitate exploration of local minima emerging after arrival of new concepts. Moreover, we demonstrate that our approach can complement existing rehearsal-based replay approaches, improving their accuracy and further mitigating the forgetting phenomenon. Additionally, our method provides an intuitive mechanism for controlling the stability-plasticity trade-off. Experimental results showcase the significant performance enhancement to state-of-the-art experience replay algorithms the proposed weight consolidation approach offers. Our algorithm can be downloaded from
Machine Learning
What problem does this paper attempt to address?
### Problems the paper attempts to solve The paper aims to solve a fundamental challenge in Continual Learning: how to make the model retain the knowledge of previous tasks while adapting to new tasks. Specifically, the paper proposes a new Continual Learning method that uses weight interpolation techniques to enhance robustness against catastrophic forgetting. This method effectively combines two models by interpolating the weights of the old and new models after each new task, thereby promoting the exploration of local minima that may emerge after the appearance of new concepts. In addition, the paper shows that this method can be combined with existing replay - based methods to improve their accuracy and further alleviate the forgetting phenomenon. The paper also provides an intuitive mechanism to control the trade - off between stability and plasticity. ### Main contributions 1. **Necessary conditions**: The paper clarifies the necessary conditions required for the successful application of weight interpolation to the Continual Learning problem and verifies the effectiveness of these conditions through experiments. 2. **Simple algorithm**: A novel and simple Continual Learning algorithm is proposed, which is compatible with popular replay - based methods. 3. **Experimental evaluation**: Extensive experimental evaluations have been carried out, demonstrating the potential of this method in significantly improving the performance of any experience replay algorithm. 4. **Stability - plasticity trade - off**: It is shown that the proposed method has a built - in, intuitive mechanism to control the trade - off between stability and plasticity. ### Method overview The core idea of the paper is to interpolate the weights of the neural network before and after training new data each time to better consolidate knowledge and suppress forgetting. The specific steps are as follows: 1. **Symbol definition**: - Each task \( t \) can be represented as a dataset \( D_t=\{(x_i, y_i)\}_{i = 0}^{n_t}\), where \( x_i\) is an image, \( y_i\) is a label, and \( n_t=|D_t|\). - The goal is to train the parameters \( \theta\) of the neural network \( f\) on each task, and only the most recent data can be accessed, that is, minimizing \( \min_{\theta}L(f(\theta), D_t)\). - Replay - based algorithms use a small buffer \( M = \{(x_j, y_j)\}_{j = 0}^m\) to store data from the previous task to mitigate forgetting. 2. **Motivation**: - The main goal of Continual Learning is to optimize the joint test loss of all tasks in the stream. Only the training data of the current task can be accessed, but the main goal is to train the network to have a low loss over the entire task. - The joint loss function can be defined as: \[ L_D(\theta)=\sum_{t = 1}^T L(\theta, D_t) \] - where \( D = D_1\cup\cdots\cup D_T\). - The joint loss function can be divided into two parts: the loss caused by the last task and the loss caused by all previous tasks. 3. **Weight interpolation**: - For each task \( t>0\), we perform weight interpolation on the previously trained weights \( \theta_P\) and the new parameters \( \theta\) trained on the current data distribution. - First, find the weight permutation \( \pi\) such that the activations of \( \theta_P\) and \( \theta\) are aligned. - Use the memory buffer \( M\) to obtain the activations of \( \theta\) and \( \theta_P\) and update the batch normalization statistics. - Apply the linear interpolation formula: \[ \theta=(1 - \alpha)\theta+\alpha\pi(\theta_P) \] - where \( \alpha\) is an interpolation coefficient, which is a superparameter.