This is the Pytorch implementation for the CMVAE
This code is written in Python. Dependencies include
- python >= 3.6
- pytorch = 1.4 or 1.7
- tqdm, wandb
- Download Omniglot data from here.
- Download pretrained features for Mini-ImageNet from here.
- Download pretrained features for CelebA from here.
- (Optional) If you want to train SimCLR from scratch, download images for ImageNet from here amd CelebA here
data directory should be look like this:
data/
├── omiglot/
├── train.npy
├── val.npy
└── test.npy
├── mimgnet/
├── train_features.npy
├── val_features.npy
└── test_features.npy
├── celeba/
├── train_features.npy
├── val_features.npy
└── test_features.npy
└── imgnet or celeba_imgs/ -> (optional) if you want to train SimCLR from scratch
├── images/
├── n0210891500001298.jpg
├── n0287152500001298.jpg
...
└── n0236282200001298.jpg
├── train.csv
├── val.csv
└── test.csv
To reproduce Omniglot 5-way experiment for CMVAE, run the following code:
cd omniglot
python main.py --data-dir DATA DIRECTORY --save-dir SAVE DIRECTORY --way 5 --sample-size 200
To reproduce Omniglot 20-way experiment for CMVAE, run the following code:
cd omniglot
python main.py --data-dir DATA DIRECTORY --save-dir SAVE DIRECTORY --way 20 --sample-size 300
To reproduce Mini-ImageNet 5-way experiment for CMVAE, run the following code:
cd mimgnet
python main.py --data-dir DATA DIRECTORY --save-dir SAVE DIRECTORY
To reproduce CelebA 5-way experiment for CMVAE, run the following code:
cd celeba
python main.py --data-dir DATA DIRECTORY --save-dir SAVE DIRECTORY
(Optional) To reproduce SimCLR features for Mini-ImageNet, run the following code:
cd simclr
python main.py --data-dir DATA DIRECTORY --save-dir SAVE DIRECTORY --feature-save-dir FEATURE SAVE DIRECTORY
Our work and code benefit from two existing works, which we are very grateful.
Meta-GMVAE
notears