Here is the official implementation of the model SHOT-VAE
in paper "SHOT-VAE: Semi-supervised Deep Generative Models
With Label-aware ELBO Approximations".
The schematic of SHOT-VAE. SHOT-VAE has great advantages in interpretability by capturing semantics-disentangled latent variables as
Python Environment: >= 3.6
torch = 1.2.0
torchvision = 0.4.0
scikit-learn >= 0.2
tensorbard >= 2.0.0
We need users to declare a base path
to store the dataset as well as the log of training procedure. The directory structure should be
base_path
│
└───dataset
│ │ cifar
│ │ cifar-10-batches-py
│ │ | ...
│ │ cifar-100-python
│ │ | ...
│ │ svhn
│ │ ...
│ │ mnist
│ │ ...
└───trained_model_1
│ │ parmater
│ │ runs
└───trained_model_2
│ │ parmater
│ │ runs
...
└───trained_model_n
│ │ parmater
│ │ runs
We refer users to use the following functions in torchvision
to install datasets
from os import path
import torch
import torchvision
# set base_path
base_path = "./"
# install mnist,svhn,cifar10,cifar100
torchvision.datasets.MNIST(path.join(base_path,"dataset","mnist"),download=True)
torchvision.datasets.CIFAR10(path.join(base_path,"dataset","cifar"),download=True)
torchvision.datasets.CIFAR100(path.join(base_path,"dataset","cifar"),download=True)
torchvision.datasets.SVHN(path.join(base_path,"dataset","cifar"),download=True)
Or you can manually put the dataset in the appropriate folder.
Notice that we have implemented 3 categories of backbones: WideResNet, PreActResNet and Densenet. Here we give the example for WideResNet. To use other network backbone, please change --net-name
parameter (e.g. --net-name preactresnet18
).
For CUDA computation, please set the --gpu
parameter (e.g. --gpu "0,1"
means to use gpu0 and gpu1 together to do calculation).
Here we list several important parameters need to be set manually in the following table
Parameter | Means |
---|---|
br | If we use BCE loss in |
annotated_ratio | The annotated ratio for dataset. |
ad | The milestone list for adjust learning rate. |
epochs | The total epochs in training process |
-
For Cifar10 (4k), please use the following command
# for wideresnet-28-2 python main_shot_vae.py -bp basepath --net-name wideresnet-28-2 --gpu gpuid --br # for wideresnet-28-10 python main_shot_vae.py -bp basepath --net-name wideresnet-28-10 --gpu gpuid --br
-
For Cifar100 (4k), please use the following command
# for wideresnet-28-2 python main_shot_vae.py -bp basepath --net-name wideresnet-28-2 --gpu gpuid --dataset "Cifar100" --annotated-ratio 0.1 -ad [500,600,650] --epochs 700 --br # for wideresnet-28-10 python main_shot_vae.py -bp basepath --net-name wideresnet-28-10 --gpu gpuid --dataset "Cifar100" --annotated-ratio 0.1 -ad [500,600,650] --epochs 700
-
For Cifar100 (10k), please use the following command
# for wideresnet-28-2 python main_shot_vae.py -bp basepath --net-name wideresnet-28-2 --gpu gpuid --dataset "Cifar100" --annotated-ratio 0.25 -ad [500,600,650] --epochs 700 --br # for wideresnet-28-10 python main_shot_vae.py -bp basepath --net-name wideresnet-28-10 --gpu gpuid --dataset "Cifar100" --annotated-ratio 0.25 -ad [500,600,650] --epochs 700
The performance of test dataset in training process for different dataset is listed as:
-
Cifar10 (4k)
- WideResNet-28-2
- WideResNet-28-10
-
Cifar100 (4k)
- WideResNet-28-2
- WideResNet-28-10
-
Cifar100 (10k)
- WideResNet-28-2
- WideResNet-28-10
-
For Cifar10 (4k), please use the following command
# for wideresnet-28-2 python main_M2_vae.py -bp basepath --net-name wideresnet-28-2 --gpu gpuid --br # for wideresnet-28-10 python main_M2_vae.py -bp basepath --net-name wideresnet-28-10 --gpu gpuid --br
-
For Cifar100 (4k), please use the following command
# for wideresnet-28-2 python main_M2_vae.py -bp basepath --net-name wideresnet-28-2 --gpu gpuid --dataset "Cifar100" --annotated-ratio 0.1 -ad [500,600,650] --epochs 700 --br # for wideresnet-28-10 python main_M2_vae.py -bp basepath --net-name wideresnet-28-10 --gpu gpuid --dataset "Cifar100" --annotated-ratio 0.1 -ad [500,600,650] --epochs 700
-
For Cifar100 (10k), please use the following command
# for wideresnet-28-2 python main_M2_vae.py -bp basepath --net-name wideresnet-28-2 --gpu gpuid --dataset "Cifar100" --annotated-ratio 0.25 -ad [500,600,650] --epochs 700 --br # for wideresnet-28-10 python main_M2_vae.py -bp basepath --net-name wideresnet-28-10 --gpu gpuid --dataset "Cifar100" --annotated-ratio 0.25 -ad [500,600,650] --epochs 700
-
For Cifar10 (4k), please use the following command
# for wideresnet-28-2 python main_classifier.py -bp basepath --net-name wideresnet-28-2 --gpu gpuid # for wideresnet-28-10 python main_classifier.py -bp basepath --net-name wideresnet-28-10 --gpu gpuid
-
For Cifar100 (4k), please use the following command
# for wideresnet-28-2 python main_classifier.py -bp basepath --net-name wideresnet-28-2 --gpu gpuid --dataset "Cifar100" --annotated-ratio 0.1 # for wideresnet-28-10 python main_classifier.py -bp basepath --net-name wideresnet-28-10 --gpu gpuid --dataset "Cifar100" --annotated-ratio 0.1
-
For Cifar100 (10k), please use the following command
# for wideresnet-28-2 python main_classifier.py -bp basepath --net-name wideresnet-28-2 --gpu gpuid --dataset "Cifar100" --annotated-ratio 0.25 # for wideresnet-28-10 python main_classifier.py -bp basepath --net-name wideresnet-28-2 --gpu gpuid --dataset "Cifar100" --annotated-ratio 0.25
Use the following commands to reproduce our results
# run Smooth-ELBO VAE on MNIST (100)
python main_smooth_ELBO_mnist.py -bp basepath --gpu gpuid
# run One-stage SSL VAE on SVHN (1k)
python main_smooth_ELBO_svhn.py -bp basepath --gpu gpuid
If you find this useful in your work please consider citing:
@article{DBLP:journals/corr/abs-2011-10684,
author = {Haozhe Feng and
Kezhi Kong and
Minghao Chen and
Tianye Zhang and
Minfeng Zhu and
Wei Chen},
title = {{SHOT-VAE:} Semi-supervised Deep Generative Models With Label-aware
{ELBO} Approximations},
journal = {CoRR},
volume = {abs/2011.10684},
year = {2020},
url = {https://arxiv.org/abs/2011.10684},
archivePrefix = {arXiv},
eprint = {2011.10684}
}