diff --git a/CHANGELOG.md b/CHANGELOG.md index de726f75ba1..6c18db1b231 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `NormalizedRootMeanSquaredError` metric to regression subpackage ([#2442](https://github.com/Lightning-AI/torchmetrics/pull/2442)) +- Added `LogAUC` metric to classification package ([#2377](https://github.com/Lightning-AI/torchmetrics/pull/2377)) + + - Added `NegativePredictiveValue` to classification metrics ([#2433](https://github.com/Lightning-AI/torchmetrics/pull/2433)) diff --git a/docs/source/classification/logauc.rst b/docs/source/classification/logauc.rst new file mode 100644 index 00000000000..a213d0177cb --- /dev/null +++ b/docs/source/classification/logauc.rst @@ -0,0 +1,55 @@ +.. customcarditem:: + :header: Log Area Receiver Operating Characteristic (LogAUC) + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg + :tags: Classification + +.. include:: ../links.rst + +####### +Log AUC +####### + +Module Interface +________________ + +.. autoclass:: torchmetrics.LogAUC + :exclude-members: update, compute + :special-members: __new__ + +BinaryLogAUC +^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.BinaryLogAUC + :exclude-members: update, compute + +MulticlassLogAUC +^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.MulticlassLogAUC + :exclude-members: update, compute + +MultilabelLogAUC +^^^^^^^^^^^^^^^^ + +.. autoclass:: torchmetrics.classification.MultilabelLogAUC + :exclude-members: update, compute + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.logauc + +binary_logauc +^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.binary_logauc + +multiclass_logauc +^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.multiclass_logauc + +multilabel_logauc +^^^^^^^^^^^^^^^^^ + +.. autofunction:: torchmetrics.functional.classification.multilabel_logauc diff --git a/docs/source/links.rst b/docs/source/links.rst index 4f2cbe6ad53..73804ce7ea1 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -177,4 +177,5 @@ .. _Hausdorff Distance: https://en.wikipedia.org/wiki/Hausdorff_distance .. _averaging curve objects: https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html .. _Procrustes Disparity: https://en.wikipedia.org/wiki/Procrustes_analysis +.. _Log AUC: https://pubmed.ncbi.nlm.nih.gov/20735049/ .. _Negative Predictive Value: https://en.wikipedia.org/wiki/Positive_and_negative_predictive_values diff --git a/requirements/classification_test.txt b/requirements/classification_test.txt index 6417cf66360..468a5effe8d 100644 --- a/requirements/classification_test.txt +++ b/requirements/classification_test.txt @@ -5,3 +5,4 @@ pandas >1.4.0, <=2.2.3 netcal >1.0.0, <1.4.0 # calibration_error numpy <2.2.0 fairlearn # group_fairness +PyTDC ==0.4.1 ; python_version <"3.12" # locauc, temporal_dependency diff --git a/src/torchmetrics/__init__.py b/src/torchmetrics/__init__.py index fb822787cc2..4ff997183f6 100644 --- a/src/torchmetrics/__init__.py +++ b/src/torchmetrics/__init__.py @@ -68,6 +68,7 @@ HammingDistance, HingeLoss, JaccardIndex, + LogAUC, MatthewsCorrCoef, NegativePredictiveValue, Precision, @@ -196,6 +197,7 @@ "JaccardIndex", "KLDivergence", "KendallRankCorrCoef", + "LogAUC", "LogCoshError", "MatchErrorRate", "MatthewsCorrCoef", diff --git a/src/torchmetrics/classification/__init__.py b/src/torchmetrics/classification/__init__.py index 86f334c970d..bbc5321bf7a 100644 --- a/src/torchmetrics/classification/__init__.py +++ b/src/torchmetrics/classification/__init__.py @@ -57,6 +57,7 @@ MulticlassJaccardIndex, MultilabelJaccardIndex, ) +from torchmetrics.classification.logauc import BinaryLogAUC, LogAUC, MulticlassLogAUC, MultilabelLogAUC from torchmetrics.classification.matthews_corrcoef import ( BinaryMatthewsCorrCoef, MatthewsCorrCoef, @@ -223,6 +224,10 @@ "MulticlassSensitivityAtSpecificity", "MultilabelSensitivityAtSpecificity", "SensitivityAtSpecificity", + "BinaryLogAUC", + "LogAUC", + "MulticlassLogAUC", + "MultilabelLogAUC", "BinaryNegativePredictiveValue", "MulticlassNegativePredictiveValue", "MultilabelNegativePredictiveValue", diff --git a/src/torchmetrics/classification/logauc.py b/src/torchmetrics/classification/logauc.py new file mode 100644 index 00000000000..6ccf7fd5d77 --- /dev/null +++ b/src/torchmetrics/classification/logauc.py @@ -0,0 +1,507 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, List, Optional, Sequence, Tuple, Type, Union + +from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.classification.base import _ClassificationTaskWrapper +from torchmetrics.classification.roc import BinaryROC, MulticlassROC, MultilabelROC +from torchmetrics.functional.classification.logauc import ( + _binary_logauc_compute, + _reduce_logauc, + _validate_fpr_range, +) +from torchmetrics.metric import Metric +from torchmetrics.utilities.enums import ClassificationTask +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["BinaryLogAUC.plot", "MulticlassLogAUC.plot", "MultilabelLogAUC.plot"] + + +class BinaryLogAUC(BinaryROC): + r"""Compute the `Log AUC`_ score for binary classification tasks. + + The score is computed by first computing the ROC curve, which then is interpolated to the specified range of false + positive rates (FPR) and then the log is taken of the FPR before the area under the curve (AUC) is computed. The + score is commonly used in applications where the positive and negative are imbalanced and a low false positive rate + is of high importance. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, ...)`` containing probabilities or logits for + each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)`` containing ground truth labels, and + therefore only contain {0,1} values (except if `ignore_index` is specified). The value 1 always encodes the + positive class. + + As output to ``forward`` and ``compute`` the metric returns the following output: + + - ``logauc`` (:class:`~torch.Tensor`): A single scalar with the logauc score. + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds})` (constant memory). + + Args: + fpr_range: 2-element tuple with the lower and upper bound of the false positive rate range to compute the log + AUC score. + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example: + >>> from torch import tensor + >>> from torchmetrics.classification import BinaryLogAUC + >>> preds = tensor([0.75, 0.05, 0.05, 0.05, 0.05]) + >>> target = tensor([1, 0, 0, 0, 0]) + >>> metric = BinaryLogAUC() + >>> metric(preds, target) + tensor(1.) + + """ + + is_differentiable: bool = False + higher_is_better: bool = True + full_state_update: bool = False + plot_lower_bound: float = 0.0 + plot_upper_bound: float = 1.0 + + def __init__( + self, + fpr_range: Tuple[float, float] = (0.001, 0.1), + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = False, + **kwargs: Any, + ) -> None: + super().__init__(thresholds=thresholds, ignore_index=ignore_index, validate_args=validate_args, **kwargs) + if validate_args: + _validate_fpr_range(fpr_range) + self.fpr_range = fpr_range + + def compute(self) -> Tensor: # type: ignore[override] + """Computes the log AUC score.""" + fpr, tpr, _ = super().compute() + return _binary_logauc_compute(fpr, tpr, fpr_range=self.fpr_range) + + def plot( # type: ignore[override] + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single + >>> import torch + >>> from torchmetrics.classification import BinaryLogAUC + >>> metric = BinaryLogAUC() + >>> metric.update(torch.rand(20,), torch.randint(2, (20,))) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.classification import BinaryLogAUC + >>> metric = BinaryLogAUC() + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(torch.rand(20,), torch.randint(2, (20,)))) + >>> fig_, ax_ = metric.plot(values) + + """ + return self._plot(val, ax) + + +class MulticlassLogAUC(MulticlassROC): + r"""Compute the `Log AUC`_ score for multiclass classification tasks. + + The score is computed by first computing the ROC curve, which then is interpolated to the specified range of false + positive rates (FPR) and then the log is taken of the FPR before the area under the curve (AUC) is computed. The + score is commonly used in applications where the positive and negative are imbalanced and a low false positive rate + is of high importance. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, C, ...)`` containing probabilities or logits + for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto + apply softmax per sample. + - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)`` containing ground truth labels, and + therefore only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified). + + As output to ``forward`` and ``compute`` the metric returns the following output: + + - ``logauc`` (:class:`~torch.Tensor`): If `average=None|"none"` then a 1d tensor of shape (n_classes, ) will + be returned with logauc score per class. If `average="macro"` then a single scalar is returned. + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds})` (constant memory). + + Args: + num_classes: Integer specifying the number of classes + fpr_range: 2-element tuple with the lower and upper bound of the false positive rate range to compute the log + AUC score. + average: + Defines the reduction that is applied over classes. Should be one of the following: + + - ``"macro"``: Calculate score for each class and average them + - ``"weighted"``: calculates score for each class and computes weighted average using their support + - ``"none"`` or ``None``: calculates score for each class and applies no reduction + + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example: + >>> from torch import tensor + >>> from torchmetrics.classification import MulticlassLogAUC + >>> preds = tensor([[0.75, 0.05, 0.05, 0.05, 0.05], + ... [0.05, 0.75, 0.05, 0.05, 0.05], + ... [0.05, 0.05, 0.75, 0.05, 0.05], + ... [0.05, 0.05, 0.05, 0.75, 0.05]]) + >>> target = tensor([0, 1, 3, 2]) + >>> metric = MulticlassLogAUC(num_classes=5, average="macro", thresholds=None) + >>> metric(preds, target) + tensor(0.4000) + >>> metric = MulticlassLogAUC(num_classes=5, average=None, thresholds=None) + >>> metric(preds, target) + tensor([1., 1., 0., 0., 0.]) + + """ + + is_differentiable: bool = False + higher_is_better: bool = True + full_state_update: bool = False + plot_lower_bound: float = 0.0 + plot_upper_bound: float = 1.0 + plot_legend_name: str = "Class" + + def __init__( + self, + num_classes: int, + fpr_range: Tuple[float, float] = (0.001, 0.1), + average: Optional[Literal["macro", "none"]] = None, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + super().__init__( + num_classes=num_classes, + thresholds=thresholds, + average=None, + ignore_index=ignore_index, + validate_args=validate_args, + **kwargs, + ) + if validate_args: + _validate_fpr_range(fpr_range) + self.fpr_range = fpr_range + self.average2 = average # self.average is already used by parent class + + def compute(self) -> Tensor: # type: ignore[override] + """Computes the log AUC score.""" + fpr, tpr, _ = super().compute() + return _reduce_logauc(fpr, tpr, fpr_range=self.fpr_range, average=self.average2) + + def plot( # type: ignore[override] + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single + >>> import torch + >>> from torchmetrics.classification import MulticlassLogAUC + >>> metric = MulticlassLogAUC(num_classes=3) + >>> metric.update(torch.randn(20, 3), torch.randint(3,(20,))) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.classification import MulticlassLogAUC + >>> metric = MulticlassLogAUC(num_classes=3) + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(torch.randn(20, 3), torch.randint(3, (20,)))) + >>> fig_, ax_ = metric.plot(values) + + """ + return self._plot(val, ax) + + +class MultilabelLogAUC(MultilabelROC): + r"""Compute the `Log AUC`_ score for multiclass classification tasks. + + The score is computed by first computing the ROC curve, which then is interpolated to the specified range of false + positive rates (FPR) and then the log is taken of the FPR before the area under the curve (AUC) is computed. The + score is commonly used in applications where the positive and negative are imbalanced and a low false positive rate + is of high importance. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): A float tensor of shape ``(N, C, ...)`` containing probabilities or logits + for each observation. If preds has values outside [0,1] range we consider the input to be logits and will auto + apply sigmoid per element. + - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, C, ...)`` containing ground truth labels, and + therefore only contain {0,1} values (except if `ignore_index` is specified). + + As output to ``forward`` and ``compute`` the metric returns the following output: + + - ``logauc`` (:class:`~torch.Tensor`): If `average=None|"none"` then a 1d tensor of shape (num_labels, ) will + be returned with logauc score per class. If `average="macro"` then a single scalar is returned. + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{labels})` (constant memory). + + Args: + num_labels: Integer specifying the number of labels + fpr_range: 2-element tuple with the lower and upper bound of the false positive rate range to compute the log + AUC score. + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``"macro"``: Calculate the score for each label and average them + - ``"none"`` or ``None``: calculates score for each label and applies no reduction + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example: + >>> from torch import tensor + >>> from torchmetrics.classification import MultilabelLogAUC + >>> preds = tensor([[0.75, 0.05, 0.35], + ... [0.45, 0.75, 0.05], + ... [0.05, 0.55, 0.75], + ... [0.05, 0.65, 0.05]]) + >>> target = tensor([[1, 0, 1], + ... [0, 0, 0], + ... [0, 1, 1], + ... [1, 1, 1]]) + >>> metric = MultilabelLogAUC(num_labels=3, average="macro", thresholds=None) + >>> metric(preds, target) + tensor(0.3945) + >>> metric = MultilabelLogAUC(num_labels=3, average=None, thresholds=None) + >>> metric(preds, target) + tensor([0.5000, 0.0000, 0.6835]) + + """ + + is_differentiable: bool = False + higher_is_better: bool = True + full_state_update: bool = False + plot_lower_bound: float = 0.0 + plot_upper_bound: float = 1.0 + plot_legend_name: str = "Label" + + def __init__( + self, + num_labels: int, + fpr_range: Tuple[float, float] = (0.001, 0.1), + average: Optional[Literal["macro", "none"]] = None, + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> None: + if validate_args: + _validate_fpr_range(fpr_range) + self.fpr_range = fpr_range + self.average2 = average # self.average is already used by parent class + super().__init__( + num_labels=num_labels, + thresholds=thresholds, + ignore_index=ignore_index, + validate_args=validate_args, + **kwargs, + ) + + def compute(self) -> Tensor: # type: ignore[override] + """Computes the log AUC score.""" + fpr, tpr, _ = super().compute() + return _reduce_logauc(fpr, tpr, fpr_range=self.fpr_range, average=self.average2) + + def plot( # type: ignore[override] + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single + >>> import torch + >>> from torchmetrics.classification import MultilabelLogAUC + >>> metric = MultilabelLogAUC(num_labels=3) + >>> metric.update(torch.rand(20,3), torch.randint(2, (20,3))) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.classification import MultilabelLogAUC + >>> metric = MultilabelLogAUC(num_labels=3) + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(torch.rand(20,3), torch.randint(2, (20,3)))) + >>> fig_, ax_ = metric.plot(values) + + """ + return self._plot(val, ax) + + +class LogAUC(_ClassificationTaskWrapper): + r"""Compute the `Log AUC`_ score for multiclass classification tasks. + + The score is computed by first computing the ROC curve, which then is interpolated to the specified range of false + positive rates (FPR) and then the log is taken of the FPR before the area under the curve (AUC) is computed. The + score is commonly used in applications where the positive and negative are imbalanced and a low false positive rate + is of high importance. + + This module is a simple wrapper to get the task specific versions of this metric, which is done by setting the + ``task`` argument to either ``'binary'``, ``'multiclass'`` or ``multilabel``. See the documentation of + :class:`~torchmetrics.classification.BinaryLogAUC`, :class:`~torchmetrics.classification.MulticlassLogAUC` and + :class:`~torchmetrics.classification.MultilabelLogAUC` for the specific details of each argument influence and + examples. + + """ + + def __new__( # type: ignore[misc] + cls: Type["LogAUC"], + task: Literal["binary", "multiclass", "multilabel"], + thresholds: Optional[Union[int, List[float], Tensor]] = None, + fpr_range: Optional[Tuple[float, float]] = (0.001, 0.1), + num_classes: Optional[int] = None, + num_labels: Optional[int] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, + **kwargs: Any, + ) -> Metric: + """Initialize task metric.""" + task = ClassificationTask.from_str(task) + kwargs.update({ + "thresholds": thresholds, + "fpr_range": fpr_range, + "ignore_index": ignore_index, + "validate_args": validate_args, + }) + if task == ClassificationTask.BINARY: + return BinaryLogAUC(**kwargs) + if task == ClassificationTask.MULTICLASS: + if not isinstance(num_classes, int): + raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") + return MulticlassLogAUC(num_classes, **kwargs) + if task == ClassificationTask.MULTILABEL: + if not isinstance(num_labels, int): + raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") + return MultilabelLogAUC(num_labels, **kwargs) + raise ValueError(f"Task {task} not supported!") diff --git a/src/torchmetrics/functional/__init__.py b/src/torchmetrics/functional/__init__.py index f76d907c907..a4f175ce02d 100644 --- a/src/torchmetrics/functional/__init__.py +++ b/src/torchmetrics/functional/__init__.py @@ -35,6 +35,7 @@ hamming_distance, hinge_loss, jaccard_index, + logauc, matthews_corrcoef, negative_predictive_value, precision, @@ -169,6 +170,7 @@ "jaccard_index", "kendall_rank_corrcoef", "kl_divergence", + "logauc", "log_cosh_error", "match_error_rate", "matthews_corrcoef", diff --git a/src/torchmetrics/functional/classification/__init__.py b/src/torchmetrics/functional/classification/__init__.py index 73ff9fcc1ea..925f977e419 100644 --- a/src/torchmetrics/functional/classification/__init__.py +++ b/src/torchmetrics/functional/classification/__init__.py @@ -71,6 +71,7 @@ multiclass_jaccard_index, multilabel_jaccard_index, ) +from torchmetrics.functional.classification.logauc import binary_logauc, logauc, multiclass_logauc, multilabel_logauc from torchmetrics.functional.classification.matthews_corrcoef import ( binary_matthews_corrcoef, matthews_corrcoef, @@ -240,6 +241,10 @@ "demographic_parity", "equal_opportunity", "precision_at_fixed_recall", + "binary_logauc", + "multiclass_logauc", + "multilabel_logauc", + "logauc", "binary_negative_predictive_value", "multiclass_negative_predictive_value", "multilabel_negative_predictive_value", diff --git a/src/torchmetrics/functional/classification/logauc.py b/src/torchmetrics/functional/classification/logauc.py new file mode 100644 index 00000000000..5cb1f90e4cf --- /dev/null +++ b/src/torchmetrics/functional/classification/logauc.py @@ -0,0 +1,356 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional, Tuple, Union + +import torch +from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.functional.classification.roc import binary_roc, multiclass_roc, multilabel_roc +from torchmetrics.utilities import rank_zero_warn +from torchmetrics.utilities.compute import _auc_compute_without_check, _safe_divide +from torchmetrics.utilities.data import interp +from torchmetrics.utilities.enums import ClassificationTask + + +def _validate_fpr_range(fpr_range: Tuple[float, float]) -> None: + """Validate the `fpr_range` argument for the logauc metric.""" + if not isinstance(fpr_range, tuple) and not len(fpr_range) == 2: + raise ValueError(f"The `fpr_range` should be a tuple of two floats, but got {type(fpr_range)}.") + if not (0 <= fpr_range[0] < fpr_range[1] <= 1): + raise ValueError(f"The `fpr_range` should be a tuple of two floats in the range [0, 1], but got {fpr_range}.") + + +def _binary_logauc_compute( + fpr: Tensor, + tpr: Tensor, + fpr_range: Tuple[float, float] = (0.001, 0.1), +) -> Tensor: + """Compute the logauc score for binary classification tasks.""" + fpr_range = torch.tensor(fpr_range).to(fpr.device) + if fpr.numel() < 2 or tpr.numel() < 2: + rank_zero_warn( + "At least two values on for the fpr and tpr are required to compute the log AUC. Returns 0 score." + ) + return torch.tensor(0.0, device=fpr.device) + + tpr = torch.cat([tpr, interp(fpr_range, fpr, tpr)]).sort().values + fpr = torch.cat([fpr, fpr_range]).sort().values + + log_fpr = torch.log10(fpr) + bounds = torch.log10(torch.tensor(fpr_range)) + + lower_bound_idx = torch.where(log_fpr == bounds[0])[0][-1] + upper_bound_idx = torch.where(log_fpr == bounds[1])[0][-1] + + trimmed_log_fpr = log_fpr[lower_bound_idx : upper_bound_idx + 1] + trimmed_tpr = tpr[lower_bound_idx : upper_bound_idx + 1] + + # compute area and rescale it to the range of fpr + return _auc_compute_without_check(trimmed_log_fpr, trimmed_tpr, 1.0) / (bounds[1] - bounds[0]) + + +def _reduce_logauc( + fpr: Union[Tensor, List[Tensor]], + tpr: Union[Tensor, List[Tensor]], + fpr_range: Tuple[float, float] = (0.001, 0.1), + average: Optional[Literal["macro", "weighted", "none"]] = "macro", + weights: Optional[Tensor] = None, +) -> Tensor: + """Reduce the logauc score to a single value for multiclass and multilabel classification tasks.""" + scores = [] + for fpr_i, tpr_i in zip(fpr, tpr): + scores.append(_binary_logauc_compute(fpr_i, tpr_i, fpr_range)) + scores = torch.stack(scores) + if torch.isnan(scores).any(): + rank_zero_warn( + "LogAUC score for one or more classes/labels was `nan`. Ignoring these classes in {average}-average." + ) + idx = ~torch.isnan(scores) + if average is None or average == "none": + return scores + if average == "macro": + return scores[idx].mean() + if average == "weighted" and weights is not None: + weights = _safe_divide(weights[idx], weights[idx].sum()) + return (scores[idx] * weights).sum() + raise ValueError(f"Got unknown average parameter: {average}. Please choose one of ['macro', 'weighted', 'none'].") + + +def binary_logauc( + preds: Tensor, + target: Tensor, + fpr_range: Tuple[float, float] = (0.001, 0.1), + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r"""Compute the `Log AUC`_ score for binary classification tasks. + + The score is computed by first computing the ROC curve, which then is interpolated to the specified range of false + positive rates (FPR) and then the log is taken of the FPR before the area under the curve (AUC) is computed. The + score is commonly used in applications where the positive and negative are imbalanced and a low false positive rate + is of high importance. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). The value 1 always encodes the positive class. + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds})` (constant memory). + + Args: + preds: Tensor with predictions + target: Tensor with ground truth labels + fpr_range: 2-element tuple with the lower and upper bound of the false positive rate range to compute the log + AUC score. + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Returns: + A single scalar with the log auc score + + Example: + >>> from torchmetrics.functional.classification import binary_logauc + >>> from torch import tensor + >>> preds = tensor([0.75, 0.05, 0.05, 0.05, 0.05]) + >>> target = tensor([1, 0, 0, 0, 0]) + >>> binary_logauc(preds, target) + tensor(1.) + + """ + _validate_fpr_range(fpr_range) + fpr, tpr, _ = binary_roc(preds, target, thresholds, ignore_index, validate_args) + return _binary_logauc_compute(fpr, tpr, fpr_range) + + +def multiclass_logauc( + preds: Tensor, + target: Tensor, + num_classes: int, + fpr_range: Tuple[float, float] = (0.001, 0.1), + average: Optional[Literal["macro", "none"]] = "macro", + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r"""Compute the `Log AUC`_ score for multiclass classification tasks. + + The score is computed by first computing the ROC curve, which then is interpolated to the specified range of false + positive rates (FPR) and then the log is taken of the FPR before the area under the curve (AUC) is computed. The + score is commonly used in applications where the positive and negative are imbalanced and a low false positive rate + is of high importance. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + softmax per sample. + - ``target`` (int tensor): ``(N, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{classes})` (constant memory). + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_classes: Integer specifying the number of classes + fpr_range: 2-element tuple with the lower and upper bound of the false positive rate range to compute the log + AUC score. + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + average: + Defines the reduction that is applied over classes. Should be one of the following: + + - ``macro``: Calculate score for each class and average them + - ``"none"`` or ``None``: calculates score for each class and applies no reduction + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Example: + >>> from torchmetrics.functional.classification import multiclass_logauc + >>> preds = torch.tensor([[0.75, 0.05, 0.05, 0.05, 0.05], + ... [0.05, 0.75, 0.05, 0.05, 0.05], + ... [0.05, 0.05, 0.75, 0.05, 0.05], + ... [0.05, 0.05, 0.05, 0.75, 0.05]]) + >>> target = torch.tensor([0, 1, 3, 2]) + >>> multiclass_logauc(preds, target, num_classes=5, average="macro", thresholds=None) + tensor(0.4000) + >>> multiclass_logauc(preds, target, num_classes=5, average=None, thresholds=None) + tensor([1., 1., 0., 0., 0.]) + + """ + if validate_args: + _validate_fpr_range(fpr_range) + fpr, tpr, _ = multiclass_roc( + preds, target, num_classes, thresholds, average=None, ignore_index=ignore_index, validate_args=validate_args + ) + return _reduce_logauc(fpr, tpr, fpr_range, average) + + +def multilabel_logauc( + preds: Tensor, + target: Tensor, + num_labels: int, + fpr_range: Tuple[float, float] = (0.001, 0.1), + average: Optional[Literal["macro", "none"]] = "macro", + thresholds: Optional[Union[int, List[float], Tensor]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Tensor: + r"""Compute the `Log AUC`_ score for multilabel classification tasks. + + The score is computed by first computing the ROC curve, which then is interpolated to the specified range of false + positive rates (FPR) and then the log is taken of the FPR before the area under the curve (AUC) is computed. The + score is commonly used in applications where the positive and negative are imbalanced and a low false positive rate + is of high importance. + + Accepts the following input tensors: + + - ``preds`` (float tensor): ``(N, C, ...)``. Preds should be a tensor containing probabilities or logits for each + observation. If preds has values outside [0,1] range we consider the input to be logits and will auto apply + sigmoid per element. + - ``target`` (int tensor): ``(N, C, ...)``. Target should be a tensor containing ground truth labels, and therefore + only contain {0,1} values (except if `ignore_index` is specified). + + Additional dimension ``...`` will be flattened into the batch dimension. + + The implementation both supports calculating the metric in a non-binned but accurate version and a binned version + that is less accurate but more memory efficient. Setting the `thresholds` argument to `None` will activate the + non-binned version that uses memory of size :math:`\mathcal{O}(n_{samples})` whereas setting the `thresholds` + argument to either an integer, list or a 1d tensor will use a binned version that uses memory of + size :math:`\mathcal{O}(n_{thresholds} \times n_{labels})` (constant memory). + + Args: + preds: Tensor with predictions + target: Tensor with true labels + num_labels: Integer specifying the number of labels + fpr_range: 2-element tuple with the lower and upper bound of the false positive rate range to compute the log + AUC score. + average: + Defines the reduction that is applied over labels. Should be one of the following: + + - ``macro``: Calculate score for each label and average them + - ``"none"`` or ``None``: calculates score for each label and applies no reduction + + thresholds: + Can be one of: + + - If set to `None`, will use a non-binned approach where thresholds are dynamically calculated from + all the data. Most accurate but also most memory consuming approach. + - If set to an `int` (larger than 1), will use that number of thresholds linearly spaced from + 0 to 1 as bins for the calculation. + - If set to an `list` of floats, will use the indicated thresholds in the list as bins for the calculation + - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as + bins for the calculation. + + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation + validate_args: bool indicating if input arguments and tensors should be validated for correctness. + Set to ``False`` for faster computations. + + Example: + >>> from torchmetrics.functional.classification import multilabel_logauc + >>> preds = torch.tensor([[0.75, 0.05, 0.35], + ... [0.45, 0.75, 0.05], + ... [0.05, 0.55, 0.75], + ... [0.05, 0.65, 0.05]]) + >>> target = torch.tensor([[1, 0, 1], + ... [0, 0, 0], + ... [0, 1, 1], + ... [1, 1, 1]]) + >>> multilabel_logauc(preds, target, num_labels=3, average="macro", thresholds=None) + tensor(0.3945) + >>> multilabel_logauc(preds, target, num_labels=3, average=None, thresholds=None) + tensor([0.5000, 0.0000, 0.6835]) + + """ + fpr, tpr, _ = multilabel_roc(preds, target, num_labels, thresholds, ignore_index, validate_args) + return _reduce_logauc(fpr, tpr, fpr_range, average=average) + + +def logauc( + preds: Tensor, + target: Tensor, + task: Literal["binary", "multiclass", "multilabel"], + thresholds: Optional[Union[int, List[float], Tensor]] = None, + num_classes: Optional[int] = None, + num_labels: Optional[int] = None, + fpr_range: Tuple[float, float] = (0.001, 0.1), + average: Optional[Literal["macro", "none"]] = None, + ignore_index: Optional[int] = None, + validate_args: bool = True, +) -> Optional[Tensor]: + r"""Compute the `Log AUC`_ score for classification tasks. + + The score is computed by first computing the ROC curve, which then is interpolated to the specified range of false + positive rates (FPR) and then the log is taken of the FPR before the area under the curve (AUC) is computed. The + score is commonly used in applications where the positive and negative are imbalanced and a low false positive rate + is of high importance. + + """ + task = ClassificationTask.from_str(task) + if task == ClassificationTask.BINARY: + return binary_logauc(preds, target, fpr_range, thresholds, ignore_index, validate_args) + if task == ClassificationTask.MULTICLASS: + if not isinstance(num_classes, int): + raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") + return multiclass_logauc( + preds, target, num_classes, fpr_range, average, thresholds, ignore_index, validate_args + ) + if task == ClassificationTask.MULTILABEL: + if not isinstance(num_labels, int): + raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") + return multilabel_logauc(preds, target, num_labels, fpr_range, average, thresholds, ignore_index, validate_args) + return None diff --git a/src/torchmetrics/utilities/data.py b/src/torchmetrics/utilities/data.py index e5bb148a9ce..8133ef9f8f1 100644 --- a/src/torchmetrics/utilities/data.py +++ b/src/torchmetrics/utilities/data.py @@ -242,3 +242,28 @@ def allclose(tensor1: Tensor, tensor2: Tensor) -> bool: if tensor1.dtype != tensor2.dtype: tensor2 = tensor2.to(dtype=tensor1.dtype) return torch.allclose(tensor1, tensor2) + + +def interp(x: Tensor, xp: Tensor, fp: Tensor) -> Tensor: + """Interpolation function comparable to numpy.interp. + + Args: + x: x-coordinates where to evaluate the interpolated values + xp: x-coordinates of the data points + fp: y-coordinates of the data points + + """ + # Sort xp and fp based on xp for compatibility with np.interp + sorted_indices = torch.argsort(xp) + xp = xp[sorted_indices] + fp = fp[sorted_indices] + + # Calculate slopes for each interval + slopes = (fp[1:] - fp[:-1]) / (xp[1:] - xp[:-1]) + + # Identify where x falls relative to xp + indices = torch.searchsorted(xp, x) - 1 + indices = torch.clamp(indices, 0, len(slopes) - 1) + + # Compute interpolated values + return fp[indices] + slopes[indices] * (x - xp[indices]) diff --git a/src/torchmetrics/utilities/imports.py b/src/torchmetrics/utilities/imports.py index 28bda373600..ef6fcf331aa 100644 --- a/src/torchmetrics/utilities/imports.py +++ b/src/torchmetrics/utilities/imports.py @@ -59,5 +59,6 @@ _SENTENCEPIECE_AVAILABLE = RequirementCache("sentencepiece") _SCIPI_AVAILABLE = RequirementCache("scipy") _SKLEARN_GREATER_EQUAL_1_3 = RequirementCache("scikit-learn>=1.3.0") +_PYTDC_AVAILABLE = RequirementCache("pyTDC") _LATEX_AVAILABLE: bool = shutil.which("latex") is not None diff --git a/tests/unittests/classification/test_logauc.py b/tests/unittests/classification/test_logauc.py new file mode 100644 index 00000000000..26cb395f45e --- /dev/null +++ b/tests/unittests/classification/test_logauc.py @@ -0,0 +1,415 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial + +import numpy as np +import pytest +import torch +from scipy.special import expit as sigmoid +from scipy.special import softmax +from torchmetrics.utilities.imports import _PYTDC_AVAILABLE + +if _PYTDC_AVAILABLE: + from tdc.evaluator import range_logAUC + +from torchmetrics.classification.logauc import BinaryLogAUC, LogAUC, MulticlassLogAUC, MultilabelLogAUC +from torchmetrics.functional.classification.logauc import binary_logauc, multiclass_logauc, multilabel_logauc +from torchmetrics.functional.classification.roc import binary_roc +from torchmetrics.metric import Metric + +from unittests import NUM_CLASSES +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester, inject_ignore_index, remove_ignore_index +from unittests.classification._inputs import _binary_cases, _multiclass_cases, _multilabel_cases + +seed_all(42) + + +def _binary_compare_implementation(preds, target, fpr_range, ignore_index=None): + """Binary comparison function for logauc.""" + preds = preds.flatten().numpy() + target = target.flatten().numpy() + if not ((preds > 0) & (preds < 1)).all(): + preds = sigmoid(preds) + target, preds = remove_ignore_index(target, preds, ignore_index) + return range_logAUC(target, preds, FPR_range=fpr_range) + + +@pytest.mark.skipif(not _PYTDC_AVAILABLE, reason="test requires pytdc installed.") +@pytest.mark.parametrize("inputs", (_binary_cases[1], _binary_cases[2], _binary_cases[4], _binary_cases[5])) +class TestBinaryLogAUC(MetricTester): + """Test class for `BinaryLogAUC` metric.""" + + atol = 1e-2 + + @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) + @pytest.mark.parametrize("fpr_range", [(0.001, 0.1), (0.01, 0.1), (0.1, 0.2)]) + def test_binary_logauc(self, inputs, ddp, fpr_range): + """Test class implementation of metric.""" + preds, target = inputs + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=BinaryLogAUC, + reference_metric=partial(_binary_compare_implementation, fpr_range=fpr_range), + metric_args={ + "fpr_range": fpr_range, + "thresholds": None, + }, + ) + + @pytest.mark.parametrize("fpr_range", [(0.001, 0.1), (0.01, 0.1), (0.1, 0.2)]) + @pytest.mark.parametrize("ignore_index", [None, -1]) + def test_binary_logauc_functional(self, inputs, fpr_range, ignore_index): + """Test functional implementation of metric.""" + preds, target = inputs + if ignore_index is not None: + target = inject_ignore_index(target, ignore_index) + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=binary_logauc, + reference_metric=partial(_binary_compare_implementation, fpr_range=fpr_range, ignore_index=ignore_index), + metric_args={ + "fpr_range": fpr_range, + "thresholds": None, + "ignore_index": ignore_index, + }, + ) + + def test_binary_logauc_differentiability(self, inputs): + """Test the differentiability of the metric, according to its `is_differentiable` attribute.""" + preds, target = inputs + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=BinaryLogAUC, + metric_functional=binary_logauc, + metric_args={"thresholds": None}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_logauc_dtype_cpu(self, inputs, dtype): + """Test dtype support of the metric on CPU.""" + preds, target = inputs + + if (preds < 0).any() and dtype == torch.half: + pytest.xfail(reason="torch.sigmoid in metric does not support cpu + half precision") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=BinaryLogAUC, + metric_functional=binary_logauc, + metric_args={"thresholds": None}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_binary_logauc_dtype_gpu(self, inputs, dtype): + """Test dtype support of the metric on GPU.""" + preds, target = inputs + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=BinaryLogAUC, + metric_functional=binary_logauc, + metric_args={"thresholds": None}, + dtype=dtype, + ) + + @pytest.mark.parametrize("threshold_fn", [lambda x: x, lambda x: x.numpy().tolist()], ids=["as tensor", "as list"]) + def test_binary_logauc_threshold_arg(self, inputs, threshold_fn): + """Test that different types of `thresholds` argument lead to same result.""" + preds, target = inputs + + for pred, true in zip(preds, target): + _, _, t = binary_roc(pred, true, thresholds=None) + ap1 = binary_logauc(pred, true, thresholds=None) + ap2 = binary_logauc(pred, true, thresholds=threshold_fn(t.flip(0))) + assert torch.allclose(ap1, ap2, atol=self.atol) + + +def _multiclass_compare_implementation(preds, target, fpr_range, average): + """Multiclass comparison function for logauc.""" + preds = preds.permute(0, 2, 1).reshape(-1, NUM_CLASSES).numpy() if preds.ndim == 3 else preds.numpy() + target = target.flatten().numpy() + if not ((preds > 0) & (preds < 1)).all(): + preds = softmax(preds, 1) + + scores = [] + for i in range(NUM_CLASSES): + p, t = preds[:, i], (target == i).astype(int) + scores.append(range_logAUC(t, p, FPR_range=fpr_range)) + if average == "macro": + return np.mean(scores) + return scores + + +@pytest.mark.skipif(not _PYTDC_AVAILABLE, reason="test requires pytdc installed.") +@pytest.mark.parametrize( + "inputs", (_multiclass_cases[1], _multiclass_cases[2], _multiclass_cases[4], _multiclass_cases[5]) +) +class TestMulticlassLogAUC(MetricTester): + """Test class for `MulticlassLogAUC` metric.""" + + atol = 1e-2 + + @pytest.mark.parametrize("fpr_range", [(0.001, 0.1), (0.01, 0.1), (0.1, 0.2)]) + @pytest.mark.parametrize("average", ["macro", None]) + @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) + def test_multiclass_logauc(self, inputs, fpr_range, average, ddp): + """Test class implementation of metric.""" + preds, target = inputs + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MulticlassLogAUC, + reference_metric=partial(_multiclass_compare_implementation, fpr_range=fpr_range, average=average), + metric_args={ + "thresholds": None, + "num_classes": NUM_CLASSES, + "fpr_range": fpr_range, + "average": average, + }, + ) + + @pytest.mark.parametrize("fpr_range", [(0.001, 0.1), (0.01, 0.1), (0.1, 0.2)]) + @pytest.mark.parametrize("average", ["macro", None]) + def test_multiclass_logauc_functional(self, inputs, fpr_range, average): + """Test functional implementation of metric.""" + preds, target = inputs + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=multiclass_logauc, + reference_metric=partial(_multiclass_compare_implementation, fpr_range=fpr_range, average=average), + metric_args={ + "thresholds": None, + "num_classes": NUM_CLASSES, + "fpr_range": fpr_range, + "average": average, + }, + ) + + def test_multiclass_logauc_differentiability(self, inputs): + """Test the differentiability of the metric, according to its `is_differentiable` attribute.""" + preds, target = inputs + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MulticlassLogAUC, + metric_functional=multiclass_logauc, + metric_args={"thresholds": None, "num_classes": NUM_CLASSES}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_logauc_dtype_cpu(self, inputs, dtype): + """Test dtype support of the metric on CPU.""" + preds, target = inputs + + if dtype == torch.half and not ((preds > 0) & (preds < 1)).all(): + pytest.xfail(reason="half support for torch.softmax on cpu not implemented") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MulticlassLogAUC, + metric_functional=multiclass_logauc, + metric_args={"thresholds": None, "num_classes": NUM_CLASSES}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_logauc_dtype_gpu(self, inputs, dtype): + """Test dtype support of the metric on GPU.""" + preds, target = inputs + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MulticlassLogAUC, + metric_functional=multiclass_logauc, + metric_args={"thresholds": None, "num_classes": NUM_CLASSES}, + dtype=dtype, + ) + + def test_multiclass_logauc_threshold_arg(self, inputs): + """Test that different types of `thresholds` argument lead to same result.""" + preds, target = inputs + if (preds < 0).any(): + preds = preds.softmax(dim=-1) + for pred, true in zip(preds, target): + pred = torch.tensor(np.round(pred.numpy(), 2)) + 1e-6 # rounding will simulate binning + ap1 = multiclass_logauc(pred, true, num_classes=NUM_CLASSES, average="macro", thresholds=None) + ap2 = multiclass_logauc( + pred, true, num_classes=NUM_CLASSES, average="macro", thresholds=torch.linspace(0, 1, 100) + ) + assert torch.allclose(ap1, ap2, atol=self.atol) + + +def _multilabel_compare_implementation(preds, target, fpr_range, average): + if preds.ndim > 2: + target = target.transpose(2, 1).reshape(-1, NUM_CLASSES) + preds = preds.transpose(2, 1).reshape(-1, NUM_CLASSES) + target = target.numpy() + preds = preds.numpy() + if not ((preds > 0) & (preds < 1)).all(): + preds = sigmoid(preds) + scores = [] + for i in range(NUM_CLASSES): + p, t = preds[:, i], target[:, i] + scores.append(range_logAUC(t, p, FPR_range=fpr_range)) + if average == "macro": + return np.mean(scores) + return scores + + +@pytest.mark.skipif(not _PYTDC_AVAILABLE, reason="test requires pytdc installed.") +@pytest.mark.parametrize( + "inputs", (_multilabel_cases[1], _multilabel_cases[2], _multilabel_cases[4], _multilabel_cases[5]) +) +class TestMultilabelLogAUC(MetricTester): + """Test class for `MultilabelLogAUC` metric.""" + + atol = 1e-2 + + @pytest.mark.parametrize("fpr_range", [(0.001, 0.1), (0.01, 0.1), (0.1, 0.2)]) + @pytest.mark.parametrize("average", ["macro", None]) + @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) + def test_multilabel_logauc(self, inputs, ddp, fpr_range, average): + """Test class implementation of metric.""" + preds, target = inputs + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=MultilabelLogAUC, + reference_metric=partial(_multilabel_compare_implementation, fpr_range=fpr_range, average=average), + metric_args={ + "thresholds": None, + "num_labels": NUM_CLASSES, + "average": average, + "fpr_range": fpr_range, + }, + ) + + @pytest.mark.parametrize("fpr_range", [(0.001, 0.1), (0.01, 0.1), (0.1, 0.2)]) + @pytest.mark.parametrize("average", ["macro", None]) + def test_multilabel_logauc_functional(self, inputs, fpr_range, average): + """Test functional implementation of metric.""" + preds, target = inputs + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=multilabel_logauc, + reference_metric=partial(_multilabel_compare_implementation, fpr_range=fpr_range, average=average), + metric_args={ + "thresholds": None, + "num_labels": NUM_CLASSES, + "average": average, + "fpr_range": fpr_range, + }, + ) + + def test_multiclass_logauc_differentiability(self, inputs): + """Test the differentiability of the metric, according to its `is_differentiable` attribute.""" + preds, target = inputs + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=MultilabelLogAUC, + metric_functional=multilabel_logauc, + metric_args={"thresholds": None, "num_labels": NUM_CLASSES}, + ) + + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multilabel_logauc_dtype_cpu(self, inputs, dtype): + """Test dtype support of the metric on CPU.""" + preds, target = inputs + + if dtype == torch.half and not ((preds > 0) & (preds < 1)).all(): + pytest.xfail(reason="half support for torch.softmax on cpu not implemented") + self.run_precision_test_cpu( + preds=preds, + target=target, + metric_module=MultilabelLogAUC, + metric_functional=multilabel_logauc, + metric_args={"thresholds": None, "num_labels": NUM_CLASSES}, + dtype=dtype, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.parametrize("dtype", [torch.half, torch.double]) + def test_multiclass_logauc_dtype_gpu(self, inputs, dtype): + """Test dtype support of the metric on GPU.""" + preds, target = inputs + self.run_precision_test_gpu( + preds=preds, + target=target, + metric_module=MultilabelLogAUC, + metric_functional=multilabel_logauc, + metric_args={"thresholds": None, "num_labels": NUM_CLASSES}, + dtype=dtype, + ) + + def test_multilabel_logauc_threshold_arg(self, inputs): + """Test that different types of `thresholds` argument lead to same result.""" + preds, target = inputs + if (preds < 0).any(): + preds = sigmoid(preds) + for pred, true in zip(preds, target): + pred = torch.tensor(np.round(pred.numpy(), 1)) + 1e-6 # rounding will simulate binning + ap1 = multilabel_logauc(pred, true, num_labels=NUM_CLASSES, average="macro", thresholds=None) + ap2 = multilabel_logauc( + pred, true, num_labels=NUM_CLASSES, average="macro", thresholds=torch.linspace(0, 1, 100) + ) + assert torch.allclose(ap1, ap2, atol=self.atol) + + +@pytest.mark.parametrize( + "metric", + [ + BinaryLogAUC, + partial(MulticlassLogAUC, num_classes=NUM_CLASSES), + partial(MultilabelLogAUC, num_labels=NUM_CLASSES), + ], +) +@pytest.mark.parametrize("thresholds", [None, 100, [0.3, 0.5, 0.7, 0.9], torch.linspace(0, 1, 10)]) +def test_valid_input_thresholds(recwarn, metric, thresholds): + """Test valid formats of the threshold argument.""" + metric(thresholds=thresholds) + assert len(recwarn) == 0, "Warning was raised when it should not have been." + + +@pytest.mark.parametrize( + ("metric", "kwargs"), + [ + (BinaryLogAUC, {"task": "binary"}), + (MulticlassLogAUC, {"task": "multiclass", "num_classes": 3}), + (MultilabelLogAUC, {"task": "multilabel", "num_labels": 3}), + (None, {"task": "not_valid_task"}), + ], +) +def test_wrapper_class(metric, kwargs, base_metric=LogAUC): + """Test the wrapper class.""" + assert issubclass(base_metric, Metric) + if metric is None: + with pytest.raises(ValueError, match=r"Invalid *"): + base_metric(**kwargs) + else: + instance = base_metric(**kwargs) + assert isinstance(instance, metric) + assert isinstance(instance, Metric) diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index efb7077682e..4ebd41fd300 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -47,6 +47,7 @@ BinaryHammingDistance, BinaryHingeLoss, BinaryJaccardIndex, + BinaryLogAUC, BinaryMatthewsCorrCoef, BinaryPrecision, BinaryPrecisionRecallCurve, @@ -66,6 +67,7 @@ MulticlassHammingDistance, MulticlassHingeLoss, MulticlassJaccardIndex, + MulticlassLogAUC, MulticlassMatthewsCorrCoef, MulticlassPrecision, MulticlassPrecisionRecallCurve, @@ -80,6 +82,7 @@ MultilabelFBetaScore, MultilabelHammingDistance, MultilabelJaccardIndex, + MultilabelLogAUC, MultilabelMatthewsCorrCoef, MultilabelPrecision, MultilabelPrecisionRecallCurve, @@ -384,6 +387,19 @@ _multilabel_randint_input, id="multilabel specificity", ), + pytest.param(BinaryLogAUC, _rand_input, _binary_randint_input, id="binary log auc"), + pytest.param( + partial(MulticlassLogAUC, num_classes=3), + _multiclass_randn_input, + _multiclass_randint_input, + id="multiclass log auc", + ), + pytest.param( + partial(MultilabelLogAUC, num_labels=3), + _multilabel_rand_input, + _multilabel_randint_input, + id="multilabel log auc", + ), pytest.param( partial(MultilabelCoverageError, num_labels=3), _multilabel_rand_input,