From 2387f2afb6d0f3dfe5ccbd60fbb1b2a0ac67a194 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 3 Oct 2023 15:46:01 +0200 Subject: [PATCH 1/3] [pre-commit.ci] pre-commit suggestions (#2136) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka --- .github/assistant.py | 2 +- .pre-commit-config.yaml | 16 ++++++++-------- pyproject.toml | 5 +++-- src/torchmetrics/audio/__init__.py | 12 ++++++------ src/torchmetrics/audio/stoi.py | 2 +- src/torchmetrics/detection/mean_ap.py | 4 ++-- src/torchmetrics/functional/audio/__init__.py | 12 ++++++------ src/torchmetrics/functional/audio/sdr.py | 4 ++-- src/torchmetrics/functional/audio/stoi.py | 2 +- .../functional/detection/__init__.py | 14 ++++++-------- src/torchmetrics/functional/image/ssim.py | 2 +- src/torchmetrics/functional/text/__init__.py | 7 +++---- src/torchmetrics/functional/text/chrf.py | 2 +- src/torchmetrics/functional/text/eed.py | 4 ++-- src/torchmetrics/functional/text/helper.py | 2 +- src/torchmetrics/functional/text/rouge.py | 2 +- src/torchmetrics/functional/text/sacre_bleu.py | 2 +- src/torchmetrics/functional/text/squad.py | 2 +- src/torchmetrics/image/ssim.py | 2 +- src/torchmetrics/retrieval/fall_out.py | 2 +- src/torchmetrics/text/__init__.py | 7 +++---- src/torchmetrics/wrappers/tracker.py | 2 +- tests/unittests/bases/test_composition.py | 2 +- tests/unittests/text/test_rouge.py | 2 +- 24 files changed, 55 insertions(+), 58 deletions(-) diff --git a/.github/assistant.py b/.github/assistant.py index 68f9eeef8f7..c7d19a9c319 100644 --- a/.github/assistant.py +++ b/.github/assistant.py @@ -79,7 +79,7 @@ def set_min_torch_by_python(fpath: str = "requirements/base.txt") -> None: return with open(fpath) as fp: reqs = parse_requirements(fp.readlines()) - pkg_ver = [p for p in reqs if p.name == "torch"][0] + pkg_ver = next(p for p in reqs if p.name == "torch") pt_ver = min([LooseVersion(v[1]) for v in pkg_ver.specs]) pt_ver = max(LooseVersion(LUT_PYTHON_TORCH[py_ver]), pt_ver) with open(fpath) as fp: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 43de956fd86..c37fdaefcd7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -38,22 +38,22 @@ repos: - id: detect-private-key - repo: https://github.com/asottile/pyupgrade - rev: v3.9.0 + rev: v3.14.0 hooks: - id: pyupgrade - args: [--py38-plus] + args: ["--py38-plus"] name: Upgrade code - repo: https://github.com/codespell-project/codespell - rev: v2.2.5 + rev: v2.2.6 hooks: - id: codespell additional_dependencies: [tomli] - #args: ["--write-changes"] + args: ["--write-changes"] exclude: pyproject.toml - repo: https://github.com/crate-ci/typos - rev: v1.16.12 + rev: v1.16.17 hooks: - id: typos # empty to do not write fixes @@ -68,13 +68,13 @@ repos: args: ["--in-place"] - repo: https://github.com/psf/black - rev: 23.7.0 + rev: 23.9.1 hooks: - id: black name: Format code - repo: https://github.com/executablebooks/mdformat - rev: 0.7.16 + rev: 0.7.17 hooks: - id: mdformat additional_dependencies: @@ -130,7 +130,7 @@ repos: - id: text-unicode-replacement-char - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.0.277 + rev: v0.0.292 hooks: - id: ruff args: ["--fix"] diff --git a/pyproject.toml b/pyproject.toml index 8fcbef0e92a..9114383c58a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,10 +22,8 @@ addopts = [ #filterwarnings = ["error::FutureWarning"] # ToDo xfail_strict = true junit_duration_report = "call" - [tool.coverage.report] exclude_lines = ["pragma: no cover", "pass"] - [tool.coverage.run] parallel = true concurrency = "thread" @@ -81,6 +79,7 @@ wil = "wil" [tool.ruff] +target-version = "py38" line-length = 120 # Enable Pyflakes `E` and `F` codes by default. select = [ @@ -122,6 +121,8 @@ ignore = [ "S301", # todo: `pickle` and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue # todo "S310", # todo: Audit URL open for permitted schemes. Allowing use of `file:` or custom schemes is often unexpected. # todo "B905", # todo: `zip()` without an explicit `strict=` parameter + "PYI024", # todo: Use `typing.NamedTuple` instead of `collections.namedtuple` + "PYI041", # todo: Use `float` instead of `int | float`` ] # Exclude a variety of commonly ignored directories. exclude = [ diff --git a/src/torchmetrics/audio/__init__.py b/src/torchmetrics/audio/__init__.py index 1df6c22645e..31c01171c01 100644 --- a/src/torchmetrics/audio/__init__.py +++ b/src/torchmetrics/audio/__init__.py @@ -41,16 +41,16 @@ ] if _PESQ_AVAILABLE: - from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality # noqa: F401 + from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality - __all__.append("PerceptualEvaluationSpeechQuality") + __all__ += ["PerceptualEvaluationSpeechQuality"] if _PYSTOI_AVAILABLE: - from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility # noqa: F401 + from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility - __all__.append("ShortTimeObjectiveIntelligibility") + __all__ += ["ShortTimeObjectiveIntelligibility"] if _GAMMATONE_AVAILABLE and _TORCHAUDIO_AVAILABLE and _TORCHAUDIO_GREATER_EQUAL_0_10: - from torchmetrics.audio.srmr import SpeechReverberationModulationEnergyRatio # noqa: F401 + from torchmetrics.audio.srmr import SpeechReverberationModulationEnergyRatio - __all__.append("SpeechReverberationModulationEnergyRatio") + __all__ += ["SpeechReverberationModulationEnergyRatio"] diff --git a/src/torchmetrics/audio/stoi.py b/src/torchmetrics/audio/stoi.py index 32c92f5a515..a60473066dd 100644 --- a/src/torchmetrics/audio/stoi.py +++ b/src/torchmetrics/audio/stoi.py @@ -34,7 +34,7 @@ class ShortTimeObjectiveIntelligibility(Metric): The STOI-measure is intrusive, i.e., a function of the clean and degraded speech signals. STOI may be a good alternative to the speech intelligibility index (SII) or the speech transmission index (STI), when you are interested in the effect of nonlinear processing to noisy speech, e.g., noise reduction, binary masking algorithms, - on speech intelligibility. Description taken from `Cees Taal's website`_ and for further defails see `STOI ref1`_ + on speech intelligibility. Description taken from `Cees Taal's website`_ and for further details see `STOI ref1`_ and `STOI ref2`_. This metric is a wrapper for the `pystoi package`_. As the implementation backend implementation only supports diff --git a/src/torchmetrics/detection/mean_ap.py b/src/torchmetrics/detection/mean_ap.py index 6cc80a62317..74d9d212a6d 100644 --- a/src/torchmetrics/detection/mean_ap.py +++ b/src/torchmetrics/detection/mean_ap.py @@ -879,7 +879,7 @@ def _get_coco_format( f"Invalid input box of sample {image_id}, element {k} (expected 4 values, got {len(image_box)})" ) - if type(image_label) != int: + if not isinstance(image_label, int): raise ValueError( f"Invalid input class of sample {image_id}, element {k}" f" (expected value of type integer, got type {type(image_label)})" @@ -915,7 +915,7 @@ def _get_coco_format( if scores is not None: score = scores[image_id][k].cpu().tolist() - if type(score) != float: + if not isinstance(score, float): raise ValueError( f"Invalid input score of sample {image_id}, element {k}" f" (expected value of type float, got type {type(score)})" diff --git a/src/torchmetrics/functional/audio/__init__.py b/src/torchmetrics/functional/audio/__init__.py index b6469c7aace..077442b0b83 100644 --- a/src/torchmetrics/functional/audio/__init__.py +++ b/src/torchmetrics/functional/audio/__init__.py @@ -42,16 +42,16 @@ ] if _PESQ_AVAILABLE: - from torchmetrics.functional.audio.pesq import perceptual_evaluation_speech_quality # noqa: F401 + from torchmetrics.functional.audio.pesq import perceptual_evaluation_speech_quality - __all__.append("perceptual_evaluation_speech_quality") + __all__ += ["perceptual_evaluation_speech_quality"] if _PYSTOI_AVAILABLE: - from torchmetrics.functional.audio.stoi import short_time_objective_intelligibility # noqa: F401 + from torchmetrics.functional.audio.stoi import short_time_objective_intelligibility - __all__.append("short_time_objective_intelligibility") + __all__ += ["short_time_objective_intelligibility"] if _GAMMATONE_AVAILABLE and _TORCHAUDIO_AVAILABLE and _TORCHAUDIO_GREATER_EQUAL_0_10: - from torchmetrics.functional.audio.srmr import speech_reverberation_modulation_energy_ratio # noqa: F401 + from torchmetrics.functional.audio.srmr import speech_reverberation_modulation_energy_ratio - __all__.append("speech_reverberation_modulation_energy_ratio") + __all__ += ["speech_reverberation_modulation_energy_ratio"] diff --git a/src/torchmetrics/functional/audio/sdr.py b/src/torchmetrics/functional/audio/sdr.py index 9e47d9489c3..9f4043948b6 100644 --- a/src/torchmetrics/functional/audio/sdr.py +++ b/src/torchmetrics/functional/audio/sdr.py @@ -63,7 +63,7 @@ def _symmetric_toeplitz(vector: Tensor) -> Tensor: def _compute_autocorr_crosscorr(target: Tensor, preds: Tensor, corr_len: int) -> Tuple[Tensor, Tensor]: r"""Compute the auto correlation of `target` and the cross correlation of `target` and `preds`. - This calculation is done using the fast Fourier transform (FFT). Let's denotes the symmetric Toeplitz matric of the + This calculation is done using the fast Fourier transform (FFT). Let's denotes the symmetric Toeplitz metric of the auto correlation of `target` as `R`, the cross correlation as 'b', then solving the equation `Rh=b` could have `h` as the coordinate of `preds` in the column space of the `corr_len` shifts of `target`. @@ -81,7 +81,7 @@ def _compute_autocorr_crosscorr(target: Tensor, preds: Tensor, corr_len: int) -> n_fft = 2 ** math.ceil(math.log2(preds.shape[-1] + target.shape[-1] - 1)) # computes the auto correlation of `target` - # r_0 is the first row of the symmetric Toeplitz matric + # r_0 is the first row of the symmetric Toeplitz metric t_fft = torch.fft.rfft(target, n=n_fft, dim=-1) r_0 = torch.fft.irfft(t_fft.real**2 + t_fft.imag**2, n=n_fft)[..., :corr_len] diff --git a/src/torchmetrics/functional/audio/stoi.py b/src/torchmetrics/functional/audio/stoi.py index 56d077a4982..81d7a6ffda7 100644 --- a/src/torchmetrics/functional/audio/stoi.py +++ b/src/torchmetrics/functional/audio/stoi.py @@ -35,7 +35,7 @@ def short_time_objective_intelligibility( STOI-measure is intrusive, i.e., a function of the clean and degraded speech signals. STOI may be a good alternative to the speech intelligibility index (SII) or the speech transmission index (STI), when you are interested in the effect of nonlinear processing to noisy speech, e.g., noise reduction, binary masking algorithms, on speech - intelligibility. Description taken from `Cees Taal's website`_ and for further defails see `STOI ref1`_ and + intelligibility. Description taken from `Cees Taal's website`_ and for further details see `STOI ref1`_ and `STOI ref2`_. This metric is a wrapper for the `pystoi package`_. As the implementation backend implementation only supports diff --git a/src/torchmetrics/functional/detection/__init__.py b/src/torchmetrics/functional/detection/__init__.py index 85a2d12e39c..8f818c7b2df 100644 --- a/src/torchmetrics/functional/detection/__init__.py +++ b/src/torchmetrics/functional/detection/__init__.py @@ -22,15 +22,13 @@ __all__ = ["modified_panoptic_quality", "panoptic_quality"] if _TORCHVISION_AVAILABLE and _TORCHVISION_GREATER_EQUAL_0_8: - from torchmetrics.functional.detection.giou import generalized_intersection_over_union # noqa: F401 - from torchmetrics.functional.detection.iou import intersection_over_union # noqa: F401 + from torchmetrics.functional.detection.giou import generalized_intersection_over_union + from torchmetrics.functional.detection.iou import intersection_over_union - __all__.append("generalized_intersection_over_union") - __all__.append("intersection_over_union") + __all__ += ["generalized_intersection_over_union", "intersection_over_union"] if _TORCHVISION_AVAILABLE and _TORCHVISION_GREATER_EQUAL_0_13: - from torchmetrics.functional.detection.ciou import complete_intersection_over_union # noqa: F401 - from torchmetrics.functional.detection.diou import distance_intersection_over_union # noqa: F401 + from torchmetrics.functional.detection.ciou import complete_intersection_over_union + from torchmetrics.functional.detection.diou import distance_intersection_over_union - __all__.append("complete_intersection_over_union") - __all__.append("distance_intersection_over_union") + __all__ += ["complete_intersection_over_union", "distance_intersection_over_union"] diff --git a/src/torchmetrics/functional/image/ssim.py b/src/torchmetrics/functional/image/ssim.py index 3f7bc7fcb4b..d0e9d15c6dc 100644 --- a/src/torchmetrics/functional/image/ssim.py +++ b/src/torchmetrics/functional/image/ssim.py @@ -479,7 +479,7 @@ def multiscale_structural_similarity_index_measure( the range is calculated as the difference and input is clamped between the values. k1: Parameter of structural similarity index measure. k2: Parameter of structural similarity index measure. - betas: Exponent parameters for individual similarities and contrastive sensitivies returned by different image + betas: Exponent parameters for individual similarities and contrastive sensitivities returned by different image resolutions. normalize: When MultiScaleSSIM loss is used for training, it is desirable to use normalizes to improve the training stability. This `normalize` argument is out of scope of the original implementation [1], and it is diff --git a/src/torchmetrics/functional/text/__init__.py b/src/torchmetrics/functional/text/__init__.py index 504a453b79e..9282be6fbae 100644 --- a/src/torchmetrics/functional/text/__init__.py +++ b/src/torchmetrics/functional/text/__init__.py @@ -47,8 +47,7 @@ if _TRANSFORMERS_GREATER_EQUAL_4_4: - from torchmetrics.functional.text.bert import bert_score # noqa: F401 - from torchmetrics.functional.text.infolm import infolm # noqa: F401 + from torchmetrics.functional.text.bert import bert_score + from torchmetrics.functional.text.infolm import infolm - __all__.append("bert_score") - __all__.append("infolm") + __all__ += ["bert_score", "infolm"] diff --git a/src/torchmetrics/functional/text/chrf.py b/src/torchmetrics/functional/text/chrf.py index 4ed62c7fc49..1b6baca5f76 100644 --- a/src/torchmetrics/functional/text/chrf.py +++ b/src/torchmetrics/functional/text/chrf.py @@ -268,7 +268,7 @@ def _calculate_fscore( beta: A parameter determining an importance of recall w.r.t. precision. If `beta=1`, their importance is equal. Return: - A chrF/chrF++ score. This function is universal both for sentence-level and corpus-level calucation. + A chrF/chrF++ score. This function is universal both for sentence-level and corpus-level calculation. """ diff --git a/src/torchmetrics/functional/text/eed.py b/src/torchmetrics/functional/text/eed.py index de5c50fa786..f6f562f5338 100644 --- a/src/torchmetrics/functional/text/eed.py +++ b/src/torchmetrics/functional/text/eed.py @@ -145,7 +145,7 @@ def _eed_function( next_row = [inf] * (len(hyp) + 1) for w in range(1, len(ref) + 1): - for i in range(0, len(hyp) + 1): + for i in range(len(hyp) + 1): if i > 0: next_row[i] = min( next_row[i - 1] + deletion, @@ -252,7 +252,7 @@ def _eed_compute(sentence_level_scores: List[Tensor]) -> Tensor: def _preprocess_sentences( preds: Union[str, Sequence[str]], target: Sequence[Union[str, Sequence[str]]], - language: Union[Literal["en"], Literal["ja"]], + language: Literal["en", "ja"], ) -> Tuple[Union[str, Sequence[str]], Sequence[Union[str, Sequence[str]]]]: """Preprocess strings according to language requirements. diff --git a/src/torchmetrics/functional/text/helper.py b/src/torchmetrics/functional/text/helper.py index 4fe72fcf635..d4c9ff7ae04 100644 --- a/src/torchmetrics/functional/text/helper.py +++ b/src/torchmetrics/functional/text/helper.py @@ -242,7 +242,7 @@ def _add_cache(self, prediction_tokens: List[str], edit_distance: List[List[Tupl node = value[0] # type: ignore def _find_cache(self, prediction_tokens: List[str]) -> Tuple[int, List[List[Tuple[int, _EditOperations]]]]: - """Find the already calculated rows of the Levenshtein edit distance matric. + """Find the already calculated rows of the Levenshtein edit distance metric. Args: prediction_tokens: A tokenized predicted sentence. diff --git a/src/torchmetrics/functional/text/rouge.py b/src/torchmetrics/functional/text/rouge.py index ff04f76cd2c..4bd7e27bc10 100644 --- a/src/torchmetrics/functional/text/rouge.py +++ b/src/torchmetrics/functional/text/rouge.py @@ -490,7 +490,7 @@ def rouge_score( if not isinstance(rouge_keys, tuple): rouge_keys = (rouge_keys,) for key in rouge_keys: - if key not in ALLOWED_ROUGE_KEYS.keys(): + if key not in ALLOWED_ROUGE_KEYS: raise ValueError(f"Got unknown rouge key {key}. Expected to be one of {list(ALLOWED_ROUGE_KEYS.keys())}") rouge_keys_values = [ALLOWED_ROUGE_KEYS[key] for key in rouge_keys] diff --git a/src/torchmetrics/functional/text/sacre_bleu.py b/src/torchmetrics/functional/text/sacre_bleu.py index ff34fb174b4..af247be76d5 100644 --- a/src/torchmetrics/functional/text/sacre_bleu.py +++ b/src/torchmetrics/functional/text/sacre_bleu.py @@ -333,7 +333,7 @@ def sacre_bleu_score( if tokenize not in AVAILABLE_TOKENIZERS: raise ValueError(f"Argument `tokenize` expected to be one of {AVAILABLE_TOKENIZERS} but got {tokenize}.") - if tokenize not in _SacreBLEUTokenizer._TOKENIZE_FN.keys(): + if tokenize not in _SacreBLEUTokenizer._TOKENIZE_FN: raise ValueError( f"Unsupported tokenizer selected. Please, choose one of {list(_SacreBLEUTokenizer._TOKENIZE_FN.keys())}" ) diff --git a/src/torchmetrics/functional/text/squad.py b/src/torchmetrics/functional/text/squad.py index 2440333c61f..01dfb4ec0e6 100644 --- a/src/torchmetrics/functional/text/squad.py +++ b/src/torchmetrics/functional/text/squad.py @@ -119,7 +119,7 @@ def _squad_input_check( ) answers: Dict[str, Union[List[str], List[int]]] = target["answers"] # type: ignore[assignment] - if "text" not in answers.keys(): + if "text" not in answers: raise KeyError( "Expected keys in a 'answers' are 'text'." "Please make sure that 'answer' maps to a `SQuAD` format dictionary.\n" diff --git a/src/torchmetrics/image/ssim.py b/src/torchmetrics/image/ssim.py index ac0808ea653..5056589fd14 100644 --- a/src/torchmetrics/image/ssim.py +++ b/src/torchmetrics/image/ssim.py @@ -249,7 +249,7 @@ class MultiScaleStructuralSimilarityIndexMeasure(Metric): The ``data_range`` must be given when ``dim`` is not None. k1: Parameter of structural similarity index measure. k2: Parameter of structural similarity index measure. - betas: Exponent parameters for individual similarities and contrastive sensitivies returned by different image + betas: Exponent parameters for individual similarities and contrastive sensitivities returned by different image resolutions. normalize: When MultiScaleStructuralSimilarityIndexMeasure loss is used for training, it is desirable to use normalizes to improve the training stability. This `normalize` argument is out of scope of the original diff --git a/src/torchmetrics/retrieval/fall_out.py b/src/torchmetrics/retrieval/fall_out.py index 52a00298665..eea6283898d 100644 --- a/src/torchmetrics/retrieval/fall_out.py +++ b/src/torchmetrics/retrieval/fall_out.py @@ -40,7 +40,7 @@ class RetrievalFallOut(RetrievalMetric): As output to ``forward`` and ``compute`` the metric returns the following output: - - ``fo@k`` (:class:`~torch.Tensor`): A tensor with the computed metric + - ``fallout@k`` (:class:`~torch.Tensor`): A tensor with the computed metric All ``indexes``, ``preds`` and ``target`` must have the same dimension and will be flatten at the beginning, so that for example, a tensor of shape ``(N, M)`` is treated as ``(N * M, )``. Predictions will be first grouped by diff --git a/src/torchmetrics/text/__init__.py b/src/torchmetrics/text/__init__.py index 52d01026f5a..48807a98fc4 100644 --- a/src/torchmetrics/text/__init__.py +++ b/src/torchmetrics/text/__init__.py @@ -45,8 +45,7 @@ ] if _TRANSFORMERS_GREATER_EQUAL_4_4: - from torchmetrics.text.bert import BERTScore # noqa: F401 - from torchmetrics.text.infolm import InfoLM # noqa: F401 + from torchmetrics.text.bert import BERTScore + from torchmetrics.text.infolm import InfoLM - __all__.append("BERTScore") - __all__.append("InfoLM") + __all__ += ["BERTScore", "InfoLM"] diff --git a/src/torchmetrics/wrappers/tracker.py b/src/torchmetrics/wrappers/tracker.py index 6809fdb115f..7e9913b23a9 100644 --- a/src/torchmetrics/wrappers/tracker.py +++ b/src/torchmetrics/wrappers/tracker.py @@ -152,7 +152,7 @@ def compute_all(self) -> Any: """Compute the metric value for all tracked metrics. Return: - By default will try stacking the results from all increaments into a single tensor if the tracked base + By default will try stacking the results from all increments into a single tensor if the tracked base object is a single metric. If a metric collection is provided a dict of stacked tensors will be returned. If the stacking process fails a list of the computed results will be returned. diff --git a/tests/unittests/bases/test_composition.py b/tests/unittests/bases/test_composition.py index 19863276bed..f33d37f2015 100644 --- a/tests/unittests/bases/test_composition.py +++ b/tests/unittests/bases/test_composition.py @@ -67,7 +67,7 @@ def test_metrics_add(second_operand, expected_result): @pytest.mark.parametrize( ("second_operand", "expected_result"), - [(DummyMetric(3), tensor(2)), (3, tensor(2)), (3, tensor(2)), (tensor(3), tensor(2))], + [(DummyMetric(3), tensor(2)), (3, tensor(2)), (tensor(3), tensor(2))], ) def test_metrics_and(second_operand, expected_result): """Test that `and` operator works and returns a compositional metric.""" diff --git a/tests/unittests/text/test_rouge.py b/tests/unittests/text/test_rouge.py index e2d31418e03..fe1ba4cbfcd 100644 --- a/tests/unittests/text/test_rouge.py +++ b/tests/unittests/text/test_rouge.py @@ -73,7 +73,7 @@ def _compute_rouge_score( aggregator_avg = BootstrapAggregator() if accumulate == "best": - key_curr = list(list_results[0].keys())[0] + key_curr = next(iter(list_results[0].keys())) all_fmeasure = torch.tensor([v[key_curr].fmeasure for v in list_results]) highest_idx = torch.argmax(all_fmeasure).item() aggregator.add_scores(list_results[highest_idx]) From b12a64723ae4548bfa1cbe189a5a549fd10d087b Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Wed, 4 Oct 2023 12:10:15 +0200 Subject: [PATCH 2/3] Add `average` to curve metrics (#2084) Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- CHANGELOG.md | 2 +- docs/source/links.rst | 1 + src/torchmetrics/classification/auroc.py | 6 +- .../classification/average_precision.py | 6 +- .../classification/precision_recall_curve.py | 25 +++++- src/torchmetrics/classification/roc.py | 15 +++- .../classification/precision_recall_curve.py | 83 +++++++++++++++---- .../functional/classification/roc.py | 58 ++++++++++--- src/torchmetrics/utilities/compute.py | 26 ++++++ .../test_precision_recall_curve.py | 19 +++++ tests/unittests/classification/test_roc.py | 15 ++++ 11 files changed, 220 insertions(+), 36 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2600ee1526e..ff0dd083d3e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- +- Added `average` argument to multiclass versions of `PrecisionRecallCurve` and `ROC` ([#2084](https://github.com/Lightning-AI/torchmetrics/pull/2084)) ### Changed diff --git a/docs/source/links.rst b/docs/source/links.rst index 3685eb8ae1a..8c8dc593fe3 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -164,3 +164,4 @@ .. _Completeness Score: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.completeness_score.html .. _Davies-Bouldin Score: https://en.wikipedia.org/wiki/Davies%E2%80%93Bouldin_index .. _Fowlkes-Mallows Index: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.fowlkes_mallows_score.html#sklearn.metrics.fowlkes_mallows_score +.. _averaging curve objects: https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html diff --git a/src/torchmetrics/classification/auroc.py b/src/torchmetrics/classification/auroc.py index e8973ca226b..428c961ec1d 100644 --- a/src/torchmetrics/classification/auroc.py +++ b/src/torchmetrics/classification/auroc.py @@ -266,13 +266,15 @@ def __init__( ) if validate_args: _multiclass_auroc_arg_validation(num_classes, average, thresholds, ignore_index) - self.average = average + self.average = average # type: ignore[assignment] self.validate_args = validate_args def compute(self) -> Tensor: # type: ignore[override] """Compute metric.""" state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat - return _multiclass_auroc_compute(state, self.num_classes, self.average, self.thresholds) + return _multiclass_auroc_compute( + state, self.num_classes, self.average, self.thresholds # type: ignore[arg-type] + ) def plot( # type: ignore[override] self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None diff --git a/src/torchmetrics/classification/average_precision.py b/src/torchmetrics/classification/average_precision.py index ae26d212a51..569c00b73d8 100644 --- a/src/torchmetrics/classification/average_precision.py +++ b/src/torchmetrics/classification/average_precision.py @@ -264,13 +264,15 @@ def __init__( ) if validate_args: _multiclass_average_precision_arg_validation(num_classes, average, thresholds, ignore_index) - self.average = average + self.average = average # type: ignore[assignment] self.validate_args = validate_args def compute(self) -> Tensor: # type: ignore[override] """Compute metric.""" state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat - return _multiclass_average_precision_compute(state, self.num_classes, self.average, self.thresholds) + return _multiclass_average_precision_compute( + state, self.num_classes, self.average, self.thresholds # type: ignore[arg-type] + ) def plot( # type: ignore[override] self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None diff --git a/src/torchmetrics/classification/precision_recall_curve.py b/src/torchmetrics/classification/precision_recall_curve.py index 25f4025f52f..9996cfd683e 100644 --- a/src/torchmetrics/classification/precision_recall_curve.py +++ b/src/torchmetrics/classification/precision_recall_curve.py @@ -103,6 +103,8 @@ class BinaryPrecisionRecallCurve(Metric): - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as bins for the calculation. + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. @@ -266,6 +268,15 @@ class MulticlassPrecisionRecallCurve(Metric): - If set to a 1D `tensor` of floats, will use the indicated thresholds in the tensor as bins for the calculation. + average: + If aggregation of curves should be applied. By default, the curves are not aggregated and a curve for + each class is returned. If `average` is set to ``"micro"``, the metric will aggregate the curves by one hot + encoding the targets and flattening the predictions, considering all classes jointly as a binary problem. + If `average` is set to ``"macro"``, the metric will aggregate the curves by first interpolating the curves + from each class at a combined set of thresholds and then average over the classwise interpolated curves. + See `averaging curve objects`_ for more info on the different averaging methods. + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. @@ -314,15 +325,17 @@ def __init__( self, num_classes: int, thresholds: Optional[Union[int, List[float], Tensor]] = None, + average: Optional[Literal["micro", "macro"]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, **kwargs: Any, ) -> None: super().__init__(**kwargs) if validate_args: - _multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index) + _multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index, average) self.num_classes = num_classes + self.average = average self.ignore_index = ignore_index self.validate_args = validate_args @@ -344,9 +357,11 @@ def update(self, preds: Tensor, target: Tensor) -> None: if self.validate_args: _multiclass_precision_recall_curve_tensor_validation(preds, target, self.num_classes, self.ignore_index) preds, target, _ = _multiclass_precision_recall_curve_format( - preds, target, self.num_classes, self.thresholds, self.ignore_index + preds, target, self.num_classes, self.thresholds, self.ignore_index, self.average + ) + state = _multiclass_precision_recall_curve_update( + preds, target, self.num_classes, self.thresholds, self.average ) - state = _multiclass_precision_recall_curve_update(preds, target, self.num_classes, self.thresholds) if isinstance(state, Tensor): self.confmat += state else: @@ -356,7 +371,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: """Compute metric.""" state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat - return _multiclass_precision_recall_curve_compute(state, self.num_classes, self.thresholds) + return _multiclass_precision_recall_curve_compute(state, self.num_classes, self.thresholds, self.average) def plot( self, @@ -456,6 +471,8 @@ class MultilabelPrecisionRecallCurve(Metric): - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as bins for the calculation. + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. diff --git a/src/torchmetrics/classification/roc.py b/src/torchmetrics/classification/roc.py index a76c3ebc02a..a391cd2046f 100644 --- a/src/torchmetrics/classification/roc.py +++ b/src/torchmetrics/classification/roc.py @@ -87,6 +87,8 @@ class BinaryROC(BinaryPrecisionRecallCurve): - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as bins for the calculation. + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. @@ -229,6 +231,15 @@ class MulticlassROC(MulticlassPrecisionRecallCurve): - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as bins for the calculation. + average: + If aggregation of curves should be applied. By default, the curves are not aggregated and a curve for + each class is returned. If `average` is set to ``"micro"``, the metric will aggregate the curves by one hot + encoding the targets and flattening the predictions, considering all classes jointly as a binary problem. + If `average` is set to ``"macro"``, the metric will aggregate the curves by first interpolating the curves + from each class at a combined set of thresholds and then average over the classwise interpolated curves. + See `averaging curve objects`_ for more info on the different averaging methods. + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. @@ -276,7 +287,7 @@ class MulticlassROC(MulticlassPrecisionRecallCurve): def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: """Compute metric.""" state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat - return _multiclass_roc_compute(state, self.num_classes, self.thresholds) + return _multiclass_roc_compute(state, self.num_classes, self.thresholds, self.average) def plot( self, @@ -381,6 +392,8 @@ class MultilabelROC(MultilabelPrecisionRecallCurve): - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as bins for the calculation. + ignore_index: + Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. Set to ``False`` for faster computations. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. diff --git a/src/torchmetrics/functional/classification/precision_recall_curve.py b/src/torchmetrics/functional/classification/precision_recall_curve.py index c9b576703fe..64958267737 100644 --- a/src/torchmetrics/functional/classification/precision_recall_curve.py +++ b/src/torchmetrics/functional/classification/precision_recall_curve.py @@ -20,7 +20,7 @@ from typing_extensions import Literal from torchmetrics.utilities.checks import _check_same_shape -from torchmetrics.utilities.compute import _safe_divide +from torchmetrics.utilities.compute import _safe_divide, interp from torchmetrics.utilities.data import _bincount, _cumsum from torchmetrics.utilities.enums import ClassificationTask @@ -363,6 +363,7 @@ def _multiclass_precision_recall_curve_arg_validation( num_classes: int, thresholds: Optional[Union[int, List[float], Tensor]] = None, ignore_index: Optional[int] = None, + average: Optional[Literal["micro", "macro"]] = None, ) -> None: """Validate non tensor input. @@ -373,6 +374,8 @@ def _multiclass_precision_recall_curve_arg_validation( """ if not isinstance(num_classes, int) or num_classes < 2: raise ValueError(f"Expected argument `num_classes` to be an integer larger than 1, but got {num_classes}") + if average not in (None, "micro", "macro"): + raise ValueError(f"Expected argument `average` to be one of None, 'micro' or 'macro', but got {average}") _binary_precision_recall_curve_arg_validation(thresholds, ignore_index) @@ -423,6 +426,7 @@ def _multiclass_precision_recall_curve_format( num_classes: int, thresholds: Optional[Union[int, List[float], Tensor]] = None, ignore_index: Optional[int] = None, + average: Optional[Literal["micro", "macro"]] = None, ) -> Tuple[Tensor, Tensor, Optional[Tensor]]: """Convert all input to the right format. @@ -443,6 +447,10 @@ def _multiclass_precision_recall_curve_format( if not torch.all((preds >= 0) * (preds <= 1)): preds = preds.softmax(1) + if average == "micro": + preds = preds.flatten() + target = torch.nn.functional.one_hot(target, num_classes=num_classes).flatten() + thresholds = _adjust_threshold_arg(thresholds, preds.device) return preds, target, thresholds @@ -452,6 +460,7 @@ def _multiclass_precision_recall_curve_update( target: Tensor, num_classes: int, thresholds: Optional[Tensor], + average: Optional[Literal["micro", "macro"]] = None, ) -> Union[Tensor, Tuple[Tensor, Tensor]]: """Return the state to calculate the pr-curve with. @@ -461,6 +470,8 @@ def _multiclass_precision_recall_curve_update( """ if thresholds is None: return preds, target + if average == "micro": + return _binary_precision_recall_curve_update(preds, target, thresholds) if preds.numel() * num_classes <= 1_000_000: update_fn = _multiclass_precision_recall_curve_update_vectorized else: @@ -520,6 +531,7 @@ def _multiclass_precision_recall_curve_compute( state: Union[Tensor, Tuple[Tensor, Tensor]], num_classes: int, thresholds: Optional[Tensor], + average: Optional[Literal["micro", "macro"]] = None, ) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: """Compute the final pr-curve. @@ -527,6 +539,9 @@ def _multiclass_precision_recall_curve_compute( original input, then we dynamically compute the binary classification curve in an iterative way. """ + if average == "micro": + return _binary_precision_recall_curve_compute(state, thresholds) + if isinstance(state, Tensor) and thresholds is not None: tps = state[:, :, 1, 1] fps = state[:, :, 0, 1] @@ -535,15 +550,37 @@ def _multiclass_precision_recall_curve_compute( recall = _safe_divide(tps, tps + fns) precision = torch.cat([precision, torch.ones(1, num_classes, dtype=precision.dtype, device=precision.device)]) recall = torch.cat([recall, torch.zeros(1, num_classes, dtype=recall.dtype, device=recall.device)]) - return precision.T, recall.T, thresholds - - precision_list, recall_list, threshold_list = [], [], [] - for i in range(num_classes): - res = _binary_precision_recall_curve_compute((state[0][:, i], state[1]), thresholds=None, pos_label=i) - precision_list.append(res[0]) - recall_list.append(res[1]) - threshold_list.append(res[2]) - return precision_list, recall_list, threshold_list + precision = precision.T + recall = recall.T + thres = thresholds + tensor_state = True + else: + precision_list, recall_list, thres_list = [], [], [] + for i in range(num_classes): + res = _binary_precision_recall_curve_compute((state[0][:, i], state[1]), thresholds=None, pos_label=i) + precision_list.append(res[0]) + recall_list.append(res[1]) + thres_list.append(res[2]) + tensor_state = False + + if average == "macro": + thres = thres.repeat(num_classes) if tensor_state else torch.cat(thres_list, 0) + thres = thres.sort().values + mean_precision = precision.flatten() if tensor_state else torch.cat(precision_list, 0) + mean_precision = mean_precision.sort().values + mean_recall = torch.zeros_like(mean_precision) + for i in range(num_classes): + mean_recall += interp( + mean_precision, + precision[i] if tensor_state else precision_list[i], + recall[i] if tensor_state else recall_list[i], + ) + mean_recall /= num_classes + return mean_precision, mean_recall, thres + + if tensor_state: + return precision, recall, thres + return precision_list, recall_list, thres_list def multiclass_precision_recall_curve( @@ -551,6 +588,7 @@ def multiclass_precision_recall_curve( target: Tensor, num_classes: int, thresholds: Optional[Union[int, List[float], Tensor]] = None, + average: Optional[Literal["micro", "macro"]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: @@ -590,6 +628,13 @@ def multiclass_precision_recall_curve( - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as bins for the calculation. + average: + If aggregation of curves should be applied. By default, the curves are not aggregated and a curve for + each class is returned. If `average` is set to ``"micro"``, the metric will aggregate the curves by one hot + encoding the targets and flattening the predictions, considering all classes jointly as a binary problem. + If `average` is set to ``"macro"``, the metric will aggregate the curves by first interpolating the curves + from each class at a combined set of thresholds and then average over the classwise interpolated curves. + See `averaging curve objects`_ for more info on the different averaging methods. ignore_index: Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. @@ -643,13 +688,18 @@ def multiclass_precision_recall_curve( """ if validate_args: - _multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index) + _multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index, average) _multiclass_precision_recall_curve_tensor_validation(preds, target, num_classes, ignore_index) preds, target, thresholds = _multiclass_precision_recall_curve_format( - preds, target, num_classes, thresholds, ignore_index + preds, + target, + num_classes, + thresholds, + ignore_index, + average, ) - state = _multiclass_precision_recall_curve_update(preds, target, num_classes, thresholds) - return _multiclass_precision_recall_curve_compute(state, num_classes, thresholds) + state = _multiclass_precision_recall_curve_update(preds, target, num_classes, thresholds, average) + return _multiclass_precision_recall_curve_compute(state, num_classes, thresholds, average) def _multilabel_precision_recall_curve_arg_validation( @@ -892,6 +942,7 @@ def precision_recall_curve( thresholds: Optional[Union[int, List[float], Tensor]] = None, num_classes: Optional[int] = None, num_labels: Optional[int] = None, + average: Optional[Literal["micro", "macro"]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: @@ -940,7 +991,9 @@ def precision_recall_curve( if task == ClassificationTask.MULTICLASS: if not isinstance(num_classes, int): raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") - return multiclass_precision_recall_curve(preds, target, num_classes, thresholds, ignore_index, validate_args) + return multiclass_precision_recall_curve( + preds, target, num_classes, thresholds, average, ignore_index, validate_args + ) if task == ClassificationTask.MULTILABEL: if not isinstance(num_labels, int): raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") diff --git a/src/torchmetrics/functional/classification/roc.py b/src/torchmetrics/functional/classification/roc.py index 65d2c16dc87..d61b920aa9b 100644 --- a/src/torchmetrics/functional/classification/roc.py +++ b/src/torchmetrics/functional/classification/roc.py @@ -33,7 +33,7 @@ _multilabel_precision_recall_curve_update, ) from torchmetrics.utilities import rank_zero_warn -from torchmetrics.utilities.compute import _safe_divide +from torchmetrics.utilities.compute import _safe_divide, interp from torchmetrics.utilities.enums import ClassificationTask @@ -163,7 +163,11 @@ def _multiclass_roc_compute( state: Union[Tensor, Tuple[Tensor, Tensor]], num_classes: int, thresholds: Optional[Tensor], + average: Optional[Literal["micro", "macro"]] = None, ) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: + if average == "micro": + return _binary_roc_compute(state, thresholds, pos_label=1) + if isinstance(state, Tensor) and thresholds is not None: tps = state[:, :, 1, 1] fps = state[:, :, 0, 1] @@ -172,14 +176,32 @@ def _multiclass_roc_compute( tpr = _safe_divide(tps, tps + fns).flip(0).T fpr = _safe_divide(fps, fps + tns).flip(0).T thres = thresholds.flip(0) + tensor_state = True else: - fpr, tpr, thres = [], [], [] # type: ignore[assignment] + fpr_list, tpr_list, thres_list = [], [], [] for i in range(num_classes): res = _binary_roc_compute((state[0][:, i], state[1]), thresholds=None, pos_label=i) - fpr.append(res[0]) - tpr.append(res[1]) - thres.append(res[2]) - return fpr, tpr, thres + fpr_list.append(res[0]) + tpr_list.append(res[1]) + thres_list.append(res[2]) + tensor_state = False + + if average == "macro": + thres = thres.repeat(num_classes) if tensor_state else torch.cat(thres_list, dim=0) + thres = thres.sort(descending=True).values + mean_fpr = fpr.flatten() if tensor_state else torch.cat(fpr_list, dim=0) + mean_fpr = mean_fpr.sort().values + mean_tpr = torch.zeros_like(mean_fpr) + for i in range(num_classes): + mean_tpr += interp( + mean_fpr, fpr[i] if tensor_state else fpr_list[i], tpr[i] if tensor_state else tpr_list[i] + ) + mean_tpr /= num_classes + return mean_fpr, mean_tpr, thres + + if tensor_state: + return fpr, tpr, thres + return fpr_list, tpr_list, thres_list def multiclass_roc( @@ -187,6 +209,7 @@ def multiclass_roc( target: Tensor, num_classes: int, thresholds: Optional[Union[int, List[float], Tensor]] = None, + average: Optional[Literal["micro", "macro"]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: @@ -229,6 +252,13 @@ def multiclass_roc( - If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as bins for the calculation. + average: + If aggregation of curves should be applied. By default, the curves are not aggregated and a curve for + each class is returned. If `average` is set to ``"micro"``, the metric will aggregate the curves by one hot + encoding the targets and flattening the predictions, considering all classes jointly as a binary problem. + If `average` is set to ``"macro"``, the metric will aggregate the curves by first interpolating the curves + from each class at a combined set of thresholds and then average over the classwise interpolated curves. + See `averaging curve objects`_ for more info on the different averaging methods. ignore_index: Specifies a target value that is ignored and does not contribute to the metric calculation validate_args: bool indicating if input arguments and tensors should be validated for correctness. @@ -282,13 +312,18 @@ def multiclass_roc( """ if validate_args: - _multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index) + _multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index, average) _multiclass_precision_recall_curve_tensor_validation(preds, target, num_classes, ignore_index) preds, target, thresholds = _multiclass_precision_recall_curve_format( - preds, target, num_classes, thresholds, ignore_index + preds, + target, + num_classes, + thresholds, + ignore_index, + average, ) - state = _multiclass_precision_recall_curve_update(preds, target, num_classes, thresholds) - return _multiclass_roc_compute(state, num_classes, thresholds) + state = _multiclass_precision_recall_curve_update(preds, target, num_classes, thresholds, average) + return _multiclass_roc_compute(state, num_classes, thresholds, average) def _multilabel_roc_compute( @@ -440,6 +475,7 @@ def roc( thresholds: Optional[Union[int, List[float], Tensor]] = None, num_classes: Optional[int] = None, num_labels: Optional[int] = None, + average: Optional[Literal["micro", "macro"]] = None, ignore_index: Optional[int] = None, validate_args: bool = True, ) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]: @@ -506,7 +542,7 @@ def roc( if task == ClassificationTask.MULTICLASS: if not isinstance(num_classes, int): raise ValueError(f"`num_classes` is expected to be `int` but `{type(num_classes)} was passed.`") - return multiclass_roc(preds, target, num_classes, thresholds, ignore_index, validate_args) + return multiclass_roc(preds, target, num_classes, thresholds, average, ignore_index, validate_args) if task == ClassificationTask.MULTILABEL: if not isinstance(num_labels, int): raise ValueError(f"`num_labels` is expected to be `int` but `{type(num_labels)} was passed.`") diff --git a/src/torchmetrics/utilities/compute.py b/src/torchmetrics/utilities/compute.py index 0ae6de91cbd..c8cb48a8cdb 100644 --- a/src/torchmetrics/utilities/compute.py +++ b/src/torchmetrics/utilities/compute.py @@ -131,3 +131,29 @@ def auc(x: Tensor, y: Tensor, reorder: bool = False) -> Tensor: """ x, y = _auc_format_inputs(x, y) return _auc_compute(x, y, reorder=reorder) + + +def interp(x: Tensor, xp: Tensor, fp: Tensor) -> Tensor: + """One-dimensional linear interpolation for monotonically increasing sample points. + + Returns the one-dimensional piecewise linear interpolant to a function with + given discrete data points :math:`(xp, fp)`, evaluated at :math:`x`. + + Adjusted version of this https://github.com/pytorch/pytorch/issues/50334#issuecomment-1000917964 + + Args: + x: the :math:`x`-coordinates at which to evaluate the interpolated values. + xp: the :math:`x`-coordinates of the data points, must be increasing. + fp: the :math:`y`-coordinates of the data points, same length as `xp`. + + Returns: + the interpolated values, same size as `x`. + + """ + m = _safe_divide(fp[1:] - fp[:-1], xp[1:] - xp[:-1]) + b = fp[:-1] - (m * xp[:-1]) + + indices = torch.sum(torch.ge(x[:, None], xp[None, :]), 1) - 1 + indices = torch.clamp(indices, 0, len(m) - 1) + + return m[indices] * x + b[indices] diff --git a/tests/unittests/classification/test_precision_recall_curve.py b/tests/unittests/classification/test_precision_recall_curve.py index 834c5be5161..7167c9711bb 100644 --- a/tests/unittests/classification/test_precision_recall_curve.py +++ b/tests/unittests/classification/test_precision_recall_curve.py @@ -280,6 +280,25 @@ def test_multiclass_error_on_wrong_dtypes(self, inputs): with pytest.raises(ValueError, match="Expected `preds` to be a float tensor, but got.*"): multiclass_precision_recall_curve(preds[0].long(), target[0], num_classes=NUM_CLASSES) + @pytest.mark.parametrize("average", ["macro", "micro"]) + @pytest.mark.parametrize("thresholds", [None, 100]) + def test_multiclass_average(self, inputs, average, thresholds): + """Test that the average argument works as expected.""" + preds, target = inputs + output = multiclass_precision_recall_curve( + preds[0], target[0], num_classes=NUM_CLASSES, thresholds=thresholds, average=average + ) + assert all(isinstance(o, torch.Tensor) for o in output) + none_output = multiclass_precision_recall_curve( + preds[0], target[0], num_classes=NUM_CLASSES, thresholds=thresholds, average=None + ) + if average == "macro": + assert len(output[0]) == len(none_output[0][0]) * NUM_CLASSES + assert len(output[1]) == len(none_output[1][0]) * NUM_CLASSES + assert ( + len(output[2]) == (len(none_output[2][0]) if thresholds is None else len(none_output[2])) * NUM_CLASSES + ) + def _sklearn_precision_recall_curve_multilabel(preds, target, ignore_index=None): precision, recall, thresholds = [], [], [] diff --git a/tests/unittests/classification/test_roc.py b/tests/unittests/classification/test_roc.py index 41f5f1d9253..4829b078ec3 100644 --- a/tests/unittests/classification/test_roc.py +++ b/tests/unittests/classification/test_roc.py @@ -251,6 +251,21 @@ def test_multiclass_roc_threshold_arg(self, inputs, threshold_fn): assert torch.allclose(r1[i], r2[i]) assert torch.allclose(t1[i], t2) + @pytest.mark.parametrize("average", ["macro", "micro"]) + @pytest.mark.parametrize("thresholds", [None, 100]) + def test_multiclass_average(self, inputs, average, thresholds): + """Test that the average argument works as expected.""" + preds, target = inputs + output = multiclass_roc(preds[0], target[0], num_classes=NUM_CLASSES, thresholds=thresholds, average=average) + assert all(isinstance(o, torch.Tensor) for o in output) + none_output = multiclass_roc(preds[0], target[0], num_classes=NUM_CLASSES, thresholds=thresholds, average=None) + if average == "macro": + assert len(output[0]) == len(none_output[0][0]) * NUM_CLASSES + assert len(output[1]) == len(none_output[1][0]) * NUM_CLASSES + assert ( + len(output[2]) == (len(none_output[2][0]) if thresholds is None else len(none_output[2])) * NUM_CLASSES + ) + def _sklearn_roc_multilabel(preds, target, ignore_index=None): fpr, tpr, thresholds = [], [], [] From f9251133f5c39c15cba444757ae5da729af10636 Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Wed, 4 Oct 2023 06:44:00 -0500 Subject: [PATCH 3/3] Replace distutils with packaging (#2137) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- .github/assistant.py | 6 +++--- MANIFEST.in | 2 +- requirements/base.txt | 1 + src/torchmetrics/utilities/imports.py | 4 ++-- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/.github/assistant.py b/.github/assistant.py index c7d19a9c319..f2f20699a5e 100644 --- a/.github/assistant.py +++ b/.github/assistant.py @@ -18,11 +18,11 @@ import re import sys import traceback -from distutils.version import LooseVersion from typing import List, Optional, Tuple, Union import fire import requests +from packaging.version import parse from pkg_resources import parse_requirements _REQUEST_TIMEOUT = 10 @@ -80,8 +80,8 @@ def set_min_torch_by_python(fpath: str = "requirements/base.txt") -> None: with open(fpath) as fp: reqs = parse_requirements(fp.readlines()) pkg_ver = next(p for p in reqs if p.name == "torch") - pt_ver = min([LooseVersion(v[1]) for v in pkg_ver.specs]) - pt_ver = max(LooseVersion(LUT_PYTHON_TORCH[py_ver]), pt_ver) + pt_ver = min([parse(v[1]) for v in pkg_ver.specs]) + pt_ver = max(parse(LUT_PYTHON_TORCH[py_ver]), pt_ver) with open(fpath) as fp: requires = fp.read() requires = re.sub(r"torch>=[\d\.]+", f"torch>={pt_ver}", requires) diff --git a/MANIFEST.in b/MANIFEST.in index 81cf1a457dd..51ef144f59b 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,4 +1,4 @@ -# Manifest syntax https://docs.python.org/2/distutils/sourcedist.html +# Manifest syntax https://packaging.python.org/en/latest/guides/using-manifest-in/ graft wheelhouse recursive-exclude __pycache__ *.py[cod] *.orig diff --git a/requirements/base.txt b/requirements/base.txt index 536c920e6f5..1d957ae399d 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -2,6 +2,7 @@ # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment numpy >1.20.0 +packaging >17.1 torch >=1.8.1, <=2.0.1 typing-extensions; python_version < '3.9' lightning-utilities >=0.8.0, <0.10.0 diff --git a/src/torchmetrics/utilities/imports.py b/src/torchmetrics/utilities/imports.py index e1ffba72a71..924571279f2 100644 --- a/src/torchmetrics/utilities/imports.py +++ b/src/torchmetrics/utilities/imports.py @@ -15,13 +15,13 @@ import operator import shutil import sys -from distutils.version import LooseVersion from typing import Optional from lightning_utilities.core.imports import compare_version, package_available +from packaging.version import Version, parse _PYTHON_VERSION = ".".join(map(str, [sys.version_info.major, sys.version_info.minor, sys.version_info.micro])) -_PYTHON_LOWER_3_8 = LooseVersion(_PYTHON_VERSION) < LooseVersion("3.8") +_PYTHON_LOWER_3_8 = parse(_PYTHON_VERSION) < Version("3.8") _TORCH_LOWER_1_12_DEV: Optional[bool] = compare_version("torch", operator.lt, "1.12.0.dev") _TORCH_GREATER_EQUAL_1_9: Optional[bool] = compare_version("torch", operator.ge, "1.9.0") _TORCH_GREATER_EQUAL_1_10: Optional[bool] = compare_version("torch", operator.ge, "1.10.0")