Knowledge Distillation on Spatial-Temporal Graph Convolutional Network for Traffic Prediction

Mohammad Izadi,Mehran Safayani,Abdolreza Mirzaei
2024-09-24
Abstract:Efficient real-time traffic prediction is crucial for reducing transportation time. To predict traffic conditions, we employ a spatio-temporal graph neural network (ST-GNN) to model our real-time traffic data as temporal graphs. Despite its capabilities, it often encounters challenges in delivering efficient real-time predictions for real-world traffic data. Recognizing the significance of timely prediction due to the dynamic nature of real-time data, we employ knowledge distillation (KD) as a solution to enhance the execution time of ST-GNNs for traffic prediction. In this paper, We introduce a cost function designed to train a network with fewer parameters (the student) using distilled data from a complex network (the teacher) while maintaining its accuracy close to that of the teacher. We use knowledge distillation, incorporating spatial-temporal correlations from the teacher network to enable the student to learn the complex patterns perceived by the teacher. However, a challenge arises in determining the student network architecture rather than considering it inadvertently. To address this challenge, we propose an algorithm that utilizes the cost function to calculate pruning scores, addressing small network architecture search issues, and jointly fine-tunes the network resulting from each pruning stage using KD. Ultimately, we evaluate our proposed ideas on two real-world datasets, PeMSD7 and PeMSD8. The results indicate that our method can maintain the student's accuracy close to that of the teacher, even with the retention of only 3% of network parameters.
Machine Learning,Artificial Intelligence
What problem does this paper attempt to address?
The problem that this paper attempts to solve is how to improve the execution efficiency of Spatio - Temporal Graph Neural Network (ST - GNN) in real - time traffic prediction while maintaining high prediction accuracy. Specifically, the author points out that although ST - GNN performs well in processing spatio - temporal data, in practical applications, due to the need to process a large number of nodes, it leads to high computational costs and a large demand for hardware resources. To overcome this challenge, the paper proposes a method that combines Knowledge Distillation (KD) and Network Pruning, aiming to train a student network with fewer parameters so that it can perform traffic prediction quickly and accurately. ### Main contributions: 1. **Combination of knowledge distillation and network pruning**: By designing a new loss function, the paper proposes a method of jointly using knowledge distillation and network pruning to optimize the structure and performance of the student network. 2. **Dynamically adjusting the student network architecture**: An algorithm is proposed, which uses the loss function to calculate the pruning score and dynamically adjusts the architecture of the student network to ensure high prediction accuracy while reducing parameters. 3. **Experimental verification**: Experiments were carried out on two real - world traffic datasets (PeMSD7 and PeMSD8). The results show that the proposed method can maintain prediction accuracy comparable to that of the teacher network while only retaining 3% of the network parameters and significantly improve the execution efficiency. ### Key technical points: - **Knowledge distillation**: Through two methods, Response - based Distillation and Feature - based Distillation, the knowledge of the teacher network is transferred to the student network. - **Network pruning**: A pruning algorithm based on importance scores is designed, and the structure of the student network is gradually optimized through iterative pruning and fine - tuning. - **Loss function**: A comprehensive loss function \( L_{STCD} \) is proposed, which combines multiple sub - loss functions of response distillation and feature distillation to ensure that the student network is consistent with the teacher network at the spatio - temporal level. ### Formula analysis: - **Response distillation loss function**: \[ L_{RD}(KL)_{bi} = \beta \cdot KL(y_s^{bi}, y_t^{bi})+(1 - \beta)\cdot \|y_s^{bi}-T^{bi}\|^2 \] \[ L_{RD}(L2)_{bi} = \beta \cdot \|y_s^{bi}-y_t^{bi}\|^2+(1 - \beta)\cdot \|y_s^{bi}-T^{bi}\|^2 \] where \( y_s^{bi} \) and \( y_t^{bi} \) represent the outputs of the student network and the teacher network respectively, \( T^{bi} \) represents the target data, and \( \beta \) is an adjustable coefficient. - **Time - correlation distillation loss function**: \[ TCD_{bnij}=\frac{1}{C}\sum_{c = 1}^C|F_{binc}-F_{bjnc}| \] \[ L_{TCD}=\frac{1}{B\cdot N\cdot \binom{T}{2}}\sum_{b = 1}^B\sum_{n = 1}^N\sum_{i,j\atop j > i}^T\|TCD_s^{bnij}-TCD_t^{bnij}\|^2 \] - **Space - correlation distillation loss function**: \[ SCD_{btij}=\frac{1}{C}\sum_{c = 1}^C|F_{btic}-F_{btjc}| \] \[ L_{SCD}=\frac{1}{B\cdot T\cdot \binom{N}{2}}\sum_{b = 1}^B\sum_{t = 1}^T\cdots