-
Notifications
You must be signed in to change notification settings - Fork 34
/
confusion_matrix.py
94 lines (77 loc) · 3.47 KB
/
confusion_matrix.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
import math
import numpy as np
from sklearn.metrics import confusion_matrix
class ConfusionMatrix:
def __init__(self, nclasses, classes, useUnlabeled=False):
self.mat = np.zeros((nclasses, nclasses), dtype=np.float)
self.valids = np.zeros((nclasses), dtype=np.float)
self.IoU = np.zeros((nclasses), dtype=np.float)
self.mIoU = 0
self.nclasses = nclasses
self.classes = classes
self.list_classes = list(range(nclasses))
self.useUnlabeled = useUnlabeled
self.matStartIdx = 1 if not self.useUnlabeled else 0
def update_matrix(self, target, prediction):
if not(isinstance(prediction, np.ndarray)) or not(isinstance(target, np.ndarray)):
print("Expecting ndarray")
elif len(target.shape) == 3: # batched spatial target
if len(prediction.shape) == 4: # prediction is 1 hot encoded
temp_prediction = np.argmax(prediction, axis=1).flatten()
elif len(prediction.shape) == 3:
temp_prediction = prediction.flatten()
else:
print("Make sure prediction and target dimension is correct")
temp_target = target.flatten()
elif len(target.shape) == 2: # spatial target
if len(prediction.shape) == 3: # prediction is 1 hot encoded
temp_prediction = np.argmax(prediction, axis=1).flatten()
elif len(prediction.shape) == 2:
temp_prediction = prediction.flatten()
else:
print("Make sure prediction and target dimension is correct")
temp_target = target.flatten()
elif len(target.shape) == 1:
if len(prediction.shape) == 2: # prediction is 1 hot encoded
temp_prediction = np.argmax(prediction, axis=1).flatten()
elif len(prediction.shape) == 1:
temp_prediction = prediction
else:
print("Make sure prediction and target dimension is correct")
temp_target = target
else:
print("Data with this dimension cannot be handled")
self.mat += confusion_matrix(temp_target, temp_prediction, labels=self.list_classes)
def scores(self):
tp = 0
fp = 0
tn = 0
fn = 0
total = 0 # Total true positives
N = 0 # Total samples
for i in range(self.matStartIdx, self.nclasses):
N += sum(self.mat[:, i])
tp = self.mat[i][i]
fp = sum(self.mat[self.matStartIdx:, i]) - tp
fn = sum(self.mat[i,self.matStartIdx:]) - tp
if (tp+fp) == 0:
self.valids[i] = 0
else:
self.valids[i] = tp/(tp + fp)
if (tp+fp+fn) == 0:
self.IoU[i] = 0
else:
self.IoU[i] = tp/(tp + fp + fn)
total += tp
self.mIoU = sum(self.IoU[self.matStartIdx:])/(self.nclasses - self.matStartIdx)
self.accuracy = total/(sum(sum(self.mat[self.matStartIdx:, self.matStartIdx:])))
return self.valids, self.accuracy, self.IoU, self.mIoU, self.mat
def plot_confusion_matrix(self, filename):
# Plot generated confusion matrix
print(filename)
def reset(self):
self.mat = np.zeros((self.nclasses, self.nclasses), dtype=float)
self.valids = np.zeros((self.nclasses), dtype=float)
self.IoU = np.zeros((self.nclasses), dtype=float)
self.mIoU = 0