Learning positional encodings in transformers depends on initialization

Takuya Ito,Luca Cocchi,Tim Klinger,Parikshit Ram,Murray Campbell,Luke Hearne
2024-11-09
Abstract:The attention mechanism is central to the transformer's ability to capture complex dependencies between tokens of an input sequence. Key to the successful application of the attention mechanism in transformers is its choice of positional encoding (PE). The PE provides essential information that distinguishes the position and order amongst tokens in a sequence. Most prior investigations of PE effects on generalization were tailored to 1D input sequences, such as those presented in natural language, where adjacent tokens (e.g., words) are highly related. In contrast, many real world tasks involve datasets with highly non-trivial positional arrangements, such as datasets organized in multiple spatial dimensions, or datasets for which ground truth positions are not known, such as in biological data. Here we study the importance of learning accurate PE for problems which rely on a non-trivial arrangement of input tokens. Critically, we find that the choice of initialization of a learnable PE greatly influences its ability to discover accurate PEs that lead to enhanced generalization. We empirically demonstrate our findings in a 2D relational reasoning task and a real world 3D neuroscience dataset, applying interpretability analyses to verify the learning of accurate PEs. Overall, we find that a learned PE initialized from a small-norm distribution can 1) uncover interpretable PEs that mirror ground truth positions, 2) learn non-trivial and modular PEs in a real-world neuroscience dataset, and 3) lead to improved downstream generalization in both datasets. Importantly, choosing an ill-suited PE can be detrimental to both model interpretability and generalization. Together, our results illustrate the feasibility of discovering accurate PEs for enhanced generalization.
Machine Learning
What problem does this paper attempt to address?