Skip to content

Commit

Permalink
xml2yolotxt_1nx.py
Browse files Browse the repository at this point in the history
  • Loading branch information
aimuch committed Jun 14, 2020
1 parent 8f542ad commit 31117c3
Showing 1 changed file with 57 additions and 31 deletions.
88 changes: 57 additions & 31 deletions xml2yolotxt_1nx.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,30 @@
from tqdm import tqdm

WITH_GROUP_ID = False
WITH_SUB_CLASSES = True
WITH_PROCESS_SUB_CLASSES = True
TRAIN_RATIO = 0.8
classes = ["Car","Van","Bus","Truck"]
classes = ["car","van","bus","truck"]
sub_classes = ["brakelight", "headlight"]

error_log = "./log.txt"
draw_path = "./draw_img"
output_txt_path = "./output_txt"
output_img_path = "./output_img"
output_txt_path_subclass = "./output_txt_subclass"
output_img_path_subclass = "./output_img_subclass"
txt_train_path = "./train.txt"
txt_val_path = "./val.txt"

if not os.path.exists(output_txt_path):
os.makedirs(output_txt_path)
if not os.path.exists(output_img_path):
os.makedirs(output_img_path)
if WITH_PROCESS_SUB_CLASSES:
if not os.path.exists(output_txt_path_subclass):
os.makedirs(output_txt_path_subclass)
if not os.path.exists(output_img_path_subclass):
os.makedirs(output_img_path_subclass)

def parse_args():
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -143,20 +152,20 @@ def convert_annotations(xml_dirs, img_dirs):
txt_file_path = os.path.join(output_txt_path, name+"_"+frame_id+".txt")
img_dst_path = os.path.join(output_img_path, os.path.basename(img_src_path))
txt_file = open(txt_file_path, 'w')

# {'group_id': {'car': [xtl, ytl, xbr, ybr], 'brakelight': [[xtl1, ytl1, xbr1, ybr1], [xtl2, ytl2, xbr2, ybr2]]}}
obj_groups = {}
for box in boxes:
if box.hasAttribute('group_id'):
group_id = int(box.getAttribute('group_id'))
group_id = box.getAttribute('group_id')
else:
group_id = None

if not group_id and WITH_GROUP_ID:
if WITH_GROUP_ID and not group_id:
continue

label = box.getAttribute('label')

if label not in classes:
continue

xtl = float(box.getAttribute('xtl'))
ytl = float(box.getAttribute('ytl'))
xbr = float(box.getAttribute('xbr'))
Expand All @@ -173,37 +182,54 @@ def convert_annotations(xml_dirs, img_dirs):
light_type = box_attr.childNodes[0].data
label = light_type

if WITH_SUB_CLASSES:
if label not in classes and label not in sub_classes:
continue
else:
if label not in classes:
continue

txt_file.write(label + " " + " ".join([str(a) for a in bb]) + '\n')

if WITH_PROCESS_SUB_CLASSES and group_id:
if label in classes:
obj_groups[group_id] = {"vehicle":[xtl, ytl, xbr, ybr]} # 固定住key便于后面便利
elif label in sub_classes and not obj_groups[group_id].has_key(label):
obj_groups[group_id] = {label:[[xtl, ytl, xbr, ybr]]}
elif label in sub_classes:
obj_groups[group_id][label].append([xtl, ytl, xbr, ybr])

txt_file.close()
shutil.copyfile(img_src_path, img_dst_path)

log_file.close()
# Process sub class
for group_id, objs in obj_groups.items(): # group
for label, coordinate in objs: # classes
if label in sub_classes:
xtl1 = objs["vehicle"][0]
ytl1 = objs["vehicle"][1]
xbr1 = objs["vehicle"][2]
ybr1 = objs["vehicle"][3]
index = 0
for coor in coordinate: # sub classes
name_subclass = os.path.basename(img_src_path).split(".")[0]+"_"+group_id+"_"+str(index)
txt_file_path_subclass = os.path.join(output_txt_path_subclass, name_subclass+".txt")
txt_file_subclass = open(txt_file_path_subclass, 'w')
img_dst_sub_name = name_subclass + ".jpg"
img_dst_path_subclass = os.path.join(output_img_path_subclass, img_dst_sub_name)
xtl2 = coor[0] - objs["vehicle"][0]
ytl2 = coor[1] - objs["vehicle"][1]
xbr2 = coor[2] - objs["vehicle"][0]
ybr2 = coor[3] - objs["vehicle"][1]

bb = convert((xbr1-xtl1, ybr1-ytl1), (xtl2, xbr2, ytl2, ybr2))
txt_file_subclass.write(label, " ".join([str(a) for a in bb]) + '\n')

txt_file_subclass.close()
cv2.imwrite(img_dst_path_subclass, img[xtl2:xbr2, ytl2:ybr2])
index += 1


img_list = os.listdir(output_img_path)
trainnum = int(len(img_list)*TRAIN_RATIO)
trainset = random.sample(img_list, trainnum)
txt_train_list = open(txt_train_path, 'w')
txt_val_list = open(txt_val_path, 'w')
for img in img_list:
if img in trainset:
txt_train_list.write(os.path.abspath(img) + '\n')
else:
txt_val_list.write(os.path.abspath(img) + '\n')


## TODO
cv2.rectangle(img, (int(xbr), int(ybr)), (int(xtl), int(ytl)), (0, 0, 255))
# print(label, xtl, ytl, xbr, ybr)
label_id = classes.index(label)
b = (float(xtl), float(xbr), float(ytl), float(ybr))
bb = convert((w, h), b)
out_file.write(str(label_id) + " " + " ".join([str(a) for a in bb]) + '\n')
# out_file.flush()
cv2.rectangle(img, (xtl, ytl), (xbr, ybr), (255, 0, 0))

cv2.imwrite(new_img, img)
log_file.close()

print("Path of txt folder = ", os.path.abspath(output_txt_path))
print("Path of train text = ", os.path.abspath(txt_train_path))
Expand Down

0 comments on commit 31117c3

Please sign in to comment.