Skip to content

Commit

Permalink
Merge branch 'master' into rittik/multiclassf1_topk
Browse files Browse the repository at this point in the history
  • Loading branch information
rittik9 authored Nov 21, 2024
2 parents 27d4e59 + 0d3494f commit fe3ac74
Show file tree
Hide file tree
Showing 293 changed files with 2,760 additions and 1,208 deletions.
12 changes: 6 additions & 6 deletions .github/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import os
import re
import sys
from typing import List, Optional, Tuple, Union
from typing import Optional, Union

import fire
from packaging.version import parse
Expand Down Expand Up @@ -83,7 +83,7 @@ def _replace_requirement(fpath: str, old_str: str = "", new_str: str = "") -> No
fp.write(req)

@staticmethod
def replace_str_requirements(old_str: str, new_str: str, req_files: List[str] = REQUIREMENTS_FILES) -> None:
def replace_str_requirements(old_str: str, new_str: str, req_files: list[str] = REQUIREMENTS_FILES) -> None:
"""Replace a particular string in all requirements files."""
if isinstance(req_files, str):
req_files = [req_files]
Expand All @@ -96,7 +96,7 @@ def replace_min_requirements(fpath: str) -> None:
AssistantCLI._replace_requirement(fpath, old_str=">=", new_str="==")

@staticmethod
def set_oldest_versions(req_files: List[str] = REQUIREMENTS_FILES) -> None:
def set_oldest_versions(req_files: list[str] = REQUIREMENTS_FILES) -> None:
"""Set the oldest version for requirements."""
AssistantCLI.set_min_torch_by_python()
if isinstance(req_files, str):
Expand All @@ -109,8 +109,8 @@ def changed_domains(
pr: Optional[int] = None,
auth_token: Optional[str] = None,
as_list: bool = False,
general_sub_pkgs: Tuple[str] = _PKG_WIDE_SUBPACKAGES,
) -> Union[str, List[str]]:
general_sub_pkgs: tuple[str] = _PKG_WIDE_SUBPACKAGES,
) -> Union[str, list[str]]:
"""Determine what domains were changed in particular PR."""
import github

