Skip to content
/ discvae Public

A deep generative model that simultaneously clusters and disentangles latent representations of sequences.

License

Notifications You must be signed in to change notification settings

mazrk7/discvae

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

DiSCVAE

Disentangled Sequence Clustering Variational Autoencoder (DiSCVAE) -- a deep generative model that simultaneously clusters and disentangles latent representations of sequences. This repository contains code for the DiSCVAE applied to two synthetic video datasets: Moving MNIST and Sprites.

Please refer to the following paper for a description of its formulation and application:

Disentangled Sequence Clustering for Human Intention Inference
Mark Zolotas, Yiannis Demiris

The paper has been accepted for publication at the 2022 IEEE International Conference on Intelligent Robots and Systems (IROS). Whilst this paper provides experimental findings on the problem of human intention inference for robotic wheelchairs, you can also refer to v3 in the arXiv source for DiSCVAE results on the Moving MNIST and Sprites datasets.

Directory Layout

  • bin: A set of shell scripts illustrating examples of how to train and evaluate different models on the chosen datasets.
  • checkpoints: Checkpoint states of trained models and the directory for storing evaluated results (quantitative and qualitative).
  • data: Where the Moving MNIST dataset (.npz format) and the Sprites dataset (npy/ directory of .npz files) should be stored.
  • scripts: Python modules related to running the DiSCVAE and other sequence models.

Dependencies

Working functionality tested in a Python 3.8.10 environment using TensorFlow 2.9.1 and its Probability library. GPU toolkit dependencies are configured for CUDA 11.2 and cuDNN 8.1.0. In general, the dependencies can be installed as such:

pip install -r requirements.txt

Dataset Preparation

The Moving MNIST dataset utilised for this work can be generated using scripts/data/moving_mnist.py. This can be run from the top-level directory:

python scripts/data/moving_mnist.py --dataset_path ./data --filename moving_mnist

Likewise, the Sprites dataset is generated by following the instructions from this repository and creating a folder of .npy files. This folder can then be moved to the data directory and scripts/data/lpc.py will load the dataset.

Training and Evaluation

The shell scripts in the bin directory contain numerous configurations for training different models on the two datasets. Before running these scripts, please change the $WS_DIR environment variable to your corresponding workspace directory.

For example, the following will train the DiSCVAE on MovingMNIST across 10 random seeds:

./bin/run_discvae_train.sh 

Or the VRNN can be trained on Sprites using:

./bin/run_vrnn_train.sh 

Similarly, the evaluation scripts can be run as follows:

./bin/run_discvae_eval.sh 

Tests have also been conducted on the Disentangled Sequential Autoencoder using the tfp implementation.

Contact

If you have any questions about the code, please feel free to reach out to the author Mark Zolotas!

About

A deep generative model that simultaneously clusters and disentangles latent representations of sequences.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published