Skip to content

Learning Diffusion Priors from Observations by Expectation Maximization

Notifications You must be signed in to change notification settings

francois-rozet/diffusion-priors

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

80 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Learning Diffusion Priors from Observations by Expectation Maximization

This repository contains the official implementation of the paper Learning Diffusion Priors from Observations by Expectation Maximization by François Rozet, Gérôme Andry, François Lanusse and Gilles Louppe.

Abstract

Diffusion models recently proved to be remarkable priors for Bayesian inverse problems. However, training these models typically requires access to large amounts of clean data, which could prove difficult in some settings. In this work, we present a novel method based on the expectation-maximization algorithm for training diffusion models from incomplete and noisy observations only. Unlike previous works, our method leads to proper diffusion models, which is crucial for downstream tasks. As part of our method, we propose and motivate a new posterior sampling scheme for unconditional diffusion models. We present empirical evidence supporting the effectiveness of our method.

Code

The majority of the code is written in Python. Neural networks are built and trained using the JAX automatic differentiation framework and the Inox library. We also rely on torch-fidelity to compute image quality metrics and POT to compute Sinkhorn divergences. All dependencies are provided as a conda environment file.

conda env create -f environment.yml
conda activate priors

To run the experiments, it is necessary to have access to a Slurm cluster, to login to a Weights & Biases account and to install the priors module as a package.

pip install -e .

Organization

The priors directory contains the implementations of the neural networks, the diffusion models and various helpers.

The manifold, cifar and fastmri directories contain the scripts for the experiments (data generation, training and evaluation) as well as the notebooks that produced the figures of the paper. For example,

~/experiments/cifar $ python data.py

launches the CIFAR-10 data processing/generation jobs and

~/experiments/cifar $ python train.py

launches the training jobs with the default config (specified within the file).

For the accelerated MRI experiment, the knee_singlecoil_{train|val}.tar.xz archives should be obtained from the NYU fastMRI Initiative database.

The code for the evaluation of MMPS (in Appendix E) is provided in the separate mmps-benchmark repository.

Citation

If you find this project useful for your research, please consider citing

@inproceedings{rozet2024learning,
  title={Learning Diffusion Priors from Observations by Expectation Maximization},
  author={Fran{\c{c}}ois Rozet and G{\'e}r{\^o}me Andry and Fran{\c{c}}ois Lanusse and Gilles Louppe},
  booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
  year={2024},
  url={https://openreview.net/forum?id=7v88Fh6iSM},
}

Acknowledgements

Compute

The present research benefited from computational resources made available on Lucia, the Tier-1 supercomputer of the Walloon Region, infrastructure funded by the Walloon Region under the grant n°1910247. The computational resources have been provided by the Consortium des Équipements de Calcul Intensif (CÉCI), funded by the Fonds de la Recherche Scientifique de Belgique (F.R.S.-FNRS) under the grant n°2.5020.11 and by the Walloon Region.

NYU fastMRI

MRI data used in the preparation of this article were obtained from the NYU fastMRI Initiative database. As such, NYU fastMRI investigators provided data but did not participate in analysis or writing of this report. A listing of NYU fastMRI investigators, subject to updates, can be found at https://fastmri.med.nyu.edu. The primary goal of fastMRI is to test whether machine learning can aid in the reconstruction of medical images.

@misc{Zbontar2018fastMRI,
  author = {Zbontar, Jure and Knoll, Florian and Sriram, Anuroop and Murrell, Tullie and Huang, Zhengnan and Muckley, Matthew J. and Defazio, Aaron and Stern, Ruben and Johnson, Patricia and Bruno, Mary and Parente, Marc and Geras, Krzysztof J. and Katsnelson, Joe and Chandarana, Hersh and Zhang, Zizhao and Drozdzal, Michal and Romero, Adriana and Rabbat, Michael and Vincent, Pascal and Yakubova, Nafissa and Pinkerton, James and Wang, Duo and Owens, Erich and Zitnick, C. Lawrence and Recht, Michael P. and Sodickson, Daniel K. and Lui, Yvonne W.},
  title  = {fastMRI: An Open Dataset and Benchmarks for Accelerated MRI},
  year   = {2018},
  url    = {http://arxiv.org/abs/1811.08839},
}

@article{Knoll2020fastMRI,
  author  = {Knoll, Florian and Zbontar, Jure and Sriram, Anuroop and Muckley, Matthew J. and Bruno, Mary and Defazio, Aaron and Parente, Marc and Geras, Krzysztof J. and Katsnelson, Joe and Chandarana, Hersh and Zhang, Zizhao and Drozdzalv, Michal and Romero, Adriana and Rabbat, Michael and Vincent, Pascal and Pinkerton, James and Wang, Duo and Yakubova, Nafissa and Owens, Erich and Zitnick, C. Lawrence and Recht, Michael P. and Sodickson, Daniel K. and Lui, Yvonne W.},
  journal = {Radiology: Artificial Intelligence},
  title   = {fastMRI: A Publicly Available Raw k-Space and DICOM Dataset of Knee Images for Accelerated MR Image Reconstruction Using Machine Learning},
  year    = {2020},
  doi     = {10.1148/ryai.2020190007},
}