Occam Gradient Descent

B.N. Kausik
2024-08-01
Abstract:Deep learning neural network models must be large enough to adapt to their problem domain, while small enough to avoid overfitting training data during gradient descent. To balance these competing demands, overprovisioned deep learning models such as transformers are trained for a single epoch on large data sets, and hence inefficient with both computing resources and training data. In response to these inefficiencies, we exploit learning theory to derive Occam Gradient Descent, an algorithm that interleaves adaptive reduction of model size to minimize generalization error, with gradient descent on model weights to minimize fitting error. In contrast, traditional gradient descent greedily minimizes fitting error without regard to generalization error. Our algorithm simultaneously descends the space of weights and topological size of any neural network without modification. With respect to loss, compute and model size, our experiments show (a) on image classification benchmarks, linear and convolutional neural networks trained with Occam Gradient Descent outperform traditional gradient descent with or without post-train pruning; (b) on a range of tabular data classification tasks, neural networks trained with Occam Gradient Descent outperform traditional gradient descent, as well as Random Forests; (c) on natural language transformers, Occam Gradient Descent outperforms traditional gradient descent.
Machine Learning
What problem does this paper attempt to address?
This paper primarily investigates the conflict between adapting deep learning neural network models to problem domains and avoiding overfitting. Specifically, it explores how to make the model large enough to fit its problem domain while keeping it small enough to avoid overfitting the training data during the gradient descent process. To address this conflict, the authors propose the Occam Gradient Descent algorithm. ### Main Issues 1. **Balancing Model Capacity and Overfitting**: Large deep learning models (such as Transformers) often require a significant number of parameters to fit complex datasets, but this also makes them prone to overfitting the training data during the training process. 2. **Inefficient Resource Utilization**: To avoid overfitting, these large models are typically trained for only 1 epoch, which not only wastes computational resources but also fails to fully utilize the training data. ### Goals of the Occam Gradient Descent Algorithm - **Reduce Overfitting**: Minimize generalization error by adaptively reducing the model size. - **Optimize the Training Process**: Combine gradient descent optimization of model weights to minimize fitting error. - **Improve Efficiency**: Reduce both the weight space and topological size of the model without modifying the model structure. ### Method Overview - **Theoretical Foundation**: The Occam Gradient Descent algorithm is derived using learning theory. It alternates between adaptive reduction of model size and gradient descent updates of weights during training to balance model size and error. - **Algorithm Characteristics**: Unlike traditional gradient descent methods that focus solely on minimizing fitting error, the Occam Gradient Descent algorithm considers both fitting error and generalization error, thereby better controlling model complexity and avoiding overfitting. - **Experimental Validation**: The effectiveness of the Occam Gradient Descent algorithm is demonstrated through experiments in image classification, tabular data classification tasks, and natural language processing. The results show advantages in terms of loss, computational cost, and model size. ### Conclusion The paper proposes a new algorithm called Occam Gradient Descent, aimed at resolving the trade-off between adaptability and overfitting in deep learning models by adaptively adjusting model size. Experimental results indicate that this algorithm can effectively improve model performance while reducing computational costs and model size, outperforming traditional gradient descent methods and some post-training pruning techniques.