Arijit Sehanobish,Krzysztof Choromanski,Yunfan Zhao,Avinava Dubey,Valerii Likhosherstov
Abstract:We introduce the concept of scalable neural network kernels (SNNKs), the replacements of regular feedforward layers (FFLs), capable of approximating the latter, but with favorable computational properties. SNNKs effectively disentangle the inputs from the parameters of the neural network in the FFL, only to connect them in the final computation via the dot-product kernel. They are also strictly more expressive, as allowing to model complicated relationships beyond the functions of the dot-products of parameter-input vectors. We also introduce the neural network bundling process that applies SNNKs to compactify deep neural network architectures, resulting in additional compression gains. In its extreme version, it leads to the fully bundled network whose optimal parameters can be expressed via explicit formulae for several loss functions (e.g. mean squared error), opening a possibility to bypass backpropagation. As a by-product of our analysis, we introduce the mechanism of the universal random features (or URFs), applied to instantiate several SNNK variants, and interesting on its own in the context of scalable kernel methods. We provide rigorous theoretical analysis of all these concepts as well as an extensive empirical evaluation, ranging from point-wise kernel estimation to Transformers' fine-tuning with novel adapter layers inspired by SNNKs. Our mechanism provides up to 5x reduction in the number of trainable parameters, while maintaining competitive accuracy.
What problem does this paper attempt to address?
The problem that this paper attempts to solve is to improve the computational efficiency and compression performance of neural networks while maintaining or enhancing the expressive ability and accuracy of the model. Specifically, the author introduced Scalable Neural Network Kernels (SNNKs) as an alternative to traditional Feedforward Layers (FFLs). The main objectives of SNNKs include:
1. **Network Compression**: By redefining the relationship between the parameters and inputs of the feedforward layer, reduce the number of trainable parameters in the network. For example, in the SNNK module, the input and parameters are decoupled before the final dot - product calculation, thereby reducing the number of parameters.
2. **Computation Saving**: If Random Features (RFs) can be constructed quickly and \(m \ll d\), the time complexity of the SNNK module can be significantly reduced, from \(O(ld)\) to \(o(dl)\).
3. **Deep Neural Network Bundling Process**: Compact multiple feedforward layers iteratively using the two - tower representation method, a process known as Neural Network Bundling. This not only reduces the number of parameters but also brings computational benefits.
4. **Deep Neural Network as a Scalable Kernel**: In extreme cases, bundling all layers of the entire deep neural network can provide a two - tower decomposition of the entire network. This decomposition method provides an explicit formula for the optimal parameters under certain loss functions (such as mean - squared error), potentially bypassing back - propagation.
To achieve these goals, the author proposed the following key techniques:
- **Universal Random Features (URFs)**: Used to efficiently construct the mappings \(\Phi_f\) and \(\Psi_f\), thereby achieving the specific instantiation of the SNNK module. URFs can unbiasedly estimate \(f(w^\top x + b)\) as long as \(f\) has a well - defined Fourier Transform (FT).
- **ReLU - SNNK Layers**: A specific instantiation of SNNK, especially suitable for downstream tasks. The author found that these layers are related to Arc - cosine Kernels and are able to compute functions of the input and parameters, not just point - wise transformations of their dot products.
- **Neural Network Compactification Process**: Use SNNKs to compact neural networks, further reducing the number of parameters and computational costs.
The paper demonstrated the effectiveness of SNNKs through extensive experiments, including point - to - point kernel estimation, fine - tuning of Transformers, etc., showing that SNNKs can maintain or even improve the performance of the model while reducing the number of trainable parameters.