Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closes TorchSpatiotemporal/tsl#23 #24

Open
wants to merge 7 commits into
base: dev
Choose a base branch
from
167 changes: 109 additions & 58 deletions tsl/metrics/torch/metric_base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import inspect
from copy import deepcopy
from functools import partial
from typing import Any
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union

import torch
from torchmetrics import Metric
from torchmetrics.utilities.checks import _check_same_shape

from tsl.typing import Slicer
from tsl.utils.python_utils import parse_slicing_string


def convert_to_masked_metric(metric_fn, **kwargs):
"""
Expand Down Expand Up @@ -36,89 +39,137 @@ class MaskedMetric(Metric):

In particular a `MaskedMetric` accounts for missing values in the input
sequences by accepting a boolean mask as additional input.
Multiple metric functions can be specified,
in which case they will be averaged.
Weights can be assigned to perform a
weighted average of the different metrics.

Args:
metric_fn: Base function to compute the metric point wise.
mask_nans (bool, optional): Whether to automatically mask nan values.
mask_inf (bool, optional): Whether to automatically mask infinite
metric_fn (Sequence[callable], callable):
Base function to compute the metric
point-wise, multiple functions can be passed as a sequence.
mask_nans (bool): Whether to automatically mask nan values.
(default: :obj:`False`)
mask_inf (bool): Whether to automatically mask infinite
values.
at (int, optional): Whether to compute the metric only w.r.t. a certain
time step.
(default: :obj:`False`)
metric_fn_kwargs (Sequence[dict], dict, optional):
Keyword arguments needed by :obj:`metric_fn`.
Use a sequence of keyword arguments if different :obj:`metric_fn`
require different arguments.
(default: :obj:`None`)
metric_fn_kwargs (Sequence[float], float, optional):
Weight assigned to each :obj:`metric_fn`.
Use a sequence if different :obj:`metric_fn`
require different weights.
(default: :obj:`None`)
at (str, Sequence[Tuple[Slicer, ...] | str], tuple[Slicer, ...],
Slicer, optional):
Numpy style slicing to define specific parts
of the output to compute the metrics on.
Either one for all metric or a sequence for each metric.
Slicing can either be a proper slicing tuple
or a string representation containing just
the part you would put inside square brackets
to index an array/tensor.
(default: :obj:`None`)
full_state_update (bool, optional): Set this to overwrite the
:obj:`full_state_update` value of the
:obj:`torchmetrics.Metric` base class.
(default: :obj:`None`)
"""

is_differentiable: bool = None
higher_is_better: bool = None
full_state_update: bool = None

def __init__(self,
metric_fn,
mask_nans=False,
mask_inf=False,
metric_fn_kwargs=None,
at=None,
full_state_update: bool = None,
**kwargs: Any):
# set 'full_state_update' before Metric instantiation
if full_state_update is not None:
self.__dict__['full_state_update'] = full_state_update
super(MaskedMetric, self).__init__(**kwargs)

def __init__(
self,
metric_fn: Union[Sequence[Callable], Callable],
metric_fn_kwargs: Optional[Union[Sequence[Dict[str, Any]],
Dict[str, Any]]] = None,
mask_nans: bool = False,
mask_inf: bool = False,
at: Union[str, Sequence[Union[Tuple[Slicer, ...], str]],
tuple[Slicer, ...], Slicer] = ...,
weights: Optional[Sequence[float]] = None,
full_state_update: Optional[bool] = None,
**kwargs: Any,
):
super().__init__(
metric_fn=None,
mask_nans=mask_nans,
mask_inf=mask_inf,
metric_fn_kwargs=None,
at=None,
full_state_update=full_state_update,
**kwargs,
)
assert (
len({
len(e)
for e in (metric_fn, metric_fn_kwargs, at, weights)
if isinstance(e, Sequence)
}) == 1
), "All sequences used as masked metric arguments " \
"must have the same length."
if metric_fn_kwargs is None:
metric_fn_kwargs = dict()

self.metric_fn = partial(metric_fn, **metric_fn_kwargs)

