- Viewing this README.md in the github repo will suffer from formatting issues. We recommend instead to view index.html (can also be viewed locally).
<p style="text-align: left; width: 60%; margin: auto;">
MIM and VAE models with 2D inputs, and 2D latent space.
<br>
Top row: <b>Black</b> contours depict level sets of P(x); <span style="color: red">red</span> dots are reconstructed test points.
<br>
Bottom row: <span style="color: green">Green</span> contours are one standard deviation ellipses of q(z|x) for test points. Dashed black circles depict one standard deviation of P(z).
<br>
<br>
</p>
<ul style="text-align: left; width: 60%; margin: auto;">
<li>AE (auto-encoder) produces zero predictive variance (i.e., delta function) and lower reconstruction errors, consistent with high mutual information. The structure in the latent space is the result of the architecture inductive bias. The lack of a prior leads to an undetermined alignment with P(z) (i.e., an arbitrary structure in the latent space).</li>
<li>MIM produces lower predictive variance and lower reconstruction errors, consistent with high mutual information, alongside alignment with P(z) (i.e., structured latent space).</li>
<li>VAE is optimized with annealing of beta in beta-VAE. Once annealing is completed (i.e., beta = 1), the VAE posteriors show high predictive variance, which is indicative of partial posterior collapse. The increased variance leads to reduced mutual information and worse reconstruction error as a result of a strong alignment with P(Z) (i.e, overly structured/regularized latent space).</li>
</ul>
The code has been tested on CPU and NVIDIA Titan Xp GPU, using Anaconda, Python 3.6, and zsh:
# tools
zsh 5.4.2
Cuda compilation tools, release 9.0, V9.0.176
conda 4.6.14
Python 3.6.8
# python packages (see requirements.txt for complete list)
scipy==1.1.0
matplotlib==3.0.3
numpy==1.15.4
torchvision==0.2.1
torch==1.0.0
scikit_learn==0.21.3
Please follow installation instructions in the following link: pytorch.
pip install -r requirements.txt
The experiments can be run on the following datasets:
All datasets are included as part of the repo for convenience. Links are provided as a workaround (i.e., in case of issues).
Directory structure (if code fail due to a missing directory please create manually):
src/ - Experiments are assumed to be executed from this directory.
data/assets - Datasets will be saved here.
data/torch-generated - Results will be saved here.
NOTE (if code fails due to CUDA/GPU issues): To prevent the use of CUDA/GPU and enforce CPU computation, please add the following flag to the supplied command lines below:
--no-cuda
Otherwise, by default CUDA will be used, if detected by pytorch.
For detailed explanation of the plots below, please see the paper.
To produce the animation at the top:
# MIM
./vae-as-mim-dataset.py \
--dataset toyMIM \
--z-dim 2 \
--mid-dim 50 \
--min-logvar 6 \
--seed 1 \
--batch-size 128 \
--epochs 49 \
--warmup-steps 25 \
--vis-progress \
--mim-loss \
--mim-samp
#VAE
./vae-as-mim-dataset.py \
--dataset toyVAE \
--z-dim 2 \
--mid-dim 50 \
--min-logvar 6 \
--seed 1 \
--batch-size 128 \
--epochs 49 \
--warmup-steps 25 \
--vis-progress
#AE
./vae-as-mim-dataset.py \
--dataset toyAE \
--z-dim 2 \
--mid-dim 50 \
--min-logvar 6 \
--seed 1 \
--batch-size 128 \
--epochs 49 \
--warmup-steps 25 \
--vis-progress \
--ae-loss
Experimenting with expressiveness of MIM and VAE:
for seed in 1 2 3 4 5 6 7 8 9 10; do
for mid_dim in 5 20 50 100 200 300 400 500; do
# MIM
./vae-as-mim-dataset.py \
--dataset toy4 \
--z-dim 2 \
--mid-dim ${mid_dim} \
--min-logvar 6 \
--seed ${seed} \
--batch-size 128 \
--epochs 200 \
--warmup-steps 3 \
--mim-loss \
--mim-samp
# VAE
./vae-as-mim-dataset.py \
--dataset toy4 \
--z-dim 2 \
--mid-dim ${mid_dim} \
--min-logvar 6 \
--seed ${seed} \
--batch-size 128 \
--epochs 200 \
--warmup-steps 3
done
done
Results below demonstrate posterior collapse in VAE, and the lack of it in MIM.
<div style="text-align: left; width: 9%; display: inline-block"><span style="color: blue;">MIM</span><br><span style="color: red;">VAE</span></div>
<div style="width: 20%; display:inline-block;">
<img width="100%" alt="MI" src="images/toy4/stats/fig.MI_ksg.png">
<p style="text-align: center;">MI</p>
</div>
<div style="width: 20%; display:inline-block;">
<img width="100%" alt="NLL" src="images/toy4/stats/fig.H_q_x.png">
<p style="text-align: center;">NLL</p>
</div>
<div style="width: 20%; display:inline-block;">
<img width="100%" alt="RMSE" src="images/toy4/stats/fig.x_recon_err.png">
<p style="text-align: center;">Recon. RMSE</p>
</div>
<div style="width: 20%; display:inline-block;">
<img width="100%" alt="Cls. Acc." src="images/toy4/stats/fig.clf_acc_KNN5.png">
<p style="text-align: center;">Classification Acc.</p>
</div>
Experimenting with effect of entropy prior on MIM and VAE:
for seed in 1 2 3 4 5 6 7 8 9 10; do
for mid_dim in 5 20 50 100 200 300 400 500; do
# MIM
./vae-as-mim-dataset.py \
--dataset toy4 \
--z-dim 2 \
--mid-dim ${mid_dim} \
--min-logvar 6 \
--seed ${seed} \
--batch-size 128 \
--epochs 200 \
--warmup-steps 3 \
--mim-loss \
--mim-samp \
--inv-H-loss
# VAE
./vae-as-mim-dataset.py \
--dataset toy4 \
--z-dim 2 \
--mid-dim ${mid_dim} \
--min-logvar 6 \
--seed ${seed} \
--batch-size 128 \
--epochs 200 \
--warmup-steps 3 \
--inv-H-loss
done
done
Results below demonstrate how adding joint entropy as regularizer can prevent posterior collapse in VAE, and subtracting the joint entropy can generate a strong collapse in MIM.
<div style="text-align: left; width: 9%; display: inline-block"><span style="color: blue;">MIM</span><br><span style="color: red;">VAE</span></div>
<div style="width: 20%; display:inline-block;">
<img width="100%" alt="MI" src="images/toy4/stats-inv_H/fig.MI_ksg.png">
<p style="text-align: center;">MI</p>
</div>
<div style="width: 20%; display:inline-block;">
<img width="100%" alt="NLL" src="images/toy4/stats-inv_H/fig.H_q_x.png">
<p style="text-align: center;">NLL</p>
</div>
<div style="width: 20%; display:inline-block;">
<img width="100%" alt="RMSE" src="images/toy4/stats-inv_H/fig.x_recon_err.png">
<p style="text-align: center;">Recon. RMSE</p>
</div>
<div style="width: 20%; display:inline-block;">
<img width="100%" alt="Cls. Acc." src="images/toy4/stats-inv_H/fig.clf_acc_KNN5.png">
<p style="text-align: center;">Classification Acc.</p>
</div>
Experimenting with effect of bottleneck on VAE and MIM.
A synthetic 5 GMM dataset with 20D x:
for seed in 1 2 3 4 5 6 7 8 9 10; do
for z_dim in 2 4 6 8 10 12 14 16 18 20; do
# MIM
./vae-as-mim-dataset.py \
--dataset toy4_20 \
--z-dim ${z_dim} \
--mid-dim 50 \
--seed ${seed} \
--epochs 200 \
--min-logvar 6 \
--warmup-steps 3 \
--mim-loss \
--mim-samp
# VAE
./vae-as-mim-dataset.py \
--dataset toy4_20 \
--z-dim ${z_dim} \
--mid-dim 50 \
--seed ${seed} \
--epochs 200 \
--min-logvar 6 \
--warmup-steps 3
done
done
Results below demonstrate posterior collapse in VAE which worsen as the latent dimensionality increases, and the lack of it in MIM.
A PCA reduction of Fashion-MNIST to 20D x:
for seed in 1 2 3 4 5 6 7 8 9 10; do
for z_dim in 2 4 6 8 10 12 14 16 18 20; do
# MIM
./vae-as-mim-dataset.py \
--dataset pca-fashion-mnist20 \
--z-dim ${z_dim} \
--mid-dim 50 \
--seed ${seed} \
--epochs 200 \
--min-logvar 6 \
--warmup-steps 3 \
--mim-loss \
--mim-samp
# VAE
./vae-as-mim-dataset.py \
--dataset pca-fashion-mnist20 \
--z-dim ${z_dim} \
--mid-dim 50 \
--seed ${seed} \
--epochs 200 \
--min-logvar 6 \
--warmup-steps 3
done
done
Results below demonstrate posterior collapse in VAE which worsen as the latent dimensionality increases, and the lack of it in MIM. Here, for real-world data observations.
Experimenting with high dimensional image data where we cannot reliably measure mutual information:
for seed in 1 2 3 4 5 6 7 8 9 10; do
for dataset_name in dynamic_mnist dynamic_fashion_mnist omniglot; do
for model_name in convhvae_2level convhvae_2level-smim pixelhvae_2level pixelhvae_2level-amim; do
for prior in vampprior standard; do
./vae-as-mim-image.py \
--dataset_name ${dataset_name} \
--model_name ${model_name} \
--prior ${prior} \
--seed ${seed} \
--use_training_data_init
done
done
done
done
Results below demonstrate comparable sampling and reconstruction of VAE and MIM, and better unsupervised clustering for MIM, as a result of higher mutual information.
Code for this experiment is based on Vamprior paper
@article{TW:2017,
title={{VAE with a VampPrior}},
author={Tomczak, Jakub M and Welling, Max},
journal={arXiv},
year={2017}
}
Please cite our paper if you use this code in your research:
@misc{livne2019mim,
title={MIM: Mutual Information Machine},
author={Micha Livne and Kevin Swersky and David J. Fleet},
year={2019},
eprint={1910.03175},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
Many thanks to Ethan Fetaya, Jacob Goldberger, Roger Grosse, Chris Maddison, and Daniel Roy for interesting discussions and for their helpful comments. We are especially grateful to Sajad Nourozi for extensive discussions and for his help to empirically validate the formulation and experimental work. This work was financially supported in part by the Canadian Institute for Advanced Research (Program on Learning in Machines and Brains), and NSERC Canada.
If you find this paper and/or repo to be useful, we would love to hear back! Tell us your success stories, and we will include them in this README.