This repository provides precise, efficient, and extensible implementations of the popular metrics for generative model evaluation, including:
- Inception Score (ISC)
- Fréchet Inception Distance (FID)
- Kernel Inception Distance (KID)
- Perceptual Path Length (PPL)
Precision: Unlike many other reimplementations, the values produced by torch-fidelity match reference implementations up to machine precision. This allows using torch-fidelity for reporting metrics in papers instead of scattered and slow reference implementations. Read more about precision of this code.
Efficiency: Feature sharing between different metrics saves recomputation time, and an additional caching level avoids recomputing features and statistics whenever possible. High efficiency allows using torch-fidelity in the training loop, for example at the end of every epoch.
Extensibility: Going beyond 2D image generation is easy due to high modularity and abstraction of the metrics from input data, models, and feature extractors. For example, one can swap out InceptionV3 feature extractor for a one accepting 3D scan volumes, such as used in MRI.
TLDR; fast and reliable GAN evaluation in PyTorch
pip install torch-fidelity
Below are three examples of using torch-fidelity to evaluate metrics from the command line. See more examples in the documentation.
Inception Score of CIFAR-10 training split:
> fidelity --gpu 0 --isc --input1 cifar10-train
inception_score_mean: 11.23678
inception_score_std: 0.09514061
Inception Score of a directory of images stored in ~/images/
:
> fidelity --gpu 0 --isc --input1 ~/images/
Efficient computation of ISC and PPL for input1
, and FID and KID between a generative model stored in
~/generator.onnx
and CIFAR-10 training split:
> fidelity \
--gpu 0 \
--isc \
--fid \
--kid \
--ppl \
--input1 ~/generator.onnx \
--input1-model-z-type normal \
--input1-model-z-size 128 \
--input1-model-num-samples 50000 \
--input2 cifar10-train
When it comes to tracking the performance of generative models as they train, evaluating metrics after every epoch
becomes prohibitively expensive due to long computation times.
torch_fidelity
tackles this problem by making full use
of caching to avoid recomputing common features and per-metric statistics whenever possible.
Computing all metrics for 50000 32x32 generated images and cifar10-train
takes only 2 min 26 seconds on NVIDIA P100
GPU, compared to >10 min if using original codebases.
Thus, computing metrics 20 times over the whole training cycle makes overall training time just one hour longer.
In the following example, assume unconditional image generation setting with CIFAR-10, and the generative model
generator
, which takes a 128-dimensional standard normal noise vector.
First, import the module:
import torch_fidelity
Add the following lines at the end of epoch evaluation:
wrapped_generator = torch_fidelity.GenerativeModelModuleWrapper(generator, 128, 'normal', 0)
metrics_dict = torch_fidelity.calculate_metrics(
input1=wrapped_generator,
input2='cifar10-train',
cuda=True,
isc=True,
fid=True,
kid=True,
verbose=False,
)
The resulting dictionary with computed metrics can logged directly to tensorboard, wandb, or console:
print(metrics_dict)
Output:
{
'inception_score_mean': 11.23678,
'inception_score_std': 0.09514061,
'frechet_inception_distance': 18.12198,
'kernel_inception_distance_mean': 0.01369556,
'kernel_inception_distance_std': 0.001310059
}
Refer to sngan_cifar10.py for a complete training example.
Evolution of fixed generator latents in the example:
A generator checkpoint resulting from training the example can be downloaded here.
Citation is recommended to reinforce the evaluation protocol in works relying on torch-fidelity. To ensure reproducibility when citing this repository, use the following BibTeX:
@misc{obukhov2020torchfidelity,
author={Anton Obukhov and Maximilian Seitzer and Po-Wei Wu and Semen Zhydenko and Jonathan Kyl and Elvis Yu-Jing Lin},
year=2020,
title={High-fidelity performance metrics for generative models in PyTorch},
url={https://github.com/toshas/torch-fidelity},
publisher={Zenodo},
version={v0.2.0},
doi={10.5281/zenodo.3786540},
note={Version: 0.2.0, DOI: 10.5281/zenodo.3786540}
}