Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New metric: Perceptual Path Length #1939

Merged
merged 40 commits into from
Aug 9, 2023
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
0d94f75
initial implementation
SkafteNicki Jul 25, 2023
aa16025
add more to module
SkafteNicki Jul 25, 2023
e126a88
changelog
SkafteNicki Jul 25, 2023
0e78bb2
add some docstrings
SkafteNicki Jul 25, 2023
38b35f5
add to doc pages
SkafteNicki Jul 25, 2023
1e0ecc0
more docs
SkafteNicki Jul 25, 2023
54be531
improve testing
SkafteNicki Jul 25, 2023
1d78d8c
Merge branch 'master' into newmetric/ppl
SkafteNicki Jul 31, 2023
ed8afaa
fix typing issues
SkafteNicki Jul 31, 2023
753aab6
fix docs build
SkafteNicki Jul 31, 2023
6b10f84
improve generator testing
SkafteNicki Jul 31, 2023
37c4de3
compatibility with older
SkafteNicki Jul 31, 2023
e12c9ea
improve testing
SkafteNicki Aug 2, 2023
c20f0d3
docstrings + doctests
SkafteNicki Aug 2, 2023
47b4ee6
Merge branch 'master' into newmetric/ppl
SkafteNicki Aug 2, 2023
1730ac6
fix
SkafteNicki Aug 3, 2023
d214170
Merge branch 'master' into newmetric/ppl
SkafteNicki Aug 3, 2023
8b6220e
Merge branch 'master' into newmetric/ppl
SkafteNicki Aug 3, 2023
f366b55
skip on missing import
SkafteNicki Aug 3, 2023
e38b846
Merge branch 'master' into newmetric/ppl
Borda Aug 3, 2023
8c4e2d2
Merge branch 'master' into newmetric/ppl
mergify[bot] Aug 4, 2023
426b4ae
Merge branch 'master' into newmetric/ppl
Borda Aug 4, 2023
e9b7a6d
Merge branch 'master' into newmetric/ppl
SkafteNicki Aug 7, 2023
62cde3e
Merge branch 'master' into newmetric/ppl
Borda Aug 7, 2023
650165d
Merge branch 'master' into newmetric/ppl
Borda Aug 7, 2023
49081a8
Merge branch 'master' into newmetric/ppl
Borda Aug 8, 2023
4037cc1
move requirement to tests
SkafteNicki Aug 8, 2023
1fe9475
fix link
SkafteNicki Aug 8, 2023
0d0f771
add resize functionality
SkafteNicki Aug 8, 2023
65a4a00
reformat to use own implementation of lpips
SkafteNicki Aug 8, 2023
a29d1d5
add tests
SkafteNicki Aug 8, 2023
d07a27d
Merge branch 'master' into newmetric/ppl
SkafteNicki Aug 8, 2023
6912d0c
Merge branch 'master' into newmetric/ppl
Borda Aug 8, 2023
fecf65c
req.
Borda Aug 8, 2023
0139533
Merge branch 'master' into newmetric/ppl
mergify[bot] Aug 8, 2023
dbcbf04
Merge branch 'master' into newmetric/ppl
SkafteNicki Aug 9, 2023
e2808e4
fix mypy
SkafteNicki Aug 9, 2023
715a360
skip on random
SkafteNicki Aug 9, 2023
b2d79a7
device placement
SkafteNicki Aug 9, 2023
e07de45
seed
SkafteNicki Aug 9, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `top_k` argument to `RetrievalMRR` in retrieval package ([#1961](https://github.com/Lightning-AI/torchmetrics/pull/1961))


- Added `PerceptualPathLength` to image package ([#1939](https://github.com/Lightning-AI/torchmetrics/pull/1939))


### Changed

-
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,9 @@ ________________
.. autoclass:: torchmetrics.image.lpip.LearnedPerceptualImagePatchSimilarity
:noindex:
:exclude-members: update, compute

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.image.learned_perceptual_image_patch_similarity
:noindex:
23 changes: 23 additions & 0 deletions docs/source/image/perceptual_path_length.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
.. customcarditem::
:header: Perceptual Path Length
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/image_classification.svg
:tags: Image

.. include:: ../links.rst

############################
Perceptual Path Length (PPL)
############################

Module Interface
________________

.. autoclass:: torchmetrics.image.perceptual_path_length.PerceptualPathLength
:noindex:
:exclude-members: update, compute

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.image.perceptual_path_length.perceptual_path_length
:noindex:
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,4 @@
.. _Seamless Scene Segmentation paper: https://arxiv.org/abs/1905.01220
.. _Fleiss kappa: https://en.wikipedia.org/wiki/Fleiss%27_kappa
.. _VIF: https://ieeexplore.ieee.org/abstract/document/1576816
.. _PPL : https://arxiv.org/pdf/1812.04948.pdf
2 changes: 1 addition & 1 deletion requirements/image.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@

scipy >1.0.0, <1.11.0
torchvision >=0.8, <=0.15.2
torch-fidelity <=0.3.0
torch-fidelity @ git+https://github.com/toshas/torch-fidelity@master
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
lpips <=0.1.4
4 changes: 4 additions & 0 deletions src/torchmetrics/functional/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from torchmetrics.functional.image.d_lambda import spectral_distortion_index
from torchmetrics.functional.image.ergas import error_relative_global_dimensionless_synthesis
from torchmetrics.functional.image.gradients import image_gradients
from torchmetrics.functional.image.lpips import learned_perceptual_image_patch_similarity
from torchmetrics.functional.image.perceptual_path_length import perceptual_path_length
from torchmetrics.functional.image.psnr import peak_signal_noise_ratio
from torchmetrics.functional.image.psnrb import peak_signal_noise_ratio_with_blocked_effect
from torchmetrics.functional.image.rase import relative_average_spectral_error
Expand Down Expand Up @@ -41,4 +43,6 @@
"total_variation",
"universal_image_quality_index",
"visual_information_fidelity",
"learned_perceptual_image_patch_similarity",
"perceptual_path_length",
]
251 changes: 251 additions & 0 deletions src/torchmetrics/functional/image/perceptual_path_length.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Literal, Optional, Tuple, Union

import torch
from torch import Tensor, nn

from torchmetrics.utilities.imports import _TORCH_FIDELITY_AVAILABLE, _TORCH_GREATER_EQUAL_1_10

if _TORCH_FIDELITY_AVAILABLE:
from torch_fidelity.noise import batch_lerp, batch_slerp_any, batch_slerp_unit
from torch_fidelity.utils import create_sample_similarity
else:
batch_lerp = batch_slerp_any = batch_slerp_unit = None
create_sample_similarity = None
__doctest_skip__ = ["perceptual_path_length"]


class _GeneratorType(nn.Module):
@property
def num_classes(self) -> int:
raise NotImplementedError

def sample(self, num_samples: int) -> Tensor:
raise NotImplementedError


def _validate_generator_model(generator: _GeneratorType, conditional: bool = False) -> None:
"""Validate that the user provided generator has the right methods and attributes.

Args:
generator: Generator model
conditional: Whether the generator is conditional or not (i.e. whether it takes labels as input).

"""
if not hasattr(generator, "sample"):
raise NotImplementedError(
"The generator must have a `sample` method with signature `sample(num_samples: int) -> Tensor` where the"
" returned tensor has shape `(num_samples, z_size)`."
)
if not callable(generator.sample):
raise ValueError("The generator's `sample` method must be callable.")
if conditional and not hasattr(generator, "num_classes"):
raise AttributeError("The generator must have a `num_classes` attribute when `conditional=True`.")
if conditional and not isinstance(generator.num_classes, int):
raise ValueError("The generator's `num_classes` attribute must be an integer when `conditional=True`.")


def _perceptual_path_length_validate_arguments(
num_samples: int = 10_000,
conditional: bool = False,
batch_size: int = 128,
interpolation_method: Literal["lerp", "slerp_any", "slerp_unit"] = "lerp",
epsilon: float = 1e-4,
resize: Optional[int] = 64,
lower_discard: Optional[float] = 0.01,
upper_discard: Optional[float] = 0.99,
) -> None:
"""Validate arguments for perceptual path length."""
if not (isinstance(num_samples, int) and num_samples > 0):
raise ValueError(f"Argument `num_samples` must be a positive integer, but got {num_samples}.")
if not isinstance(conditional, bool):
raise ValueError(f"Argument `conditional` must be a boolean, but got {conditional}.")
if not (isinstance(batch_size, int) and batch_size > 0):
raise ValueError(f"Argument `batch_size` must be a positive integer, but got {batch_size}.")
if interpolation_method not in ["lerp", "slerp_any", "slerp_unit"]:
raise ValueError(
f"Argument `interpolation_method` must be one of 'lerp', 'slerp_any', 'slerp_unit',"
f"got {interpolation_method}."
)
if not (isinstance(epsilon, float) and epsilon > 0):
raise ValueError(f"Argument `epsilon` must be a positive float, but got {epsilon}.")
if resize is not None and not (isinstance(resize, int) and resize > 0):
raise ValueError(f"Argument `resize` must be a positive integer or `None`, but got {resize}.")
if lower_discard is not None and not (isinstance(lower_discard, float) and 0 <= lower_discard <= 1):
raise ValueError(
f"Argument `lower_discard` must be a float between 0 and 1 or `None`, but got {lower_discard}."
)
if upper_discard is not None and not (isinstance(upper_discard, float) and 0 <= upper_discard <= 1):
raise ValueError(
f"Argument `upper_discard` must be a float between 0 and 1 or `None`, but got {upper_discard}."
)


def _interpolate(
latents1: Tensor,
latents2: Tensor,
epsilon: float = 1e-4,
interpolation_method: Literal["lerp", "slerp_any", "slerp_unit"] = "lerp",
) -> Tensor:
"""Interpolate between two sets of latents.

Args:
latents1: First set of latents.
latents2: Second set of latents.
epsilon: Spacing between the points on the path between latent points.
interpolation_method: Interpolation method to use. Choose from 'lerp', 'slerp_any', 'slerp_unit'.

"""
if latents1.shape != latents2.shape:
raise ValueError("Latents must have the same shape.")
if interpolation_method == "lerp":
return batch_lerp(latents1, latents2, epsilon)
if interpolation_method == "slerp_any":
return batch_slerp_unit(latents1, latents2, epsilon)
if interpolation_method == "slerp_unit":
return batch_slerp_any(latents1, latents2, epsilon)
raise ValueError(
f"Interpolation method {interpolation_method} not supported. Choose from 'lerp', 'slerp_any', 'slerp_unit'."
)


def perceptual_path_length(
generator: _GeneratorType,
num_samples: int = 10_000,
conditional: bool = False,
batch_size: int = 64,
interpolation_method: Literal["lerp", "slerp_any", "slerp_unit"] = "lerp",
epsilon: float = 1e-4,
resize: Optional[int] = 64,
lower_discard: Optional[float] = 0.01,
upper_discard: Optional[float] = 0.99,
sim_net: Optional[nn.Module] = None,
device: Union[str, torch.device] = "cpu",
) -> Tuple[Tensor, Tensor, Tensor]:
r"""Computes the perceptual path length (`PPL`_) of a generator model.

The perceptual path length can be used to measure the consistency of interpolation in latent-space models. It is
defined as

.. math::
PPL = \mathbb{E}\left[\frac{1}{\epsilon^2} D(G(I(z_1, z_2, t)), G(I(z_1, z_2, t+\epsilon)))\right]

where :math:`G` is the generator, :math:`I` is the interpolation function, :math:`D` is a similarity metric,
:math:`z_1` and :math:`z_2` are two sets of latent points, and :math:`t` is a parameter between 0 and 1. The metric
thus works by interpolating between two sets of latent points, and measuring the similarity between the generated
images. The expectation is approximated by sampling :math:`z_1` and :math:`z_2` from the generator, and averaging
the calculated distanced. The similarity metric :math:`D` is by default the `LPIPS`_ metric, but can be changed by
setting the `sim_net` argument.

The provided generator model must have a `sample` method with signature `sample(num_samples: int) -> Tensor` where
the returned tensor has shape `(num_samples, z_size)`. If the generator is conditional, it must also have a
`num_classes` attribute.

Args:
generator: Generator model, with specific requirements. See above.
num_samples: Number of samples to use for the PPL computation.
conditional: Whether the generator is conditional or not (i.e. whether it takes labels as input).
batch_size: Batch size to use for the PPL computation.
interpolation_method: Interpolation method to use. Choose from 'lerp', 'slerp_any', 'slerp_unit'.
epsilon: Spacing between the points on the path between latent points.
resize: Resize images to this size before computing the similarity between generated images.
lower_discard: Lower quantile to discard from the distances, before computing the mean and standard deviation.
upper_discard: Upper quantile to discard from the distances, before computing the mean and standard deviation.
sim_net: Similarity network to use. If `None`, a default network is used.
device: Device to use for the computation.

Returns:
A tuple containing the mean, standard deviation and all distances.

Example::
>>> from torchmetrics.functional.image import perceptual_path_length
>>> import torch
>>> _ = torch.manual_seed(42)
>>> class DummyGenerator(torch.nn.Module):
... def __init__(self, z_size) -> None:
... super().__init__()
... self.z_size = z_size
... self.model = torch.nn.Linear(z_size, 3*128*128)
... def forward(self, z):
... return self.model(z).reshape(-1, 3, 128, 128)
... def sample(self, num_samples):
... return torch.randn(num_samples, self.z_size)
>>> generator = DummyGenerator(2)
>>> perceptual_path_length(generator, num_samples=10) # doctest: +NORMALIZE_WHITESPACE
(tensor(0.0756),
tensor(0.0678),
tensor([0.0489, 0.1433, 0.1778, 0.1632, 0.0255, 0.0511, 0.0024, 0.0613, 0.0071]))

"""
if not _TORCH_FIDELITY_AVAILABLE:
raise ModuleNotFoundError(
"Metric `perceptual_path_length` requires Torch Fidelity which is not installed."
"Install with `pip install torch-fidelity` or `pip install torchmetrics[image]`"
)
_perceptual_path_length_validate_arguments(
num_samples, conditional, batch_size, interpolation_method, epsilon, resize, lower_discard, upper_discard
)
_validate_generator_model(generator, conditional)
generator = generator.to(device)

latent1 = generator.sample(num_samples).to(device)
latent2 = generator.sample(num_samples).to(device)
latent2 = _interpolate(latent1, latent2, epsilon, interpolation_method=interpolation_method)

if conditional:
labels = torch.randint(0, generator.num_classes, (num_samples,)).to(device)

if sim_net is None:
sim_net = create_sample_similarity(
"lpips-vgg16",
sample_similarity_resize=resize,
cuda=device == "cuda",
verbose=False,
)

decorator = torch.inference_mode if _TORCH_GREATER_EQUAL_1_10 else torch.no_grad
with decorator():
distances = []
num_batches = math.ceil(num_samples / batch_size)
for batch_idx in range(num_batches):
batch_latent1 = latent1[batch_idx * batch_size : (batch_idx + 1) * batch_size].to(device)
batch_latent2 = latent2[batch_idx * batch_size : (batch_idx + 1) * batch_size].to(device)

if conditional:
batch_labels = labels[batch_idx * batch_size : (batch_idx + 1) * batch_size].to(device)
outputs = generator(
torch.cat((batch_latent1, batch_latent2), dim=0), torch.cat((batch_labels, batch_labels), dim=0)
)
else:
outputs = generator(torch.cat((batch_latent1, batch_latent2), dim=0))

out1, out2 = outputs.chunk(2, dim=0)

similarity = sim_net(out1, out2)
dist = similarity / epsilon**2
distances.append(dist.detach())

distances = torch.cat(distances)

lower = torch.quantile(distances, lower_discard, interpolation="lower") if lower_discard is not None else 0.0
upper = (
torch.quantile(distances, upper_discard, interpolation="lower")
if upper_discard is not None
else max(distances)
)
distances = distances[(distances >= lower) & (distances <= upper)]

return distances.mean(), distances.std(), distances
2 changes: 2 additions & 0 deletions src/torchmetrics/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,13 @@
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
from torchmetrics.image.kid import KernelInceptionDistance
from torchmetrics.image.perceptual_path_length import PerceptualPathLength

__all__ += [
"FrechetInceptionDistance",
"InceptionScore",
"KernelInceptionDistance",
"PerceptualPathLength",
]

if _LPIPS_AVAILABLE:
Expand Down
Loading