Skip to content

Commit

Permalink
add tests for metrics calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
mpelchat04 committed Sep 20, 2022
1 parent c39dfd5 commit 844e7a8
Show file tree
Hide file tree
Showing 3 changed files with 168 additions and 19 deletions.
133 changes: 133 additions & 0 deletions tests/utils/test_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import numpy as np
from torchmetrics.functional import precision_recall
from utils.metrics import create_metrics_dict, report_classification, iou
import torch

# Test arrays: [bs=2, h=2, w,2]
pred_multi = torch.tensor([0, 0, 2, 2, 0, 2, 1, 2, 1, 0, 2, 2, 1, 0, 2, 2])
pred_binary = torch.tensor([0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1])


lbl_multi = torch.tensor([1, 0, 2, 2, 0, 1, 2, 0, 2, 2, 0, 0, 1, 2, 0, 1])
lbl_binary = torch.tensor([1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1])
# array with dont care
lbl_multi_dc = torch.tensor([-1, -1, 2, 2, 0, 1, 2, 0, 2, 2, 0, 0, 1, 2, 0, 1])
lbl_binary_dc = torch.tensor([-1, -1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1])


class TestMetrics(object):
def test_create_metrics_dict(self):
"""Evaluate the metrics dictionnary creation.
Binary and multiclass"""
# binary tasks have 1 class at class definition.
num_classes = 1
metrics_dict = create_metrics_dict(num_classes)
assert 'iou_1' in metrics_dict.keys()
assert 'iou_2' not in metrics_dict.keys()

num_classes = 3
metrics_dict = create_metrics_dict(num_classes)
assert 'iou_1' in metrics_dict.keys()
assert 'iou_2' in metrics_dict.keys()
assert 'iou_3' not in metrics_dict.keys()
del metrics_dict

def test_report_classification_multi(self):
"""Evaluate report classification.
Multiclass, without ignore_index in array."""
metrics_dict = create_metrics_dict(3)
metrics_dict = report_classification(pred_multi,
lbl_multi,
batch_size=2,
metrics_dict=metrics_dict,
ignore_index=-1)
assert "{:.6f}".format(metrics_dict['precision'].val) == "0.327083"
assert "{:.6f}".format(metrics_dict['recall'].val) == "0.312500"
assert "{:.6f}".format(metrics_dict['fscore'].val) == "0.314935"

def test_report_classification_multi_ignore_idx(self):
"""Evaluate report classification.
Multiclass, with ignore_index in array."""
metrics_dict = create_metrics_dict(3)
metrics_dict = report_classification(pred_multi,
lbl_multi_dc,
batch_size=2,
metrics_dict=metrics_dict,
ignore_index=-1)
assert "{:.6f}".format(metrics_dict['precision'].val) == "0.297619"
assert "{:.6f}".format(metrics_dict['recall'].val) == "0.285714"
assert "{:.6f}".format(metrics_dict['fscore'].val) == "0.283163"

def test_report_classification_binary(self):
"""Evaluate report classification.
Binary, without ignore_index in array."""
metrics_dict = create_metrics_dict(1)
metrics_dict = report_classification(pred_binary,
lbl_binary,
batch_size=2,
metrics_dict=metrics_dict,
ignore_index=-1)
assert "{:.6f}".format(metrics_dict['precision'].val) == "0.547727"
assert "{:.6f}".format(metrics_dict['recall'].val) == "0.562500"
assert "{:.6f}".format(metrics_dict['fscore'].val) == "0.553030"

def test_report_classification_binary_ignore_idx(self):
"""Evaluate report classification.
Binary, without ignore_index in array."""
metrics_dict = create_metrics_dict(1)
metrics_dict = report_classification(pred_binary,
lbl_binary_dc,
batch_size=2,
metrics_dict=metrics_dict,
ignore_index=-1)
assert "{:.6f}".format(metrics_dict['precision'].val) == "0.528139"
assert "{:.6f}".format(metrics_dict['recall'].val) == "0.571429"
assert "{:.6f}".format(metrics_dict['fscore'].val) == "0.539286"

