Towards Signal Processing In Large Language Models

Prateek Verma,Mert Pilanci
2024-06-10
Abstract:This paper introduces the idea of applying signal processing inside a Large Language Model (LLM). With the recent explosion of generative AI, our work can help bridge two fields together, namely the field of signal processing and large language models. We draw parallels between classical Fourier-Transforms and Fourier Transform-like learnable time-frequency representations for every intermediate activation signal of an LLM. Once we decompose every activation signal across tokens into a time-frequency representation, we learn how to filter and reconstruct them, with all components learned from scratch, to predict the next token given the previous context. We show that for GPT-like architectures, our work achieves faster convergence and significantly increases performance by adding a minuscule number of extra parameters when trained for the same epochs. We hope this work paves the way for algorithms exploring signal processing inside the signals found in neural architectures like LLMs and beyond.
Computation and Language,Machine Learning,Sound,Audio and Speech Processing
What problem does this paper attempt to address?
This paper attempts to address the problem of introducing signal processing techniques into large language models (LLMs) to improve their performance. Specifically, the authors propose applying classical Fourier Transform (FT) and FT-like learnable time-frequency representations to the intermediate activation signals of each layer in LLMs. By decomposing each activation signal into time-frequency representations and learning how to filter and reconstruct these signals, the next word prediction task is optimized. This approach not only accelerates the model's convergence speed but also significantly enhances the model's performance while adding only a small number of additional parameters. ### Main Contributions: 1. **Introducing Core Signal Processing Concepts in LLM Pre-training**: The authors demonstrate how to introduce the concept of filtering from signal processing into LLM pre-training by learning causal time-frequency representations and filters to optimize the next word prediction task. 2. **Learning Varying Time-Frequency Masks**: Further complexity is introduced by mimicking the time-frequency masking methods used in audio source separation, allowing the learned filter characteristics to vary over time (or token dimension), thereby further enhancing the performance of pre-trained LLMs. 3. **Application in Non-Causal Settings**: For non-causal settings, the authors showcase the adaptability of the method through a simple audio classification task. By applying Discrete Cosine Transform (DCT) to the intermediate embedding dimension signals, learning filters from scratch, reconstructing signals, and learning in an end-to-end manner driven by a classification loss function, classification accuracy is improved. ### Method Overview: 1. **Finding Signals in LLMs**: The authors identify 1-dimensional signals in the intermediate embedding signals of each layer in LLMs and ensure that all signal processing operations adhere to the causal assumption, i.e., not leaking future information. 2. **Learnable Time-Frequency Representations and Filtering**: Time-frequency representations are simulated through causal convolution blocks, similar to learning Short-Time Fourier Transform (STFT), decomposing signal components, and learning which frequency components are important. Signals are filtered and reconstructed through convolutional filters and weight learning. 3. **Multi-Scale Time-Frequency Representations and Filtering**: Multi-scale filters are introduced, using convolutional filters of different lengths to improve time-frequency resolution, further optimizing signal processing effects. 4. **Adaptive Weights**: Drawing from classical signal processing methods in source separation and time-frequency masking, the filter weights adaptively change with the token dimension, better fitting the characteristics of the input signals. ### Experimental Results: - **Faster Convergence Speed**: Compared to the baseline model, this method achieves faster convergence speed within the same training time, approximately 40-45% faster. - **Performance Improvement**: Within the same training time, the validation loss is reduced by about 0.02. - **Visualization of Multi-Scale Time-Frequency Representations**: By visualizing the multi-scale time-frequency representations learned in the first layer, it is found that the basis functions are not sine waves, emphasizing the necessity of learning time-frequency representations from scratch. Overall, this paper not only improves the performance of LLMs by introducing signal processing techniques but also provides new ideas for future related research.