Skip to content

Commit

Permalink
fix:Reference Metric in multiclass pecision recall unittests provides…
Browse files Browse the repository at this point in the history
… wrong answer when ignore_index is specified
  • Loading branch information
rittik9 committed Nov 24, 2024
1 parent 2365437 commit 259c4bd
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions tests/unittests/classification/test_precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,10 @@ def _reference_sklearn_precision_recall_multiclass(
if preds.ndim == target.ndim + 1:
preds = torch.argmax(preds, 1)

valid_labels = list(range(NUM_CLASSES))
if ignore_index is not None:
valid_labels = [label for label in valid_labels if label != ignore_index]

if multidim_average == "global":
preds = preds.numpy().flatten()
target = target.numpy().flatten()
Expand All @@ -210,7 +214,7 @@ def _reference_sklearn_precision_recall_multiclass(
target,
preds,
average=average,
labels=list(range(NUM_CLASSES)) if average is None else None,
labels=valid_labels if average in ("macro", "weighted") else None,
zero_division=zero_division,
)

Expand All @@ -235,7 +239,7 @@ def _reference_sklearn_precision_recall_multiclass(
true,
pred,
average=average,
labels=list(range(NUM_CLASSES)) if average is None else None,
labels=valid_labels if average in ("macro", "weighted") else None,
zero_division=zero_division,
)
res.append(0.0 if np.isnan(r).any() else r)
Expand Down

0 comments on commit 259c4bd

Please sign in to comment.