metric_fn_kwargs = {}
if isinstance(metric_fn, Sequence) and isinstance(
metric_fn_kwargs, Sequence):
self.metric_fn = tuple(
partial(fn, **fn_kwargs)
for fn, fn_kwargs in zip(metric_fn, metric_fn_kwargs))
elif isinstance(metric_fn, Sequence):
self.metric_fn = tuple(
partial(fn, **metric_fn_kwargs) for fn in metric_fn)
else:
self.metric_fn = (partial(metric_fn, **metric_fn_kwargs), )
if isinstance(at, str) or not isinstance(at, Sequence):
at = (at, )
at = list(
parse_slicing_string(e) if isinstance(e, str) else e for e in at)
self.at = at * len(self.metric_fn) if len(at) == 1 else at
if weights is None:
self.weights = (1.0, ) * len(self.metric_fn)
else:
self.weights = weights
self.mask_nans = mask_nans
self.mask_inf = mask_inf
if at is None:
self.at = slice(None)
else:
self.at = slice(at, at + 1)
self.add_state('value',
dist_reduce_fx='sum',
default=torch.tensor(0., dtype=torch.float))
self.add_state('numel',
dist_reduce_fx='sum',
default=torch.tensor(0., dtype=torch.float))

def _check_mask(self, mask, val):
self.add_state("value",
dist_reduce_fx="sum",
default=torch.tensor(0.0, dtype=torch.float))
self.add_state("numel",
dist_reduce_fx="sum",
default=torch.tensor(0.0, dtype=torch.float))

def _check_mask(self, mask, val, at=...):
if mask is None:
mask = torch.ones_like(val, dtype=torch.bool)
else:
mask = mask.bool()
mask = mask[at].bool()
_check_same_shape(mask, val)
if self.mask_nans:
mask = mask & ~torch.isnan(val)
if self.mask_inf:
mask = mask & ~torch.isinf(val)
return mask

def _compute_masked(self, y_hat, y, mask):
_check_same_shape(y_hat, y)
val = self.metric_fn(y_hat, y)
mask = self._check_mask(mask, val)
val = torch.where(mask, val, torch.zeros_like(val))
return val.sum(), mask.sum()

def _compute_std(self, y_hat, y):
_check_same_shape(y_hat, y)
val = self.metric_fn(y_hat, y)
return val.sum(), val.numel()

def is_masked(self, mask):
return self.mask_inf or self.mask_nans or (mask is not None)

def update(self, y_hat, y, mask=None):
y_hat = y_hat[:, self.at]
y = y[:, self.at]
if mask is not None:
mask = mask[:, self.at]
if self.is_masked(mask):
val, numel = self._compute_masked(y_hat, y, mask)
else:
val, numel = self._compute_std(y_hat, y)
self.value += val
self.numel += numel
_check_same_shape(y_hat, y)
for i in range(len(self.metric_fn)):
val = self.metric_fn[i](y_hat[self.at[i]], y[self.at[i]])
if self.is_masked(mask):
mask = self._check_mask(mask, val, self.at[i])
val[~mask] = 0
numel = mask.sum()
else:
numel = val.numel()
self.value += val.sum() * self.weights[i]
self.numel += numel

def compute(self):
if self.numel > 0:
Expand Down
85 changes: 62 additions & 23 deletions tsl/metrics/torch/metrics.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Optional

import torch
from torch.nn import functional as F
Expand All @@ -14,106 +14,144 @@ class MaskedMAE(MaskedMetric):
"""Mean Absolute Error Metric.

Args:
mask_nans (bool, optional): Whether to automatically mask nan values.
mask_inf (bool, optional): Whether to automatically mask infinite
mask_nans (bool): Whether to automatically mask nan values.
(default: :obj:`False`)
mask_inf (bool): Whether to automatically mask infinite
values.
(default: :obj:`False`)
at (int, optional): Whether to compute the metric only w.r.t. a certain
time step.
time step.
(default: :obj:`None`)
dim (int): The index of the dimension that represents time in a batch.
Relevant only when also 'at' is defined.
Default assumes [b t n f] format.
(default: :obj:`1`)
"""

is_differentiable: bool = True
higher_is_better: bool = False
full_state_update: bool = False

def __init__(self,
mask_nans=False,
mask_inf=False,
at=None,
mask_nans: bool = False,
mask_inf: bool = False,
at: Optional[int] = None,
dim: int = 1,
**kwargs: Any):
super(MaskedMAE, self).__init__(metric_fn=F.l1_loss,
mask_nans=mask_nans,
mask_inf=mask_inf,
metric_fn_kwargs={'reduction': 'none'},
at=at,
dim=dim,
**kwargs)


