Fourier Head: Helping Large Language Models Learn Complex Probability Distributions

Nate Gillman,Daksh Aggarwal,Michael Freeman,Saurabh Singh,Chen Sun
2024-10-30
Abstract:As the quality of large language models has improved, there has been increased interest in using them to model non-linguistic tokens. For example, the Decision Transformer recasts agentic decision making as a sequence modeling problem, using a decoder-only LLM to model the distribution over the discrete action space for an Atari agent. However, when adapting LLMs to non-linguistic domains, it remains unclear if softmax over discrete bins captures the continuous structure of the tokens and the potentially complex distributions needed for high quality token generation. We introduce a neural network layer, constructed using Fourier series, which we can easily substitute for any linear layer if we want the outputs to have a more continuous structure. We perform extensive analysis on synthetic datasets, as well as on large-scale decision making and time series forecasting tasks. We also provide theoretical evidence that this layer can better learn signal from data while ignoring high-frequency noise. All of our results support the effectiveness of our proposed Fourier head in scenarios where the underlying data distribution has a natural continuous structure. For example, the Fourier head improves a Decision Transformer agent's returns by 46% on the Atari Seaquest game, and increases a state-of-the-art times series foundation model's forecasting performance by 3.5% across 20 benchmarks unseen during training.
Machine Learning,Artificial Intelligence,Computation and Language
What problem does this paper attempt to address?
### The Problem the Paper Attempts to Solve The paper aims to address a key issue faced by large language models (LLMs) when handling non-linguistic tasks: how to effectively model complex and continuous probability distributions. Specifically, when LLMs are applied to non-linguistic domains (such as decision-making, time series forecasting, etc.), the traditional softmax method may fail to capture the continuous structure and complexity of the data in the discretized probability distribution. This can lead to poor performance of the model in generating high-quality non-linguistic tokens. ### Solution To solve this problem, the authors introduce a new neural network layer called the Fourier Head. The Fourier Head uses Fourier series to learn a continuous probability density function and discretizes it into a categorical distribution. This method can better capture low-frequency signals in the data while avoiding overfitting high-frequency noise. In this way, the Fourier Head can improve the performance of the model in various tasks, especially those that require consideration of the continuity of the output dimensions. ### Main Contributions 1. **Theoretical Analysis**: The authors reveal the trade-off between the expressive power and the smoothness of the predicted distribution of the Fourier Head and prove a theorem showing that as the number of Fourier coefficients increases, the Fourier Head can model more complex distributions but will also fit more high-frequency noise. 2. **Practical Implementation**: The authors propose a practical implementation method for the Fourier Head and provide improvement strategies, including Fourier coefficient norm regularization, weight initialization, and selecting an appropriate number of Fourier frequencies. 3. **Experimental Validation**: The authors validate the effectiveness of the Fourier Head on two large-scale tasks. The first task is modeling the distribution of the next action in the Atari game "Seaquest" using a decoder-only Transformer model, where the Fourier Head improves the return by 46%. The second task is a zero-shot time series forecasting task, where the Fourier Head improves the prediction performance by 3.5% on 20 unseen test datasets. ### Experimental Example To illustrate the applicability of the Fourier Head design in a simple problem setting, the authors apply it as a replacement for the linear classification head in an Audio Spectrogram Transformer for the beats per minute (BPM) classification task. The experimental results show that the Fourier Head improves the F1 score by 118% compared to the standard linear classification head and learns a smoother probability mass function (PMF). ### Conclusion By introducing an inductive bias of continuity, the Fourier Head significantly improves the performance of large language models in handling non-linguistic tasks. Especially in tasks that require consideration of the continuity of the output dimensions, the Fourier Head can better capture low-frequency signals in the data, thereby improving the quality of the model's generation.