Wilfried Bounsi,Borja Ibarz,Andrew Dudzik,Jessica B. Hamrick,Larisa Markeeva,Alex Vitvitskyi,Razvan Pascanu,Petar Veličković
Abstract:Transformers have revolutionized machine learning with their simple yet effective architecture. Pre-training Transformers on massive text datasets from the Internet has led to unmatched generalization for natural language understanding (NLU) tasks. However, such language models remain fragile when tasked with algorithmic forms of reasoning, where computations must be precise and robust. To address this limitation, we propose a novel approach that combines the Transformer's language understanding with the robustness of graph neural network (GNN)-based neural algorithmic reasoners (NARs). Such NARs proved effective as generic solvers for algorithmic tasks, when specified in graph form. To make their embeddings accessible to a Transformer, we propose a hybrid architecture with a two-phase training procedure, allowing the tokens in the language model to cross-attend to the node embeddings from the NAR. We evaluate our resulting TransNAR model on CLRS-Text, the text-based version of the CLRS-30 benchmark, and demonstrate significant gains over Transformer-only models for algorithmic reasoning, both in and out of distribution.
What problem does this paper attempt to address?
The main problem this paper attempts to address is: how to enhance the robustness and generalization ability of Transformers in algorithmic reasoning tasks, especially when dealing with unseen input scales. Specifically, the authors propose a hybrid architecture called TransNAR, which combines the language understanding capabilities of Transformers with the neural algorithmic reasoner (NAR) based on graph neural networks (GNN). Through this hybrid architecture, the paper aims to overcome the fragility of existing Transformers in handling algorithmic reasoning tasks, particularly in out-of-distribution (OOD) scenarios.
### Main Problems and Solutions
1. **Limitations of Transformers**:
- **Strong language understanding but weak algorithmic reasoning**: Existing Transformers perform excellently in natural language understanding tasks but poorly in tasks requiring precise computation and reasoning, especially when dealing with unseen input scales.
- **Poor out-of-distribution generalization**: Transformers show a significant drop in performance when handling tasks beyond the scope of the training data.
2. **Advantages of NAR**:
- **Robust algorithmic reasoning ability**: NAR performs well in algorithmic tasks, can handle different input scales, and has strong generalization ability in out-of-distribution scenarios.
- **Requirement for structured input**: NAR requires structured graph-form input, which limits its direct application in natural language tasks.
3. **Proposal of TransNAR**:
- **Combining Transformer and NAR**: TransNAR combines NAR's node embeddings with Transformer's token embeddings, using cross-attention mechanisms to allow the Transformer to access NAR's computation results.
- **Two-stage training**: TransNAR employs a two-stage training strategy, first pre-training NAR to perform various algorithmic tasks, then introducing NAR's embeddings into the Transformer and conducting joint training through cross-attention mechanisms.
### Experimental Validation
- **Benchmark Testing**: The paper conducts experiments on the CLRS-Text benchmark, a text version of the CLRS-30 benchmark, used to evaluate algorithmic reasoning tasks.
- **Performance Improvement**: Experimental results show that TransNAR significantly outperforms models using only Transformers in various algorithmic tasks, with particularly notable performance improvements in out-of-distribution scenarios.
- **Specific Metrics**: The paper uses multiple metrics to evaluate model performance, including shape score, parse score, and CLRS score, which respectively measure the correctness of output shape, parsing correctness, and the degree of match with the true answer.
### Conclusion
- **Main Contribution**: TransNAR significantly improves the performance of models in algorithmic reasoning tasks by combining the language understanding capabilities of Transformers with the robust algorithmic reasoning abilities of NAR, especially when dealing with unseen input scales.
- **Future Work**: The authors note that while TransNAR performs well in many tasks, some algorithmic tasks, particularly those involving index searching, have not fully surpassed baseline models. Future research could further optimize the cross-attention mechanism or explore other methods to enhance the model's generalization ability.
### Summary
By proposing the TransNAR model, this paper successfully addresses the fragility and lack of out-of-distribution generalization ability of Transformers in algorithmic reasoning tasks, providing new directions for future research.