This repository contains the code for the paper:
Ingrid Navarro and Jean Oh
Social-PatteRNN is an algorithm for recurrent, multi-modal trajectory prediction in multi-agent settings. Our approach guides long-term trajectory prediction by learning to predict short-term motion patterns. It then extracts sub-goal information from the patterns and aggregates it as social context.
This version of the code now supports context information.
Setup a conda environment:
conda create --name sprnn python=3.7
conda activate sprnn
Download the repository and install requirements:
git clone --branch sprnn [email protected]:cmubig/social-patternn.git
cd social-patternn
pip install -e .
We have tested our algorithm on three different datasets:
- TrajAir (111days)
We provide instructions and the dataloaders to setup the data here.
This repository provides four baselines:
VRNN
: a Recurrent C-VAE for trajectory predictionPATTERNN
: a VRNN with a context module for pattern learningSOCPATTERNN-MLP
: a VRNN with a context module for pattern learning and interaction encodingSOCPATTERNN-MHA
: the full model; a VRNN with a context module for pattern learning and interaction encoding with multi-head attention
All of the parameters related to the trajectory specifications, training
details and model architectures are provided in the configuration files of each
baseline and experiment. These configuration files can be found in
social-patternn/config/dataset-name
.
The run.py
script controls the training, validation and testing for all
experiments and datasets. An experiment is specified with the flag --exp-config
,
and the type of process is specified with the flag --run-type
:
python run.py --exp-config path/to/exp-config.json --run-type [trainval | train | eval | test]
For example, to train and validate the SocialPatteRNN model on the 111days dataset, execute:
python run.py --exp config/111days/socpatternn-mha.json
To test a trained model, run:
python run.py --exp config/111days/socpatternn-mha.json --run test --best
To test or evaluate one of a specific checkpoint, you can specify the checkpoint
number ckpt_num
if the checkpoint is in the default path or the checkpoint
path ckpt_path
if not.
Example with checkpoint number which would load ckpt_10.pth
:
python run.py --exp config/111days/socpatternn-mha.json --run test --ckpt_num 10
For each experiment, we provide the configuration files for all the ablations performed in our paper. They are organized as follows:
config/
├─ 111days
| ├─ base_config.json
| ├─ vrnn.json
| ├─ patternn.json
| ├─ socpatternn_mlp.json
| ├─ socpatternn_mha.json
| ├─ ...
├─ ...
Baselines | MinADE | MinFDE |
---|---|---|
VRNN | 0.647 | 1.392 |
PATTERNN | 0.619 | 1.385 |
SOCPATTERNN-MLP | 0.608 | 1.203 |
SOCPATTERNN-MHA | 0.541 | 1.192 |
Baselines | MinADE | MinFDE |
---|---|---|
VRNN | ||
PATTERNN | ||
SOCPATTERNN-MLP | ||
SOCPATTERNN-MHA |
Baselines | MinADE | MinFDE |
---|---|---|
VRNN | ||
PATTERNN | ||
SOCPATTERNN-MLP | ||
SOCPATTERNN-MHA |
@article{navarro2022social,
title={Social-PatteRNN: Socially-Aware Trajectory Prediction Guided by Motion Patterns},
author={Navarro, Ingrid and Oh, Jean},
journal={arXiv preprint arXiv:2209.05649},
year={2022}
}