Jiaojiao Fan, Amirhossein Taghvaei, Yongxin Chen [arXiv]
@misc{jiao2021nwb,
title={Scalable Computations of Wasserstein Barycenter viaInput Convex Neural Networks},
author={Jiaojiao Fan},
year={2021}
}
- The following python packages are required:
- pytorch (>= 1.3.1)
- GPUtil
- sklearn
- searborn
- POT
- matplotlib
- jacinle
- pytorch_fid
- git large file storage has to be initialized to download the input data.
-
The scripts in the root such as
G2G_sameW_3loop.py
are the core code for our NWB (Neural Wasserstein Barycenter) implementation. For example, if you are interested in Gaussian example, run in terminalpython G2G_sameW_3loop.py --parameter_1 param1_value --parameter_2 param2_value
. -
generator_example/
contains the scripts to generate comparison results or visualization. For example, if you are interested in Gaussian example, run in terminalpython ./generator_example/NWB1_gmm.py --parameter_1 param1_value --parameter_2 param2_value
. -
optimal_transport_modules/
contains the auxiliary utility modules.
The configuration of an experiment is entirely described by a optimal_transport_modules/cfg.py
config file. If you want to change the parameter, please change them there.
This repository is still under construction. If you meet a bug when you run the code, please raise up an issue, thank you!