Better Generalization in Fast Training: Flat Trainable Weight in Subspace

Zehao Lei,Yingwen Wu,Tao Li
DOI: https://doi.org/10.1145/3651671.3651740
2024-01-01
Abstract:Compressing training time of deep neural networks (DNNs) is a critical task due to the huge scale of data and models. Different from most previous works that focus on using large batch size to reduce the training time, we consider to compress the training epoch through our designed training algorithm. It is well known that simply reducing the learning rate schedule results in a significant loss of generalization. In this paper, we propose to maintain test accuracy while compressing training epochs by optimizing in extended subspace generated by historical model parameters on SGD training trajectory. Although using historical information has been studied in Trainable Weight Averaging (TWA), we design a new algorithm called Flat Trainable Weights (FTW) that optimizes the weight coefficients using explicit sharpness loss function in extended low-dimensional subspace, which achieves better generalization performance. We show that this Flat Trainable Weights (FTW) achieves significant improvement in model generalization over TWA and SGD. In fast training, FTW accelerates the convergence and saves 15% time over TWA, 35% over SGD on CIFAR datasets.
What problem does this paper attempt to address?