Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce memory usage for certain image metrics #2089

Merged
merged 12 commits into from
Oct 1, 2023
3 changes: 1 addition & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

-


### Changed

-
- Change default state of `SpectralAngleMapper` and `UniversalImageQualityIndex` to be tensors ([#2089](https://github.com/Lightning-AI/torchmetrics/pull/2089))


### Removed
Expand Down
45 changes: 31 additions & 14 deletions src/torchmetrics/image/sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
from typing import Any, List, Optional, Sequence, Union

from torch import Tensor
from torch import Tensor, tensor
from typing_extensions import Literal

from torchmetrics.functional.image.sam import _sam_compute, _sam_update
Expand Down Expand Up @@ -75,33 +75,50 @@ class SpectralAngleMapper(Metric):

preds: List[Tensor]
target: List[Tensor]
sum_sam: Tensor
numel: Tensor

def __init__(
self,
reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean",
reduction: Optional[Literal["elementwise_mean", "sum", "none"]] = "elementwise_mean",
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
rank_zero_warn(
"Metric `SpectralAngleMapper` will save all targets and predictions in the buffer."
" For large datasets, this may lead to a large memory footprint."
)

self.add_state("preds", default=[], dist_reduce_fx="cat")
self.add_state("target", default=[], dist_reduce_fx="cat")
if reduction not in ("elementwise_mean", "sum", "none", None):
raise ValueError(
f"The `reduction` {reduction} is not valid. Valid options are `elementwise_mean`, `sum`, `none`, None."
)
if reduction == "none" or reduction is None:
rank_zero_warn(
"Metric `SpectralAngleMapper` will save all targets and predictions in the buffer when using"
"`reduction=None` or `reduction='none'. For large datasets, this may lead to a large memory footprint."
)
self.add_state("preds", default=[], dist_reduce_fx="cat")
self.add_state("target", default=[], dist_reduce_fx="cat")
else:
self.add_state("sum_sam", tensor(0.0), dist_reduce_fx="sum")
self.add_state("numel", tensor(0), dist_reduce_fx="sum")
self.reduction = reduction

def update(self, preds: Tensor, target: Tensor) -> None:
"""Update state with predictions and targets."""
preds, target = _sam_update(preds, target)
self.preds.append(preds)
self.target.append(target)
if self.reduction == "none" or self.reduction is None:
self.preds.append(preds)
self.target.append(target)
else:
sam_score = _sam_compute(preds, target, reduction="sum")
self.sum_sam += sam_score
p_shape = preds.shape
self.numel += p_shape[0] * p_shape[2] * p_shape[3]

def compute(self) -> Tensor:
"""Compute spectra over state."""
preds = dim_zero_cat(self.preds)
target = dim_zero_cat(self.target)
return _sam_compute(preds, target, self.reduction)
if self.reduction == "none" or self.reduction is None:
preds = dim_zero_cat(self.preds)
target = dim_zero_cat(self.target)
return _sam_compute(preds, target, self.reduction)
return self.sum_sam / self.numel if self.reduction == "elementwise_mean" else self.sum_sam

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
Expand Down
44 changes: 30 additions & 14 deletions src/torchmetrics/image/uqi.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
from typing import Any, List, Optional, Sequence, Union

from torch import Tensor
from torch import Tensor, tensor
from typing_extensions import Literal

from torchmetrics.functional.image.uqi import _uqi_compute, _uqi_update
Expand Down Expand Up @@ -73,6 +73,8 @@ class UniversalImageQualityIndex(Metric):

preds: List[Tensor]
target: List[Tensor]
sum_uqi: Tensor
numel: Tensor

def __init__(
self,
Expand All @@ -82,29 +84,43 @@ def __init__(
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
rank_zero_warn(
"Metric `UniversalImageQualityIndex` will save all targets and"
" predictions in buffer. For large datasets this may lead"
" to large memory footprint."
)

self.add_state("preds", default=[], dist_reduce_fx="cat")
self.add_state("target", default=[], dist_reduce_fx="cat")
if reduction not in ("elementwise_mean", "sum", "none", None):
raise ValueError(
f"The `reduction` {reduction} is not valid. Valid options are `elementwise_mean`, `sum`, `none`, None."
)
if reduction is None or reduction == "none":
rank_zero_warn(
"Metric `UniversalImageQualityIndex` will save all targets and predictions in the buffer when using"
"`reduction=None` or `reduction='none'. For large datasets, this may lead to a large memory footprint."
)
self.add_state("preds", default=[], dist_reduce_fx="cat")
self.add_state("target", default=[], dist_reduce_fx="cat")
else:
self.add_state("sum_uqi", tensor(0.0), dist_reduce_fx="sum")
self.add_state("numel", tensor(0), dist_reduce_fx="sum")
self.kernel_size = kernel_size
self.sigma = sigma
self.reduction = reduction

def update(self, preds: Tensor, target: Tensor) -> None:
"""Update state with predictions and targets."""
preds, target = _uqi_update(preds, target)
self.preds.append(preds)
self.target.append(target)
if self.reduction is None or self.reduction == "none":
self.preds.append(preds)
self.target.append(target)
else:
uqi_score = _uqi_compute(preds, target, self.kernel_size, self.sigma, reduction="sum")
self.sum_uqi += uqi_score
ps = preds.shape
self.numel += ps[0] * ps[1] * (ps[2] - self.kernel_size[0] + 1) * (ps[3] - self.kernel_size[1] + 1)

def compute(self) -> Tensor:
"""Compute explained variance over state."""
preds = dim_zero_cat(self.preds)
target = dim_zero_cat(self.target)
return _uqi_compute(preds, target, self.kernel_size, self.sigma, self.reduction)
if self.reduction == "none" or self.reduction is None:
preds = dim_zero_cat(self.preds)
target = dim_zero_cat(self.target)
return _uqi_compute(preds, target, self.kernel_size, self.sigma, self.reduction)
return self.sum_uqi / self.numel if self.reduction == "elementwise_mean" else self.sum_uqi

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
Expand Down
Loading