Tao Li,Zhehao Huang,Yingwen Wu,Zhengbao He,Qinghua Tao,Xiaolin Huang,Chih-Jen Lin
Abstract:Weight averaging is a widely used technique for accelerating training and improving the generalization of deep neural networks (DNNs). While existing approaches like stochastic weight averaging (SWA) rely on pre-set weighting schemes, they can be suboptimal when handling diverse weights. We introduce Trainable Weight Averaging (TWA), a novel optimization method that operates within a reduced subspace spanned by candidate weights and learns optimal weighting coefficients through optimization. TWA offers greater flexibility and can be applied to different training scenarios. For large-scale applications, we develop a distributed training framework that combines parallel computation with low-bit compression for the projection matrix, effectively managing memory and computational demands. TWA can be implemented using either training data (TWA-t) or validation data (TWA-v), with the latter providing more effective averaging. Extensive experiments showcase TWA's advantages: (i) it consistently outperforms SWA in generalization performance and flexibility, (ii) when applied during early training, it reduces training time by over 40\% on CIFAR datasets and 30\% on ImageNet while maintaining comparable performance, and (iii) during fine-tuning, it significantly enhances generalization by weighted averaging of model checkpoints. In summary, we present an efficient and effective framework for trainable weight averaging. The code is available at <a class="link-external link-https" href="https://github.com/nblt/TWA" rel="external noopener nofollow">this https URL</a>.
What problem does this paper attempt to address?
The problem that this paper attempts to solve is: how to accelerate the training of deep neural networks (DNNs) and improve their generalization performance. Specifically, existing weight averaging techniques (such as Stochastic Weight Averaging, SWA) rely on preset weighting schemes and may perform poorly when dealing with diverse weights, resulting in sub - optimal performance. To solve this problem, the authors propose Trainable Weight Averaging (TWA), a new optimization method.
### Main problems and solutions
1. **Limitations of existing methods**:
- Existing weight averaging methods (such as SWA, LAWA, EMA, etc.) rely on preset weighting strategies, which may not be flexible enough when dealing with diverse weights under different training configurations, resulting in poor performance.
- In the early stages of training, model parameters have not been fully optimized, and using fixed or preset weight averaging strategies may produce sub - optimal solutions.
2. **The proposed new method - TWA**:
- **The core idea of TWA**: By learning the optimal weighting coefficients in a low - dimensional subspace spanned by candidate weights, more flexible and efficient weight averaging is achieved.
- **Subspace training**: Each weight is regarded as a point in the full parameter space, a subspace containing these weight points is constructed, and optimization is carried out within this subspace.
- **Distributed training framework**: To meet the memory and computing requirements in large - scale applications, a distributed training framework that combines parallel computing and low - bit compression projection matrices has been developed.
3. **Application scenarios**:
- **Accelerating training**: By averaging historical solutions in the early stages of training, the training time can be significantly reduced (for example, reducing the training time by more than 40% on the CIFAR dataset and 30% on the ImageNet).
- **Improving generalization performance**: By weighted - averaging the weights of the fine - tuned model, the generalization ability of the model can be significantly improved.
### Key contributions
1. **Proposing TWA**: An effective method that can accelerate training and improve the generalization performance of DNNs, allowing learnable weight coefficients between layers.
2. **Efficient large - scale problem - handling solutions**: Multi - node parallel training is achieved through subspace training, evenly distributing memory and computing burdens, and a compression strategy is introduced to reduce memory usage.
3. **Validation - set - based optimization**: TWA - v is proposed, which uses a small validation set to supervise the optimization of weight coefficients, making the averaging process more efficient and effective, especially suitable for the Transformer architecture.
4. **Extensive experimental verification**: Through experiments on multiple architectures (CNNs, ViTs, GPT - 2), tasks (image classification, machine translation, language modeling) and training scenarios (training from scratch to fine - tuning), the effectiveness and efficiency of TWA are proven.
In conclusion, by introducing TWA and its related techniques, this paper aims to solve the lack of flexibility and performance in existing weight averaging methods and provides a more efficient and effective training method for deep learning models.