Skip to content

Latest commit

 

History

History
67 lines (45 loc) · 2.07 KB

README.md

File metadata and controls

67 lines (45 loc) · 2.07 KB

Contrastively Disentangled Sequential Variational Autoencoder (C-DSVAE)

Overview

This is the implementation for our C-DSVAE, a novel self-supervised disentangled sequential representation learning method.

A pytorch-lightning implementation (with docker env) can be found here.

Requirements

  • Python 3
  • PyTorch 1.7
  • Numpy 1.18.5

Dataset

Sprites

We provide the raw Sprites .npy files. One can also find the dataset on a third-party repo.

For each split (train/test), we expect the following components for each sequence sample

  • x: raw sample of shape [8, 3, 64, 64]
  • c_aug: content augmentation of shape [8, 3, 64, 64]
  • m_aug: motion augmentation of shape [8, 3, 64, 64]
  • motion factors: action (3 classes), direction (3 classes)
  • content factors: skin, tops, pants, hair (each with 6 classes)

The pre-processed dataset: data.pkl

Running

Train

./run_cdsvae.sh

Test

./run_test_sprite.sh

Classification Judge

The judge classifiers are pretrained with full supervision separately.

C-DSVAE Checkpoints

We provide a sample Sprites checkpoint. Checkpoint parameters can be found in ./run_test_sprite.sh.

Paper

If you are inspired by our work, please cite the following paper:

@article{bai2021contrastively,
  title={Contrastively disentangled sequential variational autoencoder},
  author={Bai, Junwen and Wang, Weiran and Gomes, Carla P},
  journal={Advances in Neural Information Processing Systems},
  volume={34},
  pages={10105--10118},
  year={2021}
}