Skip to content

Pytorch Implementation for Stepwise Goal-Driven Networks for Trajectory Prediction (RA-L/ICRA2022)

Notifications You must be signed in to change notification settings

atanas1054/SGNet.pytorch

 
 

Repository files navigation

Pytorch Implementation for Stepwise Goal-Driven Networks for Trajectory Prediction (RA-L/ICRA2022)

Installation

Cloning

We use part of the dataloader in Trajectron++, so we include Trajectron++ as a submodule.

git clone --recurse-submodules [email protected]:ChuhuaW/SGNet.pytorch.git

Environment

  • Install conda environment from yml file
conda env create --file SGNet_env.yml

Data

ln -s path/to/dataset/ ./data/
  • ETH/UCY We follow Trajectron++ to preprocess data splits for the ETH and UCY datasets in this repository. Please refer to their repository for instruction. After the data is generated, please create symlinks from the dataset path to ./data
ln -s path/to/dataset/ ./data/

Training

Stochastic prediction

  • Training on JAAD dataset:
cd SGDNet.Pytorch
python tools/jaad/train_cvae.py --gpu $CUDA_VISIBLE_DEVICES --dataset JAAD --model SGNet_CVAE
  • Training on PIE dataset:
cd SGDNet.Pytorch
python tools/pie/train_cvae.py --gpu $CUDA_VISIBLE_DEVICES --dataset PIE --model SGNet_CVAE
  • Training on ETH/UCY dataset:
cd SGDNet.Pytorch
python tools/ethucy/train_cvae.py --gpu $CUDA_VISIBLE_DEVICES --dataset ETH --model SGNet_CVAE
python tools/ethucy/train_cvae.py --gpu $CUDA_VISIBLE_DEVICES --dataset HOTEL --model SGNet_CVAE
python tools/ethucy/train_cvae.py --gpu $CUDA_VISIBLE_DEVICES --dataset UNIV --model SGNet_CVAE
python tools/ethucy/train_cvae.py --gpu $CUDA_VISIBLE_DEVICES --dataset ZARA1 --model SGNet_CVAE
python tools/ethucy/train_cvae.py --gpu $CUDA_VISIBLE_DEVICES --dataset ZARA2 --model SGNet_CVAE

Deterministic prediction

  • Training on JAAD dataset:
cd SGDNet.Pytorch
python tools/jaad/train_deterministic.py --gpu $CUDA_VISIBLE_DEVICES --dataset JAAD --model SGNet
  • Training on PIE dataset:
cd SGDNet.Pytorch
python tools/pie/train_deterministic.py --gpu $CUDA_VISIBLE_DEVICES --dataset PIE --model SGNet
  • Training on ETH/UCY dataset:
cd SGDNet.Pytorch
python tools/ethucy/train_deterministic.py --gpu $CUDA_VISIBLE_DEVICES --dataset ETH --model SGNet
python tools/ethucy/train_deterministic.py --gpu $CUDA_VISIBLE_DEVICES --dataset HOTEL --model SGNet
python tools/ethucy/train_deterministic.py --gpu $CUDA_VISIBLE_DEVICES --dataset UNIV --model SGNet
python tools/ethucy/train_deterministic.py --gpu $CUDA_VISIBLE_DEVICES --dataset ZARA1 --model SGNet
python tools/ethucy/train_deterministic.py --gpu $CUDA_VISIBLE_DEVICES --dataset ZARA2 --model SGNet

Evaluation

Stochastic prediction

  • Evaluating on JAAD dataset:
cd SGDNet.Pytorch
python tools/jaad/eval_cvae.py --gpu $CUDA_VISIBLE_DEVICES --dataset JAAD --model SGNet_CVAE --checkpoint path/to/checkpoint
  • Evaluating on PIE dataset:
cd SGDNet.Pytorch
python tools/pie/eval_cvae.py --gpu $CUDA_VISIBLE_DEVICES --dataset PIE --model SGNet_CVAE --checkpoint path/to/checkpoint
  • Evaluating on ETH/UCY dataset:
cd SGDNet.Pytorch
python tools/ethucy/eval_cvae.py --gpu $CUDA_VISIBLE_DEVICES --dataset ETH --model SGNet_CVAE --checkpoint path/to/checkpoint
python tools/ethucy/eval_cvae.py --gpu $CUDA_VISIBLE_DEVICES --dataset HOTEL --model SGNet_CVAE --checkpoint path/to/checkpoint
python tools/ethucy/eval_cvae.py --gpu $CUDA_VISIBLE_DEVICES --dataset UNIV --model SGNet_CVAE --checkpoint path/to/checkpoint
python tools/ethucy/eval_cvae.py --gpu $CUDA_VISIBLE_DEVICES --dataset ZARA1 --model SGNet_CVAE --checkpoint path/to/checkpoint
python tools/ethucy/eval_cvae.py --gpu $CUDA_VISIBLE_DEVICES --dataset ZARA2 --model SGNet_CVAE --checkpoint path/to/checkpoint

Deterministic prediction

cd SGDNet.Pytorch
python tools/ethucy/eval_deterministic.py --gpu $CUDA_VISIBLE_DEVICES --dataset ETH --model SGNet --checkpoint path/to/checkpoint
python tools/ethucy/eval_deterministic.py --gpu $CUDA_VISIBLE_DEVICES --dataset HOTEL --model SGNet --checkpoint path/to/checkpoint
python tools/ethucy/eval_deterministic.py --gpu $CUDA_VISIBLE_DEVICES --dataset UNIV --model SGNet --checkpoint path/to/checkpoint
python tools/ethucy/eval_deterministic.py --gpu $CUDA_VISIBLE_DEVICES --dataset ZARA1 --model SGNet --checkpoint path/to/checkpoint
python tools/ethucy/eval_deterministic.py --gpu $CUDA_VISIBLE_DEVICES --dataset ZARA2 --model SGNet --checkpoint path/to/checkpoint

JAAD/PIE checkpoints

Citation

@ARTICLE{9691856,
  author={Wang, Chuhua and Wang, Yuchen and Xu, Mingze and Crandall, David J.},
  journal={IEEE Robotics and Automation Letters}, 
  title={Stepwise Goal-Driven Networks for Trajectory Prediction}, 
  year={2022}}
- Rank 3rd on nuScences prediction task at 6th AI Driving Olympics, ICRA 2021

The source code and pretrained models will be made availble. Stay tuned. PWC PWC PWC

About

Pytorch Implementation for Stepwise Goal-Driven Networks for Trajectory Prediction (RA-L/ICRA2022)

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 97.9%
  • Shell 2.1%