diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index 7b7211b9d61..b923d6ea262 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -103,7 +103,7 @@ def my_func(param_a: int, param_b: Optional[float] = None) -> str: >>> my_func(1, 2) 3 - .. note:: If you want to add something. + .. hint:: If you want to add something. """ p = param_b if param_b else 0 return str(param_a + p) diff --git a/.github/workflows/ci-checks.yml b/.github/workflows/ci-checks.yml index 3c7c3eb69f1..d7abe9bb569 100644 --- a/.github/workflows/ci-checks.yml +++ b/.github/workflows/ci-checks.yml @@ -13,29 +13,29 @@ concurrency: jobs: check-code: - uses: Lightning-AI/utilities/.github/workflows/check-typing.yml@v0.11.7 + uses: Lightning-AI/utilities/.github/workflows/check-typing.yml@v0.11.8 with: - actions-ref: v0.11.7 + actions-ref: v0.11.8 extra-typing: "typing" check-schema: - uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.11.7 + uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.11.8 check-package: if: github.event.pull_request.draft == false - uses: Lightning-AI/utilities/.github/workflows/check-package.yml@v0.11.7 + uses: Lightning-AI/utilities/.github/workflows/check-package.yml@v0.11.8 with: - actions-ref: v0.11.7 + actions-ref: v0.11.8 artifact-name: dist-packages-${{ github.sha }} import-name: "torchmetrics" testing-matrix: | { - "os": ["ubuntu-22.04", "macos-12", "windows-2022"], + "os": ["ubuntu-22.04", "macos-13", "windows-2022"], "python-version": ["3.8", "3.11"] } check-md-links: - uses: Lightning-AI/utilities/.github/workflows/check-md-links.yml@v0.11.7 + uses: Lightning-AI/utilities/.github/workflows/check-md-links.yml@v0.11.8 with: base-branch: master config-file: ".github/markdown-links-config.json" diff --git a/.github/workflows/ci-integrate.yml b/.github/workflows/ci-integrate.yml index a01bd076cb2..11bbe401f83 100644 --- a/.github/workflows/ci-integrate.yml +++ b/.github/workflows/ci-integrate.yml @@ -26,12 +26,12 @@ jobs: strategy: fail-fast: false matrix: - os: ["ubuntu-22.04", "macOS-12", "windows-2022"] - python-version: ["3.8", "3.10"] + os: ["ubuntu-22.04", "macOS-13", "windows-2022"] + python-version: ["3.9", "3.11"] requires: ["oldest", "latest"] exclude: - - { python-version: "3.10", requires: "oldest" } - - { python-version: "3.10", os: "windows" } # todo: https://discuss.pytorch.org/t/numpy-is-not-available-error/146192 + - { python-version: "3.11", requires: "oldest" } + - { python-version: "3.11", os: "windows" } # todo: https://discuss.pytorch.org/t/numpy-is-not-available-error/146192 include: - { python-version: "3.10", requires: "latest", os: "ubuntu-22.04" } # - { python-version: "3.10", requires: "latest", os: "macOS-14" } # M1 machine # todo: crashing for MPS out of memory @@ -53,6 +53,8 @@ jobs: - name: source cashing uses: ./.github/actions/pull-caches + with: + requires: ${{ matrix.requires }} - name: set oldest if/only for integrations if: matrix.requires == 'oldest' run: python .github/assistant.py set-oldest-versions --req_files='["requirements/_integrate.txt"]' diff --git a/.github/workflows/ci-tests.yml b/.github/workflows/ci-tests.yml index 0f91e84ed6f..f6e618cfb6e 100644 --- a/.github/workflows/ci-tests.yml +++ b/.github/workflows/ci-tests.yml @@ -46,7 +46,7 @@ jobs: - "2.5.0" include: # cover additional python and PT combinations - - { os: "ubuntu-22.04", python-version: "3.8", pytorch-version: "1.13.1" } + - { os: "ubuntu-20.04", python-version: "3.8", pytorch-version: "1.13.1", requires: "oldest" } - { os: "ubuntu-22.04", python-version: "3.10", pytorch-version: "2.0.1" } - { os: "ubuntu-22.04", python-version: "3.10", pytorch-version: "2.2.2" } - { os: "ubuntu-22.04", python-version: "3.11", pytorch-version: "2.3.1" } diff --git a/.github/workflows/clear-cache.yml b/.github/workflows/clear-cache.yml index ecd7c6e3ff3..057827d611c 100644 --- a/.github/workflows/clear-cache.yml +++ b/.github/workflows/clear-cache.yml @@ -23,7 +23,7 @@ on: jobs: cron-clear: if: github.event_name == 'schedule' || github.event_name == 'pull_request' - uses: Lightning-AI/utilities/.github/workflows/cleanup-caches.yml@v0.11.7 + uses: Lightning-AI/utilities/.github/workflows/cleanup-caches.yml@v0.11.8 with: scripts-ref: v0.11.7 dry-run: ${{ github.event_name == 'pull_request' }} @@ -32,9 +32,9 @@ jobs: direct-clear: if: github.event_name == 'workflow_dispatch' || github.event_name == 'pull_request' - uses: Lightning-AI/utilities/.github/workflows/cleanup-caches.yml@v0.11.7 + uses: Lightning-AI/utilities/.github/workflows/cleanup-caches.yml@v0.11.8 with: - scripts-ref: v0.11.7 + scripts-ref: v0.11.8 dry-run: ${{ github.event_name == 'pull_request' }} pattern: ${{ inputs.pattern || 'pypi_wheels' }} # setting str in case of PR / debugging age-days: ${{ fromJSON(inputs.age-days) || 0 }} # setting 0 in case of PR / debugging diff --git a/.github/workflows/docs-build.yml b/.github/workflows/docs-build.yml index be5e39fc09a..dce0f0192a2 100644 --- a/.github/workflows/docs-build.yml +++ b/.github/workflows/docs-build.yml @@ -44,7 +44,6 @@ jobs: - name: source cashing uses: ./.github/actions/pull-caches with: - requires: ${{ matrix.requires }} pytorch-version: ${{ matrix.pytorch-version }} pypi-dir: ${{ env.PYPI_CACHE }} diff --git a/.github/workflows/publish-pkg.yml b/.github/workflows/publish-pkg.yml index 7affbea96c2..f7b3f46997f 100644 --- a/.github/workflows/publish-pkg.yml +++ b/.github/workflows/publish-pkg.yml @@ -67,7 +67,7 @@ jobs: - run: ls -lh dist/ # We do this, since failures on test.pypi aren't that bad - name: Publish to Test PyPI - uses: pypa/gh-action-pypi-publish@v1.10.2 + uses: pypa/gh-action-pypi-publish@v1.11.0 with: user: __token__ password: ${{ secrets.test_pypi_password }} @@ -94,7 +94,7 @@ jobs: path: dist - run: ls -lh dist/ - name: Publish distribution 📦 to PyPI - uses: pypa/gh-action-pypi-publish@v1.10.2 + uses: pypa/gh-action-pypi-publish@v1.11.0 with: user: __token__ password: ${{ secrets.pypi_password }} diff --git a/CHANGELOG.md b/CHANGELOG.md index 8f97f05c16b..636e346940e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,22 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 **Note: we move fast, but still we preserve 0.1 version (one feature release) back compatibility.** +--- + +## [1.5.2] - 2024-11-07 + +### Changed + +- Re-adding `numpy` 2+ support ([#2804](https://github.com/Lightning-AI/torchmetrics/pull/2804)) + +### Fixed + +- Fixed iou scores in detection for either empty predictions/targets leading to wrong scores ([#2805](https://github.com/Lightning-AI/torchmetrics/pull/2805)) +- Fixed `MetricCollection` compatibility with `torch.jit.script` ([#2813](https://github.com/Lightning-AI/torchmetrics/pull/2813)) +- Fixed assert in PIT ([#2811](https://github.com/Lightning-AI/torchmetrics/pull/2811)) +- Pathed `np.Inf` for `numpy` 2.0+ ([#2826](https://github.com/Lightning-AI/torchmetrics/pull/2826)) + + --- ## [1.5.1] - 2024-10-22 diff --git a/docs/source/pages/implement.rst b/docs/source/pages/implement.rst index 1620ce29cd9..5ab044e1be1 100644 --- a/docs/source/pages/implement.rst +++ b/docs/source/pages/implement.rst @@ -257,7 +257,7 @@ and tests gets formatted in the following way: 3. ``new_metric(...)``: essentially wraps the ``_update`` and ``_compute`` private functions into one public function that makes up the functional interface for the metric. - .. note:: + .. hint:: The `functional mean squared error `_ metric is a is a great example of how to divide the logic. @@ -270,9 +270,9 @@ and tests gets formatted in the following way: ``_new_metric_compute(...)`` function in its ``compute``. No logic should really be implemented in the module interface. We do this to not have duplicate code to maintain. - .. note:: - The module `MeanSquaredError `_ - metric that corresponds to the above functional example showcases these steps. + .. note:: + The module `MeanSquaredError `_ + metric that corresponds to the above functional example showcases these steps. 4. Remember to add binding to the different relevant ``__init__`` files. @@ -291,7 +291,7 @@ and tests gets formatted in the following way: so that different combinations of inputs and parameters get tested. 5. (optional) If your metric raises any exception, please add tests that showcase this. - .. note:: + .. hint:: The `test file for MSE `_ metric shows how to implement such tests. diff --git a/docs/source/pages/lightning.rst b/docs/source/pages/lightning.rst index a96196396b8..0fd27a4d474 100644 --- a/docs/source/pages/lightning.rst +++ b/docs/source/pages/lightning.rst @@ -13,7 +13,7 @@ TorchMetrics in PyTorch Lightning TorchMetrics was originally created as part of `PyTorch Lightning `_, a powerful deep learning research framework designed for scaling models without boilerplate. -.. note:: +.. caution:: TorchMetrics always offers compatibility with the last 2 major PyTorch Lightning versions, but we recommend always keeping both frameworks up-to-date for the best experience. @@ -69,9 +69,9 @@ LightningModule `self.log 1.3, <=3.1.0 scikit-learn ==1.2.*; python_version < "3.9" scikit-learn ==1.5.*; python_version > "3.8" # we do not use `> =` because of oldest replcement -cachier ==3.0.1 +cachier ==3.1.2 diff --git a/requirements/audio.txt b/requirements/audio.txt index dcfc5b05740..d970b2373cb 100644 --- a/requirements/audio.txt +++ b/requirements/audio.txt @@ -3,9 +3,10 @@ # this need to be the same as used inside speechmetrics pesq >=0.0.4, <0.0.5 +numpy <2.0 # strict, for compatibility reasons pystoi >=0.4.0, <0.5.0 torchaudio >=0.10.0, <2.6.0 gammatone >=1.0.0, <1.1.0 librosa >=0.9.0, <0.11.0 -onnxruntime >=1.12.0, <1.20 # installing onnxruntime_gpu-gpu failed on macos +onnxruntime >=1.12.0, <1.21 # installing onnxruntime_gpu-gpu failed on macos requests >=2.19.0, <2.33.0 diff --git a/requirements/base.txt b/requirements/base.txt index 35f1b4406c7..9fc7a5e3fa8 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -1,7 +1,7 @@ # NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment -numpy >1.20.0, <2.0 # strict, for compatibility reasons +numpy >1.20.0 packaging >17.1 torch >=1.10.0, <2.6.0 typing-extensions; python_version < '3.9' diff --git a/requirements/multimodal.txt b/requirements/multimodal.txt index 1a034aa3a1f..2f5c6749bd9 100644 --- a/requirements/multimodal.txt +++ b/requirements/multimodal.txt @@ -1,5 +1,5 @@ # NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment -transformers >=4.42.3, <4.46.0 +transformers >=4.42.3, <4.47.0 piq <=0.8.0 diff --git a/requirements/text.txt b/requirements/text.txt index 62007c3e127..208b87ca34c 100644 --- a/requirements/text.txt +++ b/requirements/text.txt @@ -4,7 +4,7 @@ nltk >3.8.1, <=3.9.1 tqdm <4.67.0 regex >=2021.9.24, <=2024.9.11 -transformers >4.4.0, <4.46.0 +transformers >4.4.0, <4.47.0 mecab-python3 >=1.0.6, <1.1.0 ipadic >=1.0.0, <1.1.0 sentencepiece >=0.2.0, <0.3.0 diff --git a/requirements/typing.txt b/requirements/typing.txt index 01c6897fa9c..271a9dfd690 100644 --- a/requirements/typing.txt +++ b/requirements/typing.txt @@ -1,5 +1,5 @@ -mypy ==1.11.2 -torch ==2.5.0 +mypy ==1.13.0 +torch ==2.5.1 types-PyYAML types-emoji diff --git a/src/conftest.py b/src/conftest.py index c988c1784c5..5f4a26123d3 100644 --- a/src/conftest.py +++ b/src/conftest.py @@ -36,5 +36,5 @@ def collect(self) -> GeneratorExit: def pytest_collect_file(parent: Path, path: Path) -> Optional[DoctestModule]: """Collect doctests and add the reset_random_seed fixture.""" if path.ext == ".py": - return DoctestModule.from_parent(parent, fspath=path) + return DoctestModule.from_parent(parent, path=Path(path)) return None diff --git a/src/torchmetrics/__about__.py b/src/torchmetrics/__about__.py index 8ae351f52f5..e4a11de08c5 100644 --- a/src/torchmetrics/__about__.py +++ b/src/torchmetrics/__about__.py @@ -1,4 +1,4 @@ -__version__ = "1.5.1" +__version__ = "1.5.2" __author__ = "Lightning-AI et al." __author_email__ = "name@pytorchlightning.ai" __license__ = "Apache-2.0" diff --git a/src/torchmetrics/__init__.py b/src/torchmetrics/__init__.py index 2fa370cb1c9..15f7be3ae90 100644 --- a/src/torchmetrics/__init__.py +++ b/src/torchmetrics/__init__.py @@ -14,6 +14,13 @@ _PACKAGE_ROOT = os.path.dirname(__file__) _PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT) +if package_available("numpy"): + # compatibility for AttributeError: `np.Inf` was removed in the NumPy 2.0 release. Use `np.inf` instead + import numpy + + numpy.Inf = numpy.inf + + if package_available("PIL"): import PIL diff --git a/src/torchmetrics/audio/dnsmos.py b/src/torchmetrics/audio/dnsmos.py index 74d035a7fd4..406b817eb05 100644 --- a/src/torchmetrics/audio/dnsmos.py +++ b/src/torchmetrics/audio/dnsmos.py @@ -54,11 +54,13 @@ class DeepNoiseSuppressionMeanOpinionScore(Metric): - ``dnsmos`` (:class:`~torch.Tensor`): float tensor of DNSMOS values reduced across the batch with shape ``(...,4)`` indicating [p808_mos, mos_sig, mos_bak, mos_ovr] in the last dim. - .. note:: using this metric requires you to have ``librosa``, ``onnxruntime`` and ``requests`` installed. + .. hint:: + Using this metric requires you to have ``librosa``, ``onnxruntime`` and ``requests`` installed. Install as ``pip install torchmetrics['audio']`` or alternatively `pip install librosa onnxruntime-gpu requests` (if you do not have GPU enabled machine install `onnxruntime` instead of `onnxruntime-gpu`) - .. note:: the ``forward`` and ``compute`` methods in this class return a reduced DNSMOS value + .. caution:: + The ``forward`` and ``compute`` methods in this class return a reduced DNSMOS value for a batch. To obtain the DNSMOS value for each sample, you may use the functional counterpart in :func:`~torchmetrics.functional.audio.dnsmos.deep_noise_suppression_mean_opinion_score`. diff --git a/src/torchmetrics/audio/pesq.py b/src/torchmetrics/audio/pesq.py index dfcf623aa25..ee6a1751359 100644 --- a/src/torchmetrics/audio/pesq.py +++ b/src/torchmetrics/audio/pesq.py @@ -45,12 +45,14 @@ class PerceptualEvaluationSpeechQuality(Metric): - ``pesq`` (:class:`~torch.Tensor`): float tensor of PESQ value reduced across the batch - .. note:: using this metrics requires you to have ``pesq`` install. Either install as ``pip install + .. hint:: + Using this metrics requires you to have ``pesq`` install. Either install as ``pip install torchmetrics[audio]`` or ``pip install pesq``. ``pesq`` will compile with your currently installed version of numpy, meaning that if you upgrade numpy at some point in the future you will most likely have to reinstall ``pesq``. - .. note:: the ``forward`` and ``compute`` methods in this class return a single (reduced) PESQ value + .. caution:: + The ``forward`` and ``compute`` methods in this class return a single (reduced) PESQ value for a batch. To obtain a PESQ value for each sample, you may use the functional counterpart in :func:`~torchmetrics.functional.audio.pesq.perceptual_evaluation_speech_quality`. diff --git a/src/torchmetrics/audio/srmr.py b/src/torchmetrics/audio/srmr.py index 62a882e7ccc..5e3bad85dcd 100644 --- a/src/torchmetrics/audio/srmr.py +++ b/src/torchmetrics/audio/srmr.py @@ -49,11 +49,12 @@ class SpeechReverberationModulationEnergyRatio(Metric): - ``srmr`` (:class:`~torch.Tensor`): float scaler tensor - .. note:: using this metrics requires you to have ``gammatone`` and ``torchaudio`` installed. + .. hint:: + Using this metrics requires you to have ``gammatone`` and ``torchaudio`` installed. Either install as ``pip install torchmetrics[audio]`` or ``pip install torchaudio`` and ``pip install git+https://github.com/detly/gammatone``. - .. note:: + .. attention:: This implementation is experimental, and might not be consistent with the matlab implementation `SRMRToolbox`_, especially the fast implementation. The slow versions, a) fast=False, norm=False, max_cf=128, b) fast=False, norm=True, max_cf=30, have diff --git a/src/torchmetrics/audio/stoi.py b/src/torchmetrics/audio/stoi.py index 253dab3ea38..cea3df3514e 100644 --- a/src/torchmetrics/audio/stoi.py +++ b/src/torchmetrics/audio/stoi.py @@ -50,7 +50,8 @@ class ShortTimeObjectiveIntelligibility(Metric): - ``stoi`` (:class:`~torch.Tensor`): float scalar tensor - .. note:: using this metrics requires you to have ``pystoi`` install. Either install as ``pip install + .. hint:: + Using this metrics requires you to have ``pystoi`` install. Either install as ``pip install torchmetrics[audio]`` or ``pip install pystoi``. Args: diff --git a/src/torchmetrics/classification/calibration_error.py b/src/torchmetrics/classification/calibration_error.py index f1fe5cc17c7..404d9089bbc 100644 --- a/src/torchmetrics/classification/calibration_error.py +++ b/src/torchmetrics/classification/calibration_error.py @@ -214,7 +214,7 @@ class MulticlassCalibrationError(Metric): - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)`` containing ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified). - .. note:: + .. tip:: Additional dimension ``...`` will be flattened into the batch dimension. As output to ``forward`` and ``compute`` the metric returns the following output: diff --git a/src/torchmetrics/classification/cohen_kappa.py b/src/torchmetrics/classification/cohen_kappa.py index 58537f6acfe..3531eb6b106 100644 --- a/src/torchmetrics/classification/cohen_kappa.py +++ b/src/torchmetrics/classification/cohen_kappa.py @@ -50,7 +50,7 @@ class labels. Additionally, we convert to int tensor with thresholding using the value in ``threshold``. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``. - .. note:: + .. tip:: Additional dimension ``...`` will be flattened into the batch dimension. As output to ``forward`` and ``compute`` the metric returns the following output: @@ -175,7 +175,7 @@ class labels. convert probabilities/logits into an int tensor. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``. - .. note:: + .. tip:: Additional dimension ``...`` will be flattened into the batch dimension. As output to ``forward`` and ``compute`` the metric returns the following output: diff --git a/src/torchmetrics/classification/dice.py b/src/torchmetrics/classification/dice.py index 39bc2acabcd..23623b5dffe 100644 --- a/src/torchmetrics/classification/dice.py +++ b/src/torchmetrics/classification/dice.py @@ -75,7 +75,7 @@ class Dice(Metric): - ``'samples'``: Calculate the metric for each sample, and average the metrics across samples (with equal weights for each sample). - .. note:: + .. hint:: What is considered a sample in the multi-dimensional multi-class case depends on the value of ``mdmc_average``. diff --git a/src/torchmetrics/classification/hinge.py b/src/torchmetrics/classification/hinge.py index fbd3cf04fdd..5514f98cccc 100644 --- a/src/torchmetrics/classification/hinge.py +++ b/src/torchmetrics/classification/hinge.py @@ -55,7 +55,7 @@ class BinaryHingeLoss(Metric): ground truth labels, and therefore only contain {0,1} values (except if `ignore_index` is specified). The value 1 always encodes the positive class. - .. note:: + .. tip:: Additional dimension ``...`` will be flattened into the batch dimension. As output to ``forward`` and ``compute`` the metric returns the following output: @@ -189,7 +189,7 @@ class MulticlassHingeLoss(Metric): ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified). - .. note:: + .. tip:: Additional dimension ``...`` will be flattened into the batch dimension. As output to ``forward`` and ``compute`` the metric returns the following output: diff --git a/src/torchmetrics/classification/jaccard.py b/src/torchmetrics/classification/jaccard.py index 0ea04849696..385009d5a6a 100644 --- a/src/torchmetrics/classification/jaccard.py +++ b/src/torchmetrics/classification/jaccard.py @@ -52,7 +52,7 @@ class BinaryJaccardIndex(BinaryConfusionMatrix): Additionally, we convert to int tensor with thresholding using the value in ``threshold``. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``. - .. note:: + .. tip:: Additional dimension ``...`` will be flattened into the batch dimension. As output to ``forward`` and ``compute`` the metric returns the following output: @@ -170,7 +170,7 @@ class MulticlassJaccardIndex(MulticlassConfusionMatrix): probabilities/logits into an int tensor. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)``. - .. note:: + .. tip:: Additional dimension ``...`` will be flattened into the batch dimension. As output to ``forward`` and ``compute`` the metric returns the following output: @@ -307,7 +307,7 @@ class MultilabelJaccardIndex(MultilabelConfusionMatrix): sigmoid per element. Additionally, we convert to int tensor with thresholding using the value in ``threshold``. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, C, ...)`` - .. note:: + .. tip:: Additional dimension ``...`` will be flattened into the batch dimension. As output to ``forward`` and ``compute`` the metric returns the following output: diff --git a/src/torchmetrics/classification/matthews_corrcoef.py b/src/torchmetrics/classification/matthews_corrcoef.py index 8fab20badf5..49de1f03795 100644 --- a/src/torchmetrics/classification/matthews_corrcoef.py +++ b/src/torchmetrics/classification/matthews_corrcoef.py @@ -48,7 +48,7 @@ class BinaryMatthewsCorrCoef(BinaryConfusionMatrix): per element. Additionally, we convert to int tensor with thresholding using the value in ``threshold``. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)`` - .. note:: + .. tip:: Additional dimension ``...`` will be flattened into the batch dimension. As output to ``forward`` and ``compute`` the metric returns the following output: @@ -156,7 +156,7 @@ class MulticlassMatthewsCorrCoef(MulticlassConfusionMatrix): probabilities/logits into an int tensor. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, ...)`` - .. note:: + .. tip:: Additional dimension ``...`` will be flattened into the batch dimension. As output to ``forward`` and ``compute`` the metric returns the following output: @@ -268,7 +268,7 @@ class MultilabelMatthewsCorrCoef(MultilabelConfusionMatrix): per element. Additionally, we convert to int tensor with thresholding using the value in ``threshold``. - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, C, ...)`` - .. note:: + .. tip:: Additional dimension ``...`` will be flattened into the batch dimension. As output to ``forward`` and ``compute`` the metric returns the following output: diff --git a/src/torchmetrics/classification/precision_fixed_recall.py b/src/torchmetrics/classification/precision_fixed_recall.py index 9db713f53c6..c761f9aa8a9 100644 --- a/src/torchmetrics/classification/precision_fixed_recall.py +++ b/src/torchmetrics/classification/precision_fixed_recall.py @@ -60,7 +60,7 @@ class BinaryPrecisionAtFixedRecall(BinaryPrecisionRecallCurve): ground truth labels, and therefore only contain {0,1} values (except if `ignore_index` is specified). The value 1 always encodes the positive class. - .. note:: + .. tip:: Additional dimension ``...`` will be flattened into the batch dimension. As output to ``forward`` and ``compute`` the metric returns the following output: @@ -193,7 +193,7 @@ class MulticlassPrecisionAtFixedRecall(MulticlassPrecisionRecallCurve): ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified). - .. note:: + .. tip:: Additional dimension ``...`` will be flattened into the batch dimension. As output to ``forward`` and ``compute`` the metric returns a tuple of either 2 tensors or 2 lists containing: @@ -338,7 +338,7 @@ class MultilabelPrecisionAtFixedRecall(MultilabelPrecisionRecallCurve): ground truth labels, and therefore only contain {0,1} values (except if `ignore_index` is specified). The value 1 always encodes the positive class. - .. note:: + .. tip:: Additional dimension ``...`` will be flattened into the batch dimension. As output to ``forward`` and ``compute`` the metric returns a tuple of either 2 tensors or 2 lists containing: diff --git a/src/torchmetrics/classification/precision_recall_curve.py b/src/torchmetrics/classification/precision_recall_curve.py index 366f11710d4..d0f9c632c02 100644 --- a/src/torchmetrics/classification/precision_recall_curve.py +++ b/src/torchmetrics/classification/precision_recall_curve.py @@ -67,7 +67,7 @@ class BinaryPrecisionRecallCurve(Metric): ground truth labels, and therefore only contain {0,1} values (except if `ignore_index` is specified). The value 1 always encodes the positive class. - .. note:: + .. tip:: Additional dimension ``...`` will be flattened into the batch dimension. As output to ``forward`` and ``compute`` the metric returns the following output: @@ -244,7 +244,7 @@ class MulticlassPrecisionRecallCurve(Metric): ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified). - .. note:: + .. tip:: Additional dimension ``...`` will be flattened into the batch dimension. As output to ``forward`` and ``compute`` the metric returns the following output: @@ -441,7 +441,7 @@ class MultilabelPrecisionRecallCurve(Metric): - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, C, ...)``. Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if `ignore_index` is specified). - .. note:: + .. tip:: Additional dimension ``...`` will be flattened into the batch dimension. As output to ``forward`` and ``compute`` the metric returns the following a tuple of either 3 tensors or diff --git a/src/torchmetrics/classification/ranking.py b/src/torchmetrics/classification/ranking.py index 4d2df8a9151..9dda737030d 100644 --- a/src/torchmetrics/classification/ranking.py +++ b/src/torchmetrics/classification/ranking.py @@ -51,7 +51,7 @@ class MultilabelCoverageError(Metric): - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, C, ...)``. Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if `ignore_index` is specified). - .. note:: + .. tip:: Additional dimension ``...`` will be flattened into the batch dimension. As output to ``forward`` and ``compute`` the metric returns the following output: @@ -171,7 +171,7 @@ class MultilabelRankingAveragePrecision(Metric): - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, C, ...)``. Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if `ignore_index` is specified). - .. note:: + .. tip:: Additional dimension ``...`` will be flattened into the batch dimension. As output to ``forward`` and ``compute`` the metric returns the following output: @@ -291,7 +291,7 @@ class MultilabelRankingLoss(Metric): - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, C, ...)``. Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if `ignore_index` is specified). - .. note:: + .. tip:: Additional dimension ``...`` will be flattened into the batch dimension. As output to ``forward`` and ``compute`` the metric returns the following output: diff --git a/src/torchmetrics/classification/recall_fixed_precision.py b/src/torchmetrics/classification/recall_fixed_precision.py index 76880f2c3da..58a460b7b2e 100644 --- a/src/torchmetrics/classification/recall_fixed_precision.py +++ b/src/torchmetrics/classification/recall_fixed_precision.py @@ -59,7 +59,7 @@ class BinaryRecallAtFixedPrecision(BinaryPrecisionRecallCurve): ground truth labels, and therefore only contain {0,1} values (except if `ignore_index` is specified). The value 1 always encodes the positive class. - .. note:: + .. tip:: Additional dimension ``...`` will be flattened into the batch dimension. As output to ``forward`` and ``compute`` the metric returns the following output: @@ -194,7 +194,7 @@ class MulticlassRecallAtFixedPrecision(MulticlassPrecisionRecallCurve): ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified). - .. note:: + .. tip:: Additional dimension ``...`` will be flattened into the batch dimension. As output to ``forward`` and ``compute`` the metric returns a tuple of either 2 tensors or 2 lists containing: @@ -337,7 +337,7 @@ class MultilabelRecallAtFixedPrecision(MultilabelPrecisionRecallCurve): ground truth labels, and therefore only contain {0,1} values (except if `ignore_index` is specified). The value 1 always encodes the positive class. - .. note:: + .. tip:: Additional dimension ``...`` will be flattened into the batch dimension. As output to ``forward`` and ``compute`` the metric returns a tuple of either 2 tensors or 2 lists containing: diff --git a/src/torchmetrics/classification/roc.py b/src/torchmetrics/classification/roc.py index 2b69df1dac2..fc378e53ee0 100644 --- a/src/torchmetrics/classification/roc.py +++ b/src/torchmetrics/classification/roc.py @@ -54,7 +54,7 @@ class BinaryROC(BinaryPrecisionRecallCurve): ground truth labels, and therefore only contain {0,1} values (except if `ignore_index` is specified). The value 1 always encodes the positive class. - .. note:: + .. tip:: Additional dimension ``...`` will be flattened into the batch dimension. As output to ``forward`` and ``compute`` the metric returns a tuple of 3 tensors containing: @@ -71,8 +71,8 @@ class BinaryROC(BinaryPrecisionRecallCurve): `thresholds` argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size :math:`\mathcal{O}(n_{thresholds})` (constant memory). - .. note:: - The outputted thresholds will be in reversed order to ensure that they corresponds to both fpr and + .. attention:: + The outputted thresholds will be in reversed order to ensure that they correspond to both fpr and tpr which are sorted in reversed order during their calculation, such that they are monotome increasing. Args: @@ -191,7 +191,7 @@ class MulticlassROC(MulticlassPrecisionRecallCurve): ground truth labels, and therefore only contain values in the [0, n_classes-1] range (except if `ignore_index` is specified). - .. note:: + .. tip:: Additional dimension ``...`` will be flattened into the batch dimension. As output to ``forward`` and ``compute`` the metric returns a tuple of either 3 tensors or 3 lists containing @@ -216,8 +216,8 @@ class MulticlassROC(MulticlassPrecisionRecallCurve): `thresholds` argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size :math:`\mathcal{O}(n_{thresholds} \times n_{classes})` (constant memory). - .. note:: - Note that outputted thresholds will be in reversed order to ensure that they corresponds to both fpr + .. attention:: + Note that outputted thresholds will be in reversed order to ensure that they correspond to both fpr and tpr which are sorted in reversed order during their calculation, such that they are monotome increasing. Args: @@ -357,7 +357,7 @@ class MultilabelROC(MultilabelPrecisionRecallCurve): - ``target`` (:class:`~torch.Tensor`): An int tensor of shape ``(N, C, ...)``. Target should be a tensor containing ground truth labels, and therefore only contain {0,1} values (except if `ignore_index` is specified). - .. note:: + .. tip:: Additional dimension ``...`` will be flattened into the batch dimension. As output to ``forward`` and ``compute`` the metric returns a tuple of either 3 tensors or 3 lists containing @@ -382,8 +382,8 @@ class MultilabelROC(MultilabelPrecisionRecallCurve): `thresholds` argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size :math:`\mathcal{O}(n_{thresholds} \times n_{labels})` (constant memory). - .. note:: - The outputted thresholds will be in reversed order to ensure that they corresponds to both fpr and tpr + .. attention:: + The outputted thresholds will be in reversed order to ensure that they correspond to both fpr and tpr which are sorted in reversed order during their calculation, such that they are monotome increasing. Args: diff --git a/src/torchmetrics/collections.py b/src/torchmetrics/collections.py index 9fe0bb40761..b4ec0c4e4cc 100644 --- a/src/torchmetrics/collections.py +++ b/src/torchmetrics/collections.py @@ -14,7 +14,7 @@ # this is just a bypass for this module name collision with built-in one from collections import OrderedDict from copy import deepcopy -from typing import Any, Dict, Hashable, Iterable, Iterator, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, ClassVar, Dict, Hashable, Iterable, Iterator, List, Mapping, Optional, Sequence, Tuple, Union import torch from torch import Tensor @@ -31,6 +31,30 @@ __doctest_skip__ = ["MetricCollection.plot", "MetricCollection.plot_all"] +def _remove_prefix(string: str, prefix: str) -> str: + """Patch for older version with missing method `removeprefix`. + + >>> _remove_prefix("prefix_string", "prefix_") + 'string' + >>> _remove_prefix("not_prefix_string", "prefix_") + 'not_prefix_string' + + """ + return string[len(prefix) :] if string.startswith(prefix) else string + + +def _remove_suffix(string: str, suffix: str) -> str: + """Patch for older version with missing method `removesuffix`. + + >>> _remove_suffix("string_suffix", "_suffix") + 'string' + >>> _remove_suffix("string_suffix_missing", "_suffix") + 'string_suffix_missing' + + """ + return string[: -len(suffix)] if string.endswith(suffix) else string + + class MetricCollection(ModuleDict): """MetricCollection class can be used to chain metrics that have the same call pattern into one single class. @@ -59,7 +83,7 @@ class name as key for the output dict. this argument is ``True`` which enables this feature. Set this argument to `False` for disabling this behaviour. Can also be set to a list of lists of metrics for setting the compute groups yourself. - .. note:: + .. tip:: The compute groups feature can significantly speedup the calculation of metrics under the right conditions. First, the feature is only available when calling the ``update`` method and not when calling ``forward`` method due to the internal logic of ``forward`` preventing this. Secondly, since we compute groups share metric @@ -67,7 +91,7 @@ class name as key for the output dict. reference and a copy of states are instead returned in this case (reference will be reestablished on the next call to ``update``). - .. note:: + .. important:: Metric collections can be nested at initialization (see last example) but the output of the collection will still be a single flatten dictionary combining the prefix and postfix arguments from the nested collection. @@ -168,6 +192,7 @@ class name of the metric: _modules: Dict[str, Metric] # type: ignore[assignment] _groups: Dict[int, List[str]] + __jit_unused_properties__: ClassVar[List[str]] = ["metric_state"] def __init__( self, @@ -558,9 +583,9 @@ def __getitem__(self, key: str, copy_state: bool = True) -> Metric: """ self._compute_groups_create_state_ref(copy_state) if self.prefix: - key = key.removeprefix(self.prefix) + key = _remove_prefix(key, self.prefix) if self.postfix: - key = key.removesuffix(self.postfix) + key = _remove_suffix(key, self.postfix) return self._modules[key] @staticmethod diff --git a/src/torchmetrics/detection/_mean_ap.py b/src/torchmetrics/detection/_mean_ap.py index fd342608360..40b7c038df6 100644 --- a/src/torchmetrics/detection/_mean_ap.py +++ b/src/torchmetrics/detection/_mean_ap.py @@ -202,16 +202,16 @@ class MeanAveragePrecision(Metric): For an example on how to use this metric check the `torchmetrics mAP example`_. - .. note:: - ``map`` score is calculated with @[ IoU=self.iou_thresholds | area=all | max_dets=max_detection_thresholds ]. - Caution: If the initialization parameters are changed, dictionary keys for mAR can change as well. + .. attention:: + The ``map`` score is calculated with @[ IoU=self.iou_thresholds | area=all | max_dets=max_detection_thresholds ] + **Caution:** If the initialization parameters are changed, dictionary keys for mAR can change as well. The default properties are also accessible via fields and will raise an ``AttributeError`` if not available. - .. note:: + .. important:: This metric is following the mAP implementation of `pycocotools`_ a standard implementation for the mAP metric for object detection. - .. note:: + .. hint:: This metric requires you to have `torchvision` version 0.8.0 or newer installed (with corresponding version 1.7.0 of torch or newer). This metric requires `pycocotools` installed when iou_type is `segm`. Please install with ``pip install torchvision`` or @@ -849,9 +849,9 @@ def __calculate_recall_precision_scores( inds = torch.searchsorted(rc, rec_thresholds.to(rc.device), right=False) num_inds = inds.argmax() if inds.max() >= tp_len else num_rec_thrs - inds = inds[:num_inds] # type: ignore[misc] - prec[:num_inds] = pr[inds] # type: ignore[misc] - score[:num_inds] = det_scores_sorted[inds] # type: ignore[misc] + inds = inds[:num_inds] + prec[:num_inds] = pr[inds] + score[:num_inds] = det_scores_sorted[inds] precision[idx, :, idx_cls, idx_bbox_area, idx_max_det_thresholds] = prec scores[idx, :, idx_cls, idx_bbox_area, idx_max_det_thresholds] = score diff --git a/src/torchmetrics/detection/iou.py b/src/torchmetrics/detection/iou.py index e809b27ce6a..f828889153d 100644 --- a/src/torchmetrics/detection/iou.py +++ b/src/torchmetrics/detection/iou.py @@ -182,14 +182,17 @@ def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]] """Update state with predictions and targets.""" _input_validator(preds, target, ignore_score=True) - for p, t in zip(preds, target): - det_boxes = self._get_safe_item_values(p["boxes"]) - gt_boxes = self._get_safe_item_values(t["boxes"]) - self.groundtruth_labels.append(t["labels"]) + for p_i, t_i in zip(preds, target): + det_boxes = self._get_safe_item_values(p_i["boxes"]) + gt_boxes = self._get_safe_item_values(t_i["boxes"]) + self.groundtruth_labels.append(t_i["labels"]) iou_matrix = self._iou_update_fn(det_boxes, gt_boxes, self.iou_threshold, self._invalid_val) # N x M if self.respect_labels: - label_eq = p["labels"].unsqueeze(1) == t["labels"].unsqueeze(0) # N x M + if det_boxes.numel() > 0 and gt_boxes.numel() > 0: + label_eq = p_i["labels"].unsqueeze(1) == t_i["labels"].unsqueeze(0) # N x M + else: + label_eq = torch.eye(iou_matrix.shape[0], dtype=bool, device=iou_matrix.device) # type: ignore[call-overload] iou_matrix[~label_eq] = self._invalid_val self.iou_matrix.append(iou_matrix) diff --git a/src/torchmetrics/detection/mean_ap.py b/src/torchmetrics/detection/mean_ap.py index 948fa554348..070f50bb181 100644 --- a/src/torchmetrics/detection/mean_ap.py +++ b/src/torchmetrics/detection/mean_ap.py @@ -153,14 +153,14 @@ class MeanAveragePrecision(Metric): For an example on how to use this metric check the `torchmetrics mAP example`_. - .. note:: - ``map`` score is calculated with @[ IoU=self.iou_thresholds | area=all | max_dets=max_detection_thresholds ] + .. attention:: + The ``map`` score is calculated with @[ IoU=self.iou_thresholds | area=all | max_dets=max_detection_thresholds ] e.g. the mean average precision for IoU thresholds 0.50, 0.55, 0.60, ..., 0.95 averaged over all classes and all areas and all max detections per image. If the IoU thresholds are changed this value will be calculated with - the new thresholds. Caution: If the initialization parameters are changed, dictionary keys for mAR can change as - well. + the new thresholds. + **Caution:** If the initialization parameters are changed, dictionary keys for mAR can change as well. - .. note:: + .. important:: This metric supports, at the moment, two different backends for the evaluation. The default backend is ``"pycocotools"``, which either require the official `pycocotools`_ implementation or this `fork of pycocotools`_ to be installed. We recommend using the fork as it is better maintained and easily diff --git a/src/torchmetrics/functional/audio/dnsmos.py b/src/torchmetrics/functional/audio/dnsmos.py index 9b0dca883db..8f74f8374c9 100644 --- a/src/torchmetrics/functional/audio/dnsmos.py +++ b/src/torchmetrics/functional/audio/dnsmos.py @@ -194,7 +194,8 @@ def deep_noise_suppression_mean_opinion_score( `DNSMOS P.835 paper `_. - .. note:: using this metric requires you to have ``librosa``, ``onnxruntime`` and ``requests`` installed. Install + .. hint:: + Using this metric requires you to have ``librosa``, ``onnxruntime`` and ``requests`` installed. Install as ``pip install torchmetrics['audio']`` or alternatively ``pip install librosa onnxruntime-gpu requests`` (if you do not have GPU enabled machine install ``onnxruntime`` instead of ``onnxruntime-gpu``) diff --git a/src/torchmetrics/functional/audio/pesq.py b/src/torchmetrics/functional/audio/pesq.py index 516014434bf..04cf1ebace2 100644 --- a/src/torchmetrics/functional/audio/pesq.py +++ b/src/torchmetrics/functional/audio/pesq.py @@ -40,7 +40,8 @@ def perceptual_evaluation_speech_quality( This metric is a wrapper for the `pesq package`_. Note that input will be moved to `cpu` to perform the metric calculation. - .. note:: using this metrics requires you to have ``pesq`` install. Either install as ``pip install + .. hint:: + Usingsing this metrics requires you to have ``pesq`` install. Either install as ``pip install torchmetrics[audio]`` or ``pip install pesq``. Note that ``pesq`` will compile with your currently installed version of numpy, meaning that if you upgrade numpy at some point in the future you will most likely have to reinstall ``pesq``. diff --git a/src/torchmetrics/functional/audio/pit.py b/src/torchmetrics/functional/audio/pit.py index 90c4afb01f9..a7cc72d48f7 100644 --- a/src/torchmetrics/functional/audio/pit.py +++ b/src/torchmetrics/functional/audio/pit.py @@ -161,7 +161,7 @@ def permutation_invariant_training( if eval_func not in ["max", "min"]: raise ValueError(f'eval_func can only be "max" or "min" but got {eval_func}') if mode not in ["speaker-wise", "permutation-wise"]: - raise ValueError(f'mode can only be "speaker-wise" or "permutation-wise" but got {eval_func}') + raise ValueError(f'mode can only be "speaker-wise" or "permutation-wise" but got {mode}') if target.ndim < 2: raise ValueError(f"Inputs must be of shape [batch, spk, ...], got {target.shape} and {preds.shape} instead") @@ -182,7 +182,7 @@ def permutation_invariant_training( metric_of_ps = metric_func(ppreds, ptarget) metric_of_ps = torch.mean(metric_of_ps.reshape(batch_size, len(perms), -1), dim=-1) # find the best metric and best permutation - best_metric, best_indexes = eval_op(metric_of_ps, dim=1) # type: ignore[call-overload] + best_metric, best_indexes = eval_op(metric_of_ps, dim=1) best_indexes = best_indexes.detach() best_perm = perms[best_indexes, :] return best_metric, best_perm diff --git a/src/torchmetrics/functional/audio/srmr.py b/src/torchmetrics/functional/audio/srmr.py index 79f710a6a31..5ca00de55f3 100644 --- a/src/torchmetrics/functional/audio/srmr.py +++ b/src/torchmetrics/functional/audio/srmr.py @@ -203,11 +203,12 @@ def speech_reverberation_modulation_energy_ratio( Note: this argument is inherited from `SRMRpy`_. As the translated code is based to pytorch, setting `fast=True` may slow down the speed for calculating this metric on GPU. - .. note:: using this metrics requires you to have ``gammatone`` and ``torchaudio`` installed. + .. hint:: + Usingsing this metrics requires you to have ``gammatone`` and ``torchaudio`` installed. Either install as ``pip install torchmetrics[audio]`` or ``pip install torchaudio`` and ``pip install git+https://github.com/detly/gammatone``. - .. note:: + .. attention:: This implementation is experimental, and might not be consistent with the matlab implementation `SRMRToolbox`_, especially the fast implementation. The slow versions, a) fast=False, norm=False, max_cf=128, b) fast=False, norm=True, max_cf=30, have diff --git a/src/torchmetrics/functional/audio/stoi.py b/src/torchmetrics/functional/audio/stoi.py index 48e9e78510b..e029b721ef8 100644 --- a/src/torchmetrics/functional/audio/stoi.py +++ b/src/torchmetrics/functional/audio/stoi.py @@ -39,7 +39,8 @@ def short_time_objective_intelligibility( calculations on CPU, all input will automatically be moved to CPU to perform the metric calculation before being moved back to the original device. - .. note:: using this metrics requires you to have ``pystoi`` install. Either install as ``pip install + .. hint:: + Usingsing this metrics requires you to have ``pystoi`` install. Either install as ``pip install torchmetrics[audio]`` or ``pip install pystoi`` Args: diff --git a/src/torchmetrics/functional/classification/dice.py b/src/torchmetrics/functional/classification/dice.py index 49d66ea9361..78f58fafbb4 100644 --- a/src/torchmetrics/functional/classification/dice.py +++ b/src/torchmetrics/functional/classification/dice.py @@ -106,10 +106,12 @@ def dice( - ``'samples'``: Calculate the metric for each sample, and average the metrics across samples (with equal weights for each sample). - .. note:: What is considered a sample in the multi-dimensional multi-class case + .. tip:: + What is considered a sample in the multi-dimensional multi-class case depends on the value of ``mdmc_average``. - .. note:: If ``'none'`` and a given class doesn't occur in the ``preds`` or ``target``, + .. hint:: + If ``'none'`` and a given class doesn't occur in the ``preds`` or ``target``, the value for the class will be ``nan``. mdmc_average: diff --git a/src/torchmetrics/functional/detection/ciou.py b/src/torchmetrics/functional/detection/ciou.py index 9669029ba73..55560241aec 100644 --- a/src/torchmetrics/functional/detection/ciou.py +++ b/src/torchmetrics/functional/detection/ciou.py @@ -31,6 +31,11 @@ def _ciou_update( from torchvision.ops import complete_box_iou + if preds.numel() == 0: # if no boxes are predicted + return torch.zeros(target.shape[0], target.shape[0], device=target.device, dtype=torch.float32) + if target.numel() == 0: # if no boxes are true + return torch.zeros(preds.shape[0], preds.shape[0], device=preds.device, dtype=torch.float32) + iou = complete_box_iou(preds, target) if iou_threshold is not None: iou[iou < iou_threshold] = replacement_val diff --git a/src/torchmetrics/functional/detection/diou.py b/src/torchmetrics/functional/detection/diou.py index 13fb0071fed..42492675c98 100644 --- a/src/torchmetrics/functional/detection/diou.py +++ b/src/torchmetrics/functional/detection/diou.py @@ -31,6 +31,11 @@ def _diou_update( from torchvision.ops import distance_box_iou + if preds.numel() == 0: # if no boxes are predicted + return torch.zeros(target.shape[0], target.shape[0], device=target.device, dtype=torch.float32) + if target.numel() == 0: # if no boxes are true + return torch.zeros(preds.shape[0], preds.shape[0], device=preds.device, dtype=torch.float32) + iou = distance_box_iou(preds, target) if iou_threshold is not None: iou[iou < iou_threshold] = replacement_val diff --git a/src/torchmetrics/functional/detection/giou.py b/src/torchmetrics/functional/detection/giou.py index cc39f813b41..5fd8f873286 100644 --- a/src/torchmetrics/functional/detection/giou.py +++ b/src/torchmetrics/functional/detection/giou.py @@ -31,6 +31,11 @@ def _giou_update( from torchvision.ops import generalized_box_iou + if preds.numel() == 0: # if no boxes are predicted + return torch.zeros(target.shape[0], target.shape[0], device=target.device, dtype=torch.float32) + if target.numel() == 0: # if no boxes are true + return torch.zeros(preds.shape[0], preds.shape[0], device=preds.device, dtype=torch.float32) + iou = generalized_box_iou(preds, target) if iou_threshold is not None: iou[iou < iou_threshold] = replacement_val diff --git a/src/torchmetrics/functional/detection/iou.py b/src/torchmetrics/functional/detection/iou.py index 3d3cef26bb2..6558af5cd37 100644 --- a/src/torchmetrics/functional/detection/iou.py +++ b/src/torchmetrics/functional/detection/iou.py @@ -32,6 +32,11 @@ def _iou_update( from torchvision.ops import box_iou + if preds.numel() == 0: # if no boxes are predicted + return torch.zeros(target.shape[0], target.shape[0], device=target.device, dtype=torch.float32) + if target.numel() == 0: # if no boxes are true + return torch.zeros(preds.shape[0], preds.shape[0], device=preds.device, dtype=torch.float32) + iou = box_iou(preds, target) if iou_threshold is not None: iou[iou < iou_threshold] = replacement_val diff --git a/src/torchmetrics/functional/image/gradients.py b/src/torchmetrics/functional/image/gradients.py index 03f751f0870..e87d4a75198 100644 --- a/src/torchmetrics/functional/image/gradients.py +++ b/src/torchmetrics/functional/image/gradients.py @@ -70,7 +70,8 @@ def image_gradients(img: Tensor) -> Tuple[Tensor, Tensor]: [5., 5., 5., 5., 5.], [0., 0., 0., 0., 0.]]) - .. note:: The implementation follows the 1-step finite difference method as followed + .. note:: + The implementation follows the 1-step finite difference method as followed by the TF implementation. The values are organized such that the gradient of [I(x+1, y)-[I(x, y)]] are at the (x, y) location diff --git a/src/torchmetrics/functional/image/psnr.py b/src/torchmetrics/functional/image/psnr.py index adb80e2bc77..7bd93ba94e1 100644 --- a/src/torchmetrics/functional/image/psnr.py +++ b/src/torchmetrics/functional/image/psnr.py @@ -134,8 +134,8 @@ def peak_signal_noise_ratio( >>> peak_signal_noise_ratio(pred, target) tensor(2.5527) - .. note:: - Half precision is only support on GPU for this metric + .. attention:: + Half precision is only support on GPU for this metric. """ if dim is None and reduction != "elementwise_mean": diff --git a/src/torchmetrics/functional/multimodal/clip_iqa.py b/src/torchmetrics/functional/multimodal/clip_iqa.py index 49e710e248a..659c667fc52 100644 --- a/src/torchmetrics/functional/multimodal/clip_iqa.py +++ b/src/torchmetrics/functional/multimodal/clip_iqa.py @@ -271,7 +271,8 @@ def clip_image_quality_assessment( available prompts (see above). If tuple is provided, it must be of length 2 and the first string must be a positive prompt and the second string must be a negative prompt. - .. note:: If using the default `clip_iqa` model, the package `piq` must be installed. Either install with + .. hint:: + If using the default `clip_iqa` model, the package `piq` must be installed. Either install with `pip install piq` or `pip install torchmetrics[multimodal]`. Returns: diff --git a/src/torchmetrics/functional/multimodal/clip_score.py b/src/torchmetrics/functional/multimodal/clip_score.py index 920eb6972e6..f70f37d534b 100644 --- a/src/torchmetrics/functional/multimodal/clip_score.py +++ b/src/torchmetrics/functional/multimodal/clip_score.py @@ -135,7 +135,8 @@ def clip_score( textual CLIP embedding :math:`E_C` for an caption :math:`C`. The score is bound between 0 and 100 and the closer to 100 the better. - .. note:: Metric is not scriptable + .. caution:: + Metric is not scriptable Args: images: Either a single [N, C, H, W] tensor or a list of [C, H, W] tensors diff --git a/src/torchmetrics/functional/regression/log_mse.py b/src/torchmetrics/functional/regression/log_mse.py index 96d2938a8ee..3ba27ab9e99 100644 --- a/src/torchmetrics/functional/regression/log_mse.py +++ b/src/torchmetrics/functional/regression/log_mse.py @@ -68,8 +68,8 @@ def mean_squared_log_error(preds: Tensor, target: Tensor) -> Tensor: >>> mean_squared_log_error(x, y) tensor(0.0207) - .. note:: - Half precision is only support on GPU for this metric + .. attention:: + Half precision is only support on GPU for this metric. """ sum_squared_log_error, num_obs = _mean_squared_log_error_update(preds, target) diff --git a/src/torchmetrics/image/fid.py b/src/torchmetrics/image/fid.py index 8c2e5d3cf76..ac559ed0c68 100644 --- a/src/torchmetrics/image/fid.py +++ b/src/torchmetrics/image/fid.py @@ -211,9 +211,8 @@ class FrechetInceptionDistance(Metric): that you calculate using `torch.float64` (default is `torch.float32`) which can be set using the `.set_dtype` method of the metric. - .. note:: using this metrics requires you to have torch 1.9 or higher installed - - .. note:: using this metric with the default feature extractor requires that ``torch-fidelity`` + .. hint:: + Using this metric with the default feature extractor requires that ``torch-fidelity`` is installed. Either install as ``pip install torchmetrics[image]`` or ``pip install torch-fidelity`` As input to ``forward`` and ``update`` the metric accepts the following input diff --git a/src/torchmetrics/image/inception.py b/src/torchmetrics/image/inception.py index 2b125542a17..20d53d10f2b 100644 --- a/src/torchmetrics/image/inception.py +++ b/src/torchmetrics/image/inception.py @@ -49,7 +49,8 @@ class InceptionScore(Metric): ``normalize`` is set to ``False`` images are expected to have dtype uint8 and take values in the ``[0, 255]`` range. All images will be resized to 299 x 299 which is the size of the original training data. - .. note:: using this metric with the default feature extractor requires that ``torch-fidelity`` + .. hint:: + Using this metric with the default feature extractor requires that ``torch-fidelity`` is installed. Either install as ``pip install torchmetrics[image]`` or ``pip install torch-fidelity`` diff --git a/src/torchmetrics/image/kid.py b/src/torchmetrics/image/kid.py index 975edf800eb..018fc7a7511 100644 --- a/src/torchmetrics/image/kid.py +++ b/src/torchmetrics/image/kid.py @@ -97,7 +97,8 @@ class KernelInceptionDistance(Metric): effect and update method expects to have the tensor given to `imgs` argument to be in the correct shape and type that is compatible to the custom feature extractor. - .. note:: using this metric with the default feature extractor requires that ``torch-fidelity`` + .. hint:: + Using this metric with the default feature extractor requires that ``torch-fidelity`` is installed. Either install as ``pip install torchmetrics[image]`` or ``pip install torch-fidelity`` diff --git a/src/torchmetrics/image/lpip.py b/src/torchmetrics/image/lpip.py index f1adb73648a..1893fb734ba 100644 --- a/src/torchmetrics/image/lpip.py +++ b/src/torchmetrics/image/lpip.py @@ -47,12 +47,10 @@ class LearnedPerceptualImagePatchSimilarity(Metric): Both input image patches are expected to have shape ``(N, 3, H, W)``. The minimum size of `H, W` depends on the chosen backbone (see `net_type` arg). - .. note:: using this metrics requires you to have ``torchvision`` package installed. Either install as + .. hint:: + Using this metrics requires you to have ``torchvision`` package installed. Either install as ``pip install torchmetrics[image]`` or ``pip install torchvision``. - .. note:: this metric is not scriptable when using ``torch<1.8``. Please update your pytorch installation - if this is a issue. - As input to ``forward`` and ``update`` the metric accepts the following input - ``img1`` (:class:`~torch.Tensor`): tensor with images of shape ``(N, 3, H, W)`` diff --git a/src/torchmetrics/image/mifid.py b/src/torchmetrics/image/mifid.py index 7d7237b7416..a1e2d2f4c0a 100644 --- a/src/torchmetrics/image/mifid.py +++ b/src/torchmetrics/image/mifid.py @@ -84,10 +84,12 @@ class MemorizationInformedFrechetInceptionDistance(Metric): flag ``real`` determines if the images should update the statistics of the real distribution or the fake distribution. - .. note:: using this metrics requires you to have ``scipy`` install. Either install as ``pip install + .. hint:: + Using this metrics requires you to have ``scipy`` install. Either install as ``pip install torchmetrics[image]`` or ``pip install scipy`` - .. note:: using this metric with the default feature extractor requires that ``torch-fidelity`` + .. hint:: + Using this metric with the default feature extractor requires that ``torch-fidelity`` is installed. Either install as ``pip install torchmetrics[image]`` or ``pip install torch-fidelity`` diff --git a/src/torchmetrics/image/perceptual_path_length.py b/src/torchmetrics/image/perceptual_path_length.py index b2255941250..24f701e5371 100644 --- a/src/torchmetrics/image/perceptual_path_length.py +++ b/src/torchmetrics/image/perceptual_path_length.py @@ -51,7 +51,8 @@ class PerceptualPathLength(Metric): if `conditional=False`, and `forward(z: Tensor, labels: Tensor) -> Tensor` if `conditional=True`. The returned tensor should have shape `(num_samples, C, H, W)` and be scaled to the range [0, 255]. - .. note:: using this metric with the default feature extractor requires that ``torchvision`` is installed. + .. hint:: + Using this metric with the default feature extractor requires that ``torchvision`` is installed. Either install as ``pip install torchmetrics[image]`` or ``pip install torchvision`` As input to ``forward`` and ``update`` the metric accepts the following input diff --git a/src/torchmetrics/image/psnrb.py b/src/torchmetrics/image/psnrb.py index bac58b84e46..c9fb157d6ff 100644 --- a/src/torchmetrics/image/psnrb.py +++ b/src/torchmetrics/image/psnrb.py @@ -34,7 +34,7 @@ class PeakSignalNoiseRatioWithBlockedEffect(Metric): Where :math:`\text{MSE}` denotes the `mean-squared-error`_ function. This metric is a modified version of PSNR that better supports evaluation of images with blocked artifacts, that oftens occur in compressed images. - .. note:: + .. attention:: Metric only supports grayscale images. If you have RGB images, please convert them to grayscale first. As input to ``forward`` and ``update`` the metric accepts the following input diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index 940e393c6d1..a1119cf9d35 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -236,11 +236,11 @@ def add_state( - If the metric state is a ``list``, the synced value will be a ``list`` containing the combined elements from all processes. - Note: + .. important:: When passing a custom function to ``dist_reduce_fx``, expect the synchronized metric state to follow the format discussed in the above note. - Note: + .. caution:: The values inserted into a list state are deleted whenever :meth:`~Metric.reset` is called. This allows device memory to be automatically reallocated, but may produce unexpected effects when referencing list states. To retain such values after :meth:`~Metric.reset` is called, you must first copy them to another diff --git a/src/torchmetrics/multimodal/clip_iqa.py b/src/torchmetrics/multimodal/clip_iqa.py index 7a66de4ae3e..f49113e297f 100644 --- a/src/torchmetrics/multimodal/clip_iqa.py +++ b/src/torchmetrics/multimodal/clip_iqa.py @@ -112,7 +112,8 @@ class CLIPImageQualityAssessment(Metric): positive prompt and the second string must be a negative prompt. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. - .. note:: If using the default `clip_iqa` model, the package `piq` must be installed. Either install with + .. hint:: + If using the default `clip_iqa` model, the package `piq` must be installed. Either install with `pip install piq` or `pip install torchmetrics[image]`. Raises: diff --git a/src/torchmetrics/multimodal/clip_score.py b/src/torchmetrics/multimodal/clip_score.py index f385fbc145d..92ca7ad6b4f 100644 --- a/src/torchmetrics/multimodal/clip_score.py +++ b/src/torchmetrics/multimodal/clip_score.py @@ -54,7 +54,8 @@ class CLIPScore(Metric): textual CLIP embedding :math:`E_C` for an caption :math:`C`. The score is bound between 0 and 100 and the closer to 100 the better. - .. note:: Metric is not scriptable + .. caution:: + Metric is not scriptable As input to ``forward`` and ``update`` the metric accepts the following input diff --git a/src/torchmetrics/regression/kl_divergence.py b/src/torchmetrics/regression/kl_divergence.py index 8c5c41d3d57..368cdef0adf 100644 --- a/src/torchmetrics/regression/kl_divergence.py +++ b/src/torchmetrics/regression/kl_divergence.py @@ -73,8 +73,8 @@ class KLDivergence(Metric): ValueError: If ``reduction`` is not one of ``'mean'``, ``'sum'``, ``'none'`` or ``None``. - .. note:: - Half precision is only support on GPU for this metric + .. attention:: + Half precision is only support on GPU for this metric. Example: >>> from torch import tensor diff --git a/src/torchmetrics/regression/log_mse.py b/src/torchmetrics/regression/log_mse.py index 85e01dba314..da190016844 100644 --- a/src/torchmetrics/regression/log_mse.py +++ b/src/torchmetrics/regression/log_mse.py @@ -52,8 +52,8 @@ class MeanSquaredLogError(Metric): >>> mean_squared_log_error(preds, target) tensor(0.0397) - .. note:: - Half precision is only support on GPU for this metric + .. attention:: + Half precision is only support on GPU for this metric. """ diff --git a/src/torchmetrics/retrieval/base.py b/src/torchmetrics/retrieval/base.py index 5446d8668b1..f9a0a4f8cc4 100644 --- a/src/torchmetrics/retrieval/base.py +++ b/src/torchmetrics/retrieval/base.py @@ -50,10 +50,11 @@ class RetrievalMetric(Metric, ABC): - ``indexes`` (:class:`~torch.Tensor`): A long tensor of shape ``(N, ...)`` which indicate to which query a prediction belongs - .. note:: ``indexes``, ``preds`` and ``target`` must have the same dimension and will be flatten - to single dimension once provided. + .. hint:: + The ``indexes``, ``preds`` and ``target`` must have the same dimension and will be flattened + to single dimension once provided. - .. note:: + .. attention:: Predictions will be first grouped by ``indexes`` and then the real metric, defined by overriding the `_metric` method, will be computed as the mean of the scores over each query. diff --git a/src/torchmetrics/retrieval/precision_recall_curve.py b/src/torchmetrics/retrieval/precision_recall_curve.py index 701e89b1433..9b4f4d38d0e 100644 --- a/src/torchmetrics/retrieval/precision_recall_curve.py +++ b/src/torchmetrics/retrieval/precision_recall_curve.py @@ -303,9 +303,10 @@ class RetrievalRecallAtFixedPrecision(RetrievalPrecisionRecallCurve): - ``indexes`` (:class:`~torch.Tensor`): A long tensor of shape ``(N, ...)`` which indicate to which query a prediction belongs - .. note:: All ``indexes``, ``preds`` and ``target`` must have the same dimension. + .. important:: + All ``indexes``, ``preds`` and ``target`` must have the same dimension. - .. note:: + .. attention:: Predictions will be first grouped by ``indexes`` and then `RetrievalRecallAtFixedPrecision` will be computed as the mean of the `RetrievalRecallAtFixedPrecision` over each query. diff --git a/src/torchmetrics/utilities/compute.py b/src/torchmetrics/utilities/compute.py index ee11a36136f..e526ecc8456 100644 --- a/src/torchmetrics/utilities/compute.py +++ b/src/torchmetrics/utilities/compute.py @@ -53,6 +53,13 @@ def _safe_divide(num: Tensor, denom: Tensor, zero_division: float = 0.0) -> Tens denom: denominator tensor, which may contain zeros zero_division: value to replace elements divided by zero + Example: + >>> import torch + >>> num = torch.tensor([1.0, 2.0, 3.0]) + >>> denom = torch.tensor([0.0, 1.0, 2.0]) + >>> _safe_divide(num, denom) + tensor([0.0000, 2.0000, 1.5000]) + """ num = num if num.is_floating_point() else num.float() denom = denom if denom.is_floating_point() else denom.float() @@ -102,6 +109,16 @@ def _auc_compute_without_check(x: Tensor, y: Tensor, direction: float, axis: int def _auc_compute(x: Tensor, y: Tensor, reorder: bool = False) -> Tensor: + """Compute area under the curve using the trapezoidal rule. + + Example: + >>> import torch + >>> x = torch.tensor([1, 2, 3, 4]) + >>> y = torch.tensor([1, 2, 3, 4]) + >>> _auc_compute(x, y) + tensor(7.5000) + + """ with torch.no_grad(): if reorder: x, x_idx = torch.sort(x, stable=True) @@ -139,7 +156,7 @@ def auc(x: Tensor, y: Tensor, reorder: bool = False) -> Tensor: def interp(x: Tensor, xp: Tensor, fp: Tensor) -> Tensor: """One-dimensional linear interpolation for monotonically increasing sample points. - Returns the one-dimensional piecewise linear interpolant to a function with + Returns the one-dimensional piecewise linear interpolation to a function with given discrete data points :math:`(xp, fp)`, evaluated at :math:`x`. Adjusted version of this https://github.com/pytorch/pytorch/issues/50334#issuecomment-1000917964 @@ -152,6 +169,13 @@ def interp(x: Tensor, xp: Tensor, fp: Tensor) -> Tensor: Returns: the interpolated values, same size as `x`. + Example: + >>> x = torch.tensor([0.5, 1.5, 2.5]) + >>> xp = torch.tensor([1, 2, 3]) + >>> fp = torch.tensor([1, 2, 3]) + >>> interp(x, xp, fp) + tensor([0.5000, 1.5000, 2.5000]) + """ m = _safe_divide(fp[1:] - fp[:-1], xp[1:] - xp[:-1]) b = fp[:-1] - (m * xp[:-1]) diff --git a/src/torchmetrics/wrappers/multitask.py b/src/torchmetrics/wrappers/multitask.py index fa2f04db97d..04ddd87ad71 100644 --- a/src/torchmetrics/wrappers/multitask.py +++ b/src/torchmetrics/wrappers/multitask.py @@ -43,8 +43,8 @@ class MultitaskWrapper(WrapperMetric): postfix: A string to append after the keys of the output dict. If not provided, will default to an empty string. - .. note:: - The use pre prefix and postfix allows for easily creating task wrappers for training, validation and test. + .. tip:: + The use prefix and postfix allows for easily creating task wrappers for training, validation and test. The arguments are only changing the output keys of the computed metrics and not the input keys. This means that a ``MultitaskWrapper`` initialized as ``MultitaskWrapper({"task": Metric()}, prefix="train_")`` will still expect the input to be a dictionary with the key "task", but the output will be a dictionary with the key diff --git a/src/torchmetrics/wrappers/tracker.py b/src/torchmetrics/wrappers/tracker.py index 554902d1092..c1fe9957b1b 100644 --- a/src/torchmetrics/wrappers/tracker.py +++ b/src/torchmetrics/wrappers/tracker.py @@ -219,7 +219,8 @@ def best_metric( ) -> Union[ None, float, - Tuple[float, int], + Tensor, + Tuple[Union[int, float, Tensor], Union[int, float, Tensor]], Tuple[None, None], Dict[str, Union[float, None]], Tuple[Dict[str, Union[float, None]], Dict[str, Union[int, None]]], @@ -260,7 +261,7 @@ def best_metric( if isinstance(self._base_metric, Metric): fn = torch.max if self.maximize else torch.min try: - value, idx = fn(res, 0) # type: ignore[call-overload] + value, idx = fn(res, 0) if return_step: return value.item(), idx.item() return value.item() @@ -277,11 +278,11 @@ def best_metric( else: # this is a metric collection maximize = self.maximize if isinstance(self.maximize, list) else len(res) * [self.maximize] - value, idx = {}, {} + value, idx = {}, {} # type: ignore[assignment] for i, (k, v) in enumerate(res.items()): try: fn = torch.max if maximize[i] else torch.min - out = fn(v, 0) # type: ignore[call-overload] + out = fn(v, 0) value[k], idx[k] = out[0].item(), out[1].item() except (ValueError, RuntimeError) as error: # noqa: PERF203 # todo rank_zero_warn( @@ -290,7 +291,7 @@ def best_metric( "Returning `None` instead.", UserWarning, ) - value[k], idx[k] = None, None + value[k], idx[k] = None, None # type: ignore[assignment] if return_step: return value, idx diff --git a/tests/unittests/_helpers/testers.py b/tests/unittests/_helpers/testers.py index c5a69077f3c..9412fa79d75 100644 --- a/tests/unittests/_helpers/testers.py +++ b/tests/unittests/_helpers/testers.py @@ -685,7 +685,16 @@ def inject_ignore_index(x: Tensor, ignore_index: int) -> Tensor: def remove_ignore_index(target: Tensor, preds: Tensor, ignore_index: Optional[int]) -> Tuple[Tensor, Tensor]: - """Remove samples that are equal to the ignore_index in comparison functions.""" + """Remove samples that are equal to the ignore_index in comparison functions. + + Example: + >>> target = torch.tensor([0, 1, 2, 3, 4]) + >>> preds = torch.tensor([0, 1, 2, 3, 4]) + >>> ignore_index = 2 + >>> remove_ignore_index(target, preds, ignore_index) + (tensor([0, 1, 3, 4]), tensor([0, 1, 3, 4])) + + """ if ignore_index is not None: idx = target == ignore_index target, preds = deepcopy(target[~idx]), deepcopy(preds[~idx]) diff --git a/tests/unittests/bases/test_collections.py b/tests/unittests/bases/test_collections.py index 55062ccbe29..f674e76b376 100644 --- a/tests/unittests/bases/test_collections.py +++ b/tests/unittests/bases/test_collections.py @@ -41,6 +41,15 @@ seed_all(42) +def test_metric_collection_jit_script(): + """Test that the MetricCollection can be scripted and jitted.""" + m1 = DummyMetricSum() + m2 = DummyMetricDiff() + metric_collection = MetricCollection([m1, m2]) + scripted = torch.jit.script(metric_collection) + assert isinstance(scripted, torch.jit.ScriptModule) + + def test_metric_collection(tmpdir): """Test that updating the metric collection is equal to individually updating metrics in the collection.""" m1 = DummyMetricSum() diff --git a/tests/unittests/classification/test_accuracy.py b/tests/unittests/classification/test_accuracy.py index 30d4a473a84..65e42c00b07 100644 --- a/tests/unittests/classification/test_accuracy.py +++ b/tests/unittests/classification/test_accuracy.py @@ -55,14 +55,14 @@ def _reference_sklearn_accuracy_binary(preds, target, ignore_index, multidim_ave preds = (preds >= THRESHOLD).astype(np.uint8) if multidim_average == "global": - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return _reference_sklearn_accuracy(target, preds) res = [] for pred, true in zip(preds, target): pred = pred.flatten() true = true.flatten() - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) res.append(_reference_sklearn_accuracy(true, pred)) return np.stack(res) @@ -185,7 +185,7 @@ def _reference_sklearn_accuracy_multiclass(preds, target, ignore_index, multidim if multidim_average == "global": preds = preds.numpy().flatten() target = target.numpy().flatten() - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) if average == "micro": return _reference_sklearn_accuracy(target, preds) confmat = sk_confusion_matrix(target, preds, labels=list(range(NUM_CLASSES))) @@ -207,7 +207,7 @@ def _reference_sklearn_accuracy_multiclass(preds, target, ignore_index, multidim for pred, true in zip(preds, target): pred = pred.flatten() true = true.flatten() - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) if average == "micro": res.append(_reference_sklearn_accuracy(true, pred)) else: @@ -445,13 +445,13 @@ def _reference_sklearn_accuracy_multilabel(preds, target, ignore_index, multidim if average == "micro": preds = preds.flatten() target = target.flatten() - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return _reference_sklearn_accuracy(target, preds) accuracy, weights = [], [] for i in range(preds.shape[1]): pred, true = preds[:, i].flatten(), target[:, i].flatten() - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) accuracy.append(_reference_sklearn_accuracy(true, pred)) weights.append(confmat[1, 1] + confmat[1, 0]) @@ -472,7 +472,7 @@ def _reference_sklearn_accuracy_multilabel(preds, target, ignore_index, multidim for i in range(preds.shape[0]): if average == "micro": pred, true = preds[i].flatten(), target[i].flatten() - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) accuracy.append(_reference_sklearn_accuracy(true, pred)) confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) weights.append(confmat[1, 1] + confmat[1, 0]) @@ -480,7 +480,7 @@ def _reference_sklearn_accuracy_multilabel(preds, target, ignore_index, multidim scores, w = [], [] for j in range(preds.shape[1]): pred, true = preds[i, j], target[i, j] - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) scores.append(_reference_sklearn_accuracy(true, pred)) confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) w.append(confmat[1, 1] + confmat[1, 0]) diff --git a/tests/unittests/classification/test_auroc.py b/tests/unittests/classification/test_auroc.py index 3691c7305b7..30d4acb470c 100644 --- a/tests/unittests/classification/test_auroc.py +++ b/tests/unittests/classification/test_auroc.py @@ -38,7 +38,7 @@ def _reference_sklearn_auroc_binary(preds, target, max_fpr=None, ignore_index=No target = target.flatten().numpy() if not ((preds > 0) & (preds < 1)).all(): preds = sigmoid(preds) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return sk_roc_auc_score(target, preds, max_fpr=max_fpr) @@ -144,7 +144,7 @@ def _reference_sklearn_auroc_multiclass(preds, target, average="macro", ignore_i target = target.numpy().flatten() if not ((preds > 0) & (preds < 1)).all(): preds = softmax(preds, 1) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return sk_roc_auc_score(target, preds, average=average, multi_class="ovr", labels=list(range(NUM_CLASSES))) diff --git a/tests/unittests/classification/test_average_precision.py b/tests/unittests/classification/test_average_precision.py index 51d40839642..da0dc2f56b6 100644 --- a/tests/unittests/classification/test_average_precision.py +++ b/tests/unittests/classification/test_average_precision.py @@ -47,7 +47,7 @@ def _reference_sklearn_avg_precision_binary(preds, target, ignore_index=None): target = target.flatten().numpy() if np.issubdtype(preds.dtype, np.floating) and not ((preds > 0) & (preds < 1)).all(): preds = sigmoid(preds) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return sk_average_precision_score(target, preds) @@ -156,7 +156,7 @@ def _reference_sklearn_avg_precision_multiclass(preds, target, average="macro", target = target.numpy().flatten() if not ((preds > 0) & (preds < 1)).all(): preds = softmax(preds, 1) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) res = [] for i in range(NUM_CLASSES): diff --git a/tests/unittests/classification/test_calibration_error.py b/tests/unittests/classification/test_calibration_error.py index a2000cc984e..f4fc0703881 100644 --- a/tests/unittests/classification/test_calibration_error.py +++ b/tests/unittests/classification/test_calibration_error.py @@ -44,7 +44,7 @@ def _reference_netcal_binary_calibration_error(preds, target, n_bins, norm, igno target = target.numpy().flatten() if not ((preds > 0) & (preds < 1)).all(): preds = sigmoid(preds) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) metric = ECE if norm == "l1" else MCE return metric(n_bins).measure(preds, target) @@ -154,7 +154,7 @@ def _reference_netcal_multiclass_calibration_error(preds, target, n_bins, norm, if not ((preds > 0) & (preds < 1)).all(): preds = softmax(preds, 1) preds = np.moveaxis(preds, 1, -1).reshape((-1, preds.shape[1])) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) metric = ECE if norm == "l1" else MCE return metric(n_bins).measure(preds, target) diff --git a/tests/unittests/classification/test_cohen_kappa.py b/tests/unittests/classification/test_cohen_kappa.py index 40ebbc028bd..1f2585372bd 100644 --- a/tests/unittests/classification/test_cohen_kappa.py +++ b/tests/unittests/classification/test_cohen_kappa.py @@ -37,7 +37,7 @@ def _reference_sklearn_cohen_kappa_binary(preds, target, weights=None, ignore_in if not ((preds > 0) & (preds < 1)).all(): preds = sigmoid(preds) preds = (preds >= THRESHOLD).astype(np.uint8) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return sk_cohen_kappa(y1=target, y2=preds, weights=weights) @@ -136,7 +136,7 @@ def _reference_sklearn_cohen_kappa_multiclass(preds, target, weights, ignore_ind preds = np.argmax(preds, axis=1) preds = preds.flatten() target = target.flatten() - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return sk_cohen_kappa(y1=target, y2=preds, weights=weights) diff --git a/tests/unittests/classification/test_confusion_matrix.py b/tests/unittests/classification/test_confusion_matrix.py index 12f21451949..4d27dfc2069 100644 --- a/tests/unittests/classification/test_confusion_matrix.py +++ b/tests/unittests/classification/test_confusion_matrix.py @@ -46,7 +46,7 @@ def _reference_sklearn_confusion_matrix_binary(preds, target, normalize=None, ig if not ((preds > 0) & (preds < 1)).all(): preds = sigmoid(preds) preds = (preds >= THRESHOLD).astype(np.uint8) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return sk_confusion_matrix(y_true=target, y_pred=preds, labels=[0, 1], normalize=normalize) @@ -147,7 +147,7 @@ def _reference_sklearn_confusion_matrix_multiclass(preds, target, normalize=None preds = np.argmax(preds, axis=1) preds = preds.flatten() target = target.flatten() - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return sk_confusion_matrix(y_true=target, y_pred=preds, normalize=normalize, labels=list(range(NUM_CLASSES))) @@ -298,7 +298,7 @@ def _reference_sklearn_confusion_matrix_multilabel(preds, target, normalize=None confmat = [] for i in range(preds.shape[1]): pred, true = preds[:, i], target[:, i] - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) confmat.append(sk_confusion_matrix(true, pred, normalize=normalize, labels=[0, 1])) return np.stack(confmat, axis=0) diff --git a/tests/unittests/classification/test_f_beta.py b/tests/unittests/classification/test_f_beta.py index 03c39d336fe..075e37cc699 100644 --- a/tests/unittests/classification/test_f_beta.py +++ b/tests/unittests/classification/test_f_beta.py @@ -63,14 +63,14 @@ def _reference_sklearn_fbeta_score_binary(preds, target, sk_fn, ignore_index, mu preds = (preds >= THRESHOLD).astype(np.uint8) if multidim_average == "global": - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return sk_fn(target, preds, zero_division=zero_division) res = [] for pred, true in zip(preds, target): pred = pred.flatten() true = true.flatten() - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) res.append(sk_fn(true, pred, zero_division=zero_division)) return np.stack(res) @@ -205,7 +205,7 @@ def _reference_sklearn_fbeta_score_multiclass( if multidim_average == "global": preds = preds.numpy().flatten() target = target.numpy().flatten() - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return sk_fn( target, preds, @@ -220,7 +220,7 @@ def _reference_sklearn_fbeta_score_multiclass( for pred, true in zip(preds, target): pred = pred.flatten() true = true.flatten() - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) if len(pred) == 0 and average == "weighted": # The result of sk_fn([], [], labels=None, average="weighted", zero_division=zero_division) @@ -417,13 +417,13 @@ def _reference_sklearn_fbeta_score_multilabel_global(preds, target, sk_fn, ignor if average == "micro": preds = preds.flatten() target = target.flatten() - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return sk_fn(target, preds, zero_division=zero_division) fbeta_score, weights = [], [] for i in range(preds.shape[1]): pred, true = preds[:, i].flatten(), target[:, i].flatten() - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) fbeta_score.append(sk_fn(true, pred, zero_division=zero_division)) confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) weights.append(confmat[1, 1] + confmat[1, 0]) @@ -446,7 +446,7 @@ def _reference_sklearn_fbeta_score_multilabel_local(preds, target, sk_fn, ignore for i in range(preds.shape[0]): if average == "micro": pred, true = preds[i].flatten(), target[i].flatten() - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) fbeta_score.append(sk_fn(true, pred, zero_division=zero_division)) confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) weights.append(confmat[1, 1] + confmat[1, 0]) @@ -454,7 +454,7 @@ def _reference_sklearn_fbeta_score_multilabel_local(preds, target, sk_fn, ignore scores, w = [], [] for j in range(preds.shape[1]): pred, true = preds[i, j], target[i, j] - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) scores.append(sk_fn(true, pred, zero_division=zero_division)) confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) w.append(confmat[1, 1] + confmat[1, 0]) diff --git a/tests/unittests/classification/test_hamming_distance.py b/tests/unittests/classification/test_hamming_distance.py index f7f3686c73b..a7a42db61b0 100644 --- a/tests/unittests/classification/test_hamming_distance.py +++ b/tests/unittests/classification/test_hamming_distance.py @@ -59,14 +59,14 @@ def _reference_sklearn_hamming_distance_binary(preds, target, ignore_index, mult preds = (preds >= THRESHOLD).astype(np.uint8) if multidim_average == "global": - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return _reference_sklearn_hamming_loss(target, preds) res = [] for pred, true in zip(preds, target): pred = pred.flatten() true = true.flatten() - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) res.append(_reference_sklearn_hamming_loss(true, pred)) return np.stack(res) @@ -167,7 +167,7 @@ def test_binary_hamming_distance_dtype_gpu(self, inputs, dtype): def _reference_sklearn_hamming_distance_multiclass_global(preds, target, ignore_index, average): preds = preds.numpy().flatten() target = target.numpy().flatten() - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) if average == "micro": return _reference_sklearn_hamming_loss(target, preds) confmat = sk_confusion_matrix(y_true=target, y_pred=preds, labels=list(range(NUM_CLASSES))) @@ -191,7 +191,7 @@ def _reference_sklearn_hamming_distance_multiclass_local(preds, target, ignore_i for pred, true in zip(preds, target): pred = pred.flatten() true = true.flatten() - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) if average == "micro": res.append(_reference_sklearn_hamming_loss(true, pred)) else: @@ -331,13 +331,13 @@ def _reference_sklearn_hamming_distance_multilabel_global(preds, target, ignore_ if average == "micro": preds = preds.flatten() target = target.flatten() - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return _reference_sklearn_hamming_loss(target, preds) hamming, weights = [], [] for i in range(preds.shape[1]): pred, true = preds[:, i].flatten(), target[:, i].flatten() - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) hamming.append(_reference_sklearn_hamming_loss(true, pred)) weights.append(confmat[1, 1] + confmat[1, 0]) @@ -360,13 +360,13 @@ def _reference_sklearn_hamming_distance_multilabel_local(preds, target, ignore_i for i in range(preds.shape[0]): if average == "micro": pred, true = preds[i].flatten(), target[i].flatten() - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) hamming.append(_reference_sklearn_hamming_loss(true, pred)) else: scores, w = [], [] for j in range(preds.shape[1]): pred, true = preds[i, j], target[i, j] - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) scores.append(_reference_sklearn_hamming_loss(true, pred)) confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) w.append(confmat[1, 1] + confmat[1, 0]) diff --git a/tests/unittests/classification/test_hinge.py b/tests/unittests/classification/test_hinge.py index fb0c62838cc..8963177d7a0 100644 --- a/tests/unittests/classification/test_hinge.py +++ b/tests/unittests/classification/test_hinge.py @@ -38,7 +38,7 @@ def _reference_sklearn_binary_hinge_loss(preds, target, ignore_index): if not ((preds > 0) & (preds < 1)).all(): preds = sigmoid(preds) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) target = 2 * target - 1 return sk_hinge(target, preds) @@ -125,7 +125,7 @@ def _reference_sklearn_multiclass_hinge_loss(preds, target, multiclass_mode, ign if not ((preds > 0) & (preds < 1)).all(): preds = softmax(preds, 1) preds = np.moveaxis(preds, 1, -1).reshape((-1, preds.shape[1])) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) if multiclass_mode == "one-vs-all": enc = OneHotEncoder() diff --git a/tests/unittests/classification/test_jaccard.py b/tests/unittests/classification/test_jaccard.py index e7afdb557a6..0a20a2e458a 100644 --- a/tests/unittests/classification/test_jaccard.py +++ b/tests/unittests/classification/test_jaccard.py @@ -45,7 +45,7 @@ def _reference_sklearn_jaccard_index_binary(preds, target, ignore_index=None, ze if not ((preds > 0) & (preds < 1)).all(): preds = sigmoid(preds) preds = (preds >= THRESHOLD).astype(np.uint8) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return sk_jaccard_index(y_true=target, y_pred=preds, zero_division=zero_division) @@ -141,7 +141,7 @@ def _reference_sklearn_jaccard_index_multiclass(preds, target, ignore_index=None preds = np.argmax(preds, axis=1) preds = preds.flatten() target = target.flatten() - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) if average is None: return sk_jaccard_index( y_true=target, y_pred=preds, average=average, labels=list(range(NUM_CLASSES)), zero_division=zero_division @@ -269,7 +269,7 @@ def _reference_sklearn_jaccard_index_multilabel(preds, target, ignore_index=None scores, weights = [], [] for i in range(preds.shape[1]): pred, true = preds[:, i], target[:, i] - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) scores.append(sk_jaccard_index(true, pred, zero_division=zero_division)) weights.append(confmat[1, 0] + confmat[1, 1]) diff --git a/tests/unittests/classification/test_matthews_corrcoef.py b/tests/unittests/classification/test_matthews_corrcoef.py index 2f881604d09..b340db8d713 100644 --- a/tests/unittests/classification/test_matthews_corrcoef.py +++ b/tests/unittests/classification/test_matthews_corrcoef.py @@ -46,7 +46,7 @@ def _reference_sklearn_matthews_corrcoef_binary(preds, target, ignore_index=None if not ((preds > 0) & (preds < 1)).all(): preds = sigmoid(preds) preds = (preds >= THRESHOLD).astype(np.uint8) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return sk_matthews_corrcoef(y_true=target, y_pred=preds) @@ -138,7 +138,7 @@ def _reference_sklearn_matthews_corrcoef_multiclass(preds, target, ignore_index= preds = np.argmax(preds, axis=1) preds = preds.flatten() target = target.flatten() - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return sk_matthews_corrcoef(y_true=target, y_pred=preds) @@ -228,7 +228,7 @@ def _reference_sklearn_matthews_corrcoef_multilabel(preds, target, ignore_index= if not ((preds > 0) & (preds < 1)).all(): preds = sigmoid(preds) preds = (preds >= THRESHOLD).astype(np.uint8) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return sk_matthews_corrcoef(y_true=target, y_pred=preds) diff --git a/tests/unittests/classification/test_precision_fixed_recall.py b/tests/unittests/classification/test_precision_fixed_recall.py index f320d2cf1e9..b6649ad869d 100644 --- a/tests/unittests/classification/test_precision_fixed_recall.py +++ b/tests/unittests/classification/test_precision_fixed_recall.py @@ -58,7 +58,7 @@ def _reference_sklearn_precision_at_fixed_recall_binary(preds, target, min_recal target = target.flatten().numpy() if np.issubdtype(preds.dtype, np.floating) and not ((preds > 0) & (preds < 1)).all(): preds = sigmoid(preds) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return _precision_at_recall_x_multilabel(preds, target, min_recall) @@ -169,7 +169,7 @@ def _reference_sklearn_precision_at_fixed_recall_multiclass(preds, target, min_r target = target.numpy().flatten() if not ((preds > 0) & (preds < 1)).all(): preds = softmax(preds, 1) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) precision, thresholds = [], [] for i in range(NUM_CLASSES): diff --git a/tests/unittests/classification/test_precision_recall.py b/tests/unittests/classification/test_precision_recall.py index 00eee202cc0..7717ffa5b0d 100644 --- a/tests/unittests/classification/test_precision_recall.py +++ b/tests/unittests/classification/test_precision_recall.py @@ -49,7 +49,9 @@ seed_all(42) -def _reference_sklearn_precision_recall_binary(preds, target, sk_fn, ignore_index, multidim_average, zero_division=0): +def _reference_sklearn_precision_recall_binary( + preds, target, sk_fn, ignore_index, multidim_average, zero_division=0, prob_threshold: float = THRESHOLD +): if multidim_average == "global": preds = preds.view(-1).numpy() target = target.view(-1).numpy() @@ -60,17 +62,17 @@ def _reference_sklearn_precision_recall_binary(preds, target, sk_fn, ignore_inde if np.issubdtype(preds.dtype, np.floating): if not ((preds > 0) & (preds < 1)).all(): preds = sigmoid(preds) - preds = (preds >= THRESHOLD).astype(np.uint8) + preds = (preds >= prob_threshold).astype(np.uint8) if multidim_average == "global": - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return sk_fn(target, preds, zero_division=zero_division) res = [] for pred, true in zip(preds, target): pred = pred.flatten() true = true.flatten() - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) res.append(sk_fn(true, pred, zero_division=zero_division)) return np.stack(res) @@ -197,7 +199,7 @@ def test_binary_precision_recall_half_gpu(self, inputs, module, functional, comp def _reference_sklearn_precision_recall_multiclass( - preds, target, sk_fn, ignore_index, multidim_average, average, zero_division=0 + preds, target, sk_fn, ignore_index, multidim_average, average, zero_division=0, num_classes: int = NUM_CLASSES ): if preds.ndim == target.ndim + 1: preds = torch.argmax(preds, 1) @@ -205,12 +207,12 @@ def _reference_sklearn_precision_recall_multiclass( if multidim_average == "global": preds = preds.numpy().flatten() target = target.numpy().flatten() - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return sk_fn( target, preds, average=average, - labels=list(range(NUM_CLASSES)) if average is None else None, + labels=list(range(num_classes)) if average is None else None, zero_division=zero_division, ) @@ -220,7 +222,7 @@ def _reference_sklearn_precision_recall_multiclass( for pred, true in zip(preds, target): pred = pred.flatten() true = true.flatten() - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) if len(pred) == 0 and average == "weighted": # The result of sk_fn([], [], labels=None, average="weighted", zero_division=zero_division) # varies depending on the sklearn version: @@ -235,7 +237,7 @@ def _reference_sklearn_precision_recall_multiclass( true, pred, average=average, - labels=list(range(NUM_CLASSES)) if average is None else None, + labels=list(range(num_classes)) if average is None else None, zero_division=zero_division, ) res.append(0.0 if np.isnan(r).any() else r) @@ -422,13 +424,13 @@ def _reference_sklearn_precision_recall_multilabel_global(preds, target, sk_fn, if average == "micro": preds = preds.flatten() target = target.flatten() - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return sk_fn(target, preds, zero_division=zero_division) precision_recall, weights = [], [] for i in range(preds.shape[1]): pred, true = preds[:, i].flatten(), target[:, i].flatten() - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) precision_recall.append(sk_fn(true, pred, zero_division=zero_division)) confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) weights.append(confmat[1, 1] + confmat[1, 0]) @@ -451,7 +453,7 @@ def _reference_sklearn_precision_recall_multilabel_local(preds, target, sk_fn, i for i in range(preds.shape[0]): if average == "micro": pred, true = preds[i].flatten(), target[i].flatten() - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) precision_recall.append(sk_fn(true, pred, zero_division=zero_division)) confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) weights.append(confmat[1, 1] + confmat[1, 0]) @@ -459,7 +461,7 @@ def _reference_sklearn_precision_recall_multilabel_local(preds, target, sk_fn, i scores, w = [], [] for j in range(preds.shape[1]): pred, true = preds[i, j], target[i, j] - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) scores.append(sk_fn(true, pred, zero_division=zero_division)) confmat = sk_confusion_matrix(true, pred, labels=[0, 1]) w.append(confmat[1, 1] + confmat[1, 0]) @@ -481,7 +483,7 @@ def _reference_sklearn_precision_recall_multilabel_local(preds, target, sk_fn, i def _reference_sklearn_precision_recall_multilabel( - preds, target, sk_fn, ignore_index, multidim_average, average, zero_division=0 + preds, target, sk_fn, ignore_index, multidim_average, average, zero_division=0, num_classes: int = NUM_CLASSES ): preds = preds.numpy() target = target.numpy() @@ -493,8 +495,8 @@ def _reference_sklearn_precision_recall_multilabel( target = target.reshape(*target.shape[:2], -1) if ignore_index is None and multidim_average == "global": return sk_fn( - target.transpose(0, 2, 1).reshape(-1, NUM_CLASSES), - preds.transpose(0, 2, 1).reshape(-1, NUM_CLASSES), + target.transpose(0, 2, 1).reshape(-1, num_classes), + preds.transpose(0, 2, 1).reshape(-1, num_classes), average=average, zero_division=zero_division, ) diff --git a/tests/unittests/classification/test_precision_recall_curve.py b/tests/unittests/classification/test_precision_recall_curve.py index 6f78438007e..7c034c528e6 100644 --- a/tests/unittests/classification/test_precision_recall_curve.py +++ b/tests/unittests/classification/test_precision_recall_curve.py @@ -46,7 +46,7 @@ def _reference_sklearn_precision_recall_curve_binary(preds, target, ignore_index target = target.flatten().numpy() if np.issubdtype(preds.dtype, np.floating) and not ((preds > 0) & (preds < 1)).all(): preds = sigmoid(preds) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return sk_precision_recall_curve(target, preds) @@ -159,7 +159,7 @@ def _reference_sklearn_precision_recall_curve_multiclass(preds, target, ignore_i target = target.numpy().flatten() if not ((preds > 0) & (preds < 1)).all(): preds = softmax(preds, 1) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) precision, recall, thresholds = [], [], [] for i in range(NUM_CLASSES): diff --git a/tests/unittests/classification/test_recall_fixed_precision.py b/tests/unittests/classification/test_recall_fixed_precision.py index 9bdca8950bd..2d73d64f264 100644 --- a/tests/unittests/classification/test_recall_fixed_precision.py +++ b/tests/unittests/classification/test_recall_fixed_precision.py @@ -58,7 +58,7 @@ def _reference_sklearn_recall_at_fixed_precision_binary(preds, target, min_preci target = target.flatten().numpy() if np.issubdtype(preds.dtype, np.floating) and not ((preds > 0) & (preds < 1)).all(): preds = sigmoid(preds) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return _recall_at_precision_x_multilabel(preds, target, min_precision) @@ -173,7 +173,7 @@ def _reference_sklearn_recall_at_fixed_precision_multiclass(preds, target, min_p target = target.numpy().flatten() if not ((preds > 0) & (preds < 1)).all(): preds = softmax(preds, 1) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) recall, thresholds = [], [] for i in range(NUM_CLASSES): diff --git a/tests/unittests/classification/test_roc.py b/tests/unittests/classification/test_roc.py index 167ad4876f0..f6cbd173128 100644 --- a/tests/unittests/classification/test_roc.py +++ b/tests/unittests/classification/test_roc.py @@ -37,7 +37,7 @@ def _reference_sklearn_roc_binary(preds, target, ignore_index=None): target = target.flatten().numpy() if np.issubdtype(preds.dtype, np.floating) and not ((preds > 0) & (preds < 1)).all(): preds = sigmoid(preds) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) fpr, tpr, thresholds = sk_roc_curve(target, preds, drop_intermediate=False) thresholds[0] = 1.0 return [np.nan_to_num(x, nan=0.0) for x in [fpr, tpr, thresholds]] @@ -140,7 +140,7 @@ def _reference_sklearn_roc_multiclass(preds, target, ignore_index=None): target = target.numpy().flatten() if not ((preds > 0) & (preds < 1)).all(): preds = softmax(preds, 1) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) fpr, tpr, thresholds = [], [], [] for i in range(NUM_CLASSES): diff --git a/tests/unittests/classification/test_sensitivity_specificity.py b/tests/unittests/classification/test_sensitivity_specificity.py index df01b67e4ab..e170365c1fb 100644 --- a/tests/unittests/classification/test_sensitivity_specificity.py +++ b/tests/unittests/classification/test_sensitivity_specificity.py @@ -79,7 +79,7 @@ def _reference_sklearn_sensitivity_at_specificity_binary(preds, target, min_spec target = target.flatten().numpy() if np.issubdtype(preds.dtype, np.floating) and not ((preds > 0) & (preds < 1)).all(): preds = sigmoid(preds) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return _sensitivity_at_specificity_x_multilabel(preds, target, min_specificity) @@ -198,7 +198,7 @@ def _reference_sklearn_sensitivity_at_specificity_multiclass(preds, target, min_ target = target.numpy().flatten() if not ((preds > 0) & (preds < 1)).all(): preds = softmax(preds, 1) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) sensitivity, thresholds = [], [] for i in range(NUM_CLASSES): diff --git a/tests/unittests/classification/test_specificity_sensitivity.py b/tests/unittests/classification/test_specificity_sensitivity.py index 0bafdfe55ea..934d669678a 100644 --- a/tests/unittests/classification/test_specificity_sensitivity.py +++ b/tests/unittests/classification/test_specificity_sensitivity.py @@ -77,7 +77,7 @@ def _reference_sklearn_specificity_at_sensitivity_binary(preds, target, min_sens target = target.flatten().numpy() if np.issubdtype(preds.dtype, np.floating) and not ((preds > 0) & (preds < 1)).all(): preds = sigmoid(preds) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) return _specificity_at_sensitivity_x_multilabel(preds, target, min_sensitivity) @@ -192,7 +192,7 @@ def _reference_sklearn_specificity_at_sensitivity_multiclass(preds, target, min_ target = target.numpy().flatten() if not ((preds > 0) & (preds < 1)).all(): preds = softmax(preds, 1) - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) specificity, thresholds = [], [] for i in range(NUM_CLASSES): diff --git a/tests/unittests/classification/test_stat_scores.py b/tests/unittests/classification/test_stat_scores.py index 5ea4c206bc0..2a5a53bb8aa 100644 --- a/tests/unittests/classification/test_stat_scores.py +++ b/tests/unittests/classification/test_stat_scores.py @@ -54,7 +54,7 @@ def _reference_sklearn_stat_scores_binary(preds, target, ignore_index, multidim_ preds = (preds >= THRESHOLD).astype(np.uint8) if multidim_average == "global": - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) tn, fp, fn, tp = sk_confusion_matrix(y_true=target, y_pred=preds, labels=[0, 1]).ravel() return np.array([tp, fp, tn, fn, tp + fn]) @@ -62,7 +62,7 @@ def _reference_sklearn_stat_scores_binary(preds, target, ignore_index, multidim_ for pred, true in zip(preds, target): pred = pred.flatten() true = true.flatten() - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) tn, fp, fn, tp = sk_confusion_matrix(y_true=true, y_pred=pred, labels=[0, 1]).ravel() res.append(np.array([tp, fp, tn, fn, tp + fn])) return np.stack(res) @@ -164,7 +164,7 @@ def test_binary_stat_scores_dtype_gpu(self, inputs, dtype): def _reference_sklearn_stat_scores_multiclass_global(preds, target, ignore_index, average): preds = preds.numpy().flatten() target = target.numpy().flatten() - target, preds = remove_ignore_index(target, preds, ignore_index) + target, preds = remove_ignore_index(target=target, preds=preds, ignore_index=ignore_index) confmat = sk_confusion_matrix(y_true=target, y_pred=preds, labels=list(range(NUM_CLASSES))) tp = np.diag(confmat) fp = confmat.sum(0) - tp @@ -192,7 +192,7 @@ def _reference_sklearn_stat_scores_multiclass_local(preds, target, ignore_index, for pred, true in zip(preds, target): pred = pred.flatten() true = true.flatten() - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) confmat = sk_confusion_matrix(y_true=true, y_pred=pred, labels=list(range(NUM_CLASSES))) tp = np.diag(confmat) fp = confmat.sum(0) - tp @@ -431,7 +431,7 @@ def _reference_sklearn_stat_scores_multilabel(preds, target, ignore_index, multi stat_scores = [] for i in range(preds.shape[1]): pred, true = preds[:, i].flatten(), target[:, i].flatten() - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) tn, fp, fn, tp = sk_confusion_matrix(true, pred, labels=[0, 1]).ravel() stat_scores.append(np.array([tp, fp, tn, fn, tp + fn])) res = np.stack(stat_scores, axis=0) @@ -452,7 +452,7 @@ def _reference_sklearn_stat_scores_multilabel(preds, target, ignore_index, multi scores = [] for j in range(preds.shape[1]): pred, true = preds[i, j], target[i, j] - true, pred = remove_ignore_index(true, pred, ignore_index) + true, pred = remove_ignore_index(target=true, preds=pred, ignore_index=ignore_index) tn, fp, fn, tp = sk_confusion_matrix(true, pred, labels=[0, 1]).ravel() scores.append(np.array([tp, fp, tn, fn, tp + fn])) stat_scores.append(np.stack(scores, 1)) diff --git a/tests/unittests/detection/test_intersection.py b/tests/unittests/detection/test_intersection.py index a028a014ea1..f44c6017e6c 100644 --- a/tests/unittests/detection/test_intersection.py +++ b/tests/unittests/detection/test_intersection.py @@ -355,6 +355,43 @@ def test_corner_case_only_one_empty_prediction(self, class_metric, functional_me for val in res.values(): assert val == torch.tensor(0.0) + def test_empty_preds_and_target(self, class_metric, functional_metric, reference_metric): + """Check that for either empty preds and targets that the metric returns 0 in these cases before averaging.""" + x = [ + { + "boxes": torch.empty(size=(0, 4), dtype=torch.float32), + "labels": torch.tensor([], dtype=torch.long), + }, + { + "boxes": torch.FloatTensor([[0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4]]), + "labels": torch.LongTensor([1, 2]), + }, + ] + + y = [ + { + "boxes": torch.FloatTensor([[0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4]]), + "labels": torch.LongTensor([1, 2]), + "scores": torch.FloatTensor([0.9, 0.8]), + }, + { + "boxes": torch.FloatTensor([[0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4]]), + "labels": torch.LongTensor([1, 2]), + "scores": torch.FloatTensor([0.9, 0.8]), + }, + ] + metric = class_metric() + metric.update(x, y) + res = metric.compute() + for val in res.values(): + assert val == torch.tensor(0.5) + + metric = class_metric() + metric.update(y, x) + res = metric.compute() + for val in res.values(): + assert val == torch.tensor(0.5) + def test_corner_case(): """See issue: https://github.com/Lightning-AI/torchmetrics/issues/1921.""" diff --git a/tests/unittests/segmentation/test_generalized_dice_score.py b/tests/unittests/segmentation/test_generalized_dice_score.py index 3f8acec842a..ef05face380 100644 --- a/tests/unittests/segmentation/test_generalized_dice_score.py +++ b/tests/unittests/segmentation/test_generalized_dice_score.py @@ -16,6 +16,7 @@ import pytest import torch +from lightning_utilities.core.imports import RequirementCache from monai.metrics.generalized_dice import compute_generalized_dice from torchmetrics.functional.segmentation.generalized_dice import generalized_dice_score from torchmetrics.segmentation.generalized_dice import GeneralizedDiceScore @@ -51,7 +52,8 @@ def _reference_generalized_dice( if input_format == "index": preds = torch.nn.functional.one_hot(preds, num_classes=NUM_CLASSES).movedim(-1, 1) target = torch.nn.functional.one_hot(target, num_classes=NUM_CLASSES).movedim(-1, 1) - val = compute_generalized_dice(preds, target, include_background=include_background) + monai_extra_arg = {"sum_over_classes": True} if RequirementCache("monai>=1.4.0") else {} + val = compute_generalized_dice(preds, target, include_background=include_background, **monai_extra_arg) if reduce: val = val.mean() return val