Skip to content

Commit

Permalink
fix tests for init test tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
mpelchat04 committed Sep 21, 2022
1 parent 2403db5 commit ec8a8ca
Showing 1 changed file with 50 additions and 37 deletions.
87 changes: 50 additions & 37 deletions tests/utils/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
import numpy as np
from torchmetrics.functional import precision_recall
from utils.metrics import create_metrics_dict, report_classification, iou
import torch

import segmentation_models_pytorch as smp
from torchmetrics import JaccardIndex
import pytest

# Test arrays: [bs=2, h=2, w,2]
pred_multi = torch.tensor([0, 0, 2, 2, 0, 2, 1, 2, 1, 0, 2, 2, 1, 0, 2, 2])
pred_binary = torch.tensor([0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1])
def init_tensors():
pred_multi = torch.tensor([0, 0, 2, 2, 0, 2, 1, 2, 1, 0, 2, 2, 1, 0, 2, 2])
pred_binary = torch.tensor([0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1])

lbl_multi = torch.tensor([1, 0, 2, 2, 0, 1, 2, 0, 2, 2, 0, 0, 1, 2, 0, 1])
lbl_binary = torch.tensor([1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1])
# array with dont care
lbl_multi_dc = torch.tensor([-1, -1, 2, 2, 0, 1, 2, 0, 2, 2, 0, 0, 1, 2, 0, 1])
lbl_binary_dc = torch.tensor([-1, -1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1])
lbl_multi = torch.tensor([1, 0, 2, 2, 0, 1, 2, 0, 2, 2, 0, 0, 1, 2, 0, 1])
lbl_binary = torch.tensor([1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1])
# array with dont care
lbl_multi_dc = torch.tensor([-1, -1, 2, 2, 0, 1, 2, 0, 2, 2, 0, 0, 1, 2, 0, 1])
lbl_binary_dc = torch.tensor([-1, -1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1])
return {'pred_multi': pred_multi,
'pred_binary': pred_binary,
'lbl_multi': lbl_multi,
'lbl_binary': lbl_binary,
'lbl_multi_dc': lbl_multi_dc,
'lbl_binary_dc': lbl_binary_dc}


class TestMetrics(object):
Expand All @@ -37,9 +40,10 @@ def test_create_metrics_dict(self):
def test_report_classification_multi(self):
"""Evaluate report classification.
Multiclass, without ignore_index in array."""
t = init_tensors()
metrics_dict = create_metrics_dict(3)
metrics_dict = report_classification(pred_multi,
lbl_multi,
metrics_dict = report_classification(t['pred_multi'],
t['lbl_multi'],
batch_size=2,
metrics_dict=metrics_dict,
ignore_index=-1)
Expand All @@ -50,9 +54,10 @@ def test_report_classification_multi(self):
def test_report_classification_multi_ignore_idx(self):
"""Evaluate report classification.
Multiclass, with ignore_index in array."""
t = init_tensors()
metrics_dict = create_metrics_dict(3)
metrics_dict = report_classification(pred_multi,
lbl_multi_dc,
metrics_dict = report_classification(t['pred_multi'],
t['lbl_multi_dc'],
batch_size=2,
metrics_dict=metrics_dict,
ignore_index=-1)
Expand All @@ -63,9 +68,10 @@ def test_report_classification_multi_ignore_idx(self):
def test_report_classification_binary(self):
"""Evaluate report classification.
Binary, without ignore_index in array."""
t = init_tensors()
metrics_dict = create_metrics_dict(1)
metrics_dict = report_classification(pred_binary,
lbl_binary,
metrics_dict = report_classification(t['pred_binary'],
t['lbl_binary'],
batch_size=2,
metrics_dict=metrics_dict,
ignore_index=-1)
Expand All @@ -76,9 +82,10 @@ def test_report_classification_binary(self):
def test_report_classification_binary_ignore_idx(self):
"""Evaluate report classification.
Binary, without ignore_index in array."""
t = init_tensors()
metrics_dict = create_metrics_dict(1)
metrics_dict = report_classification(pred_binary,
lbl_binary_dc,
metrics_dict = report_classification(t['pred_binary'],
t['lbl_binary_dc'],
batch_size=2,
metrics_dict=metrics_dict,
ignore_index=-1)
Expand All @@ -89,9 +96,10 @@ def test_report_classification_binary_ignore_idx(self):
def test_iou_multi(self):
"""Evaluate iou calculation.
Multiclass, without ignore_index in array."""
t = init_tensors()
metrics_dict = create_metrics_dict(3)
metrics_dict = iou(pred_multi,
lbl_multi,
metrics_dict = iou(t['pred_multi'],
t['lbl_multi'],
batch_size=2,
num_classes=3,
metric_dict=metrics_dict,
Expand All @@ -101,20 +109,22 @@ def test_iou_multi(self):
def test_iou_multi_ignore_idx(self):
"""Evaluate iou calculation.
Multiclass, with ignore_index in array."""
t = init_tensors()
metrics_dict = create_metrics_dict(3)
# wih ignore_index == -1
metrics_dict = iou(pred_multi,
lbl_multi_dc,
metrics_dict = iou(t['pred_multi'],
t['lbl_multi_dc'],
batch_size=2,
num_classes=3,
metric_dict=metrics_dict,
ignore_index=-1)
assert "{:.6f}".format(metrics_dict['iou'].val) == "0.233333"

# with ignore_index == 0
t = init_tensors()
metrics_dict = create_metrics_dict(3)
metrics_dict = iou(pred_multi,
lbl_multi,
metrics_dict = iou(t['pred_multi'],
t['lbl_multi'],
batch_size=2,
num_classes=3,
metric_dict=metrics_dict,
Expand All @@ -124,9 +134,10 @@ def test_iou_multi_ignore_idx(self):
def test_iou_binary(self):
"""Evaluate iou calculation.
Binary, without ignore_index in array."""
t = init_tensors()
metrics_dict = create_metrics_dict(1)
metrics_dict = iou(pred_binary,
lbl_binary,
metrics_dict = iou(t['pred_binary'],
t['lbl_binary'],
batch_size=2,
num_classes=1,
metric_dict=metrics_dict,
Expand All @@ -136,22 +147,24 @@ def test_iou_binary(self):
def test_iou_binary_ignore_idx(self):
"""Evaluate iou calculation.
Binary, with ignore_index in array."""
t = init_tensors()
metrics_dict = create_metrics_dict(1)
# with ignore_index == -1
metrics_dict = iou(pred_binary,
lbl_binary_dc,
metrics_dict = iou(t['pred_binary'],
t['lbl_binary_dc'],
batch_size=2,
num_classes=1,
metric_dict=metrics_dict,
ignore_index=-1)
assert "{:.6f}".format(metrics_dict['iou'].val) == "0.435897"

# with ignore_index == 0
t = init_tensors()
metrics_dict = create_metrics_dict(3)
metrics_dict = iou(pred_multi,
lbl_multi,
batch_size=2,
num_classes=3,
metric_dict=metrics_dict,
ignore_index=0)
assert "{:.6f}".format(metrics_dict['iou'].val) == "0.208333"
with pytest.raises(ValueError):
metrics_dict = iou(t['pred_binary'],
t['lbl_binary_dc'],
batch_size=2,
num_classes=3,
metric_dict=metrics_dict,
ignore_index=0)

0 comments on commit ec8a8ca

Please sign in to comment.