forked from mikekestemont/syllabify
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_syllabifier.py
136 lines (110 loc) · 4.98 KB
/
train_syllabifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import argparse
import os
import shutil
from sklearn.metrics import accuracy_score
from keras.models import load_model
from keras.utils import to_categorical
from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
from syllabification.vectorization import SequenceVectorizer
from syllabification.modelling import build_model
import syllabification.utils as u
def main():
parser = argparse.ArgumentParser(description='Trains a syllabifier')
parser.add_argument('--input_dir', type=str,
default='data/splits',
help='location of the splits folder')
parser.add_argument('--model_dir', type=str,
default='model_s',
help='location of the model folder')
parser.add_argument('--num_epochs', type=int,
default=30,
help='Number of epochs')
parser.add_argument('--lr', type=float,
default=0.001,
help='Initial learning rate')
parser.add_argument('--dropout', type=float,
default=0.25,
help='Recurrent dropout')
parser.add_argument('--num_layers', type=int,
default=2,
help='Number of recurrent layers')
parser.add_argument('--retrain', default=False, action='store_true',
help='Retrain a model from scratch')
parser.add_argument('--recurrent_dim', type=int,
default=30,
help='Number of recurrent dims')
parser.add_argument('--emb_dim', type=int,
default=64,
help='Number of character embedding dims')
parser.add_argument('--batch_size', type=int,
default=50,
help='Batch size')
parser.add_argument('--seed', type=int,
default=43432,
help='Random seed')
args = parser.parse_args()
print(args)
train, dev, test = u.load_splits(args.input_dir)
train_words, train_Y = train
dev_words, dev_Y = dev
test_words, test_Y = test
v = SequenceVectorizer().fit(train_words)
v_path = os.sep.join((args.model_dir, 'vectorizer.json'))
train_X = v.transform(train_words)
dev_X = v.transform(dev_words)
test_X = v.transform(test_words)
train_Y = v.normalize_len(train_Y)
dev_Y = v.normalize_len(dev_Y)
test_Y = v.normalize_len(test_Y)
train_Y = to_categorical(train_Y, num_classes=2)
dev_Y = to_categorical(dev_Y, num_classes=2)
test_Y = to_categorical(test_Y, num_classes=2)
model = build_model(vectorizer=v, embed_dim=args.emb_dim,
num_layers=args.num_layers, lr=args.lr,
recurrent_dim=args.recurrent_dim,
dropout=args.dropout)
model.summary()
m_path = os.sep.join((args.model_dir, 'syllab.model'))
if args.retrain:
try:
shutil.rmtree(args.model_dir)
except FileNotFoundError:
pass
os.mkdir(args.model_dir)
checkpoint = ModelCheckpoint(m_path, monitor='val_loss',
verbose=1, save_best_only=True)
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.3,
patience=1, min_lr=0.0001,
verbose=1, min_delta=0.001)
try:
model.fit(train_X, train_Y, validation_data=[dev_X, dev_Y],
epochs=args.num_epochs, batch_size=args.batch_size,
shuffle=True, callbacks=[checkpoint, reduce_lr])
except KeyboardInterrupt:
print('\n' + '-' * 64 + '\n')
pass
v.dump(v_path)
model = load_model(m_path)
# evaluate on test:
test_silver = u.pred_to_classes(model.predict(test_X))
test_gold = u.pred_to_classes(test_Y)
gold_syll, pred_syll = [], []
for test_item, gold, silver in zip(test_X, test_gold, test_silver):
end = list(test_item).index(v.syll2idx['<EOS>'])
gold_syll.append(tuple(gold[1:end]))
pred_syll.append(tuple(silver[1:end]))
test_acc_syll = accuracy_score([i for s in gold_syll for i in s],
[i for s in pred_syll for i in s])
test_acc_token = accuracy_score([str(s) for s in gold_syll],
[str(s) for s in pred_syll])
print('test acc (char):', test_acc_syll)
print('test acc (token):', test_acc_token)
dev_silver = u.pred_to_classes(model.predict(dev_X))
with open(os.sep.join((args.model_dir, 'silver_dev.txt')), 'w') as f:
for token, pred in zip(dev_words, dev_silver):
f.write(u.stringify(token, pred) + '\n')
with open(os.sep.join((args.model_dir, 'silver_test.txt')), 'w') as f:
for token, pred in zip(test_words, test_silver):
f.write(u.stringify(token, pred) + '\n')
if __name__ == '__main__':
main()