Skip to content

Commit

Permalink
Merge branch 'master' into docker/py3.11
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki authored Oct 5, 2023
2 parents a229514 + f925113 commit fa0b0a8
Show file tree
Hide file tree
Showing 38 changed files with 282 additions and 100 deletions.
8 changes: 4 additions & 4 deletions .github/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
import re
import sys
import traceback
from distutils.version import LooseVersion
from typing import List, Optional, Tuple, Union

import fire
import requests
from packaging.version import parse
from pkg_resources import parse_requirements

_REQUEST_TIMEOUT = 10
Expand Down Expand Up @@ -79,9 +79,9 @@ def set_min_torch_by_python(fpath: str = "requirements/base.txt") -> None:
return
with open(fpath) as fp:
reqs = parse_requirements(fp.readlines())
pkg_ver = [p for p in reqs if p.name == "torch"][0]
pt_ver = min([LooseVersion(v[1]) for v in pkg_ver.specs])
pt_ver = max(LooseVersion(LUT_PYTHON_TORCH[py_ver]), pt_ver)
pkg_ver = next(p for p in reqs if p.name == "torch")
pt_ver = min([parse(v[1]) for v in pkg_ver.specs])
pt_ver = max(parse(LUT_PYTHON_TORCH[py_ver]), pt_ver)
with open(fpath) as fp:
requires = fp.read()
requires = re.sub(r"torch>=[\d\.]+", f"torch>={pt_ver}", requires)
Expand Down
16 changes: 8 additions & 8 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,22 +38,22 @@ repos:
- id: detect-private-key

- repo: https://github.com/asottile/pyupgrade
rev: v3.9.0
rev: v3.14.0
hooks:
- id: pyupgrade
args: [--py38-plus]
args: ["--py38-plus"]
name: Upgrade code

- repo: https://github.com/codespell-project/codespell
rev: v2.2.5
rev: v2.2.6
hooks:
- id: codespell
additional_dependencies: [tomli]
#args: ["--write-changes"]
args: ["--write-changes"]
exclude: pyproject.toml

- repo: https://github.com/crate-ci/typos
rev: v1.16.12
rev: v1.16.17
hooks:
- id: typos
# empty to do not write fixes
Expand All @@ -68,13 +68,13 @@ repos:
args: ["--in-place"]

- repo: https://github.com/psf/black
rev: 23.7.0
rev: 23.9.1
hooks:
- id: black
name: Format code

- repo: https://github.com/executablebooks/mdformat
rev: 0.7.16
rev: 0.7.17
hooks:
- id: mdformat
additional_dependencies:
Expand Down Expand Up @@ -130,7 +130,7 @@ repos:
- id: text-unicode-replacement-char

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.277
rev: v0.0.292
hooks:
- id: ruff
args: ["--fix"]
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

