-
Notifications
You must be signed in to change notification settings - Fork 12
/
evaluate.py
153 lines (128 loc) · 5.46 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
"""
Evaluation of -.tar.gz file.
Yu Fang - March 2019
"""
from eval import *
import os, sys, glob
import shutil
import tarfile
import xml.dom.minidom
from os.path import join as osj
reg_gt_path = os.path.abspath("./annotations/trackA/")
str_gt_path = os.path.abspath("./annotations/trackB/")
# calculate the gt adj_relations of the missing file
# @param: file_lst - list of missing ground truth file
# @param: cur_gt_num - current total of ground truth objects (tables / cells)
def process_missing_files(file_lst, cur_gt_num):
if track == "-trackA":
gt_file_lst_full = [osj(reg_gt_path, filename) for filename in gt_file_lst]
for file in gt_file_lst_full:
if os.path.split(file)[-1].split(".")[-1] == "xml":
gt_dom = xml.dom.minidom.parse(file)
gt_root = gt_dom.documentElement
# tables = []
table_elements = gt_root.getElementsByTagName("table")
for res_table in table_elements:
# t = Table(res_table)
# tables.append(t)
cur_gt_num += 1
return cur_gt_num
elif track == "-trackB1" or track == "-trackB2":
gt_file_lst_full = [osj(str_gt_path, filename) for filename in gt_file_lst]
for file in gt_file_lst_full:
if os.path.split(file)[-1].split(".")[-1] == "xml":
gt_dom = xml.dom.minidom.parse(file)
gt_root = gt_dom.documentElement
tables = []
table_elements = gt_root.getElementsByTagName("table")
for res_table in table_elements:
t = Table(res_table)
tables.append(t)
for table in tables:
cur_gt_num += len(table.find_adj_relations())
return cur_gt_num
if __name__ == '__main__':
# measure = eval(*sys.argv[1:])
gt_file_lst = []
track = sys.argv[1]
if track == "-trackA":
gt_file_lst = os.listdir(reg_gt_path)
elif track == "-trackB1" or track == "-trackB2":
gt_file_lst = os.listdir(str_gt_path)
result_path = sys.argv[2]
untar_path = "./untar_file/"
if os.path.exists(untar_path):
shutil.rmtree(untar_path)
os.makedirs(untar_path)
try:
tar = tarfile.open(result_path, "r:gz")
tar.extractall(path=untar_path)
tar.close()
except FileNotFoundError:
print("Tar.gz file path incorrect, please check your spelling.")
res_lst = []
for root, files, dirs in os.walk(untar_path):
for name in dirs:
if name.split(".")[-1] == "xml":
cur_filepath = osj(os.path.abspath(root), name)
res_lst.append(eval(track, cur_filepath))
print("\n")
# note: results are stored as list of each when iou at [0.6, 0.7, 0.8, 0.9, gt_filename]
# gt number should be the same for all files
correct_six, gt_six, res_six = 0, 0, 0
correct_seven, gt_seven = 0, 0
correct_eight, gt_eight = 0, 0
correct_nine, gt_nine = 0, 0
# correct_seven, gt_seven, res_seven = 0, 0, 0
# correct_eight, gt_eight, res_eight = 0, 0, 0
# correct_nine, gt_nine, res_nine = 0, 0, 0
for each_file in res_lst:
# print(each_file.result[-1])
# for el in each_file.result:
# print(el)
gt_file_lst.remove(each_file.result[-1])
correct_six += each_file.result[0].truePos
gt_six += each_file.result[0].gtTotal
res_six += each_file.result[0].resTotal
correct_seven += each_file.result[1].truePos
gt_seven += each_file.result[1].gtTotal
# res_seven += each_file.result[1].resTotal
correct_eight += each_file.result[2].truePos
gt_eight += each_file.result[2].gtTotal
# res_eight += each_file.result[2].resTotal
correct_nine += each_file.result[3].truePos
gt_nine += each_file.result[3].gtTotal
# res_nine += each_file.result[3].resTotal
for file in gt_file_lst:
if file.split(".") != "xml":
gt_file_lst.remove(file)
print(gt_file_lst)
if len(gt_file_lst) > 0:
print("\nWarning: missing result annotations for file: {}\n".format(gt_file_lst))
gt_total = process_missing_files(gt_file_lst, res_six)
else:
gt_total = res_six
try:
print("Evaluation of {}".format(track.replace("-", "")))
# iou @ 0.6
p_six = correct_six / res_six
r_six = correct_six / gt_total
f1_six = 2 * p_six * r_six / (p_six + r_six)
print("IOU @ 0.6 -\nprecision: {}\nrecall: {}\nf1: {}\n".format(p_six, r_six, f1_six))
# iou @ 0.7
p_seven = correct_seven / res_six
r_seven = correct_seven / gt_total
f1_seven = 2 * p_seven * r_seven / (p_seven + r_seven)
print("IOU @ 0.7 -\nprecision: {}\nrecall: {}\nf1: {}\n".format(p_seven, r_seven, f1_seven))
# iou @ 0.8
p_eight = correct_eight / res_six
r_eight = correct_eight / gt_total
f1_eight = 2 * p_eight * r_eight / (p_eight + r_eight)
print("IOU @ 0.8 -\nprecision: {}\nrecall: {}\nf1: {}\n".format(p_eight, r_eight, f1_eight))
# iou @ 0.9
p_nine = correct_nine / res_six
r_nine = correct_nine / gt_total
f1_nine = 2 * p_nine * r_nine / (p_nine + r_nine)
print("IOU @ 0.9 -\nprecision: {}\nrecall: {}\nf1: {}\n".format(p_nine, r_nine, f1_nine))
except ZeroDivisionError:
print("Error: no adjacency relations are found, please check the file input.")