Skip to content

Commit

Permalink
Update split_train_val.py
Browse files Browse the repository at this point in the history
  • Loading branch information
aimuch authored Nov 1, 2020
1 parent 7a510eb commit c74db60
Showing 1 changed file with 37 additions and 36 deletions.
73 changes: 37 additions & 36 deletions split_train_val.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# Author : Andy Liu
# Last modified: 2020-06-15
# Last modified: 2020-11-01

import os
import shutil
Expand All @@ -9,80 +9,81 @@
import random
from tqdm import tqdm

val_rate = 0.8
val_rate = 0.2

train_img = "./train_img"
train_xml = "./train_xml"
train_label = "./train_label"
val_img = "./val_img"
val_xml = "./val_xml"
val_label = "./val_label"
if os.path.exists(train_img):
shutil.rmtree(train_img)
os.makedirs(train_img)

if os.path.exists(train_xml):
shutil.rmtree(train_xml)
os.makedirs(train_xml)
if os.path.exists(train_label):
shutil.rmtree(train_label)
os.makedirs(train_label)

if os.path.exists(val_img):
shutil.rmtree(val_img)
os.makedirs(val_img)

if os.path.exists(val_xml):
shutil.rmtree(val_xml)
os.makedirs(val_xml)
if os.path.exists(val_label):
shutil.rmtree(val_label)
os.makedirs(val_label)

def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('srcxml', help='xml directory', type=str)
parser.add_argument('srclabel', help='label directory', type=str)
parser.add_argument('srcimg', help='images directory', type=str)

args = parser.parse_args()
return args


def pick(srcxml, srcimg):
srcxml = os.path.abspath(srcxml)
if srcxml[-1] == "/":
srcxml = srcxml[:-1]
def pick(srclabel, srcimg):
srclabel = os.path.abspath(srclabel)
if srclabel[-1] == "/":
srclabel = srclabel[:-1]

xmllists = os.listdir(srcxml)
random.shuffle(xmllists)
# val_num = int(len(xmllists)*val_rate)
val_num = 800
valset = random.sample(xmllists, val_num)
labellists = os.listdir(srclabel)
random.shuffle(labellists)
val_num = int(len(labellists)*val_rate)
# val_num = 800
valset = random.sample(labellists, val_num)


for xml in tqdm(xmllists):
xml_info = xml.split(".")
if xml_info[-1] != "xml":
continue
for label in tqdm(labellists):
label_info = label.split(".")
if label_info[-1] != "xml":
if label_info[-1] != "txt":
continue

xml_path = os.path.join(srcxml, xml)
img_path = os.path.join(srcimg, xml_info[0]+".jpg")
label_path = os.path.join(srclabel, label)
img_path = os.path.join(srcimg, label_info[0]+".jpg")
if(not os.path.exists(img_path)):
img_path = os.path.join(srcimg, xml_info[0]+".png")
if xml in valset:
xml_path_dst = os.path.join(val_xml, os.path.basename(xml_path))
img_path = os.path.join(srcimg, label_info[0]+".png")
if label in valset:
label_path_dst = os.path.join(val_label, os.path.basename(label_path))
img_path_dst = os.path.join(val_img, os.path.basename(img_path))
shutil.copyfile(xml_path, xml_path_dst)
shutil.copyfile(label_path, label_path_dst)
shutil.copyfile(img_path, img_path_dst)
else:
xml_path_dst = os.path.join(train_xml, os.path.basename(xml_path))
label_path_dst = os.path.join(train_label, os.path.basename(label_path))
img_path_dst = os.path.join(train_img, os.path.basename(img_path))
shutil.copyfile(xml_path, xml_path_dst)
shutil.copyfile(label_path, label_path_dst)
shutil.copyfile(img_path, img_path_dst)

if __name__ == '__main__':
args = parse_args()
srcxml = args.srcxml
srclabel = args.srclabel
srcimg = args.srcimg

if not os.path.exists(srcxml):
print("Error !!! %s is not exists, please check the parameter"%srcxml)
if not os.path.exists(srclabel):
print("Error !!! %s is not exists, please check the parameter"%srclabel)
sys.exit(0)
if not os.path.exists(srcimg):
print("Error !!! %s is not exists, please check the parameter"%srcimg)
sys.exit(0)

pick(srcxml, srcimg)
pick(srclabel, srcimg)
print("Done!")

0 comments on commit c74db60

Please sign in to comment.