DIVE: Subgraph Disagreement for Graph Out-of-Distribution Generalization

Xin Sun,Liang Wang,Qiang Liu,Shu Wu,Zilei Wang,Liang Wang
2024-08-08
Abstract:This paper addresses the challenge of out-of-distribution (OOD) generalization in graph machine learning, a field rapidly advancing yet grappling with the discrepancy between source and target data distributions. Traditional graph learning algorithms, based on the assumption of uniform distribution between training and test data, falter in real-world scenarios where this assumption fails, resulting in suboptimal performance. A principal factor contributing to this suboptimal performance is the inherent simplicity bias of neural networks trained through Stochastic Gradient Descent (SGD), which prefer simpler features over more complex yet equally or more predictive ones. This bias leads to a reliance on spurious correlations, adversely affecting OOD performance in various tasks such as image recognition, natural language understanding, and graph classification. Current methodologies, including subgraph-mixup and information bottleneck approaches, have achieved partial success but struggle to overcome simplicity bias, often reinforcing spurious correlations. To tackle this, we propose DIVE, training a collection of models to focus on all label-predictive subgraphs by encouraging the models to foster divergence on the subgraph mask, which circumvents the limitation of a model solely focusing on the subgraph corresponding to simple structural patterns. Specifically, we employs a regularizer to punish overlap in extracted subgraphs across models, thereby encouraging different models to concentrate on distinct structural patterns. Model selection for robust OOD performance is achieved through validation accuracy. Tested across four datasets from GOOD benchmark and one dataset from DrugOOD benchmark, our approach demonstrates significant improvement over existing methods, effectively addressing the simplicity bias and enhancing generalization in graph machine learning.
Machine Learning,Artificial Intelligence
What problem does this paper attempt to address?
This paper attempts to address the Out-of-Distribution (OOD) generalization problem in graph machine learning. Specifically, the paper points out that most current graph learning algorithms are based on the assumption that the training data and test data distributions are consistent. However, in real-world scenarios, this assumption often does not hold, leading to poor model performance. Particularly in the presence of distribution shifts, neural networks trained via Stochastic Gradient Descent (SGD) exhibit a "simplicity bias," which means they prefer simple features while ignoring equally or more predictive complex features. This bias causes models to rely on spurious correlations, thereby affecting their OOD generalization ability. To address the above issues, the authors propose a new learning paradigm—DIVE (Subgraph Disagreement for Graph Out-of-Distribution Generalization). DIVE trains a set of models to focus on all label-predictive subgraphs and encourages these models to disagree on subgraph masks, thereby avoiding the problem of focusing only on simple structural patterns. Specifically, DIVE introduces a regularization term to penalize the overlap of subgraphs extracted by different models, thus encouraging different models to focus on different structural patterns. In this way, DIVE can effectively tackle the simplicity bias and improve the generalization ability in graph machine learning tasks. Experimental results show that DIVE outperforms existing methods on multiple datasets.