STDA-Meta: A Meta-Learning Framework for Few-Shot Traffic Prediction

Maoxiang Sun,Weilong Ding,Tianpu Zhang,Zijian Liu,Mengda Xing
2023-10-31
Abstract:As the development of cities, traffic congestion becomes an increasingly pressing issue, and traffic prediction is a classic method to relieve that issue. Traffic prediction is one specific application of spatio-temporal prediction learning, like taxi scheduling, weather prediction, and ship trajectory prediction. Against these problems, classical spatio-temporal prediction learning methods including deep learning, require large amounts of training data. In reality, some newly developed cities with insufficient sensors would not hold that assumption, and the data scarcity makes predictive performance worse. In such situation, the learning method on insufficient data is known as few-shot learning (FSL), and the FSL of traffic prediction remains challenges. On the one hand, graph structures' irregularity and dynamic nature of graphs cannot hold the performance of spatio-temporal learning method. On the other hand, conventional domain adaptation methods cannot work well on insufficient training data, when transferring knowledge from different domains to the intended target <a class="link-external link-http" href="http://domain.To" rel="external noopener nofollow">this http URL</a> address these challenges, we propose a novel spatio-temporal domain adaptation (STDA) method that learns transferable spatio-temporal meta-knowledge from data-sufficient cities in an adversarial manner. This learned meta-knowledge can improve the prediction performance of data-scarce cities. Specifically, we train the STDA model using a Model-Agnostic Meta-Learning (MAML) based episode learning process, which is a model-agnostic meta-learning framework that enables the model to solve new learning tasks using only a small number of training samples. We conduct numerous experiments on four traffic prediction datasets, and our results show that the prediction performance of our model has improved by 7\% compared to baseline models on the two metrics of MAE and RMSE.
Machine Learning
What problem does this paper attempt to address?
The problem that this paper attempts to solve is traffic prediction in data - scarce cities. Specifically, traditional spatio - temporal prediction learning methods (such as deep learning methods) require a large amount of training data to achieve effective model training. However, in some newly - developed cities, due to insufficient sensors, an adequate amount of data cannot be provided, which leads to poor prediction performance. Therefore, the paper focuses on how to improve the accuracy of traffic prediction through few - shot learning (FSL) techniques in the case of data scarcity. ### Main contributions of the paper 1. **Application of domain adaptation in spatio - temporal few - shot prediction**: - This method makes full use of the inherent structural similarities in the graph representations of different urban traffic datasets. - The data augmentation technique based on generative adversarial networks (GANs) significantly improves the prediction results, which has been strictly verified through extensive ablation experiments. 2. **Incorporating model - agnostic meta - learning (MAML) into domain adaptation**: - It has been proven that MAML has significant advantages in few - shot learning. By integrating MAML, it can quickly adapt to new tasks with a small amount of data, significantly improving the efficiency of domain adaptation. - This method has been verified in extensive experiments on actual datasets in different cities, showing its potential to solve few - shot challenges in various fields. 3. **Effectiveness on traffic speed datasets in different cities**: - The experimental results show that this model improves the prediction accuracy by 7% compared to the baseline model. ### Method overview #### 1. Problem definition - **Traffic prediction**: The goal is to predict future traffic states, such as traffic speed, flow, demand, and travel time, based on historical traffic data. - **Spatio - temporal graph few - shot learning**: The source city has a large amount of traffic data, while the target city has less data. The paper proposes a few - shot learning method to improve the prediction accuracy in the case of data scarcity in the target city by extracting meta - knowledge from the source city data. #### 2. Method framework - **STDA - Meta framework**: Consists of two modules, namely the spatio - temporal domain adaptation (STDA) module and the inference module. - **STDA module**: Includes a spatio - temporal embedding (ST - E) sub - module and an adversarial adaptation sub - module. Through the adversarial classification method, spatio - temporal transferable features are efficiently captured. - **Inference module**: Includes an ST - E sub - module and an output layer. The output layer adopts a prediction loss \( L_p \). #### 3. Key components - **Spatio - temporal embedding module (ST - E)**: - **Time feature extractor (TF)**: Uses a gated recurrent unit (GRU) to extract time features. - **Space feature extractor (SF)**: Uses a graph attention network (GAT) to extract space features. - **Spatio - temporal feature discrimination module (Gst)**: Distinguishes the spatio - temporal features of the source city and the target city through the adversarial classification method. - **Inference module**: Predicts on the time - series data of data - scarce cities, and uses the root - mean - square error (RMSE) as the prediction loss function. #### 4. Loss function - **Overall loss function** \( L_{\text{overall}} \): Combines the spatio - temporal domain loss \( L_{\text{st}} \) and the prediction loss \( L_p \), and is defined as: \[ L_{\text{overall}}=\lambda L_{\text{st}}+L_p \] #### 5. Learning process - **MAML - based meta - learning process**: Divided into the base model stage and the adaptation stage, and updates the training parameters of the target domain through the gradient descent algorithm. ### Experimental results - **Dataset**: Four publicly available traffic datasets are used, including METR - LA, PEMS - BAY, Didi - Chengdu, and Didi - Shenzhen. - **Performance comparison**: Compared with classical spatio - temporal graph learning methods and transfer learning methods, the results show that STDA - Meta is significantly superior to the baseline model in terms of prediction accuracy, especially in the case of data scarcity. ### Conclusion This paper proposes a novel spatio - temporal domain - adaptation meta - learning framework STDA - Meta, which is effective.