Normalized Narrow Jump To Conclusions: Normalized Narrow Shortcuts for Parameter Efficient Early Exit Transformer Prediction

Amrit Diggavi Seshadri
2024-10-03
Abstract:With the size and cost of large transformer-based language models growing, recently, there has been interest in shortcut casting of early transformer hidden-representations to final-representations for cheaper model inference. In particular, shortcutting pre-trained transformers with linear transformations over early layers has been shown to improve precision in early inference. However, for large language models, even this becomes computationally expensive. In this work, we propose Narrow Jump to Conclusions (NJTC) and Normalized Narrow Jump to Conclusions (N-NJTC) - parameter efficient alternatives to standard linear shortcutting that reduces shortcut parameter count by over 97%. We show that N-NJTC reliably outperforms Identity shortcuts at early stages and offers stable precision from all transformer block levels for GPT-2-XL, Phi3-Mini and Llama2-7B transformer models, demonstrating the viability of more parameter efficient short-cutting approaches.
Artificial Intelligence
What problem does this paper attempt to address?
### What problem does this paper attempt to solve? This paper aims to solve the problem of the excessively high computational cost in the inference process of large Transformer language models. Specifically, as the scale and complexity of Transformer models keep increasing, their inference processes require more GPU memory and longer time to process input data. In order to reduce these computational costs, researchers attempt to use the "short - cutting" technique to exit early during the inference process, thereby reducing unnecessary computations. However, existing short - circuit methods (such as Jump - to - Conclusions, JTC) can improve the accuracy of early - stage inference, but they introduce a large number of additional parameters, resulting in still high computational costs. For example, the JTC method adds \(H\times H\) parameters for each short - circuit inference (where \(H\) is the hidden dimension of the Transformer). For deep - seated language models (such as Phi3 - Mini, Llama2 - 7B, etc.), this method becomes very expensive. Therefore, this paper proposes a new short - circuit method - **Narrow Jump to Conclusions (NJTC)** and **Normalized Narrow Jump to Conclusions (N - NJTC)** to significantly reduce the number of parameters required for short - circuiting, thereby achieving higher parameter efficiency. Specifically: 1. **NJTC**: By low - rank approximation, the original linear transformation matrix is decomposed into two smaller matrices \(A\) and \(B\), which greatly reduces the number of parameters. 2. **N - NJTC**: On the basis of NJTC, batch normalization is introduced, further improving parameter efficiency and model performance. Experimental results show that N - NJTC can reduce the number of parameters by more than 97% while maintaining high accuracy. It is applicable to multiple Transformer models (such as GPT - 2 - XL, Phi3 - Mini and Llama2 - 7B), and its performance in the early - stage inference is better than that of Identity shortcuts. In summary, the main contribution of this paper is to propose an efficient short - circuit method with a very small number of parameters, which can significantly reduce the inference cost of Transformer models without sacrificing too much performance.