Skip to content

Commit

Permalink
Mutual Information Score (Lightning-AI#2008)
Browse files Browse the repository at this point in the history
Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Aug 25, 2023
1 parent 5ed0f83 commit 2780a32
Show file tree
Hide file tree
Showing 12 changed files with 552 additions and 1 deletion.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

-
- Added `MutualInformationScore` metric to cluster package ([#2008](https://github.com/Lightning-AI/torchmetrics/pull/2008)


### Changed
Expand Down
21 changes: 21 additions & 0 deletions docs/source/clustering/mutual_info_score.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
.. customcarditem::
:header: Mutual Information Score
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/clustering.svg
:tags: Clustering

.. include:: ../links.rst

########################
Mutual Information Score
########################

Module Interface
________________

.. autoclass:: torchmetrics.clustering.MutualInfoScore
:exclude-members: update, compute

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.clustering.mutual_info_score
8 changes: 8 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,14 @@ Or directly from conda

classification/*

.. toctree::
:maxdepth: 2
:name: clustering
:caption: Clustering
:glob:

clustering/*

.. toctree::
:maxdepth: 2
:name: detection
Expand Down
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -150,4 +150,5 @@
.. _CIOU: https://arxiv.org/abs/2005.03572
.. _DIOU: https://arxiv.org/abs/1911.08287v1
.. _GIOU: https://arxiv.org/abs/1902.09630
.. _Mutual Information Score: https://en.wikipedia.org/wiki/Mutual_information
.. _pycocotools: https://github.com/cocodataset/cocoapi/tree/master/PythonAPI/pycocotools
18 changes: 18 additions & 0 deletions src/torchmetrics/clustering/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.clustering.mutual_info_score import MutualInfoScore

__all__ = [
"MutualInfoScore",
]
125 changes: 125 additions & 0 deletions src/torchmetrics/clustering/mutual_info_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional, Sequence, Union

from torch import Tensor

from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score
from torchmetrics.metric import Metric
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

if not _MATPLOTLIB_AVAILABLE:
__doctest_skip__ = ["MutualInfoScore.plot"]


class MutualInfoScore(Metric):
r"""Compute `Mutual Information Score`_.
.. math::
MI(U,V) = \sum_{i=1}^{\abs{U}} \sum_{j=1}^{\abs{V}} \frac{\abs{U_i\cap V_j}}{N}
\log\frac{N\abs{U_i\cap V_j}}{\abs{U_i}\abs{V_j}}
Where :math:`U` is a tensor of target values, :math:`V` is a tensor of predictions,
:math:`\abs{U_i}` is the number of samples in cluster :math:`U_i`, and
:math:`\abs{V_i}` is the number of samples in cluster :math:`V_i`.
The metric is symmetric, therefore swapping :math:`U` and :math:`V` yields
the same mutual information score.
As input to ``forward`` and ``update`` the metric accepts the following input:
- ``preds`` (:class:`~torch.Tensor`): either single output float tensor with shape ``(N,)``
- ``target`` (:class:`~torch.Tensor`): either single output tensor with shape ``(N,)``
As output of ``forward`` and ``compute`` the metric returns the following output:
- ``mi_score`` (:class:`~torch.Tensor`): A tensor with the Mutual Information Score
Args:
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Example:
>>> import torch
>>> from torchmetrics.clustering import MutualInfoScore
>>> preds = torch.tensor([2, 1, 0, 1, 0])
>>> target = torch.tensor([0, 2, 1, 1, 0])
>>> mi_score = MutualInfoScore()
>>> mi_score(preds, target)
tensor(0.5004)
"""

is_differentiable = True
higher_is_better = None
full_state_update: bool = True
plot_lower_bound: float = 0.0
preds: List[Tensor]
target: List[Tensor]
contingency: Tensor

def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)

self.add_state("preds", default=[], dist_reduce_fx="cat")
self.add_state("target", default=[], dist_reduce_fx="cat")

def update(self, preds: Tensor, target: Tensor) -> None:
"""Update state with predictions and targets."""
self.preds.append(preds)
self.target.append(target)

def compute(self) -> Tensor:
"""Compute mutual information over state."""
return mutual_info_score(dim_zero_cat(self.preds), dim_zero_cat(self.target))

def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE:
"""Plot a single or multiple values from the metric.
Args:
val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
If no value is provided, will automatically call `metric.compute` and plot that result.
ax: An matplotlib axis object. If provided will add plot to that axis
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError:
If `matplotlib` is not installed
.. plot::
:scale: 75
>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.clustering import MutualInfoScore
>>> metric = MutualInfoScore()
>>> metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,)))
>>> fig_, ax_ = metric.plot(metric.compute())
.. plot::
:scale: 75
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.clustering import MutualInfoScore
>>> metric = MutualInfoScore()
>>> for _ in range(10):
... metric.update(torch.randint(0, 4, (10,)), torch.randint(0, 4, (10,)))
>>> fig_, ax_ = metric.plot(metric.compute())
"""
return self._plot(val, ax)
16 changes: 16 additions & 0 deletions src/torchmetrics/functional/clustering/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.functional.clustering.mutual_info_score import mutual_info_score

__all__ = ["mutual_info_score"]
79 changes: 79 additions & 0 deletions src/torchmetrics/functional/clustering/mutual_info_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import Tensor, tensor

from torchmetrics.functional.clustering.utils import calculate_contingency_matrix, check_cluster_labels


def _mutual_info_score_update(preds: Tensor, target: Tensor) -> Tensor:
"""Update and return variables required to compute the mutual information score.
Args:
preds: predicted class labels
target: ground truth class labels
Returns:
contingency: contingency matrix
"""
check_cluster_labels(preds, target)
return calculate_contingency_matrix(preds, target)


def _mutual_info_score_compute(contingency: Tensor) -> Tensor:
"""Compute the mutual information score based on the contingency matrix.
Args:
contingency: contingency matrix
Returns:
mutual_info: mutual information score
"""
n = contingency.sum()
u = contingency.sum(dim=1)
v = contingency.sum(dim=0)

# Check if preds or target labels only have one cluster
if u.size() == 1 or v.size() == 1:
return tensor(0.0)

# Find indices of nonzero values in U and V
nzu, nzv = torch.nonzero(contingency, as_tuple=True)
contingency = contingency[nzu, nzv]

# Calculate MI using entries corresponding to nonzero contingency matrix entries
log_outer = torch.log(u[nzu]) + torch.log(v[nzv])
mutual_info = contingency / n * (torch.log(n) + torch.log(contingency) - log_outer)
return mutual_info.sum()


def mutual_info_score(preds: Tensor, target: Tensor) -> Tensor:
"""Compute mutual information between two clusterings.
Args:
preds: predicted classes
target: ground truth classes
Example:
>>> from torchmetrics.functional.clustering import mutual_info_score
>>> target = torch.tensor([0, 3, 2, 2, 1])
>>> preds = torch.tensor([1, 3, 2, 0, 1])
>>> mutual_info_score(preds, target)
tensor(1.0549)
"""
contingency = _mutual_info_score_update(preds, target)
return _mutual_info_score_compute(contingency)
101 changes: 101 additions & 0 deletions src/torchmetrics/functional/clustering/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional

import torch
from torch import Tensor

from torchmetrics.utilities.checks import _check_same_shape


def calculate_contingency_matrix(
preds: Tensor, target: Tensor, eps: Optional[float] = None, sparse: bool = False
) -> Tensor:
"""Calculate contingency matrix.
Args:
preds: predicted labels
target: ground truth labels
eps: value added to contingency matrix
sparse: If True, returns contingency matrix as a sparse matrix. Else, return as dense matrix.
`eps` must be `None` if `sparse` is `True`.
Returns:
contingency: contingency matrix of shape (n_classes_target, n_classes_preds)
Example:
>>> import torch
>>> from torchmetrics.functional.clustering.utils import calculate_contingency_matrix
>>> preds = torch.tensor([2, 1, 0, 1, 0])
>>> target = torch.tensor([0, 2, 1, 1, 0])
>>> calculate_contingency_matrix(preds, target, eps=1e-16)
tensor([[1.0000e+00, 1.0000e-16, 1.0000e+00],
[1.0000e+00, 1.0000e+00, 1.0000e-16],
[1.0000e-16, 1.0000e+00, 1.0000e-16]])
"""
if eps is not None and sparse is True:
raise ValueError("Cannot specify `eps` and return sparse tensor.")
if preds.ndim != 1 or target.ndim != 1:
raise ValueError(f"Expected 1d `preds` and `target` but got {preds.ndim} and {target.dim}.")

preds_classes, preds_idx = torch.unique(preds, return_inverse=True)
target_classes, target_idx = torch.unique(target, return_inverse=True)

n_classes_preds = preds_classes.size(0)
n_classes_target = target_classes.size(0)

contingency = torch.sparse_coo_tensor(
torch.stack(
(
target_idx,
preds_idx,
)
),
torch.ones(target_idx.shape[0], dtype=preds_idx.dtype, device=preds_idx.device),
(
n_classes_target,
n_classes_preds,
),
)

if not sparse:
contingency = contingency.to_dense()
if eps:
contingency = contingency + eps

return contingency


def check_cluster_labels(preds: Tensor, target: Tensor) -> None:
"""Check shape of input tensors and if they are real, discrete tensors.
Args:
preds: predicted labels
target: ground truth labels
"""
_check_same_shape(preds, target)
if preds.ndim != 1:
raise ValueError(f"Expected arguments to be 1d tensors but got {preds.ndim} and {target.ndim}")
if (
torch.is_floating_point(preds)
or torch.is_complex(preds)
or torch.is_floating_point(target)
or torch.is_complex(target)
):
raise ValueError(
f"Expected real, discrete values but received {preds.dtype} for"
f"predictions and {target.dtype} for target labels instead."
)
Empty file.
Loading

0 comments on commit 2780a32

Please sign in to comment.