Expand Down Expand Up @@ -139,7 +139,7 @@ def changed_domains(
return "unittests"

# parse domains
def _crop_path(fname: str, paths: List[str]) -> str:
def _crop_path(fname: str, paths: list[str]) -> str:
for p in paths:
fname = fname.replace(p, "")
return fname
Expand Down
2 changes: 0 additions & 2 deletions .github/workflows/_focus-diff.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ jobs:
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
#with:
# python-version: 3.8

- name: Get PR diff
id: diff-domains
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci-checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
testing-matrix: |
{
"os": ["ubuntu-22.04", "macos-13", "windows-2022"],
"python-version": ["3.8", "3.11"]
"python-version": ["3.9", "3.11"]
}
check-md-links:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ jobs:
- "2.5.0"
include:
# cover additional python and PT combinations
- { os: "ubuntu-20.04", python-version: "3.8", pytorch-version: "2.0.1", requires: "oldest" }
- { os: "ubuntu-20.04", python-version: "3.9", pytorch-version: "2.0.1", requires: "oldest" }
- { os: "ubuntu-22.04", python-version: "3.11", pytorch-version: "2.4.1" }
- { os: "ubuntu-22.04", python-version: "3.12", pytorch-version: "2.5.0" }
# standard mac machine, not the M1
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/publish-pkg.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: 3.8
python-version: "3.10"

- name: Install dependencies
run: >-
Expand Down
33 changes: 23 additions & 10 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,43 +12,56 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added `NormalizedRootMeanSquaredError` metric to regression subpackage ([#2442](https://github.com/Lightning-AI/torchmetrics/pull/2442))
-


- Added `NegativePredictiveValue` to classification metrics ([#2433](https://github.com/Lightning-AI/torchmetrics/pull/2433))
### Changed

-

- Added method `merge_state` to `Metric` ([#2786](https://github.com/Lightning-AI/torchmetrics/pull/2786))

### Removed

- Added a new audio metric `NISQA` ([#2792](https://github.com/PyTorchLightning/metrics/pull/2792))
-


- Added `Dice` metric to segmentation metrics ([#2725](https://github.com/Lightning-AI/torchmetrics/pull/2725))
### Fixed

-

- Added support for propagation of the autograd graph in ddp setting ([#2754](https://github.com/Lightning-AI/torchmetrics/pull/2754))

---

## [1.6.0] - 2024-11-12

### Added

- Added audio metric `NISQA` ([#2792](https://github.com/PyTorchLightning/metrics/pull/2792))
- Added classification metric `LogAUC` ([#2377](https://github.com/Lightning-AI/torchmetrics/pull/2377))
- Added classification metric `NegativePredictiveValue` ([#2433](https://github.com/Lightning-AI/torchmetrics/pull/2433))
- Added regression metric `NormalizedRootMeanSquaredError` ([#2442](https://github.com/Lightning-AI/torchmetrics/pull/2442))
- Added segmentation metric `Dice` ([#2725](https://github.com/Lightning-AI/torchmetrics/pull/2725))
- Added method `merge_state` to `Metric` ([#2786](https://github.com/Lightning-AI/torchmetrics/pull/2786))
- Added support for propagation of the autograd graph in ddp setting ([#2754](https://github.com/Lightning-AI/torchmetrics/pull/2754))

### Changed

- Changed naming and input order arguments in `KLDivergence` ([#2800](https://github.com/Lightning-AI/torchmetrics/pull/2800))


### Deprecated

- Deprecated Dice from classification metrics ([#2725](https://github.com/Lightning-AI/torchmetrics/pull/2725))


### Removed

- Changed minimum supported Pytorch version to 2.0 ([#2671](https://github.com/Lightning-AI/torchmetrics/pull/2671))


- Dropped support for Python 3.8 ([#2827](https://github.com/Lightning-AI/torchmetrics/pull/2827))
- Removed `num_outputs` in `R2Score` ([#2800](https://github.com/Lightning-AI/torchmetrics/pull/2800))

### Fixed


- Fixed segmentation `Dice` + `GeneralizedDice` for 2d index tensors ([#2832](https://github.com/Lightning-AI/torchmetrics/pull/2832))
- Fixed mixed results of `rouge_score` with `accumulate='best'` ([#2830](https://github.com/Lightning-AI/torchmetrics/pull/2830))


Expand Down
8 changes: 4 additions & 4 deletions _samples/bert_score-own_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"""

from pprint import pprint
from typing import Dict, List, Union
from typing import Union

import torch
from torch import Tensor, nn
Expand Down Expand Up @@ -53,7 +53,7 @@ def __init__(self) -> None:
self.PAD_TOKEN: torch.zeros(1, _MODEL_DIM),
}

def __call__(self, sentences: Union[str, List[str]], max_len: int = _MAX_LEN) -> Dict[str, Tensor]:
def __call__(self, sentences: Union[str, list[str]], max_len: int = _MAX_LEN) -> dict[str, Tensor]:
"""Call method to tokenize user input.
The `__call__` method must be defined for this class. To ensure the functionality, the `__call__` method
Expand All @@ -69,7 +69,7 @@ def __call__(self, sentences: Union[str, List[str]], max_len: int = _MAX_LEN) ->
Python dictionary containing the keys `input_ids` and `attention_mask` with corresponding values.
"""
output_dict: Dict[str, Tensor] = {}
output_dict: dict[str, Tensor] = {}
if isinstance(sentences, str):
sentences = [sentences]
# Add special tokens
Expand All @@ -96,7 +96,7 @@ def get_user_model_encoder(num_layers: int = _NUM_LAYERS, d_model: int = _MODEL_
return nn.TransformerEncoder(encoder_layer, num_layers=num_layers)


def user_forward_fn(model: Module, batch: Dict[str, Tensor]) -> Tensor:
def user_forward_fn(model: Module, batch: dict[str, Tensor]) -> Tensor:
"""User forward function used for the computation of model embeddings.
This function might be arbitrarily complicated inside. However, to ensure functionality, it should obey the
Expand Down
2 changes: 1 addition & 1 deletion _samples/rouge_score-own_normalizer_and_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
"""

import re
from collections.abc import Sequence
from pprint import pprint
from typing import Sequence

from torchmetrics.text.rouge import ROUGEScore

Expand Down
55 changes: 55 additions & 0 deletions docs/source/classification/logauc.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
.. customcarditem::
:header: Log Area Receiver Operating Characteristic (LogAUC)
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg
:tags: Classification

.. include:: ../links.rst

#######
Log AUC
#######

Module Interface
________________

.. autoclass:: torchmetrics.LogAUC
:exclude-members: update, compute
:special-members: __new__

BinaryLogAUC
^^^^^^^^^^^^

.. autoclass:: torchmetrics.classification.BinaryLogAUC
:exclude-members: update, compute

MulticlassLogAUC
^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.classification.MulticlassLogAUC
:exclude-members: update, compute

MultilabelLogAUC
^^^^^^^^^^^^^^^^

.. autoclass:: torchmetrics.classification.MultilabelLogAUC
:exclude-members: update, compute

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.logauc

binary_logauc
^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.classification.binary_logauc

multiclass_logauc
^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.classification.multiclass_logauc

multilabel_logauc
^^^^^^^^^^^^^^^^^

.. autofunction:: torchmetrics.functional.classification.multilabel_logauc
1 change: 1 addition & 0 deletions docs/source/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -177,4 +177,5 @@
.. _Hausdorff Distance: https://en.wikipedia.org/wiki/Hausdorff_distance
.. _averaging curve objects: https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html
.. _Procrustes Disparity: https://en.wikipedia.org/wiki/Procrustes_analysis
.. _Log AUC: https://pubmed.ncbi.nlm.nih.gov/20735049/
.. _Negative Predictive Value: https://en.wikipedia.org/wiki/Positive_and_negative_predictive_values
3 changes: 1 addition & 2 deletions examples/audio/signal_to_noise_ratio.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

# %%
# Import necessary libraries
from typing import Tuple

import matplotlib.animation as animation
import matplotlib.pyplot as plt
Expand All @@ -20,7 +19,7 @@
# Generate a clean signal (simulating a high-quality recording)


def generate_clean_signal(length: int = 1000) -> Tuple[np.ndarray, np.ndarray]:
def generate_clean_signal(length: int = 1000) -> tuple[np.ndarray, np.ndarray]:
"""Generate a clean signal (sine wave)"""
t = np.linspace(0, 1, length)
signal = np.sin(2 * np.pi * 10 * t) # 10 Hz sine wave, representing the clean recording
Expand Down
7 changes: 3 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ requires = [
]

[tool.ruff]
target-version = "py38"
target-version = "py39"
line-length = 120

#[tool.ruff.pycodestyle]
Expand Down Expand Up @@ -68,6 +68,8 @@ lint.per-file-ignores."setup.py" = [
lint.per-file-ignores."src/**" = [
"ANN401",
"S310", # todo: Audit URL open for permitted schemes. Allowing use of `file:` or custom schemes is often unexpected.
"UP006", # todo: Use `list` instead of `List` for type annotation
"UP035", # todo: `typing.List` is deprecated, use `list` instead
]
lint.per-file-ignores."tests/**" = [
"ANN001",
Expand All @@ -77,9 +79,6 @@ lint.per-file-ignores."tests/**" = [
"S101",
"S301", # todo: `pickle` and modules that wrap it can be unsafe when used to deserialize untrusted data, possible security issue
]
lint.unfixable = [
"F401",
]
# Unlike Flake8, default to a complexity level of 10.
lint.mccabe.max-complexity = 10
# Use Google-style docstrings.
Expand Down
4 changes: 2 additions & 2 deletions requirements/_tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@ codecov ==2.1.13
coverage ==7.6.*
codecov ==2.1.13
pytest ==8.3.*
pytest-cov ==5.0.0
pytest-cov ==6.0.0
pytest-doctestplus ==1.2.1
pytest-rerunfailures ==14.0
pytest-timeout ==2.3.1
pytest-xdist ==3.6.1
phmdoctest ==1.4.0

psutil ==6.*
pyGithub >2.0.0, <2.5.0
pyGithub >2.0.0, <2.6.0
fire ==0.7.*

cloudpickle >1.3, <=3.1.0
Expand Down
1 change: 1 addition & 0 deletions requirements/classification_test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ pandas >1.4.0, <=2.2.3
netcal >1.0.0, <1.4.0 # calibration_error
numpy <2.2.0
fairlearn # group_fairness
PyTDC ==0.4.1 ; python_version <"3.12" # locauc, temporal_dependency
4 changes: 2 additions & 2 deletions requirements/text.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

nltk >3.8.1, <=3.9.1
tqdm <4.67.0
regex >=2021.9.24, <=2024.9.11
tqdm <4.68.0
regex >=2021.9.24, <=2024.11.6
transformers >4.4.0, <4.47.0
mecab-python3 >=1.0.6, <1.1.0
ipadic >=1.0.0, <1.1.0
Expand Down
9 changes: 5 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
import glob
import os
import re
from collections.abc import Iterable, Iterator
from functools import partial
from importlib.util import module_from_spec, spec_from_file_location
from itertools import chain
from pathlib import Path
from typing import Any, Iterable, Iterator, List, Optional, Tuple, Union
from typing import Any, Optional, Union

from pkg_resources import Requirement, yield_lines
from setuptools import find_packages, setup
Expand Down Expand Up @@ -97,7 +98,7 @@ def _parse_requirements(strs: Union[str, Iterable[str]]) -> Iterator[_Requiremen

def _load_requirements(
path_dir: str, file_name: str = "base.txt", unfreeze: bool = not _FREEZE_REQUIREMENTS
) -> List[str]:
) -> list[str]:
"""Load requirements from a file.
>>> _load_requirements(_PATH_REQUIRE)
Expand Down Expand Up @@ -161,7 +162,7 @@ def _load_py_module(fname: str, pkg: str = "torchmetrics"):
BASE_REQUIREMENTS = _load_requirements(path_dir=_PATH_REQUIRE, file_name="base.txt")


def _prepare_extras(skip_pattern: str = "^_", skip_files: Tuple[str] = ("base.txt",)) -> dict:
def _prepare_extras(skip_pattern: str = "^_", skip_files: tuple[str] = ("base.txt",)) -> dict:
"""Preparing extras for the package listing requirements.
Args:
Expand Down Expand Up @@ -215,7 +216,7 @@ def _prepare_extras(skip_pattern: str = "^_", skip_files: Tuple[str] = ("base.tx
include_package_data=True,
zip_safe=False,
keywords=["deep learning", "machine learning", "pytorch", "metrics", "AI"],
python_requires=">=3.8",
python_requires=">=3.9",
setup_requires=[],
install_requires=BASE_REQUIREMENTS,
extras_require=_prepare_extras(),
Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/__about__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
__version__ = "1.6.0dev"
__version__ = "1.7.0dev"
__author__ = "Lightning-AI et al."
__author_email__ = "[email protected]"
__license__ = "Apache-2.0"
__copyright__ = f"Copyright (c) 2020-2023, {__author__}."
__copyright__ = f"Copyright (c) 2020-2024, {__author__}."
__homepage__ = "https://github.com/Lightning-AI/torchmetrics"
__docs__ = "PyTorch native Metrics"
__docs_url__ = "https://lightning.ai/docs/torchmetrics/stable/"
Expand Down
Loading

0 comments on commit fe3ac74

Please sign in to comment.