forked from ketranm/neuralHMM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.lua
305 lines (269 loc) · 9.79 KB
/
main.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
local cmd = torch.CmdLine()
cmd:text('unsupervised HMM')
cmd:text('Data')
cmd:option('-datapath', '../data', 'location of data')
cmd:option('-vocabsize', -1, 'size of dynamic softmax, -1 for using all')
cmd:option('-mnbz', 250, 'size of minibatch')
cmd:option('-maxlen', 40, 'maximum number of tokens per sentence')
cmd:text('model')
cmd:option('-hidsize', 128, 'hidden size of char-softmax, use when -use_char is true')
cmd:option('-maxchars', 15, 'use char-softmax')
cmd:option('-kernels', {1,2,3,4,5,6,7}, 'kernels of char-cnn')
cmd:option('-feature_maps', {50, 100, 128, 128, 128, 128, 128}, 'feature map of char-cnn')
cmd:option('-charsize', 15, 'char embedding dim')
cmd:option('-nstates', 45, 'number of latent states')
cmd:option('-conv', false, 'use Char-CNN for emission')
cmd:option('-lstm', false, 'use LSTM for transition')
cmd:option('-max_word_len', 15, 'truncate word that is longer than this, use for Char-CNN')
cmd:option('-nlayers', 3, 'number of lstm layers')
cmd:text('optimization')
cmd:option('-nloops', 6, 'number of inner loops for optim when -nn is true')
cmd:option('-niters', 16, 'number of iterations/epochs')
cmd:option('-max_norm', 5, 'max gradient norm')
cmd:option('-dropout', 0.5, 'dropout')
cmd:option('-report_every', 10, 'print out after certain number of mini batches')
cmd:option('-modelpath', '../cv', 'saved model location')
cmd:text('utils')
cmd:option('-model', 'hmm', 'trained model model file')
cmd:option('-output', '../data/pred.txt', 'output prediction')
cmd:option('-input', '', 'input file to predict')
cmd:option('-cuda', false, 'using cuda')
cmd:option('-debug', false, 'debugging mode')
cmd:text()
opt = cmd:parse(arg or {})
print(opt)
--torch.manualSeed(42)
require 'nn'
require 'BaumWelch'
require 'loader'
if opt.cuda then
require 'cunn'
require 'cutorch'
print('using GPU')
--cutorch.manualSeed(42)
end
-- loading data
local loader = DataLoader(opt)
print('vocabulary size: ', loader.vocabsize)
-- add number of words to opt
opt.nobs = loader.vocabsize
opt.padidx = loader.padidx or -1
opt.vocab = loader.vocab
require 'optim'
local model_utils = require 'model_utils'
local utils = require 'utils'
print('create networks')
require 'Prior'
require 'FFTran'
local prior_net = nn.Prior(opt.nstates)
local emiss_net, trans_net
if opt.conv then
print('use Convolutional Character Model')
require 'EmiConv'
local word2char = loader:getchar(opt.max_word_len)
emiss_net = nn.EmiConv(word2char, opt.nstates, opt.feature_maps, opt.kernels, opt.charsize, opt.hidsize)
print('set up Char-CNN completed!')
else
print('use Feed-forward Emission Model')
require 'Emission'
emiss_net = nn.EmiNet(opt.nobs, opt.nstates, opt.hidsize)
end
if opt.lstm then
print('use LSTM for transition')
require 'RNNTran'
trans_net = nn.RNNTran(opt.nobs, opt.nstates, opt.hidsize, opt.nlayers, opt.dropout)
else
trans_net = nn.FFTran(opt.nstates, opt.hidsize)
end
local inference = nn.BaumWelch(opt.padidx)
if opt.cuda then
prior_net:cuda()
emiss_net:cuda()
trans_net:cuda()
inference:cuda()
end
prior_net:reset()
emiss_net:reset()
trans_net:reset()
local params, gradParams
= model_utils.combine_all_parameters(emiss_net, trans_net, prior_net)
-- It seems that uniform initialization works the best for feed-forward model
--params:uniform(-1e-3, 1e-3)
function process(input)
-- keep in mind that we do padding
-- so for each batch, we will take the length as the lenth
-- of the max sequence without pad
local real_length = input:size(2)
local mnbz = input:size(1)
for i = real_length, 1, -1 do
if input[{{}, i}]:sum() > mnbz then
real_length = i
break
end
end
return input[{{}, {1, real_length}}]
end
function train()
trans_net:training()
emiss_net:training()
prior_net:training()
local optim_config, optim_states = {}, {}
local nprobes = 5
local best_start_loglik = -1000
local iter = 0
-- adding noise to gradient
local gnoise = {}
gnoise.t = 0
gnoise.noise = gradParams.new()
gnoise.noise:resizeAs(gradParams)
gnoise.tau = 0.01
gnoise.gamma = 0.55
while iter < opt.niters do
local loglik = 0
local data = loader:train()
for j = 1, #data do
local input = process(data[j])
if opt.cuda then
input = input:cuda()
else
input = input:long()
end
local count, f = nil, nil
local prev_f = nil
for k = 1, opt.nloops do
local log_prior = prior_net:log_prob(input)
local log_trans = trans_net:log_prob(input)
local log_emiss = emiss_net:log_prob(input)
local stats = {log_emiss, log_trans, log_prior}
count, f = inference:run(input, stats)
if not prev_f then
prev_f = f
else
local improve = f - prev_f
local imp = -improve / prev_f -- note that f is negative
if imp < 1e-4 then
break
else
prev_f = f
end
end
-- update
local feval = function(x)
gradParams:zero()
if params ~= x then params:copy(x) end
prior_net:update(input, count[1]:mul(-1 / opt.nstates))
trans_net:update(input, count[2]:mul(-1 / opt.nstates))
emiss_net:update(input, count[3]:mul(-1 / opt.nstates))
--gradParams:add(1e-3, params)
utils.scaleClip(gradParams, 3)
-- gradient noise
local var = gnoise.tau / torch.pow(1 + gnoise.t, gnoise.gamma)
gnoise.noise:normal(0, var)
gradParams:add(gnoise.noise)
gnoise.t = gnoise.t + 1
return _, gradParams
end
optim.adam(feval, params, optim_config, optim_states)
end
loglik = loglik + f
if j % opt.report_every == 0 then
io.write(string.format('iter %d\tloglik %.4f\t %.3f\r', iter, loglik/j, j/#data))
io.flush()
collectgarbage()
end
end
local curr_loss = loglik/#data
local modelfile = string.format('%s/%s.iter%d.t7', opt.modelpath, opt.model, iter)
if nprobes > 0 then
print(string.format('current loss %.3f\tbest %.3f\tremained probes %d', curr_loss, best_start_loglik, nprobes - 1))
if curr_loss > best_start_loglik then
best_start_loglik = curr_loss
paths.mkdir(paths.dirname(modelfile))
print(string.format('probe: %d loglik %.4f ||| states: %s', nprobes, curr_loss, modelfile))
local probe_states = {params = params, optim_config = optim_config, optim_states = optim_states, t = gnoise.t}
torch.save(modelfile, probe_states)
end
nprobes = nprobes - 1
-- reseeding
torch.seed()
if opt.cuda then
cutorch.seed()
end
optim_config = {}
optim_states = {}
gnoise.t = 0
prior_net:reset()
trans_net:reset()
emiss_net:reset()
elseif nprobes == 0 then
-- load file, recover all optimization states
print('end of probing, use the best probing model to continue training!')
local probe_states = torch.load(modelfile)
params:copy(probe_states.params)
optim_config = probe_states.optim_config
optim_states = probe_states.optim_states
gnoise.t = probe_states.t
nprobes = -1
iter = iter + 1
else
paths.mkdir(paths.dirname(modelfile))
torch.save(modelfile, params)
print(string.format('saved: %s\tloglik %.4f\titer %d', modelfile, curr_loss, iter))
iter = iter + 1
end
end
end
function infer(textfile, predfile, modelfile)
-- batch inference
print(string.format('load model: %s', modelfile))
params:copy(torch.load(modelfile))
emiss_net:precompute()
trans_net:evaluate()
emiss_net:evaluate()
prior_net:evaluate()
prior_net:precompute()
local fw = io.open(predfile, 'w')
local n = 0
local sents = {}
for line in io.lines(textfile) do
sents[#sents + 1] = loader:tensorize(line):view(1, -1)
end
local n = #sents
local mnbz = 256
for i = 1, n, mnbz do
local max_seq_len = 0
local bs = 0
for k = i, math.min(i + mnbz - 1, n) do
if sents[k]:numel() > max_seq_len then max_seq_len = sents[k]:numel() end
bs = bs + 1
end
local input = torch.IntTensor(bs, max_seq_len):fill(1)
for k = 0, bs - 1 do
local x = sents[k+i]
input[{{k+1}, {1, x:numel()}}] = x
end
if opt.cuda then
input = input:cuda()
end
io.write(string.format('sent: %d\r', i))
io.flush()
local log_prior = prior_net:log_prob(input)
local log_trans = trans_net:log_prob(input)
local log_emiss = emiss_net:log_prob(input)
for k = 0, bs - 1 do
local x = sents[k+i]
local ex = log_emiss[{{k + 1}, {1, x:numel()}}]
local tx = log_trans[{{k + 1}, {1, x:numel()}}]
local predx = inference:argmax(x, {ex, tx, log_prior})
local output = table.concat(torch.totable(predx), ' ')
fw:write(output .. '\n')
end
end
fw:close()
end
--- main script
if opt.input ~= '' then
infer(opt.input, opt.output, opt.model)
else
train()
end