Skip to content

Commit

Permalink
Use lightning utils apply_to_collection (Lightning-AI#2013)
Browse files Browse the repository at this point in the history
* refactor to use lightning utils
* increase requirement
  • Loading branch information
SkafteNicki authored Aug 22, 2023
1 parent 928fbfc commit 09fccca
Show file tree
Hide file tree
Showing 9 changed files with 10 additions and 61 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ numpy >1.20.0
torch >=1.8.1, <=2.0.1
typing-extensions; python_version < '3.9'
packaging # hotfix for utils, can be dropped with lit-utils >=0.5
lightning-utilities >=0.7.0, <0.10.0
lightning-utilities >=0.8.0, <0.10.0
2 changes: 1 addition & 1 deletion src/torchmetrics/detection/mean_ap.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@

import numpy as np
import torch
from lightning_utilities import apply_to_collection
from torch import Tensor
from torch import distributed as dist
from typing_extensions import Literal

from torchmetrics.detection.helpers import _fix_empty_tensors, _input_validator, _validate_iou_type_arg
from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.data import apply_to_collection
from torchmetrics.utilities.imports import (
_MATPLOTLIB_AVAILABLE,
_PYCOCOTOOLS_AVAILABLE,
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@
from typing import Any, Callable, ClassVar, Dict, Generator, List, Optional, Sequence, Tuple, Union

import torch
from lightning_utilities import apply_to_collection
from torch import Tensor
from torch.nn import Module

from torchmetrics.utilities.data import (
_flatten,
_squeeze_if_scalar,
apply_to_collection,
dim_zero_cat,
dim_zero_max,
dim_zero_mean,
Expand Down
2 changes: 0 additions & 2 deletions src/torchmetrics/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.utilities.checks import check_forward_full_state_property
from torchmetrics.utilities.data import apply_to_collection
from torchmetrics.utilities.distributed import class_reduce, reduce
from torchmetrics.utilities.prints import rank_zero_debug, rank_zero_info, rank_zero_warn

__all__ = [
"check_forward_full_state_property",
"apply_to_collection",
"class_reduce",
"reduce",
"rank_zero_debug",
Expand Down
54 changes: 2 additions & 52 deletions src/torchmetrics/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union
from typing import Any, Dict, List, Optional, Sequence, Union

import torch
from lightning_utilities import apply_to_collection
from torch import Tensor

from torchmetrics.utilities.exceptions import TorchMetricsUserWarning
Expand Down Expand Up @@ -152,57 +153,6 @@ def to_categorical(x: Tensor, argmax_dim: int = 1) -> Tensor:
return torch.argmax(x, dim=argmax_dim)


def apply_to_collection(
data: Any,
dtype: Union[type, tuple],
function: Callable,
*args: Any,
wrong_dtype: Optional[Union[type, tuple]] = None,
**kwargs: Any,
) -> Any:
"""Recursively applies a function to all elements of a certain dtype.
Args:
data: the collection to apply the function to
dtype: the given function will be applied to all elements of this dtype
function: the function to apply
*args: positional arguments (will be forwarded to call of ``function``)
wrong_dtype: the given function won't be applied if this type is specified and the given collections is of
the :attr:`wrong_type` even if it is of type :attr`dtype`
**kwargs: keyword arguments (will be forwarded to call of ``function``)
Returns:
the resulting collection
Example:
>>> apply_to_collection(torch.tensor([8, 0, 2, 6, 7]), dtype=Tensor, function=lambda x: x ** 2)
tensor([64, 0, 4, 36, 49])
>>> apply_to_collection([8, 0, 2, 6, 7], dtype=int, function=lambda x: x ** 2)
[64, 0, 4, 36, 49]
>>> apply_to_collection(dict(abc=123), dtype=int, function=lambda x: x ** 2)
{'abc': 15129}
"""
elem_type = type(data)

# Breaking condition
if isinstance(data, dtype) and (wrong_dtype is None or not isinstance(data, wrong_dtype)):
return function(data, *args, **kwargs)

# Recursively apply to collection items
if isinstance(data, Mapping):
return elem_type({k: apply_to_collection(v, dtype, function, *args, **kwargs) for k, v in data.items()})

if isinstance(data, tuple) and hasattr(data, "_fields"): # named tuple
return elem_type(*(apply_to_collection(d, dtype, function, *args, **kwargs) for d in data))

if isinstance(data, Sequence) and not isinstance(data, str):
return elem_type([apply_to_collection(d, dtype, function, *args, **kwargs) for d in data])

# data is neither of dtype, nor a collection
return data


def _squeeze_scalar_element_tensor(x: Tensor) -> Tensor:
return x.squeeze() if x.numel() == 1 else x

Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/wrappers/bootstrapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
from typing import Any, Dict, Optional, Sequence, Union

import torch
from lightning_utilities import apply_to_collection
from torch import Tensor
from torch.nn import ModuleList

from torchmetrics.metric import Metric
from torchmetrics.utilities import apply_to_collection
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
from torchmetrics.wrappers.abstract import WrapperMetric
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/wrappers/multioutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
from typing import Any, List, Optional, Sequence, Tuple, Union

import torch
from lightning_utilities import apply_to_collection
from torch import Tensor
from torch.nn import ModuleList

from torchmetrics.metric import Metric
from torchmetrics.utilities import apply_to_collection
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
from torchmetrics.wrappers.abstract import WrapperMetric
Expand Down
3 changes: 2 additions & 1 deletion tests/unittests/helpers/testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
import numpy as np
import pytest
import torch
from lightning_utilities import apply_to_collection
from torch import Tensor, tensor
from torchmetrics import Metric
from torchmetrics.utilities.data import _flatten, apply_to_collection
from torchmetrics.utilities.data import _flatten

from unittests import NUM_PROCESSES

Expand Down
2 changes: 1 addition & 1 deletion tests/unittests/wrappers/test_bootstrapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
import numpy as np
import pytest
import torch
from lightning_utilities import apply_to_collection
from sklearn.metrics import mean_squared_error, precision_score, recall_score
from torch import Tensor
from torchmetrics.classification import MulticlassPrecision, MulticlassRecall
from torchmetrics.regression import MeanSquaredError
from torchmetrics.utilities import apply_to_collection
from torchmetrics.wrappers.bootstrapping import BootStrapper, _bootstrap_sampler

from unittests.helpers import seed_all
Expand Down

0 comments on commit 09fccca

Please sign in to comment.