The official repository for <Neighborhood Reconstructing Autoencoders> (Lee, Kwon, and Park, NeurIPS 2021).
Figure 1: De-noising property of the NRAE (Left: Vanilla AE, Middle: NRAE-L, Right: NRAE-Q). Figure 2: Correct local connectivity learned by the NRAE (Left: Vanilla AE, Middle: NRAE-L, Right: NRAE-Q). Figure 3: Generated sequences of rotated images by travelling the 1d latent spaces (Top: Vanilla AE, Middle: NRAE-L, Bottom: NRAE-Q). Figure 3: Generated sequences of shifted images by travelling the 1d latent spaces (Top: Vanilla AE, Middle: NRAE-L, Bottom: NRAE-Q).This paper proposes Neighborhood Reconstructing Autoencoders (NRAE), which is a graph-based autoencoder that explicitly accounts for the local connectivity and geometry of the data, and consequently learns a more accurate data manifold and representation.
The project is developed under a standard PyTorch environment.
- python 3.8.8
- numpy
- matplotlib
- imageio
- argparse
- yaml
- omegaconf
- torch 1.8.0
- CUDA 11.1
python train_{X}.py --config configs/{A}_{B}_{C}.yml --device 0
X
is eithersynthetic
orMNIST
A
is eitherAE
,NRAEL
, orNRAEQ
B
is eithertoy
ormnist
- If
B
istoy
, thenC
is eitherdenoising
orgeometry_preserving
. ElseifB
ismnist
, thenC
is eitherrotated
orshifted
.
- The most important parameters requiring tuning include: i) the number of nearest neighbors for graph construction
num_nn
and ii) kernel parameterlambda
(you can find these parameters inconfigs/NRAEL_toy_denoising.yml
for example). - We empirically observe that setting as
include_center=True
(when defining data loader) has performance advantange. - You can add a new type of 2d synthetic dataset in
loader.synthetic_dataset.SyntheticData.get_data
(currently, we havesincurve
andswiss_roll
).
If you found this library useful in your research, please consider citing:
@article{lee2021neighborhood,
title={Neighborhood Reconstructing Autoencoders},
author={Lee, Yonghyeon and Kwon, Hyeokjun and Park, Frank},
journal={Advances in Neural Information Processing Systems},
volume={34},
year={2021}
}