forked from yhenon/pytorch-retinanet
-
Notifications
You must be signed in to change notification settings - Fork 3
/
pagexml2csv.py
96 lines (81 loc) · 3.7 KB
/
pagexml2csv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#!/usr/bin/env python3
import xml.etree.ElementTree as ET
import os
import pdb
import pagexml
import sys
import csv
import re
import argparse
import glob
#First parameter is the replacement, second parameter is your input string
def get_coords_and_transcript(pxml,textobject,key):
regex = re.compile('[^a-zA-Z]')
coords = pxml.getPoints(textobject)
if len(coords)==4:
arg_max_coord=2
else:
arg_max_coord=1
x0=int(coords[0].x)
y0=int(coords[0].y)
x1=int(coords[arg_max_coord].x)
y1=int(coords[arg_max_coord].y)
transcription = pxml.getTextEquiv(textobject)
tag = pxml.getPropertyValue(textobject,key=key)
line_transcript=[]
w=transcription
'''for w in transcription.split(" "):
if '<' in w:
w = w.split('>')[1].split('<')[0]
'''
#w= regex.sub('', w)
line_transcript.append(w)
line_transcript = " ".join(line_transcript)
line_transcript = line_transcript.strip()
return x0,y0,x1,y1,line_transcript,tag
def main(args=None):
parser = argparse.ArgumentParser(description='Convert pagexml files to RetinaNet network csv groundtruth.')
parser.add_argument('--pxml_dir', help='Path of directory with pagexml files.',default = ".")
parser.add_argument('--fout', help='Path of gt file to be read by the model.',default = "train.csv")
parser.add_argument('--classes_out', help='Path to save text category classes.')
parser.add_argument('--seg_lev',help='segmentation level of the boxes to get (Word/TextLine)',default ="Word")
parser.add_argument('--get_property',help='segmentation level of the boxes to get (Word/TextLine)',default =False)
parser.add_argument('--property_key',help='key to get property from pagexml',default ='category')
parser = parser.parse_args(args)
pagexml.set_omnius_schema()
pxml = pagexml.PageXML()
if parser.classes_out is not None:
classes_out = open(parser.classes_out,'w')
csv_out = open(parser.fout,'w')
writer = csv.writer(csv_out,delimiter=',')
writer_classes = csv.writer(classes_out,delimiter=',')
all_tags = []
for root, dirs, files in os.walk(os.path.join(os.getcwd(),parser.pxml_dir)):
for f in files:
if '.xml' in f:
pxml.loadXml(os.path.join(root,f))
pages = pxml.select('_:Page')
for page in pages:
pagenum = pxml.getPageNumber(page)
page_im_file =pxml.getPageImageFilename(page)
page_im_file = os.path.join(os.getcwd(),root,page_im_file)
regions = pxml.select('_:TextRegion',page)
for region in regions:
reg_tag=pxml.getPropertyValue(region,key=parser.property_key)
for textLine in pxml.select('_:TextLine',region):
for word in pxml.select('_:Word',textLine):
x0,y0,x1,y1,transcription,tag=get_coords_and_transcript(pxml,word,parser.property_key)
if tag not in all_tags: all_tags.append(tag)
if x0>=x1 or y0>=y1: continue
if parser.get_property:
if len(tag)>0:
writer.writerow([page_im_file,x0,y0,x1,y1,tag,transcription])
else:
writer.writerow([page_im_file,x0,y0,x1,y1,reg_tag,transcription])
else:
writer.writerow([page_im_file,x0,y0,x1,y1,'text',transcription])
if len(all_tags)>0:
for idx,tag in enumerate(all_tags):
writer_classes.writerow([tag,idx])
if __name__=='__main__':
main()