Diffusion Models With Learned Adaptive Noise (NeurIPS 2024, spotlight)
By Subham Sekhar Sahoo, Aaron Gokaslan, Chris De Sa, Volodymyr Kuleshov
We introduce MuLAN (MUltivariate Learned Adaptive Noise) that learns the forward noising process from the data. In this work we dispel the widely held assumption that the ELBO is invariant to the noise process. Empirically, MULAN sets a new state-of-the-art in density estimation on CIFAR-10 and ImageNet and reduces the number of training steps by 50% as summarized in the table below (Likelihood in bits per dimension):
CIFAR-10 |
ImageNet |
|
---|---|---|
PixelCNN | 3.03 | 3.83 |
Image Transformer | 2.90 | 3.77 |
DDPM |
|
/ |
ScoreFlow | 2.83 | 3.76 |
VDM |
|
|
Flow Matching | 2.99 | / |
Reflected Diffusion Models | 2.68 | 3.74 |
MuLAN (Ours) | 2.55 | 3.67 |
Note: We only compare with results achieved without data augmentation.
Install the dependencies via pip
using the following commands:
pip install -U "jax[cuda12_pip]<=0.4.23" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install -r requirements.txt
The experiments were conducted on CIFAR-10 and ImageNet32 datasets. We used the dataloader provided by tensorflow_datasets
. To maintain consistency with previous baselines, we utilized the older-version of ImageNet32, which is no longer publicly available. Therefore, we provide the dataset, which can be downloaded from this google-drive link. To use this dataset please download the tar file and extract it into the ~/tensorflow_datasets
directory. The final structure should look like the following:
~/tensorflow_datasets/downsampled_imagenet/32x32/2.0.0/downsampled_imagenet-train.tfrecord-000*-of-00032
The implementation of MuLAN can be found in ldm/model_mulan_epsilon.py. The denoising model uses noise-parameterization
, as described in suppl. 11.1.1
of the paper. The file ldm/model_mulan_velocity.py implements velocity parameterization, as detailed in suppl. 11.1.2
of the paper.
Download the checkpoints and Tensorboard logs from the Google Drive folder. Please note that the eval likelihood / BPD (bits per dimension) in the tensorboard log was computed using a partial dataset, which is why they are worse than the numbers reported in the paper. To compute the BPD accurately, use the following commands:
To compute the exact likelihood as per suppl. 15.2
use the following commands:
JAX_DEFAULT_MATMUL_PRECISION=float32 XLA_PYTHON_CLIENT_MEM_FRACTION=0.85 python -m ldm.eval_bpd --config=ldm/configs/cifar10-conditioned.py --config.vdm_type=mulan_velocity --checkpoint_directory=/share/kuleshov/ssahoo/diffusion_models/velocity_parameterization/1124188-vdm_type=mulan_velocity-topk_noise_type=gamma-ckpt_restore_dir/checkpoints-0 --checkpoint=223
JAX_DEFAULT_MATMUL_PRECISION=float32 XLA_PYTHON_CLIENT_MEM_FRACTION=0.85 python -m ldm.eval_bpd --config=ldm/configs/imagenet32.py --config.vdm_type=mulan_velocity --config.model.velocity_from_epsilon=True --checkpoint_directory=/share/kuleshov/ssahoo/diffusion_models/imagenet_mulan_epsilon/checkpoints-0 --checkpoint=220
The code for exact likelihood estimation supports multi-gpu evaluations.
To compute the likelihood using the Variance Lower Bound (VLB) as per suppl. 15.1
in the paper, use the following commands:
JAX_DEFAULT_MATMUL_PRECISION=float32 XLA_PYTHON_CLIENT_MEM_FRACTION=0.85 python -m ldm.eval_bpd --config=ldm/configs/cifar10-conditioned.py --checkpoint_directory=/path/to/checkpoints/cifar10 --checkpoint=223 --bpd_eval_method=dense --config.training.batch_size_eval=16
JAX_DEFAULT_MATMUL_PRECISION=float32 XLA_PYTHON_CLIENT_MEM_FRACTION=0.85 python -m ldm.eval_bpd --config=ldm/configs/imagenet32.py --config.vdm_type=mulan_velocity --config.model.velocity_from_epsilon=True --checkpoint_directory=/path/to/checkpoints/imagenet --checkpoint=200 --bpd_eval_method=dense --config.training.batch_size_eval=16
The code for VLB estimation doesn't support multi-gpu evaluations and hence must be run on a single gpu.
For CIFAR-10
, we trained our models on V100s
using the following slurm
commands:
sbatch -J cifar --partition=kuleshov --gres=gpu:4 run.sh -m ldm.main --mode train --config=ldm/configs/cifar10-conditioned.py --workdir /path/to/experiment_dir
For ImageNet-32
, we trained our models on A100s
using the following command:
sbatch -J img --partition=gpu --gres=gpu:4 run.sh -m ldm.main --mode train --config=ldm/configs/imagenet32.py --workdir /path/to/experiment_dir
This repository was built off of VDM.
@inproceedings{
sahoo2024diffusion,
title={Diffusion Models With Learned Adaptive Noise},
author={Subham Sekhar Sahoo and Aaron Gokaslan and Christopher De Sa and Volodymyr Kuleshov},
booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
year={2024},
url={https://openreview.net/forum?id=loMa99A4p8}
}