-
Notifications
You must be signed in to change notification settings - Fork 349
/
Copy pathProcess.py
97 lines (72 loc) · 3.24 KB
/
Process.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import pandas as pd
import torchtext
from torchtext.legacy import data
from Tokenize import tokenize
from Batch import MyIterator, batch_size_fn
import os
import dill as pickle
def read_data(opt):
if opt.src_data is not None:
try:
opt.src_data = open(opt.src_data).read().strip().split('\n')
except:
print("error: '" + opt.src_data + "' file not found")
quit()
if opt.trg_data is not None:
try:
opt.trg_data = open(opt.trg_data).read().strip().split('\n')
except:
print("error: '" + opt.trg_data + "' file not found")
quit()
def create_fields(opt):
spacy_langs = ['en_core_web_sm', 'fr_core_news_sm', 'de', 'es', 'pt', 'it', 'nl']
if opt.src_lang not in spacy_langs:
print('invalid src language: ' + opt.src_lang + 'supported languages : ' + str(spacy_langs))
if opt.trg_lang not in spacy_langs:
print('invalid trg language: ' + opt.trg_lang + 'supported languages : ' + str(spacy_langs))
print("loading spacy tokenizers...")
t_src = tokenize(opt.src_lang)
t_trg = tokenize(opt.trg_lang)
TRG = data.Field(lower=True, tokenize=t_trg.tokenizer, init_token='<sos>', eos_token='<eos>')
SRC = data.Field(lower=True, tokenize=t_src.tokenizer)
if opt.load_weights is not None:
try:
print("loading presaved fields...")
SRC = pickle.load(open(f'{opt.load_weights}/SRC.pkl', 'rb'))
TRG = pickle.load(open(f'{opt.load_weights}/TRG.pkl', 'rb'))
except:
print("error opening SRC.pkl and TXT.pkl field files, please ensure they are in " + opt.load_weights + "/")
quit()
return(SRC, TRG)
def create_dataset(opt, SRC, TRG):
print("creating dataset and iterator... ")
raw_data = {'src' : [line for line in opt.src_data], 'trg': [line for line in opt.trg_data]}
df = pd.DataFrame(raw_data, columns=["src", "trg"])
mask = (df['src'].str.count(' ') < opt.max_strlen) & (df['trg'].str.count(' ') < opt.max_strlen)
df = df.loc[mask]
df.to_csv("translate_transformer_temp.csv", index=False)
data_fields = [('src', SRC), ('trg', TRG)]
train = data.TabularDataset('./translate_transformer_temp.csv', format='csv', fields=data_fields)
train_iter = MyIterator(train, batch_size=opt.batchsize, device=opt.device,
repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)),
batch_size_fn=batch_size_fn, train=True, shuffle=True)
os.remove('translate_transformer_temp.csv')
if opt.load_weights is None:
SRC.build_vocab(train)
TRG.build_vocab(train)
if opt.checkpoint > 0:
try:
os.mkdir("weights")
except:
print("weights folder already exists, run program with -load_weights weights to load them")
quit()
pickle.dump(SRC, open('weights/SRC.pkl', 'wb'))
pickle.dump(TRG, open('weights/TRG.pkl', 'wb'))
opt.src_pad = SRC.vocab.stoi['<pad>']
opt.trg_pad = TRG.vocab.stoi['<pad>']
opt.train_len = get_len(train_iter)
return train_iter
def get_len(train):
for i, b in enumerate(train):
pass
return i