Towards optimal hierarchical training of neural networks

Michael Feischl,Alexander Rieder,Fabian Zehetgruber
2024-10-30
Abstract:We propose a hierarchical training algorithm for standard feed-forward neural networks that adaptively extends the network architecture as soon as the optimization reaches a stationary point. By solving small (low-dimensional) optimization problems, the extended network provably escapes any local minimum or stationary point. Under some assumptions on the approximability of the data with stable neural networks, we show that the algorithm achieves an optimal convergence rate s in the sense that loss is bounded by the number of parameters to the -s. As a byproduct, we obtain computable indicators which judge the optimality of the training state of a given network and derive a new notion of generalization error.
Numerical Analysis
What problem does this paper attempt to address?
The main problem that this paper attempts to solve is how to design a hierarchical training algorithm to efficiently train standard feed - forward neural networks and achieve a near - optimal loss function value with a given number of trainable parameters. Specifically, the paper proposes a hierarchical training algorithm, which avoids local minima or stable points by adaptively expanding the network architecture when the optimization reaches a stable point. By solving small - scale (low - dimensional) optimization problems, the expanded network can escape from any local minima or stable points. In addition, under certain assumptions, this algorithm can achieve the optimal convergence rate, that is, the relationship between the loss and the number of parameters satisfies \( \text{loss} \lesssim \# \text{parameters}^{-s} \). The paper also proposes some computable indicators for determining whether the training state of a given network is close to optimal and introduces a new concept of generalization error. These achievements not only help to improve the training efficiency of neural networks but also provide a theoretical basis for evaluating and understanding the performance of neural networks.