Skip to content

Latest commit

 

History

History
44 lines (39 loc) · 1.75 KB

README.md

File metadata and controls

44 lines (39 loc) · 1.75 KB

Scripts

We also provided a training script as example allowing you to train VAE models on well known benchmark data set (mnist, cifar10, celeba ...). The script can be launched with the following commandline

python training.py --dataset mnist --model_name ae --model_config 'configs/ae_config.json' --training_config 'configs/base_training_config.json'

The folder structure should be as follows:

.
├── configs # the model & training config files (you can amend these files as desired or specify the location of yours in '--model_config' )
│   ├── ae_config.json
│   ├── base_training_config.json
│   ├── beta_vae_config.json
│   ├── hvae_config.json
│   ├── rhvae_config.json
│   ├── vae_config.json
│   └── vamp_config.json
├── data # the dataset with train_data.npz and eval_data.npz files
│   ├── celeba
│   │   ├── eval_data.npz
│   │   └── train_data.npz
│   ├── cifar10
│   │   ├── eval_data.npz
│   │   └── train_data.npz
│   └── mnist
│       ├── eval_data.npz
│       └── train_data.npz
├── my_models # trained models are saved here
│   ├── AE_training_2021-10-15_16-07-04 
│   └── RHVAE_training_2021-10-15_15-54-27
├── README.md
└── training.py

Note The data in the train_data.npz and eval_data.npz files must be loadable as follows

train_data = np.load(os.path.join(PATH, f'data/{args.dataset}', 'train_data.npz'))['data']
eval_data = np.load(os.path.join(PATH, f'data/{args.dataset}', 'eval_data.npz'))['data']

where train_data and eval_data have now the shape (n_img x im_channel x height x width)