-
Notifications
You must be signed in to change notification settings - Fork 35
/
Copy pathnormalize_text.py
167 lines (122 loc) · 4.69 KB
/
normalize_text.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import os
import sys
import torch.multiprocessing as mp
def safe_readline(f):
pos = f.tell()
while True:
try:
return f.readline()
except UnicodeDecodeError:
pos -= 1
f.seek(pos) # search where this character begins
class Normalizer:
@staticmethod
def find_offsets(filename, num_chunks):
"""
:param filename: string
:param num_chunks: int
:return: a list of offsets (positions to start and stop reading)
"""
with open(filename, 'r', encoding='utf-8') as f:
size = os.fstat(f.fileno()).st_size
chunk_size = size // num_chunks
offsets = [0 for _ in range(num_chunks + 1)]
for i in range(1, num_chunks):
f.seek(chunk_size * i)
safe_readline(f)
offsets[i] = f.tell()
return offsets
@staticmethod
def normalize_file_single_thread(filename, normalizer, worker_id=0,
offset=0, end=-1):
result = dict()
data = list()
data_org = list()
with open(filename, 'r', encoding='utf-8') as f:
f.seek(offset)
# next(f) breaks f.tell(), hence readline() must be used
line = safe_readline(f)
count = 0
while line:
if 0 < end < f.tell():
break
_line = line.strip()
outline = normalizer(_line)
data.append(outline)
data_org.append(_line)
line = f.readline()
count += 1
result['data'] = data
result['data_org'] = data_org
result['id'] = worker_id
result['total'] = data
return result
@staticmethod
def normalize_file(filename, normalizer, num_workers=1):
result = dict()
for i in range(num_workers):
result[i] = dict()
def merge_result(bin_result):
result[bin_result['id']]['data'] = bin_result['data']
result[bin_result['id']]['data_org'] = bin_result['data_org']
offsets = Normalizer.find_offsets(filename, num_workers)
if num_workers > 1:
pool = mp.Pool(processes=num_workers)
mp_results = []
for worker_id in range(num_workers):
mp_results.append(pool.apply_async(
Normalizer.normalize_file_single_thread,
args=(filename, normalizer, worker_id,
offsets[worker_id], offsets[worker_id + 1]),
))
pool.close()
pool.join()
for r in mp_results:
merge_result(r.get())
else:
sp_result = Normalizer.normalize_file_single_thread(filename, normalizer, 0,
offsets[0], offsets[1])
merge_result(sp_result)
final_result = list()
org_data = list()
# put the data into the list according the worker indices
for idx in range(num_workers):
final_result += result[idx]['data']
org_data += result[idx]['data_org']
return org_data, final_result
if __name__ == '__main__':
from whisper_normalizer.basic import BasicTextNormalizer
from whisper_normalizer.english import EnglishTextNormalizer
input_file = sys.argv[1]
lang = sys.argv[2]
num_workers = int(sys.argv[3])
cleaning = 0 if len(sys.argv) == 4 else int(sys.argv[4])
if lang == "en":
normalizer = EnglishTextNormalizer()
else:
normalizer = BasicTextNormalizer()
print("Normalizing file {}".format(input_file))
# problem: still need to write data into RAM
org_data, normalized_data = Normalizer.normalize_file(input_file, normalizer, num_workers=num_workers)
if cleaning != 0:
output_file = input_file + ".norm.clean"
writer = open(output_file, 'w')
org_writer = open(input_file + ".clean", "w")
print("Done. Now cleaning and writing to {}".format(output_file))
for org_line, line in zip(org_data, normalized_data):
org_parts = org_line.split()
parts = line.split()
if len(parts) <= 0.7 * len(org_parts):
continue
org_writer.write(org_line + "\n")
writer.write(line + "\n")
org_writer.close()
else:
output_file = input_file + ".norm"
writer = open(output_file, 'w')
print("Done. Now writing to {}".format(output_file))
for line in normalized_data:
parts = line.split()
writer.write(line + "\n")
writer.close()
print("Finished.")