def test_iou_multi(self):
"""Evaluate iou calculation.
Multiclass, without ignore_index in array."""
metrics_dict = create_metrics_dict(3)
metrics_dict = iou(pred_multi,
lbl_multi,
batch_size=2,
num_classes=3,
metric_dict=metrics_dict,
ignore_index=-1)
assert "{:.6f}".format(metrics_dict['iou'].val) == "0.185185"

def test_iou_multi_ignore_idx(self):
"""Evaluate iou calculation.
Multiclass, with ignore_index in array."""
metrics_dict = create_metrics_dict(3)
metrics_dict = iou(pred_multi,
lbl_multi_dc,
batch_size=2,
num_classes=3,
metric_dict=metrics_dict,
ignore_index=-1)
assert "{:.6f}".format(metrics_dict['iou'].val) == "0.169841"

def test_iou_binary(self):
"""Evaluate iou calculation.
Binary, without ignore_index in array."""
metrics_dict = create_metrics_dict(1)
metrics_dict = iou(pred_binary,
lbl_binary,
batch_size=2,
num_classes=1,
metric_dict=metrics_dict,
ignore_index=-1)
assert "{:.6f}".format(metrics_dict['iou'].val) == "0.361111"

def test_iou_binary_ignore_idx(self):
"""Evaluate iou calculation.
Binary, with ignore_index in array."""
metrics_dict = create_metrics_dict(1)
metrics_dict = iou(pred_binary,
lbl_binary_dc,
batch_size=2,
num_classes=1,
metric_dict=metrics_dict,
ignore_index=-1)
assert "{:.6f}".format(metrics_dict['iou'].val) == "0.340659"
31 changes: 17 additions & 14 deletions train_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,8 @@ def evaluation(eval_loader,
batch_metrics=None,
dataset='val',
device=None,
debug=False):
debug=False,
dontcare=-1):
"""
Evaluate the model and return the updated metrics
:param eval_loader: data loader
Expand Down Expand Up @@ -404,14 +405,14 @@ def evaluation(eval_loader,
f"{len(eval_loader)}. Metrics in validation loop won't be computed")
if (batch_index + 1) % batch_metrics == 0: # +1 to skip val loop at very beginning
a, segmentation = torch.max(outputs_flatten, dim=1)
eval_metrics = iou(segmentation, labels_flatten, batch_size, num_classes, eval_metrics)
eval_metrics = iou(segmentation, labels_flatten, batch_size, num_classes, eval_metrics, dontcare)
eval_metrics = report_classification(segmentation, labels_flatten, batch_size, eval_metrics,
ignore_index=eval_loader.dataset.dontcare)
elif (dataset == 'tst') and (batch_metrics is not None):
ignore_index=dontcare)
elif (dataset == 'tst'):
a, segmentation = torch.max(outputs_flatten, dim=1)
eval_metrics = iou(segmentation, labels_flatten, batch_size, num_classes, eval_metrics)
eval_metrics = iou(segmentation, labels_flatten, batch_size, num_classes, eval_metrics, dontcare)
eval_metrics = report_classification(segmentation, labels_flatten, batch_size, eval_metrics,
ignore_index=eval_loader.dataset.dontcare)
ignore_index=dontcare)

logging.debug(OrderedDict(dataset=dataset, loss=f'{eval_metrics["loss"].avg:.4f}'))

Expand All @@ -424,11 +425,11 @@ def evaluation(eval_loader,

if eval_metrics['loss'].avg:
logging.info(f"\n{dataset} Loss: {eval_metrics['loss'].avg:.4f}")
if batch_metrics is not None:
logging.info(f"\n{dataset} precision: {eval_metrics['precision'].avg}")
logging.info(f"\n{dataset} recall: {eval_metrics['recall'].avg}")
logging.info(f"\n{dataset} fscore: {eval_metrics['fscore'].avg}")
logging.info(f"\n{dataset} iou: {eval_metrics['iou'].avg}")
if batch_metrics is not None or dataset == 'tst':
logging.info(f"\n{dataset} precision: {eval_metrics['precision'].avg:.4f}")
logging.info(f"\n{dataset} recall: {eval_metrics['recall'].avg:.4f}")
logging.info(f"\n{dataset} fscore: {eval_metrics['fscore'].avg:.4f}")
logging.info(f"\n{dataset} iou: {eval_metrics['iou'].avg:.4f}")

return eval_metrics

Expand Down Expand Up @@ -524,7 +525,7 @@ def train(cfg: DictConfig) -> None:
# info on the hdf5 name
samples_size = get_key_def("input_dim", cfg['dataset'], expected_type=int, default=256)
overlap = get_key_def("overlap", cfg['dataset'], expected_type=int, default=0)
min_annot_perc = get_key_def('min_annotated_percent', cfg['dataset'], expected_type=int, default=0)
min_annot_perc = get_key_def('min_annotated_percent', cfg['dataset'], default=0)
samples_folder_name = (
f'samples{samples_size}_overlap{overlap}_min-annot{min_annot_perc}_{num_bands}bands_{experiment_name}'
)
Expand Down Expand Up @@ -688,7 +689,8 @@ def train(cfg: DictConfig) -> None:
device=device,
scale=scale,
vis_params=vis_params,
debug=debug)
debug=debug,
dontcare=dontcare_val)
val_loss = val_report['loss'].avg
if 'val_log' in locals(): # only save the value if a tracker is setup
if batch_metrics is not None:
Expand Down Expand Up @@ -745,7 +747,8 @@ def train(cfg: DictConfig) -> None:
dataset='tst',
scale=scale,
vis_params=vis_params,
device=device)
device=device,
dontcare=dontcare_val)
if 'tst_log' in locals(): # only save the value if a tracker is setup
tst_log.add_values(tst_report, num_epochs)

Expand Down
23 changes: 18 additions & 5 deletions utils/metrics.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import numpy as np
from sklearn.metrics import classification_report
from math import sqrt
from torch import IntTensor

min_val = 1e-6
def create_metrics_dict(num_classes):
num_classes = num_classes if num_classes == 1 else num_classes + 1

num_classes = num_classes + 1 if num_classes == 1 else num_classes

metrics_dict = {'precision': AverageMeter(), 'recall': AverageMeter(), 'fscore': AverageMeter(),
'loss': AverageMeter(), 'iou': AverageMeter()}

Expand Down Expand Up @@ -59,7 +63,15 @@ def report_classification(pred, label, batch_size, metrics_dict, ignore_index=-1
"""Computes precision, recall and f-score for each class and average of all classes.
http://scikit-learn.org/stable/modules/generated/sklearn.metrics.classification_report.html
"""
class_report = classification_report(label.cpu(), pred.cpu(), output_dict=True, zero_division=1)
pred = pred.cpu()
label = label.cpu()
pred[label == ignore_index] = ignore_index

# Required to remove ignore_index from scikit-learn's classification report
n = max(IntTensor.item(pred.amax()), IntTensor.item(label.amax()))
labels = np.arange(n+1)

class_report = classification_report(label, pred, labels=labels, output_dict=True, zero_division=1)

class_score = {}
for key, value in class_report.items():
Expand All @@ -77,12 +89,13 @@ def report_classification(pred, label, batch_size, metrics_dict, ignore_index=-1
return metrics_dict


def iou(pred, label, batch_size, num_classes, metric_dict, only_present=True):
def iou(pred, label, batch_size, num_classes, metric_dict, ignore_index, only_present=True):
"""Calculate the intersection over union class-wise and mean-iou"""
ious = []
num_classes = num_classes if num_classes == 1 else num_classes + 1
num_classes = num_classes + 1 if num_classes == 1 else num_classes
pred = pred.cpu()
label = label.cpu()
pred[label == ignore_index] = ignore_index
for i in range(num_classes):
c_label = label == i
if only_present and c_label.sum() == 0:
Expand Down Expand Up @@ -122,7 +135,7 @@ class ComputePixelMetrics():
def __init__(self, label, pred, num_classes):
self.label = label
self.pred = pred
self.num_classes = num_classes if num_classes == 1 else num_classes + 1
self.num_classes = num_classes + 1 if num_classes == 1 else num_classes

def update(self, metric_func):
metric = {}
Expand Down

0 comments on commit 844e7a8

Please sign in to comment.