-
- Added `average` argument to multiclass versions of `PrecisionRecallCurve` and `ROC` ([#2084](https://github.com/Lightning-AI/torchmetrics/pull/2084))

### Changed

Expand Down
2 changes: 1 addition & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Manifest syntax https://docs.python.org/2/distutils/sourcedist.html
# Manifest syntax https://packaging.python.org/en/latest/guides/using-manifest-in/
graft wheelhouse

recursive-exclude __pycache__ *.py[cod] *.orig
Expand Down
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,4 @@
.. _Completeness Score: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.completeness_score.html
.. _Davies-Bouldin Score: https://en.wikipedia.org/wiki/Davies%E2%80%93Bouldin_index
.. _Fowlkes-Mallows Index: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.fowlkes_mallows_score.html#sklearn.metrics.fowlkes_mallows_score
.. _averaging curve objects: https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@ addopts = [
#filterwarnings = ["error::FutureWarning"] # ToDo
xfail_strict = true
junit_duration_report = "call"

[tool.coverage.report]
exclude_lines = ["pragma: no cover", "pass"]

[tool.coverage.run]
parallel = true
concurrency = "thread"
Expand Down Expand Up @@ -81,6 +79,7 @@ wil = "wil"


[tool.ruff]
target-version = "py38"
line-length = 120
# Enable Pyflakes `E` and `F` codes by default.
select = [
Expand Down Expand Up @@ -122,6 +121,8 @@ ignore = [
"S301", # todo: `pickle` and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue # todo
"S310", # todo: Audit URL open for permitted schemes. Allowing use of `file:` or custom schemes is often unexpected. # todo
"B905", # todo: `zip()` without an explicit `strict=` parameter
"PYI024", # todo: Use `typing.NamedTuple` instead of `collections.namedtuple`
"PYI041", # todo: Use `float` instead of `int | float``
]
# Exclude a variety of commonly ignored directories.
exclude = [
Expand Down
1 change: 1 addition & 0 deletions requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

numpy >1.20.0
packaging >17.1
torch >=1.8.1, <=2.0.1
typing-extensions; python_version < '3.9'
lightning-utilities >=0.8.0, <0.10.0
12 changes: 6 additions & 6 deletions src/torchmetrics/audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,16 @@
]

if _PESQ_AVAILABLE:
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality # noqa: F401
from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality

__all__.append("PerceptualEvaluationSpeechQuality")
__all__ += ["PerceptualEvaluationSpeechQuality"]

if _PYSTOI_AVAILABLE:
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility # noqa: F401
from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility

__all__.append("ShortTimeObjectiveIntelligibility")
__all__ += ["ShortTimeObjectiveIntelligibility"]

if _GAMMATONE_AVAILABLE and _TORCHAUDIO_AVAILABLE and _TORCHAUDIO_GREATER_EQUAL_0_10:
from torchmetrics.audio.srmr import SpeechReverberationModulationEnergyRatio # noqa: F401
from torchmetrics.audio.srmr import SpeechReverberationModulationEnergyRatio

__all__.append("SpeechReverberationModulationEnergyRatio")
__all__ += ["SpeechReverberationModulationEnergyRatio"]
2 changes: 1 addition & 1 deletion src/torchmetrics/audio/stoi.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class ShortTimeObjectiveIntelligibility(Metric):
The STOI-measure is intrusive, i.e., a function of the clean and degraded speech signals. STOI may be a good
alternative to the speech intelligibility index (SII) or the speech transmission index (STI), when you are
interested in the effect of nonlinear processing to noisy speech, e.g., noise reduction, binary masking algorithms,
on speech intelligibility. Description taken from `Cees Taal's website`_ and for further defails see `STOI ref1`_
on speech intelligibility. Description taken from `Cees Taal's website`_ and for further details see `STOI ref1`_
and `STOI ref2`_.
This metric is a wrapper for the `pystoi package`_. As the implementation backend implementation only supports
Expand Down
6 changes: 4 additions & 2 deletions src/torchmetrics/classification/auroc.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,13 +266,15 @@ def __init__(
)
if validate_args:
_multiclass_auroc_arg_validation(num_classes, average, thresholds, ignore_index)
self.average = average
self.average = average # type: ignore[assignment]
self.validate_args = validate_args

def compute(self) -> Tensor: # type: ignore[override]
"""Compute metric."""
state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat
return _multiclass_auroc_compute(state, self.num_classes, self.average, self.thresholds)
return _multiclass_auroc_compute(
state, self.num_classes, self.average, self.thresholds # type: ignore[arg-type]
)

def plot( # type: ignore[override]
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
Expand Down
6 changes: 4 additions & 2 deletions src/torchmetrics/classification/average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,13 +264,15 @@ def __init__(
)
if validate_args:
_multiclass_average_precision_arg_validation(num_classes, average, thresholds, ignore_index)
self.average = average
self.average = average # type: ignore[assignment]
self.validate_args = validate_args

def compute(self) -> Tensor: # type: ignore[override]
"""Compute metric."""
state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat
return _multiclass_average_precision_compute(state, self.num_classes, self.average, self.thresholds)
return _multiclass_average_precision_compute(
state, self.num_classes, self.average, self.thresholds # type: ignore[arg-type]
)

def plot( # type: ignore[override]
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
Expand Down
25 changes: 21 additions & 4 deletions src/torchmetrics/classification/precision_recall_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ class BinaryPrecisionRecallCurve(Metric):
- If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
bins for the calculation.
ignore_index:
Specifies a target value that is ignored and does not contribute to the metric calculation
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Expand Down Expand Up @@ -266,6 +268,15 @@ class MulticlassPrecisionRecallCurve(Metric):
- If set to a 1D `tensor` of floats, will use the indicated thresholds in the tensor as
bins for the calculation.
average:
If aggregation of curves should be applied. By default, the curves are not aggregated and a curve for
each class is returned. If `average` is set to ``"micro"``, the metric will aggregate the curves by one hot
encoding the targets and flattening the predictions, considering all classes jointly as a binary problem.
If `average` is set to ``"macro"``, the metric will aggregate the curves by first interpolating the curves
from each class at a combined set of thresholds and then average over the classwise interpolated curves.
See `averaging curve objects`_ for more info on the different averaging methods.
ignore_index:
Specifies a target value that is ignored and does not contribute to the metric calculation
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Expand Down Expand Up @@ -314,15 +325,17 @@ def __init__(
self,
num_classes: int,
thresholds: Optional[Union[int, List[float], Tensor]] = None,
average: Optional[Literal["micro", "macro"]] = None,
ignore_index: Optional[int] = None,
validate_args: bool = True,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
if validate_args:
_multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index)
_multiclass_precision_recall_curve_arg_validation(num_classes, thresholds, ignore_index, average)

self.num_classes = num_classes
self.average = average
self.ignore_index = ignore_index
self.validate_args = validate_args

Expand All @@ -344,9 +357,11 @@ def update(self, preds: Tensor, target: Tensor) -> None:
if self.validate_args:
_multiclass_precision_recall_curve_tensor_validation(preds, target, self.num_classes, self.ignore_index)
preds, target, _ = _multiclass_precision_recall_curve_format(
preds, target, self.num_classes, self.thresholds, self.ignore_index
preds, target, self.num_classes, self.thresholds, self.ignore_index, self.average
)
state = _multiclass_precision_recall_curve_update(
preds, target, self.num_classes, self.thresholds, self.average
)
state = _multiclass_precision_recall_curve_update(preds, target, self.num_classes, self.thresholds)
if isinstance(state, Tensor):
self.confmat += state
else:
Expand All @@ -356,7 +371,7 @@ def update(self, preds: Tensor, target: Tensor) -> None:
def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
"""Compute metric."""
state = (dim_zero_cat(self.preds), dim_zero_cat(self.target)) if self.thresholds is None else self.confmat
return _multiclass_precision_recall_curve_compute(state, self.num_classes, self.thresholds)
return _multiclass_precision_recall_curve_compute(state, self.num_classes, self.thresholds, self.average)

def plot(
self,
Expand Down Expand Up @@ -456,6 +471,8 @@ class MultilabelPrecisionRecallCurve(Metric):
- If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
bins for the calculation.
ignore_index:
Specifies a target value that is ignored and does not contribute to the metric calculation
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
Expand Down
15 changes: 14 additions & 1 deletion src/torchmetrics/classification/roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ class BinaryROC(BinaryPrecisionRecallCurve):
- If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
bins for the calculation.
ignore_index:
Specifies a target value that is ignored and does not contribute to the metric calculation
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Expand Down Expand Up @@ -229,6 +231,15 @@ class MulticlassROC(MulticlassPrecisionRecallCurve):
- If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
bins for the calculation.
average:
If aggregation of curves should be applied. By default, the curves are not aggregated and a curve for
each class is returned. If `average` is set to ``"micro"``, the metric will aggregate the curves by one hot
encoding the targets and flattening the predictions, considering all classes jointly as a binary problem.
If `average` is set to ``"macro"``, the metric will aggregate the curves by first interpolating the curves
from each class at a combined set of thresholds and then average over the classwise interpolated curves.
See `averaging curve objects`_ for more info on the different averaging methods.
ignore_index:
Specifies a target value that is ignored and does not contribute to the metric calculation
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Expand Down Expand Up @@ -276,7 +287,7 @@ class MulticlassROC(MulticlassPrecisionRecallCurve):
def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], List[Tensor], List[Tensor]]]:
"""Compute metric."""
state = [dim_zero_cat(self.preds), dim_zero_cat(self.target)] if self.thresholds is None else self.confmat
return _multiclass_roc_compute(state, self.num_classes, self.thresholds)
return _multiclass_roc_compute(state, self.num_classes, self.thresholds, self.average)

def plot(
self,
Expand Down Expand Up @@ -381,6 +392,8 @@ class MultilabelROC(MultilabelPrecisionRecallCurve):
- If set to an 1d `tensor` of floats, will use the indicated thresholds in the tensor as
bins for the calculation.
ignore_index:
Specifies a target value that is ignored and does not contribute to the metric calculation
validate_args: bool indicating if input arguments and tensors should be validated for correctness.
Set to ``False`` for faster computations.
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/detection/mean_ap.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,7 +879,7 @@ def _get_coco_format(
f"Invalid input box of sample {image_id}, element {k} (expected 4 values, got {len(image_box)})"
)

if type(image_label) != int:
if not isinstance(image_label, int):
raise ValueError(
f"Invalid input class of sample {image_id}, element {k}"
f" (expected value of type integer, got type {type(image_label)})"
Expand Down Expand Up @@ -915,7 +915,7 @@ def _get_coco_format(

if scores is not None:
score = scores[image_id][k].cpu().tolist()
if type(score) != float:
if not isinstance(score, float):
raise ValueError(
f"Invalid input score of sample {image_id}, element {k}"
f" (expected value of type float, got type {type(score)})"
Expand Down
12 changes: 6 additions & 6 deletions src/torchmetrics/functional/audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,16 @@
]

if _PESQ_AVAILABLE:
from torchmetrics.functional.audio.pesq import perceptual_evaluation_speech_quality # noqa: F401
from torchmetrics.functional.audio.pesq import perceptual_evaluation_speech_quality

__all__.append("perceptual_evaluation_speech_quality")
__all__ += ["perceptual_evaluation_speech_quality"]

if _PYSTOI_AVAILABLE:
from torchmetrics.functional.audio.stoi import short_time_objective_intelligibility # noqa: F401
from torchmetrics.functional.audio.stoi import short_time_objective_intelligibility

__all__.append("short_time_objective_intelligibility")
__all__ += ["short_time_objective_intelligibility"]

if _GAMMATONE_AVAILABLE and _TORCHAUDIO_AVAILABLE and _TORCHAUDIO_GREATER_EQUAL_0_10:
from torchmetrics.functional.audio.srmr import speech_reverberation_modulation_energy_ratio # noqa: F401
from torchmetrics.functional.audio.srmr import speech_reverberation_modulation_energy_ratio

__all__.append("speech_reverberation_modulation_energy_ratio")
__all__ += ["speech_reverberation_modulation_energy_ratio"]
4 changes: 2 additions & 2 deletions src/torchmetrics/functional/audio/sdr.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _symmetric_toeplitz(vector: Tensor) -> Tensor:
def _compute_autocorr_crosscorr(target: Tensor, preds: Tensor, corr_len: int) -> Tuple[Tensor, Tensor]:
r"""Compute the auto correlation of `target` and the cross correlation of `target` and `preds`.
This calculation is done using the fast Fourier transform (FFT). Let's denotes the symmetric Toeplitz matric of the
This calculation is done using the fast Fourier transform (FFT). Let's denotes the symmetric Toeplitz metric of the
auto correlation of `target` as `R`, the cross correlation as 'b', then solving the equation `Rh=b` could have `h`
as the coordinate of `preds` in the column space of the `corr_len` shifts of `target`.
Expand All @@ -81,7 +81,7 @@ def _compute_autocorr_crosscorr(target: Tensor, preds: Tensor, corr_len: int) ->
n_fft = 2 ** math.ceil(math.log2(preds.shape[-1] + target.shape[-1] - 1))

# computes the auto correlation of `target`
# r_0 is the first row of the symmetric Toeplitz matric
# r_0 is the first row of the symmetric Toeplitz metric
t_fft = torch.fft.rfft(target, n=n_fft, dim=-1)
r_0 = torch.fft.irfft(t_fft.real**2 + t_fft.imag**2, n=n_fft)[..., :corr_len]

Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/functional/audio/stoi.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def short_time_objective_intelligibility(
STOI-measure is intrusive, i.e., a function of the clean and degraded speech signals. STOI may be a good alternative
to the speech intelligibility index (SII) or the speech transmission index (STI), when you are interested in
the effect of nonlinear processing to noisy speech, e.g., noise reduction, binary masking algorithms, on speech
intelligibility. Description taken from `Cees Taal's website`_ and for further defails see `STOI ref1`_ and
intelligibility. Description taken from `Cees Taal's website`_ and for further details see `STOI ref1`_ and
`STOI ref2`_.
This metric is a wrapper for the `pystoi package`_. As the implementation backend implementation only supports
Expand Down
Loading

0 comments on commit fa0b0a8

Please sign in to comment.