Jump to Conclusions: Short-Cutting Transformers With Linear Transformations

Alexander Yom Din,Taelin Karidi,Leshem Choshen,Mor Geva
2024-06-19
Abstract:Transformer-based language models create hidden representations of their inputs at every layer, but only use final-layer representations for prediction. This obscures the internal decision-making process of the model and the utility of its intermediate representations. One way to elucidate this is to cast the hidden representations as final representations, bypassing the transformer computation in-between. In this work, we suggest a simple method for such casting, using linear transformations. This approximation far exceeds the prevailing practice of inspecting hidden representations from all layers, in the space of the final layer. Moreover, in the context of language modeling, our method produces more accurate predictions from hidden layers, across various model scales, architectures, and data distributions. This allows "peeking" into intermediate representations, showing that GPT-2 and BERT often predict the final output already in early layers. We then demonstrate the practicality of our method to recent early exit strategies, showing that when aiming, for example, at retention of 95% accuracy, our approach saves additional 7.9% layers for GPT-2 and 5.4% layers for BERT. Last, we extend our method to linearly approximate sub-modules, finding that attention is most tolerant to this change. Our code and learned mappings are publicly available at <a class="link-external link-https" href="https://github.com/sashayd/mat" rel="external noopener nofollow">this https URL</a>.
Computation and Language
What problem does this paper attempt to address?
The paper attempts to address the issue of how to better utilize the representations of intermediate layers in Transformer models to improve model interpretability and efficiency. Specifically, Transformer models generate hidden representations of the input at each layer, but ultimately only use the representations from the final layer for prediction. This makes the internal decision process of the model opaque and makes it difficult to assess the utility of the intermediate layer representations. To solve this problem, the paper proposes a simple method: using a linear transformation to convert the representations of intermediate layers into the representations of the final layer, thereby bypassing the intermediate Transformer computations. This method not only allows for more accurate predictions from intermediate layers but also significantly improves prediction accuracy across different model scales and data distributions. Additionally, the paper explores the potential application of this method in early exit strategies, where predictions are made at earlier layers to save a substantial amount of computational resources. Experimental results show that using this linear transformation method, compared to traditional direct projection methods, can further reduce the number of required computational layers while maintaining high accuracy, thereby enhancing model efficiency. Finally, the paper analyzes the linear approximation effects on different submodules within the Transformer model (such as the attention mechanism, feedforward neural network, and layer normalization), finding that the attention mechanism is the most tolerant of this transformation. This suggests that the non-contextual processing parts can be further parallelized to reduce computation time.