Skip to content

Commit

Permalink
fix macro-f1 dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
sfluegel committed Dec 22, 2023
1 parent 78ae996 commit 15c0a8a
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions chebai/callbacks/epoch_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,25 +149,27 @@ def apply_metric(self, target, pred, mode="train"):


class MacroF1(torchmetrics.Metric):
def __init__(self, n_labels, dist_sync_on_step=False):
def __init__(self, n_labels, dist_sync_on_step=False, threshold=0.5):
super().__init__(dist_sync_on_step=dist_sync_on_step)

self.add_state(
"true_positives", default=torch.empty((0, n_labels)), dist_reduce_fx="sum"
"true_positives", default=torch.zeros((n_labels)), dist_reduce_fx="sum"
)
self.add_state(
"positive_predictions",
default=torch.empty((0, n_labels)),
default=torch.empty((n_labels)),
dist_reduce_fx="sum",
)
self.add_state(
"positive_labels", default=torch.empty((0, n_labels)), dist_reduce_fx="sum"
"positive_labels", default=torch.empty((n_labels)), dist_reduce_fx="sum"
)
self.threshold = threshold

def update(self, preds: torch.Tensor, labels: torch.Tensor):
self.true_positives += torch.sum(torch.logical_and(preds, labels), dim=1)
self.positive_predictions += torch.sum(preds, dim=1)
self.positive_labels += torch.sum(labels, dim=1)
tps = torch.sum(torch.logical_and(preds > self.threshold, labels), dim=0)
self.true_positives += tps
self.positive_predictions += torch.sum(preds, dim=0)
self.positive_labels += torch.sum(labels, dim=0)

def compute(self):
mask = torch.logical_and(
Expand Down

0 comments on commit 15c0a8a

Please sign in to comment.