Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix: ignore index + macro combination in classification metrics #2443

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/torchmetrics/classification/precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def compute(self) -> Tensor:
"""Compute metric."""
tp, fp, tn, fn = self._final_state()
return _precision_recall_reduce(
"precision", tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average, top_k=self.top_k
"precision", tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average
)

def plot(
Expand Down Expand Up @@ -702,7 +702,7 @@ def compute(self) -> Tensor:
"""Compute metric."""
tp, fp, tn, fn = self._final_state()
return _precision_recall_reduce(
"recall", tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average, top_k=self.top_k
"recall", tp, fp, tn, fn, average=self.average, multidim_average=self.multidim_average
)

def plot(
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/functional/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def _accuracy_reduce(
return _safe_divide(tp, tp + fn)

score = _safe_divide(tp + tn, tp + tn + fp + fn) if multilabel else _safe_divide(tp, tp + fn)
return _adjust_weights_safe_divide(score, average, multilabel, tp, fp, fn, top_k)
return _adjust_weights_safe_divide(score, average, multilabel, tp, fn)


def binary_accuracy(
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/functional/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _fbeta_reduce(
return _safe_divide((1 + beta2) * tp, (1 + beta2) * tp + beta2 * fn + fp)

fbeta_score = _safe_divide((1 + beta2) * tp, (1 + beta2) * tp + beta2 * fn + fp)
return _adjust_weights_safe_divide(fbeta_score, average, multilabel, tp, fp, fn)
return _adjust_weights_safe_divide(fbeta_score, average, multilabel, tp, fn)


def _binary_fbeta_score_arg_validation(
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/functional/classification/hamming.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _hamming_distance_reduce(
return 1 - _safe_divide(tp, tp + fn)

score = 1 - _safe_divide(tp + tn, tp + tn + fp + fn) if multilabel else 1 - _safe_divide(tp, tp + fn)
return _adjust_weights_safe_divide(score, average, multilabel, tp, fp, fn)
return _adjust_weights_safe_divide(score, average, multilabel, tp, fn)


def binary_hamming_distance(
Expand Down
23 changes: 19 additions & 4 deletions src/torchmetrics/functional/classification/precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def _precision_recall_reduce(
average: Optional[Literal["binary", "micro", "macro", "weighted", "none"]],
multidim_average: Literal["global", "samplewise"] = "global",
multilabel: bool = False,
top_k: int = 1,
) -> Tensor:
different_stat = fp if stat == "precision" else fn # this is what differs between the two scores
if average == "binary":
Expand All @@ -55,7 +54,7 @@ def _precision_recall_reduce(
return _safe_divide(tp, tp + different_stat)

score = _safe_divide(tp, tp + different_stat)
return _adjust_weights_safe_divide(score, average, multilabel, tp, fp, fn, top_k=top_k)
return _adjust_weights_safe_divide(score, average, multilabel, tp, fn)


def binary_precision(
Expand Down Expand Up @@ -237,7 +236,15 @@ def multiclass_precision(
preds, target, num_classes, top_k, average, multidim_average, ignore_index
)
return _precision_recall_reduce(
"precision", tp, fp, tn, fn, average=average, multidim_average=multidim_average, top_k=top_k
"precision",
tp,
fp,
tn,
fn,
average=average,
multidim_average=multidim_average,
top_k=top_k,
ignore_index=ignore_index,
)


Expand Down Expand Up @@ -523,7 +530,15 @@ def multiclass_recall(
preds, target, num_classes, top_k, average, multidim_average, ignore_index
)
return _precision_recall_reduce(
"recall", tp, fp, tn, fn, average=average, multidim_average=multidim_average, top_k=top_k
"recall",
tp,
fp,
tn,
fn,
average=average,
multidim_average=multidim_average,
top_k=top_k,
ignore_index=ignore_index,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _specificity_reduce(
return _safe_divide(tn, tn + fp)

specificity_score = _safe_divide(tn, tn + fp)
return _adjust_weights_safe_divide(specificity_score, average, multilabel, tp, fp, fn)
return _adjust_weights_safe_divide(specificity_score, average, multilabel, tp, fn)


def binary_specificity(
Expand Down
8 changes: 6 additions & 2 deletions src/torchmetrics/utilities/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,11 @@ def _safe_divide(num: Tensor, denom: Tensor) -> Tensor:


def _adjust_weights_safe_divide(
score: Tensor, average: Optional[str], multilabel: bool, tp: Tensor, fp: Tensor, fn: Tensor, top_k: int = 1
score: Tensor,
average: Optional[str],
multilabel: bool,
tp: Tensor,
fn: Tensor,
) -> Tensor:
if average is None or average == "none":
return score
Expand All @@ -65,7 +69,7 @@ def _adjust_weights_safe_divide(
else:
weights = torch.ones_like(score)
if not multilabel:
weights[tp + fp + fn == 0 if top_k == 1 else tp + fn == 0] = 0.0
weights[tp + fn == 0] = 0.0
return _safe_divide(weights * score, weights.sum(-1, keepdim=True)).sum(-1)


Expand Down
8 changes: 2 additions & 6 deletions tests/unittests/classification/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,7 @@ def _reference_sklearn_accuracy_multiclass(preds, target, ignore_index, multidim
acc_per_class = confmat.diagonal() / confmat.sum(axis=1)
acc_per_class[np.isnan(acc_per_class)] = 0.0
if average == "macro":
acc_per_class = acc_per_class[
(np.bincount(preds, minlength=NUM_CLASSES) + np.bincount(target, minlength=NUM_CLASSES)) != 0.0
]
acc_per_class = acc_per_class[np.bincount(target, minlength=NUM_CLASSES) != 0.0]
return acc_per_class.mean()
if average == "weighted":
weights = confmat.sum(1)
Expand All @@ -215,9 +213,7 @@ def _reference_sklearn_accuracy_multiclass(preds, target, ignore_index, multidim
acc_per_class = confmat.diagonal() / confmat.sum(axis=1)
acc_per_class[np.isnan(acc_per_class)] = 0.0
if average == "macro":
acc_per_class = acc_per_class[
(np.bincount(pred, minlength=NUM_CLASSES) + np.bincount(true, minlength=NUM_CLASSES)) != 0.0
]
acc_per_class = acc_per_class[np.bincount(true, minlength=NUM_CLASSES) != 0.0]
res.append(acc_per_class.mean() if len(acc_per_class) > 0 else 0.0)
elif average == "weighted":
weights = confmat.sum(1)
Expand Down
Loading