Kaidi Cao,Phitchaya Mangpo Phothilimthana,Sami Abu-El-Haija,Dustin Zelle,Yanqi Zhou,Charith Mendis,Jure Leskovec,Bryan Perozzi
Abstract:Learning to predict properties of large graphs is challenging because each prediction requires the knowledge of an entire graph, while the amount of memory available during training is bounded. Here we propose Graph Segment Training (GST), a general framework that utilizes a divide-and-conquer approach to allow learning large graph property prediction with a constant memory footprint. GST first divides a large graph into segments and then backpropagates through only a few segments sampled per training iteration. We refine the GST paradigm by introducing a historical embedding table to efficiently obtain embeddings for segments not sampled for backpropagation. To mitigate the staleness of historical embeddings, we design two novel techniques. First, we finetune the prediction head to fix the input distribution shift. Second, we introduce Stale Embedding Dropout to drop some stale embeddings during training to reduce bias. We evaluate our complete method GST-EFD (with all the techniques together) on two large graph property prediction benchmarks: MalNet and TpuGraphs. Our experiments show that GST-EFD is both memory-efficient and fast, while offering a slight boost on test accuracy over a typical full graph training regime.
What problem does this paper attempt to address?
### What problems does this paper attempt to solve?
This paper aims to solve the problem of computational resource limitations in large - scale graph property prediction, especially when dealing with extremely large graphs containing millions or even billions of nodes and edges. Traditional Graph Neural Networks (GNNs) require memory that is linearly related to the size of the graph during the training process, which makes it difficult for even the most powerful GPUs to handle such large - scale graph data. Specifically, the paper proposes solutions to the following problems:
1. **Excessively high memory consumption**: For extremely large graphs, the traditional full - graph training method will lead to Out - of - Memory (OOM) because it is necessary to store the intermediate activation states of all nodes and edges for backpropagation.
2. **Low computational efficiency**: Since it is necessary to process the data of the entire graph, the training process is very time - consuming, especially when the graph scale is large.
To solve these problems, the paper proposes the **Graph Segment Training (GST)** framework. By dividing a large graph into multiple small segments for training, efficient graph property prediction can be achieved while maintaining low memory usage. In addition, in order to further optimize the training process and alleviate the staleness problem caused by historical embeddings, the author introduces techniques such as historical embedding tables, prediction head fine - tuning, and stale embedding dropout.
### Main contributions of the GST framework
- **Constant memory usage**: GST ensures that the memory usage in each training step is at a constant level by only performing backpropagation on randomly selected partial segments, and does not depend on the original size of the graph.
- **Accelerated training**: By using the historical embedding table to avoid forward propagation for segments that do not require gradients, the training speed is significantly increased.
- **Alleviation of the staleness problem**: Two techniques are introduced to reduce the staleness problem caused by historical embeddings:
- **Prediction head fine - tuning**: Fine - tune the prediction head at the end of training to reduce the difference in input distribution.
- **Stale Embedding Dropout (SED)**: Selectively discard some stale embeddings during the training process to reduce bias.
### Experimental results
The paper conducted experiments on two large - scale graph property prediction benchmark datasets, namely MalNet and TpuGraphs. The experimental results show that the GST + EFD method is not only more efficient in memory usage, but also 3 times faster in training speed than traditional methods, and also has a slight improvement in test accuracy.
### Formula summary
- **Graph segmentation**: \[ G(i)\approx\sum_{j = 1}^{J(i)}G(i)_j\]
- **Segment embedding aggregation**: \[ h(i)=L(\{h(i)_s\}_{s\in S(i)}\oplus\{\bar{h(i)}_j\}_{j\notin S(i)})\]
- **Loss function**: \[ L(F'(h(i)_s\oplus\tilde{h(i)}_j), y(i))\]
These formulas show how GST achieves efficient graph property prediction by segmenting the graph, aggregating segment embeddings, and using historical embedding tables.
In conclusion, this paper solves the memory and computational bottleneck problems in large - scale graph property prediction through innovative methods, providing new ideas and technical means for research in related fields.