[CS726-2023] Programming assignment exploring diffusion models
Steps to get started with the code:
- Install Anaconda on your system, download from --
https://www.anaconda.com
. - Clone the github repo --
git clone https://github.com/ashutoshbsathe/explore-diffusion.git
, into some convenient folder of your choice. cd explore-diffusion
.- Run the command --
conda env create --file environment.yaml
. This will setup all the required dependencies. - Activate the environment using
source activate cs726-env
orconda activate cs726-env
. You are done with the setup.
Once you code up your model in the model.py
file, you can use the provided trainer in the train.py
file to train your model, as -- python train.py
.
You can use various command line arguments to tweak the number of epochs, batch size, etc. Please check the train.py
file for details. You can get the full list of available hyperparameters by doing python train.py -h
After completion of training you can find the checkpoint and hyperparams under the runs
directory. A demo directory structure is shown as follows:
Of interest are the last.ckpt
and hparams.yaml
files, which will be used while evaluating the trained model.
Once the trained model is available, you can use the eval.py
file to generate the metrics and visualizations. Refer the command line arguments to
understand further. A demo run is as follows:
python eval.py --ckpt_path runs/n_dim=3,n_steps=50,lbeta=1.000e-05,ubeta=1.280e-02,batch_size=1024,n_epochs=500/last.ckpt \
--hparams_path runs/n_dim=3,n_steps=50,lbeta=1.000e-05,ubeta=1.280e-02,batch_size=1024,n_epochs=500/lightning_logs/version_0/hparams.yaml \
--eval_nll --vis_diffusion --vis_overlay
This evaluates the trained model on samples generated from --eval_nll
). It also generates neat visualization of the diffusion process as gif
animations.
Example plot generated with --vis_overlay
.
Here, yellow-magenta points represent the original distribution and the blue-purple points indicate samples generated from a trained DDPM
Example animation produced with --vis_diffusion
.
Here, yellow-magenta points represent the original distribution and the blue-purple points indicate samples generated from a trained DDPM. Notice how the blue-purple points slowly become closer and closer to the original distribution as the reverse process progresses.
Special thanks to Kanad Pardeshi for generating the 3d_sin_5_5
and helix
distributions and helping with the implementation of several evaluation metrics