-
Notifications
You must be signed in to change notification settings - Fork 35
/
train_language_model.py
executable file
·144 lines (108 loc) · 4.69 KB
/
train_language_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
#!/usr/bin/env python
from __future__ import division
import onmt
import onmt.markdown
import onmt.modules
import argparse
import torch
import torch.nn as nn
from torch import cuda
from torch.autograd import Variable
import math
import time, datetime
from onmt.train_utils.trainer import XETrainer
from onmt.modules.loss import NMTLossFunc, NMTAndCTCLossFunc
from onmt.model_factory import build_language_model, optimize_model
from onmt.data.lm_dataset import LanguageModelDataset
from collections import defaultdict
parser = argparse.ArgumentParser(description='train.py')
onmt.markdown.add_md_help_argument(parser)
from options import make_parser
# Please look at the options file to see the options regarding models and data
parser = make_parser(parser)
opt = parser.parse_args()
print(opt)
# An ugly hack to have weight norm on / off
onmt.constants.weight_norm = opt.weight_norm
onmt.constants.checkpointing = opt.checkpointing
onmt.constants.max_position_length = opt.max_position_length
# Use static dropout if checkpointing > 0
if opt.checkpointing > 0:
onmt.constants.static = True
if torch.cuda.is_available() and not opt.gpus:
print("WARNING: You have a CUDA device, should run with -gpus 0")
torch.manual_seed(opt.seed)
def main():
start = time.time()
print("Loading data from '%s'" % opt.data)
if opt.data_format == 'raw':
dataset = torch.load(opt.data)
elapse = str(datetime.timedelta(seconds=int(time.time() - start)))
print("Done after %s" % elapse)
dicts = dataset['dicts']
# For backward compatibility
train_dict = defaultdict(lambda: None, dataset['train'])
valid_dict = defaultdict(lambda: None, dataset['valid'])
if train_dict['src_lang'] is not None:
assert 'langs' in dicts
train_src_langs = train_dict['src_lang']
train_tgt_langs = train_dict['tgt_lang']
else:
# allocate new languages
dicts['langs'] = {'src': 0, 'tgt': 1}
train_src_langs = list()
train_tgt_langs = list()
# Allocation one for the bilingual case
train_src_langs.append(torch.Tensor([dicts['langs']['src']]))
train_tgt_langs.append(torch.Tensor([dicts['langs']['tgt']]))
train_data = LanguageModelDataset(
dataset['train']['tgt'], train_tgt_langs,
batch_size_sents=opt.batch_size_sents,
seq_length=opt.lm_seq_length)
if valid_dict['src_lang'] is not None:
assert 'langs' in dicts
valid_src_langs = valid_dict['src_lang']
valid_tgt_langs = valid_dict['tgt_lang']
else:
# allocate new languages
valid_src_langs = list()
valid_tgt_langs = list()
# Allocation one for the bilingual case
valid_src_langs.append(torch.Tensor([dicts['langs']['src']]))
valid_tgt_langs.append(torch.Tensor([dicts['langs']['tgt']]))
valid_data = LanguageModelDataset(
dataset['valid']['tgt'], valid_tgt_langs,
batch_size_sents=opt.batch_size_sents,
seq_length=opt.lm_seq_length)
if opt.load_from:
checkpoint = torch.load(opt.load_from, map_location=lambda storage, loc: storage)
print("* Loading dictionaries from the checkpoint")
dicts = checkpoint['dicts']
else:
dicts['tgt'].patch(opt.patch_vocab_multiplier)
checkpoint = None
if "src" in dicts:
print(' * vocabulary size. source = %d; target = %d' %
(dicts['src'].size(), dicts['tgt'].size()))
else:
print(' * vocabulary size. target = %d' %
(dicts['tgt'].size()))
print(' * number of training sentences. %d' %
train_data.size())
print(' * maximum batch size (words per batch). %d' % (opt.batch_size_sents * opt.lm_seq_length))
else:
raise NotImplementedError
print('Building model...')
model = build_language_model(opt, dicts)
optimize_model(model)
""" Building the loss function """
loss_function = NMTLossFunc(opt.model_size, dicts['tgt'].size(), label_smoothing=opt.label_smoothing)
n_params = sum([p.nelement() for p in model.parameters()])
print('* number of parameters: %d' % n_params)
if len(opt.gpus) > 1 or opt.virtual_gpu > 1:
raise NotImplementedError("Multi-GPU training is not supported ATM.")
else:
trainer = XETrainer(model, loss_function, train_data, valid_data, dicts, opt)
trainer.run(checkpoint=checkpoint)
if __name__ == "__main__":
main()