Enhancing Out-of-distribution Generalization on Graphs via Causal Attention Learning

Yongduo Sui,Wenyu Mao,Shuyao Wang,Xiang Wang,Jiancan Wu,Xiangnan He,Tat-Seng Chua
DOI: https://doi.org/10.1145/3644392
IF: 4.157
2024-02-05
ACM Transactions on Knowledge Discovery from Data
Abstract:In graph classification, attention- and pooling-based graph neural networks (GNNs) predominate to extract salient features from the input graph and support the prediction. They mostly follow the paradigm of “learning to attend”, which maximizes the mutual information between the attended graph and the ground-truth label. However, this paradigm causes GNN classifiers to indiscriminately absorb all statistical correlations between input features and labels in the training data, without distinguishing the causal and noncausal effects of features. Rather than emphasizing causal features, the attended graphs tend to rely on noncausal features as shortcuts to predictions. These shortcut features may easily change outside the training distribution, thereby leading to poor generalization for GNN classifiers. In this paper, we take a causal view on GNN modeling. Under our causal assumption, the shortcut feature serves as a confounder between the causal feature and prediction. It misleads the classifier into learning spurious correlations that facilitate prediction in in-distribution (ID) test evaluation, while causing significant performance drop in out-of-distribution (OOD) test data. To address this issue, we employ the backdoor adjustment from causal theory — combining each causal feature with various shortcut features, to identify causal patterns and mitigate the confounding effect. Specifically, we employ attention modules to estimate the causal and shortcut features of the input graph. Then, a memory bank collects the estimated shortcut features, enhancing the diversity of shortcut features for combination. Simultaneously, we apply the prototype strategy to improve the consistency of intra-class causal features. We term our method as CAL+, which can promote stable relationships between causal estimation and prediction, regardless of distribution changes. Extensive experiments on synthetic and real-world OOD benchmarks demonstrate our method’s effectiveness in improving OOD generalization. Our codes are released at https://github.com/shuyao-wang/CAL-plus.
computer science, information systems, software engineering
What problem does this paper attempt to address?
### What problem does this paper attempt to solve? This paper aims to solve the generalization problem of graph neural networks (GNNs) in graph classification tasks, especially the generalization ability when the test data distribution is inconsistent with the training data distribution. Specifically: 1. **Problems with existing methods**: - Current GNNs based on the attention mechanism and pooling mainly follow the "learning - to - attend" paradigm, that is, maximizing the mutual information between the attended graph and the true label. - This paradigm causes the model to indiscriminately absorb all statistical correlations between input features and labels, and is unable to distinguish between causal and non - causal features. - The model often relies on non - causal features (such as shortcut features) for prediction, and these features are prone to change outside the training distribution, resulting in a significant decline in the performance of the model on out - of - distribution (OOD) test data. 2. **Research objectives**: - Re - examine GNN modeling from a causal perspective and propose a new framework CAL+ (Causal Attention Learning Plus) to reduce the impact of shortcut features on prediction. - By introducing causal inference theories such as backdoor adjustment, ensure that the model can use causal features for prediction while minimizing the confounding effect of shortcut features. - Improve the generalization ability of GNNs on OOD test data, so that they can maintain stable performance in data with different distributions. ### Specific measures: - **GNN modeling from a causal perspective**: Construct a structural causal model (SCM), identify shortcut features as confounding variables, and cut off their influence on the causal path. - **Backdoor adjustment strategy**: Eliminate the confounding effect by combining causal features and multiple shortcut features to ensure the stability of the causal relationship. - **Memory bank module**: Collect and enhance the diversity of shortcut features to ensure sufficient combination. - **Prototype strategy**: Maintain the consistency of causal features within the same class to improve the accuracy and stability of estimation. ### Summary: This paper proposes a new framework CAL+ by introducing causal inference theories, aiming to improve the generalization ability of GNNs on OOD data, especially in graph classification tasks, ensuring that the model can better capture and utilize causal features instead of relying on volatile shortcut features.