class MaskedMAPE(MaskedMetric):
"""Mean Absolute Percentage Error Metric.

Args:
mask_nans (bool, optional): Whether to automatically mask nan values.
mask_nans (bool): Whether to automatically mask nan values.
(default: :obj:`False`)
at (int, optional): Whether to compute the metric only w.r.t. a certain
time step.
(default: :obj:`None`)
dim (int): The index of the dimension that represents time in a batch.
Relevant only when also 'at' is defined.
Default assumes [b t n f] format.
(default: :obj:`1`)
"""

is_differentiable: bool = True
higher_is_better: bool = False
full_state_update: bool = False

def __init__(self, mask_nans=False, at=None, **kwargs: Any):
def __init__(self,
mask_nans: bool = False,
at: Optional[int] = None,
dim: int = 1,
**kwargs: Any):
super(MaskedMAPE,
self).__init__(metric_fn=mape,
mask_nans=mask_nans,
mask_inf=True,
metric_fn_kwargs={'reduction': 'none'},
at=at,
dim=dim,
**kwargs)


class MaskedMSE(MaskedMetric):
"""Mean Squared Error Metric.

Args:
mask_nans (bool, optional): Whether to automatically mask nan values.
mask_inf (bool, optional): Whether to automatically mask infinite
mask_nans (bool): Whether to automatically mask nan values.
(default: :obj:`False`)
mask_inf (bool): Whether to automatically mask infinite
values.
(default: :obj:`False`)
at (int, optional): Whether to compute the metric only w.r.t. a certain
time step.
(default: :obj:`None`)
dim (int): The index of the dimension that represents time in a batch.
Relevant only when also 'at' is defined.
Default assumes [b t n f] format.
(default: :obj:`1`)
"""

is_differentiable: bool = True
higher_is_better: bool = False
full_state_update: bool = False

def __init__(self,
mask_nans=False,
mask_inf=False,
at=None,
mask_nans: bool = False,
mask_inf: bool = False,
at: Optional[int] = None,
dim: int = 1,
**kwargs: Any):
super(MaskedMSE, self).__init__(metric_fn=F.mse_loss,
mask_nans=mask_nans,
mask_inf=mask_inf,
metric_fn_kwargs={'reduction': 'none'},
at=at,
dim=dim,
**kwargs)


class MaskedMRE(MaskedMetric):
"""Mean Relative Error Metric.

Args:
mask_nans (bool, optional): Whether to automatically mask nan values.
mask_inf (bool, optional): Whether to automatically mask infinite
mask_nans (bool): Whether to automatically mask nan values.
(default: :obj:`False`)
mask_inf (bool): Whether to automatically mask infinite
values.
(default: :obj:`False`)
at (int, optional): Whether to compute the metric only w.r.t. a certain
time step.
(default: :obj:`None`)
dim (int): The index of the dimension that represents time in a batch.
Relevant only when also 'at' is defined.
Default assumes [b t n f] format.
(default: :obj:`1`)
"""

is_differentiable: bool = True
higher_is_better: bool = False
full_state_update: bool = False

def __init__(self,
mask_nans=False,
mask_inf=False,
at=None,
mask_nans: bool = False,
mask_inf: bool = False,
at: Optional[int] = None,
dim: int = 1,
**kwargs: Any):
super(MaskedMRE, self).__init__(metric_fn=F.l1_loss,
mask_nans=mask_nans,
mask_inf=mask_inf,
metric_fn_kwargs={'reduction': 'none'},
at=at,
dim=dim,
**kwargs)
self.add_state('tot',
dist_reduce_fx='sum',
Expand All @@ -138,10 +176,11 @@ def compute(self):
return self.value

def update(self, y_hat, y, mask=None):
y_hat = y_hat[:, self.at]
y = y[:, self.at]
if mask is not None:
mask = mask[:, self.at]
if self.at is not None:
y_hat = y_hat.select(self.dim, self.at)
y = y.select(self.dim, self.at)
if mask is not None:
mask = mask.select(self.dim, self.at)
if self.is_masked(mask):
val, numel, tot = self._compute_masked(y_hat, y, mask)
else:
Expand Down
Loading