From 15c0a8ab7c9e8c6cbc389e76179d2d5f57ce1269 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Fri, 22 Dec 2023 12:27:06 +0100 Subject: [PATCH] fix macro-f1 dimensions --- chebai/callbacks/epoch_metrics.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/chebai/callbacks/epoch_metrics.py b/chebai/callbacks/epoch_metrics.py index 75cae77d..30fde179 100644 --- a/chebai/callbacks/epoch_metrics.py +++ b/chebai/callbacks/epoch_metrics.py @@ -149,25 +149,27 @@ def apply_metric(self, target, pred, mode="train"): class MacroF1(torchmetrics.Metric): - def __init__(self, n_labels, dist_sync_on_step=False): + def __init__(self, n_labels, dist_sync_on_step=False, threshold=0.5): super().__init__(dist_sync_on_step=dist_sync_on_step) self.add_state( - "true_positives", default=torch.empty((0, n_labels)), dist_reduce_fx="sum" + "true_positives", default=torch.zeros((n_labels)), dist_reduce_fx="sum" ) self.add_state( "positive_predictions", - default=torch.empty((0, n_labels)), + default=torch.empty((n_labels)), dist_reduce_fx="sum", ) self.add_state( - "positive_labels", default=torch.empty((0, n_labels)), dist_reduce_fx="sum" + "positive_labels", default=torch.empty((n_labels)), dist_reduce_fx="sum" ) + self.threshold = threshold def update(self, preds: torch.Tensor, labels: torch.Tensor): - self.true_positives += torch.sum(torch.logical_and(preds, labels), dim=1) - self.positive_predictions += torch.sum(preds, dim=1) - self.positive_labels += torch.sum(labels, dim=1) + tps = torch.sum(torch.logical_and(preds > self.threshold, labels), dim=0) + self.true_positives += tps + self.positive_predictions += torch.sum(preds, dim=0) + self.positive_labels += torch.sum(labels, dim=0) def compute(self): mask = torch.logical_and(