From 2403db58858ceb765fe64423c72700d1ac986678 Mon Sep 17 00:00:00 2001 From: Mathieu Date: Wed, 21 Sep 2022 11:14:40 -0400 Subject: [PATCH] replace iou calculation with torchmetrics iou --- tests/utils/test_metrics.py | 30 +++++++++++-- utils/metrics.py | 85 +++++++++++++++++++++++-------------- 2 files changed, 79 insertions(+), 36 deletions(-) diff --git a/tests/utils/test_metrics.py b/tests/utils/test_metrics.py index 19d05461..d4fe38ec 100644 --- a/tests/utils/test_metrics.py +++ b/tests/utils/test_metrics.py @@ -3,11 +3,13 @@ from utils.metrics import create_metrics_dict, report_classification, iou import torch +import segmentation_models_pytorch as smp +from torchmetrics import JaccardIndex + # Test arrays: [bs=2, h=2, w,2] pred_multi = torch.tensor([0, 0, 2, 2, 0, 2, 1, 2, 1, 0, 2, 2, 1, 0, 2, 2]) pred_binary = torch.tensor([0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1]) - lbl_multi = torch.tensor([1, 0, 2, 2, 0, 1, 2, 0, 2, 2, 0, 0, 1, 2, 0, 1]) lbl_binary = torch.tensor([1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1]) # array with dont care @@ -100,13 +102,24 @@ def test_iou_multi_ignore_idx(self): """Evaluate iou calculation. Multiclass, with ignore_index in array.""" metrics_dict = create_metrics_dict(3) + # wih ignore_index == -1 metrics_dict = iou(pred_multi, lbl_multi_dc, batch_size=2, num_classes=3, metric_dict=metrics_dict, ignore_index=-1) - assert "{:.6f}".format(metrics_dict['iou'].val) == "0.169841" + assert "{:.6f}".format(metrics_dict['iou'].val) == "0.233333" + + # with ignore_index == 0 + metrics_dict = create_metrics_dict(3) + metrics_dict = iou(pred_multi, + lbl_multi, + batch_size=2, + num_classes=3, + metric_dict=metrics_dict, + ignore_index=0) + assert "{:.6f}".format(metrics_dict['iou'].val) == "0.208333" def test_iou_binary(self): """Evaluate iou calculation. @@ -124,10 +137,21 @@ def test_iou_binary_ignore_idx(self): """Evaluate iou calculation. Binary, with ignore_index in array.""" metrics_dict = create_metrics_dict(1) + # with ignore_index == -1 metrics_dict = iou(pred_binary, lbl_binary_dc, batch_size=2, num_classes=1, metric_dict=metrics_dict, ignore_index=-1) - assert "{:.6f}".format(metrics_dict['iou'].val) == "0.340659" + assert "{:.6f}".format(metrics_dict['iou'].val) == "0.435897" + + # with ignore_index == 0 + metrics_dict = create_metrics_dict(3) + metrics_dict = iou(pred_multi, + lbl_multi, + batch_size=2, + num_classes=3, + metric_dict=metrics_dict, + ignore_index=0) + assert "{:.6f}".format(metrics_dict['iou'].val) == "0.208333" \ No newline at end of file diff --git a/utils/metrics.py b/utils/metrics.py index ff482bb2..dc295a73 100644 --- a/utils/metrics.py +++ b/utils/metrics.py @@ -1,10 +1,13 @@ +from ast import Raise +from xml.dom import ValidationErr import numpy as np from sklearn.metrics import classification_report from math import sqrt from torch import IntTensor +from torchmetrics import JaccardIndex min_val = 1e-6 -def create_metrics_dict(num_classes): +def create_metrics_dict(num_classes, ignore_index=None): num_classes = num_classes + 1 if num_classes == 1 else num_classes @@ -12,10 +15,11 @@ def create_metrics_dict(num_classes): 'loss': AverageMeter(), 'iou': AverageMeter()} for i in range(0, num_classes): - metrics_dict['precision_' + str(i)] = AverageMeter() - metrics_dict['recall_' + str(i)] = AverageMeter() - metrics_dict['fscore_' + str(i)] = AverageMeter() - metrics_dict['iou_' + str(i)] = AverageMeter() + if ignore_index != i: + metrics_dict['precision_' + str(i)] = AverageMeter() + metrics_dict['recall_' + str(i)] = AverageMeter() + metrics_dict['fscore_' + str(i)] = AverageMeter() + metrics_dict['iou_' + str(i)] = AverageMeter() # Add overall non-background iou metric metrics_dict['iou_nonbg'] = AverageMeter() @@ -89,36 +93,51 @@ def report_classification(pred, label, batch_size, metrics_dict, ignore_index=-1 return metrics_dict -def iou(pred, label, batch_size, num_classes, metric_dict, ignore_index, only_present=True): +def iou(pred, label, batch_size, num_classes, metric_dict, ignore_index=None): """Calculate the intersection over union class-wise and mean-iou""" - ious = [] - num_classes = num_classes + 1 if num_classes == 1 else num_classes - pred = pred.cpu() - label = label.cpu() - pred[label == ignore_index] = ignore_index - for i in range(num_classes): - c_label = label == i - if only_present and c_label.sum() == 0: - ious.append(np.nan) - continue - c_pred = pred == i - intersection = (c_pred & c_label).float().sum() - union = (c_pred | c_label).float().sum() - iou = (intersection + min_val) / (union + min_val) # minimum value added to avoid Zero division - ious.append(iou) - metric_dict['iou_' + str(i)].update(iou.item(), batch_size) - # Add overall non-background iou metric - c_label = (1 <= label) & (label <= num_classes - 1) - c_pred = (1 <= pred) & (pred <= num_classes - 1) - intersection = (c_pred & c_label).float().sum() - union = (c_pred | c_label).float().sum() - iou = (intersection + min_val) / (union + min_val) # minimum value added to avoid Zero division - metric_dict['iou_nonbg'].update(iou.item(), batch_size) - - mean_IOU = np.nanmean(ious) - if (not only_present) or (not np.isnan(mean_IOU)): - metric_dict['iou'].update(mean_IOU, batch_size) + num_classes = num_classes + 1 if num_classes == 1 else num_classes + # Torchmetrics cannot handle ignore_index that are not in range 0 -> num_classes-1. + # if invalid ignore_index is provided, invalid values (e.g. -1) will be set to 0 + # and no ignore_index will be used. + if ignore_index and ignore_index not in range(0, num_classes-1): + pred[label == ignore_index] = 0 + label[label == ignore_index] = 0 + ignore_index = None + + cls_lst = [j for j in range(0, num_classes)] + if ignore_index is not None: + cls_lst.remove(ignore_index) + + jaccard = JaccardIndex(num_classes=num_classes, + average='none', + ignore_index=ignore_index, + absent_score=1) + cls_ious = jaccard(pred, label) + + + if len(cls_ious) > 1: + for i in range(len(cls_lst)): + metric_dict['iou_' + str(cls_lst[i])].update(cls_ious[i], batch_size) + + elif len(cls_ious) == 1: + if f"iou_{cls_lst[0]}" in metric_dict.keys(): + metric_dict['iou_' + str(cls_lst[0])].update(cls_ious, batch_size) + + jaccard_nobg = JaccardIndex(num_classes=num_classes, + average='macro', + ignore_index=0, + absent_score=1) + iou_nobg = jaccard_nobg(pred, label) + metric_dict['iou_nonbg'].update(iou_nobg.item(), batch_size) + + jaccard = JaccardIndex(num_classes=num_classes, + average='macro', + ignore_index=ignore_index, + absent_score=1) + mean_iou = jaccard(pred, label) + + metric_dict['iou'].update(mean_iou, batch_size) return metric_dict #### Benchmark Metrics ####