This is the official Python
implementation of the paper Variational Wasserstein gradient flow (paper on arXiv) by Jiaojiao Fan, Qinsheng Zhang, Amirhossein Taghvaei and Yongxin Chen.
The repository contains reproducible PyTorch
source code for computing Wasserstein gradient flow with variational estimation of target functional in high dimension.
The codebase is tested on CUDA version 11.4 and PyTorch version 1.10.1+cu113.
To reproduce the experiments except image geneation, go to toy
folder and follow the instructions in toy/README.md
cd toy
To reproduce the experiment of image geneation, go to image
folder and follow the instructions in image/README.md
cd image
@inproceedings{
fan2022variational,
title={Variational Wasserstein gradient flow},
author={Fan, Jiaojiao and Zhang, Qinsheng and Taghvaei, Amirhossein and Chen, Yongxin},
booktitle={International Conference on Machine Learning},
year={2022}
}