On the Initialization of Graph Neural Networks

Jiahang Li,Yakun Song,Xiang Song,David Paul Wipf
2023-12-05
Abstract:Graph Neural Networks (GNNs) have displayed considerable promise in graph representation learning across various applications. The core learning process requires the initialization of model weight matrices within each GNN layer, which is typically accomplished via classic initialization methods such as Xavier initialization. However, these methods were originally motivated to stabilize the variance of hidden embeddings and gradients across layers of Feedforward Neural Networks (FNNs) and Convolutional Neural Networks (CNNs) to avoid vanishing gradients and maintain steady information flow. In contrast, within the GNN context classical initializations disregard the impact of the input graph structure and message passing on variance. In this paper, we analyze the variance of forward and backward propagation across GNN layers and show that the variance instability of GNN initializations comes from the combined effect of the activation function, hidden dimension, graph structure and message passing. To better account for these influence factors, we propose a new initialization method for Variance Instability Reduction within GNN Optimization (Virgo), which naturally tends to equate forward and backward variances across successive layers. We conduct comprehensive experiments on 15 datasets to show that Virgo can lead to superior model performance and more stable variance at initialization on node classification, link prediction and graph classification tasks. Codes are in <a class="link-external link-https" href="https://github.com/LspongebobJH/virgo_icml2023" rel="external noopener nofollow">this https URL</a>.
Machine Learning,Artificial Intelligence
What problem does this paper attempt to address?
This paper attempts to solve the problem of unstable variance in the initialization process of graph neural networks (GNNs). Specifically, traditional initialization methods (such as Xavier and LeCun initialization) were originally designed for feed - forward neural networks (FNNs) and convolutional neural networks (CNNs). These methods aim to avoid the vanishing gradient and maintain the stability of the information flow by stabilizing the variances of hidden embeddings and gradients. However, when these classical initialization methods are applied to GNNs, they ignore the influence of the input graph structure and the message - passing mechanism on the variance. To better explain this problem, we can represent the forward - propagation process of node \(i\) at layer \(l\) in GNN with the following formula: \[ h_i^l=\sigma\left(\sum_{j \in N(i)} d_{ij} h_j^{l - 1}W^{l - 1}\right) \] where: - \(N(i)\) is the set of first - order neighbors of node \(i\). - \(\sigma\) is the activation function, assumed to be ReLU. - \(W^{l-1}\) is the weight matrix of layer \(l - 1\). - \(d_{ij}\) is the normalization coefficient, defined as \(\frac{1}{\sqrt{(d_i + 1)(d_j + 1)}}\), where \(d_i\) and \(d_j\) are the degrees of nodes \(i\) and \(j\) respectively. Since the message - passing layer and the graph structure of GNNs can affect the initial variance in more complex ways, such as the dependence introduced by the size of the receptive fields of different nodes, the classical initialization assumptions may no longer be applicable. For this reason, the author proposes a new initialization method - Virgo, which is used to reduce the variance instability in GNN optimization. Virgo derives the distribution variance of the weight matrix by minimizing the difference in the overall variance between consecutive layers, thereby achieving better variance stability. In summary, the main contributions of this paper include: 1. Deriving and analyzing the variance expressions of GNN embeddings (forward - propagation) and gradients (reverse - propagation), showing how these quantities are jointly affected by the hidden dimension, activation function, graph structure and GNN message - passing mechanism. 2. Proposing a new initialization method Virgo, which mitigates the influence of these factors by minimizing the overall variance difference between consecutive layers. 3. Conducting comprehensive experiments on 15 datasets to verify the superior performance of Virgo in node classification, link prediction and graph classification tasks, which can improve the prediction accuracy by up to 7% and significantly stabilize the variance during initialization.