-
Notifications
You must be signed in to change notification settings - Fork 9
/
voc.py
68 lines (57 loc) · 2.77 KB
/
voc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import numpy as np
import os
import xml.etree.ElementTree as ET
import pickle
def parse_voc_annotation(ann_dir, img_dir, cache_name, labels=[]):
if os.path.exists(cache_name):
with open(cache_name, 'rb') as handle:
cache = pickle.load(handle)
all_insts, seen_labels = cache['all_insts'], cache['seen_labels']
else:
all_insts = []
seen_labels = {}
for ann in sorted(os.listdir(ann_dir)):
img = {'object':[]}
try:
tree = ET.parse(os.path.join(ann_dir, ann))
except Exception as e:
print(e)
print('Ignore this bad annotation: ' + ann_dir + ann)
continue
for elem in tree.iter():
if 'filename' in elem.tag:
#img['filename'] = img_dir + elem.text
img['filename'] = os.path.join(img_dir, ann.replace('xml','jpg'))
if 'width' in elem.tag:
img['width'] = int(elem.text)
if 'height' in elem.tag:
img['height'] = int(elem.text)
if 'object' in elem.tag or 'part' in elem.tag:
obj = {}
for attr in list(elem):
if 'name' in attr.tag:
obj['name'] = attr.text
if obj['name'] in seen_labels:
seen_labels[obj['name']] += 1
else:
seen_labels[obj['name']] = 1
if len(labels) > 0 and obj['name'] not in labels:
break
else:
img['object'] += [obj]
if 'bndbox' in attr.tag:
for dim in list(attr):
if 'xmin' in dim.tag:
obj['xmin'] = int(round(float(dim.text)))
if 'ymin' in dim.tag:
obj['ymin'] = int(round(float(dim.text)))
if 'xmax' in dim.tag:
obj['xmax'] = int(round(float(dim.text)))
if 'ymax' in dim.tag:
obj['ymax'] = int(round(float(dim.text)))
if len(img['object']) > 0:
all_insts += [img]
cache = {'all_insts': all_insts, 'seen_labels': seen_labels}
with open(cache_name, 'wb') as handle:
pickle.dump(cache, handle, protocol=pickle.HIGHEST_PROTOCOL)
return all_insts, seen_labels