Skip to content

PyTorch implementaion of Spatial-Temporal Sequence to Sequence model.

Notifications You must be signed in to change notification settings

xlwang233/STSeq2Seq

Repository files navigation

Spatial-Temporal Sequence to Sequence Model for Traffic Forecasting

A PyTorch implementation of Spatial-Temporal Sequence to Sequence model in the paper: Forecast Network-Wide Traffic States for Multiple Steps Ahead: A Deep Learning Approach Considering Dynamic Non-Local Spatial Correlation and Non-Stationary Temporal Dependency (preprint: https://arxiv.org/abs/2004.02391; journal article: https://doi.org/10.1016/j.trc.2020.102763).

Requirements

  • pytorch >= 1.2.0
  • tensorboard >= 1.14.0
  • scikit-learn
  • statsmodels
  • tqdm

Data Preparation

Speed Data

Step1: Download METR-LA and PEMS-BAY data from Google Drive or Baidu Yun links provided by DCRNN.

Step2: Follow DCRNN's scripts to preprocess data.

Run the following commands to generate train/test/val dataset at data/{METR-LA,PEMS-BAY}/{train,val,test}.npz.

# Create data directories
mkdir -p data/{METR-LA,PEMS-BAY}

# METR-LA
python -m scripts.generate_training_data --output_dir=data/METR-LA --traffic_df_filename=data/metr-la.h5

# PEMS-BAY
python -m scripts.generate_training_data --output_dir=data/PEMS-BAY --traffic_df_filename=data/pems-bay.h5

Graph Data

The construction of graph adjacency matrix is based on pre-calculated road network distances between sensors. Here we simply use the adjacency matrices provided by DCRNN. These matrices are available in /data folder.

Model Training and Testing

This repo follows the PyTorch Template, which uses .json file for parameter configuration.

Training

Run following command for model training.

# train STSeq2Seq 
python train.py -c config.json

Each epoch takes about 100 seconds for METR-LA under computing environment with one Core i7-7700K CPU and single NVIDIA RTX 2080Ti GPU. The training log and models will be saved in saved/METR-LA_STSeq2Seq/

Testing

Run following command to evaluate your trained model.

# test STSeq2Seq 
python test.py -r saved/METR-LA_STSeq2Seq/models/{time stamp}/model_best.pth

A pre-trained model for METR-LA is provided and can be run by:

# run pre-trained STSeq2Seq for METR-LA
python test.py -r pretrained/METR-LA/metr-la.pth

Evaluate Baseline Methods

For neural network models, i.e. FNN and GRU, the training and testing procedure is similar to that of STSeq2Seq.

# train FNN/GRU
python train.py -c config_{FNN,GRU}.json

# test FNN/GRU
python test.py -r saved/METR-LA_{FNN,GRU}/models/{time stamp}/model_best.pth

For HA, ARIMA and SVR, go to scripts/ directory and run

# METR-LA
python eval_baselines.py

Note that ARIMA and SVR are fitted independently on each sensor, thus would probably lead to intolerable computation time (on single machine). Some workarounds may include:

  • consider using parallel and distributed computing tools (e.g. Apache Spark, which I have not tested its feasibility, though) or
  • use simpler models (e.g. use LinearSVR instead of SVR).

Citation

This method has been published as a journal paper. If you find this repo useful for your research, please consider citing the following paper:

@article{wang2020forecast,
  title={Forecast network-wide traffic states for multiple steps ahead: A deep learning approach considering dynamic non-local spatial correlation and non-stationary temporal dependency},
  author={Wang, Xinglei and Guan, Xuefeng and Cao, Jun and Zhang, Na and Wu, Huayi},
  journal={Transportation research part C: emerging technologies},
  volume={119},
  pages={102763},
  year={2020},
  publisher={Elsevier}
}

About

PyTorch implementaion of Spatial-Temporal Sequence to Sequence model.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages