Skip to content

The implementation of the SHOT-VAE model in paper "SHOT-VAE: Semi-supervised Deep Generative Models With Label-aware ELBO Approximations"

License

Notifications You must be signed in to change notification settings

FengHZ/SHOT-VAE

Repository files navigation

SHOT-VAE: Semi-supervised Deep Generative Models With Label-aware ELBO Approximations

Here is the official implementation of the model SHOT-VAE in paper "SHOT-VAE: Semi-supervised Deep Generative Models With Label-aware ELBO Approximations".

Model Review

SHOT-VAE

The schematic of SHOT-VAE. SHOT-VAE has great advantages in interpretability by capturing semantics-disentangled latent variables as $\mathbf{z}$ represents the image style and $\mathbf{y}$ represents the image class. The smooth-ELBO proposes a more flexible assumption of $$\hat{p}(\mathbf{y}\vert\mathbf{X})$$ with label-smoothing technique. The optimal interpolation performs data augmentation on the input pairs with the most similar continuous representations. The above two components break the ELBO bottleneck.

Setup

Install Package Dependencies

Python Environment: >= 3.6
torch = 1.2.0
torchvision = 0.4.0
scikit-learn >= 0.2
tensorbard >= 2.0.0

Install Datasets

We need users to declare a base path to store the dataset as well as the log of training procedure. The directory structure should be

base_path

└───dataset
│   │   cifar
│       │   cifar-10-batches-py
│       │   |	...
│       │   cifar-100-python
│       │   |	...
│   │   svhn
│       │   ...
│   │   mnist
│       │   ...
└───trained_model_1
│   │	parmater
│   │	runs
└───trained_model_2
│   │	parmater
│   │	runs
...
└───trained_model_n
│   │	parmater
│   │	runs

We refer users to use the following functions in torchvision to install datasets

from os import path
import torch
import torchvision
# set base_path
base_path = "./"
# install mnist,svhn,cifar10,cifar100
torchvision.datasets.MNIST(path.join(base_path,"dataset","mnist"),download=True)
torchvision.datasets.CIFAR10(path.join(base_path,"dataset","cifar"),download=True)
torchvision.datasets.CIFAR100(path.join(base_path,"dataset","cifar"),download=True)
torchvision.datasets.SVHN(path.join(base_path,"dataset","cifar"),download=True)

Or you can manually put the dataset in the appropriate folder.

Running

Notice that we have implemented 3 categories of backbones: WideResNet, PreActResNet and Densenet. Here we give the example for WideResNet. To use other network backbone, please change --net-name parameter (e.g. --net-name preactresnet18).

For CUDA computation, please set the --gpu parameter (e.g. --gpu "0,1" means to use gpu0 and gpu1 together to do calculation).

Semi-supervised Learning

SHOT-VAE in Cifar10 (4k) and Cifar100 (4k and 10k)

Here we list several important parameters need to be set manually in the following table

Parameter Means
br If we use BCE loss in $E_{p,q}\log p(X\vert z,y)$, default is False.
annotated_ratio The annotated ratio for dataset.
ad The milestone list for adjust learning rate.
epochs The total epochs in training process
  1. For Cifar10 (4k), please use the following command

    # for wideresnet-28-2
    python main_shot_vae.py -bp basepath --net-name wideresnet-28-2 --gpu gpuid --br
    # for wideresnet-28-10
    python main_shot_vae.py -bp basepath --net-name wideresnet-28-10 --gpu gpuid --br
  2. For Cifar100 (4k), please use the following command

    # for wideresnet-28-2
    python main_shot_vae.py -bp basepath --net-name wideresnet-28-2 --gpu gpuid --dataset "Cifar100" --annotated-ratio 0.1 -ad [500,600,650] --epochs 700 --br
    # for wideresnet-28-10
    python main_shot_vae.py -bp basepath --net-name wideresnet-28-10 --gpu gpuid --dataset "Cifar100" --annotated-ratio 0.1 -ad [500,600,650] --epochs 700
  3. For Cifar100 (10k), please use the following command

    # for wideresnet-28-2
    python main_shot_vae.py -bp basepath --net-name wideresnet-28-2 --gpu gpuid --dataset "Cifar100" --annotated-ratio 0.25 -ad [500,600,650] --epochs 700 --br
    # for wideresnet-28-10
    python main_shot_vae.py -bp basepath --net-name wideresnet-28-10 --gpu gpuid --dataset "Cifar100" --annotated-ratio 0.25 -ad [500,600,650] --epochs 700

