Abstract:Recent years have witnessed the great potential of attention mechanism in graph representation learning. However, while variants of attention-based GNNs are setting new benchmarks for numerous real-world datasets, recent works have pointed out that their induced attentions are less robust and generalizable against noisy graphs due to lack of direct supervision. In this paper, we present a new framework which utilizes the tool of causality to provide a powerful supervision signal for the learning process of attention functions. Specifically, we estimate the direct causal effect of attention to the final prediction, and then maximize such effect to guide attention attending to more meaningful neighbors. Our method can serve as a plug-and-play module for any canonical attention-based GNNs in an end-to-end fashion. Extensive experiments on a wide range of benchmark datasets illustrated that, by directly supervising attention functions, the model is able to converge faster with a clearer decision boundary, and thus yields better performances.
What problem does this paper attempt to address?
### Problems the paper attempts to solve
This paper aims to solve the problems of insufficient robustness and generalization ability of the attention mechanism in Graph Neural Networks (GNNs). Specifically, although attention - based GNNs have achieved significant performance improvements on many real - world datasets, the existing attention mechanisms are not stable and reliable enough on noisy graphs due to the lack of direct supervision signals. To improve this situation, the authors propose a new framework that uses causal inference tools to provide strong supervision signals for the learning process of the attention mechanism.
### Main contributions
1. **Explore a new perspective**: This paper is the first attempt to directly enhance the attention mechanism of GNNs using causal inference tools, which is an underexplored direction.
2. **Propose the CSA framework**: The authors propose Causal Supervision for Attention (CSA), a new supervision framework based on causal inference, which can be applied as a plug - in module to multiple models and tasks to improve the quality of attention.
3. **Experimental verification**: Through extensive experiments and analyses, the authors prove the effectiveness and generality of CSA on standard benchmark datasets.
### Method overview
#### 4.1 Derivation of causal effects
1. **Representation of causal models**: Figure 3 shows a schematic diagram of CSA combined with the graph attention method as a plug - in module. By introducing counterfactual analysis, the authors define the causal effect of attention. Specifically, given node features \(X\) and an attention graph \(A\), the model prediction \(Y\) can be represented as:
\[
Y_{x,a}=Y(X = x, A = a)
\]
2. **Counterfactual intervention**: Through the intervention operation \(do(A = a^*)\), the counterfactual model prediction \(Y_{x,a^*}\) can be obtained:
\[
Y_{x,a^*}=Y(X = x, do(A = a^*))
\]
3. **Total direct effect**: By calculating the difference in model predictions, the Total Direct Effect (TDE) of attention can be obtained:
\[
\text{TDE}=Y_{x,a}-Y_{x,\tilde{a}}
\]
#### 4.2 Supervision using causal effects
1. **Supervision signal**: Use the causal effect of attention as a supervision signal to guide the learning process of the attention mechanism. Specifically, for each layer \(l\), the model can be optimized by maximizing the causal effect of attention:
\[
L=\sum_l\lambda_lL_{\text{ce}}(Y_l^{\text{effect}}, y)+L_{\text{others}}
\]
where \(L_{\text{ce}}\) is the cross - entropy loss, \(\lambda_l\) is a balancing coefficient, and \(L_{\text{others}}\) represents the original objective (such as the standard classification loss).
2. **Counterfactual schemes**: To generate counterfactual attention, the authors propose three heuristic schemes:
- **Scheme I**: Generate counterfactual attention using a uniform distribution:
\[
\hat{a}\sim U(e, f)
\]
- **Scheme II**: Generate counterfactual attention using an identity mapping, only focusing on the central node:
\[
\hat{a}\sim I
\]
- **Scheme III**: Use the historical attention graph as counterfactual:
\[
\hat{a}\sim A_{\text{hist}}
\]
### Experimental results
1. **Heterogeneous graph datasets**: In Wisconsin, Cornell, Texas, Actor, Squirrel