Skip to content

Commit

Permalink
replace iou calculation with torchmetrics iou
Browse files Browse the repository at this point in the history
  • Loading branch information
mpelchat04 committed Sep 21, 2022
1 parent 844e7a8 commit 2403db5
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 36 deletions.
30 changes: 27 additions & 3 deletions tests/utils/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
from utils.metrics import create_metrics_dict, report_classification, iou
import torch

import segmentation_models_pytorch as smp
from torchmetrics import JaccardIndex

# Test arrays: [bs=2, h=2, w,2]
pred_multi = torch.tensor([0, 0, 2, 2, 0, 2, 1, 2, 1, 0, 2, 2, 1, 0, 2, 2])
pred_binary = torch.tensor([0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1])


lbl_multi = torch.tensor([1, 0, 2, 2, 0, 1, 2, 0, 2, 2, 0, 0, 1, 2, 0, 1])
lbl_binary = torch.tensor([1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1])
# array with dont care
Expand Down Expand Up @@ -100,13 +102,24 @@ def test_iou_multi_ignore_idx(self):
"""Evaluate iou calculation.
Multiclass, with ignore_index in array."""
metrics_dict = create_metrics_dict(3)
# wih ignore_index == -1
metrics_dict = iou(pred_multi,
lbl_multi_dc,
batch_size=2,
num_classes=3,
metric_dict=metrics_dict,
ignore_index=-1)
assert "{:.6f}".format(metrics_dict['iou'].val) == "0.169841"
assert "{:.6f}".format(metrics_dict['iou'].val) == "0.233333"

# with ignore_index == 0
metrics_dict = create_metrics_dict(3)
metrics_dict = iou(pred_multi,
lbl_multi,
batch_size=2,
num_classes=3,
metric_dict=metrics_dict,
ignore_index=0)
assert "{:.6f}".format(metrics_dict['iou'].val) == "0.208333"

def test_iou_binary(self):
"""Evaluate iou calculation.
Expand All @@ -124,10 +137,21 @@ def test_iou_binary_ignore_idx(self):
"""Evaluate iou calculation.
Binary, with ignore_index in array."""
metrics_dict = create_metrics_dict(1)
# with ignore_index == -1
metrics_dict = iou(pred_binary,
lbl_binary_dc,
batch_size=2,
num_classes=1,
metric_dict=metrics_dict,
ignore_index=-1)
assert "{:.6f}".format(metrics_dict['iou'].val) == "0.340659"
assert "{:.6f}".format(metrics_dict['iou'].val) == "0.435897"

# with ignore_index == 0
metrics_dict = create_metrics_dict(3)
metrics_dict = iou(pred_multi,
lbl_multi,
batch_size=2,
num_classes=3,
metric_dict=metrics_dict,
ignore_index=0)
assert "{:.6f}".format(metrics_dict['iou'].val) == "0.208333"
85 changes: 52 additions & 33 deletions utils/metrics.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
from ast import Raise
from xml.dom import ValidationErr
import numpy as np
from sklearn.metrics import classification_report
from math import sqrt
from torch import IntTensor
from torchmetrics import JaccardIndex

min_val = 1e-6
def create_metrics_dict(num_classes):
def create_metrics_dict(num_classes, ignore_index=None):

num_classes = num_classes + 1 if num_classes == 1 else num_classes

metrics_dict = {'precision': AverageMeter(), 'recall': AverageMeter(), 'fscore': AverageMeter(),
'loss': AverageMeter(), 'iou': AverageMeter()}

for i in range(0, num_classes):
metrics_dict['precision_' + str(i)] = AverageMeter()
metrics_dict['recall_' + str(i)] = AverageMeter()
metrics_dict['fscore_' + str(i)] = AverageMeter()
metrics_dict['iou_' + str(i)] = AverageMeter()
if ignore_index != i:
metrics_dict['precision_' + str(i)] = AverageMeter()
metrics_dict['recall_' + str(i)] = AverageMeter()
metrics_dict['fscore_' + str(i)] = AverageMeter()
metrics_dict['iou_' + str(i)] = AverageMeter()

# Add overall non-background iou metric
metrics_dict['iou_nonbg'] = AverageMeter()
Expand Down Expand Up @@ -89,36 +93,51 @@ def report_classification(pred, label, batch_size, metrics_dict, ignore_index=-1
return metrics_dict


def iou(pred, label, batch_size, num_classes, metric_dict, ignore_index, only_present=True):
def iou(pred, label, batch_size, num_classes, metric_dict, ignore_index=None):
"""Calculate the intersection over union class-wise and mean-iou"""
ious = []
num_classes = num_classes + 1 if num_classes == 1 else num_classes
pred = pred.cpu()
label = label.cpu()
pred[label == ignore_index] = ignore_index
for i in range(num_classes):
c_label = label == i
if only_present and c_label.sum() == 0:
ious.append(np.nan)
continue
c_pred = pred == i
intersection = (c_pred & c_label).float().sum()
union = (c_pred | c_label).float().sum()
iou = (intersection + min_val) / (union + min_val) # minimum value added to avoid Zero division
ious.append(iou)
metric_dict['iou_' + str(i)].update(iou.item(), batch_size)

# Add overall non-background iou metric
c_label = (1 <= label) & (label <= num_classes - 1)
c_pred = (1 <= pred) & (pred <= num_classes - 1)
intersection = (c_pred & c_label).float().sum()
union = (c_pred | c_label).float().sum()
iou = (intersection + min_val) / (union + min_val) # minimum value added to avoid Zero division
metric_dict['iou_nonbg'].update(iou.item(), batch_size)

mean_IOU = np.nanmean(ious)
if (not only_present) or (not np.isnan(mean_IOU)):
metric_dict['iou'].update(mean_IOU, batch_size)
num_classes = num_classes + 1 if num_classes == 1 else num_classes
# Torchmetrics cannot handle ignore_index that are not in range 0 -> num_classes-1.
# if invalid ignore_index is provided, invalid values (e.g. -1) will be set to 0
# and no ignore_index will be used.
if ignore_index and ignore_index not in range(0, num_classes-1):
pred[label == ignore_index] = 0
label[label == ignore_index] = 0
ignore_index = None

cls_lst = [j for j in range(0, num_classes)]
if ignore_index is not None:
cls_lst.remove(ignore_index)

jaccard = JaccardIndex(num_classes=num_classes,
average='none',
ignore_index=ignore_index,
absent_score=1)
cls_ious = jaccard(pred, label)


if len(cls_ious) > 1:
for i in range(len(cls_lst)):
metric_dict['iou_' + str(cls_lst[i])].update(cls_ious[i], batch_size)

elif len(cls_ious) == 1:
if f"iou_{cls_lst[0]}" in metric_dict.keys():
metric_dict['iou_' + str(cls_lst[0])].update(cls_ious, batch_size)

jaccard_nobg = JaccardIndex(num_classes=num_classes,
average='macro',
ignore_index=0,
absent_score=1)
iou_nobg = jaccard_nobg(pred, label)
metric_dict['iou_nonbg'].update(iou_nobg.item(), batch_size)

jaccard = JaccardIndex(num_classes=num_classes,
average='macro',
ignore_index=ignore_index,
absent_score=1)
mean_iou = jaccard(pred, label)

metric_dict['iou'].update(mean_iou, batch_size)
return metric_dict

#### Benchmark Metrics ####
Expand Down

0 comments on commit 2403db5

Please sign in to comment.