The performance of test dataset in training process for different dataset is listed as:

  • Cifar10 (4k)

    • WideResNet-28-2

    Cifar10-4K-WRN-28-2

    • WideResNet-28-10

    Cifar10-4K-WRN-28-10

  • Cifar100 (4k)

    • WideResNet-28-2

    Cifar100-4k-WRN-28-2

    • WideResNet-28-10

    Cifar100-4k-WRN-28-10

  • Cifar100 (10k)

    • WideResNet-28-2

    Cifar100-WRN-28-2

    • WideResNet-28-10

    Cifar100-WRN-28-10

M2-VAE in Cifar10 (4k) and Cifar100 (4k and 10k)

  1. For Cifar10 (4k), please use the following command

    # for wideresnet-28-2
    python main_M2_vae.py -bp basepath --net-name wideresnet-28-2 --gpu gpuid --br
    # for wideresnet-28-10
    python main_M2_vae.py -bp basepath --net-name wideresnet-28-10 --gpu gpuid --br
  2. For Cifar100 (4k), please use the following command

    # for wideresnet-28-2
    python main_M2_vae.py -bp basepath --net-name wideresnet-28-2 --gpu gpuid --dataset "Cifar100" --annotated-ratio 0.1 -ad [500,600,650] --epochs 700 --br
    # for wideresnet-28-10
    python main_M2_vae.py -bp basepath --net-name wideresnet-28-10 --gpu gpuid --dataset "Cifar100" --annotated-ratio 0.1 -ad [500,600,650] --epochs 700
  3. For Cifar100 (10k), please use the following command

    # for wideresnet-28-2
    python main_M2_vae.py -bp basepath --net-name wideresnet-28-2 --gpu gpuid --dataset "Cifar100" --annotated-ratio 0.25 -ad [500,600,650] --epochs 700 --br
    # for wideresnet-28-10
    python main_M2_vae.py -bp basepath --net-name wideresnet-28-10 --gpu gpuid --dataset "Cifar100" --annotated-ratio 0.25 -ad [500,600,650] --epochs 700

Only classifier in Cifar10 (4k) and Cifar100 (4k and 10k)

  1. For Cifar10 (4k), please use the following command

    # for wideresnet-28-2
    python main_classifier.py -bp basepath --net-name wideresnet-28-2 --gpu gpuid
    # for wideresnet-28-10
    python main_classifier.py -bp basepath --net-name wideresnet-28-10 --gpu gpuid 
  2. For Cifar100 (4k), please use the following command

    # for wideresnet-28-2
    python main_classifier.py -bp basepath --net-name wideresnet-28-2 --gpu gpuid --dataset "Cifar100" --annotated-ratio 0.1 
    # for wideresnet-28-10
    python main_classifier.py -bp basepath --net-name wideresnet-28-10 --gpu gpuid --dataset "Cifar100" --annotated-ratio 0.1 
  3. For Cifar100 (10k), please use the following command

    # for wideresnet-28-2
    python main_classifier.py -bp basepath --net-name wideresnet-28-2 --gpu gpuid --dataset "Cifar100" --annotated-ratio 0.25
    # for wideresnet-28-10
    python main_classifier.py -bp basepath --net-name wideresnet-28-2 --gpu gpuid --dataset "Cifar100" --annotated-ratio 0.25

Smooth-ELBO VAE in MNIST (100) and SVHN (1k) [Table.1]

Use the following commands to reproduce our results

# run Smooth-ELBO VAE on MNIST (100)
python main_smooth_ELBO_mnist.py -bp basepath --gpu gpuid

# run One-stage SSL VAE on SVHN (1k)
python main_smooth_ELBO_svhn.py -bp basepath --gpu gpuid

References

If you find this useful in your work please consider citing:

@article{DBLP:journals/corr/abs-2011-10684,
  author    = {Haozhe Feng and
               Kezhi Kong and
               Minghao Chen and
               Tianye Zhang and
               Minfeng Zhu and
               Wei Chen},
  title     = {{SHOT-VAE:} Semi-supervised Deep Generative Models With Label-aware
               {ELBO} Approximations},
  journal   = {CoRR},
  volume    = {abs/2011.10684},
  year      = {2020},
  url       = {https://arxiv.org/abs/2011.10684},
  archivePrefix = {arXiv},
  eprint    = {2011.10684}
}

About

The implementation of the SHOT-VAE model in paper "SHOT-VAE: Semi-supervised Deep Generative Models With Label-aware ELBO Approximations"

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages