-
Notifications
You must be signed in to change notification settings - Fork 2
/
main.lua
168 lines (156 loc) · 7.16 KB
/
main.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
local status, cunn = pcall(require, 'nn')
local stringx = require 'pl.stringx'
LookupTable = nn.LookupTable
require 'torch'
require 'CSPPT/utils/util'
require 'CSPPT/utils/queue'
require 'CSPPT/utils/heap'
require 'CSPPT/utils/reader'
require 'CSPPT/utils/metrics'
require 'CSPPT/nets/lstm'
require 'CSPPT/nets/rnn'
--[[
#########################################################
Deep LSTM/RNN SLU implementation via torch
Su Zhu, Wengong Jin
Email: [email protected], [email protected]
Speech Lab, Shanghai Jiao Tong University
#########################################################
--]]
local cmd = torch.CmdLine() -- cmd is used for decode parameters in the NN
cmd:text('General Options:')
cmd:option('-train', '', 'training set file')
cmd:option('-valid', '', 'validation set file')
cmd:option('-test', '', 'test set file')
cmd:option('-read_model', '', 'read model from this file')
cmd:option('-print_model', '', 'print model to this file')
cmd:option('-vocab', '', 'read vocab from this file')
cmd:option('-outlabel', '', 'read output label from this file')
cmd:option('-print_vocab', '', 'print vocab to this file')
cmd:option('-trace_level', 0, 'trace level')
cmd:option('-test_only', 0, 'only test a test file using an existing model')
cmd:text('Model Options:')
cmd:option('-rnn_type', 'lstm', 'recurrent type: lstm or rnn')
cmd:option('-emb_size', 100, 'word embedding dimension')
cmd:option('-word_win_left', 0, 'number of words in the previous context window')
cmd:option('-word_win_right', 0, 'number of words in the next context window')
-- cmd:option('-layers', 1, 'number of recurrent layers')
-- cmd:option('-hidden_size', 300, 'hidden layer dimension')
cmd:option('-hidden_prototype', '200-300', 'hidden layer dimension of each hidden layer')
cmd:text('Runtime Options:')
cmd:option('-deviceId', 1, 'train model on ith gpu')
cmd:option('-random_seed', 7, 'set initial random seed')
cmd:text('Training Options:')
cmd:option('-alpha', 0.08, 'initial learning rate')
cmd:option('-beta', 0, 'regularization constant')
cmd:option('-momentum', 0, 'momentum')
cmd:option('-dropout', 0, 'dropout rate at each non-recurrent layer')
cmd:option('-batch_size', 32, 'number of minibatch')
cmd:option('-bptt', 10, 'back propagation through time')
cmd:option('-alpha_decay', 0.6, 'alpha *= alpha_decay if no improvement on validation set')
cmd:option('-init_weight', 0.1, 'all weights will be set to [-init_weight, init_weight] during initialization')
cmd:option('-max_norm', 50, 'threshold of gradient clipping (2-norm)')
cmd:option('-max_epoch', 20, 'max number of epoch')
cmd:option('-min_improvement', 1.01, 'start learning rate decay when improvement less than threshold')
cmd:option('-shuffle', 1, 'whether to shuffle data before each epoch')
-- parse input params
local options = cmd:parse(arg)
-- print("OK")
random_seed(options.random_seed)
--print(torch.zeros(1, 1):cuda():uniform())
local vocab = Vocab()
if options.vocab == '' then
if options.outlabel == '' then
vocab:build_vocab(options.train,true)
else
vocab:build_vocab(options.train,false)
vocab:build_vocab_output(options.outlabel)
end
else
if options.outlabel == '' then
vocab:build_vocab(options.vocab,true)
else
vocab:build_vocab(options.vocab,false)
vocab:build_vocab_output(options.outlabel)
end
end
options.vocab_size = vocab:vocab_size()
--print(options.vocab_size["input"] .. ' ' .. options.vocab_size["output"])
options.vocab = vocab
if options.print_vocab ~= '' then
options.vocab:save(options.print_vocab,options.print_vocab .. '.label')
end
if options.trace_level > 0 then
cmd:log('/dev/null', options)
io.stdout:flush()
end
local hidden_prototype = stringx.split(options.hidden_prototype, '-')
options.layers = #hidden_prototype
options.hidden_size = {}
for i = 1, #hidden_prototype do
options.hidden_size[i] = tonumber(hidden_prototype[i])
end
local slu
if options.rnn_type == 'lstm' then
slu = LSTM(options)
else
slu = RNN(options)
end
slu:init(options.read_model)
local start_time = torch.tic()
local alpha_decay = false
if options.train ~= '' and options.valid ~= '' and options.test_only == 0 then
local result = {}
local best_f1 = -1
local len, best_ce, res_valid = slu:evaluate(options.valid, options.valid .. '.iter0')
local test_len, test_ce, res_test = slu:evaluate(options.test, options.test .. '.iter0')
print('Epoch 0 validation result: words = ' .. len .. ', CE = ' .. best_ce .. ', F1 = ' .. string.format('%.4f',res_valid['F1']))
print('Epoch 0 test result: words = ' .. test_len .. ', CE = ' .. test_ce .. ', F1 = ' .. string.format('%.4f',res_test['F1']))
io.stdout:flush()
slu:save_model(options.print_model)
local print_model = options.print_model
for iter = 1, options.max_epoch do
print('Start training epoch ' .. iter .. ', learning rate: ' .. string.format("%.3f", options.alpha))
if options.shuffle == 1 then
os.execute('./utils/shuf.sh ' .. options.train .. ' ' .. options.random_seed)
end
slu:train_one_epoch(options.train)
len, valid_ce, res_valid = slu:evaluate(options.valid, options.valid .. '.iter' .. iter)
test_len, test_ce, res_test = slu:evaluate(options.test, options.test .. '.iter' .. iter)
print('Epoch ' .. iter .. ' validation result: tested words = ' .. len .. ', CE = ' .. valid_ce .. ', F1 = ' .. string.format('%.4f',res_valid['F1']))
print('Epoch ' .. iter .. ' test result: words = ' .. test_len .. ', CE = ' .. test_ce .. ', F1 = ' .. string.format('%.4f',res_test['F1']))
--[[if alpha_decay or best_ce / valid_ce < options.min_improvement then
options.alpha = options.alpha * options.alpha_decay
alpha_decay = true
elseif iter == options.max_epoch then
options.max_epoch = options.max_epoch + 1
end--]]
--[[if best_ce < valid_ce then
slu:restore(print_model)
print('Model is restored to previous epoch.')
else
best_ce = valid_ce
print_model = options.print_model .. '.iter' .. iter
slu:save_model(print_model)
end--]]
if best_f1 < res_valid['F1'] then
slu:save_model(print_model)
best_f1 = res_valid['F1']
print('NEW BEST: epoch ' .. iter .. ' best valid F1 ' .. res_valid['F1'] .. ' test F1 ' .. res_test['F1'])
result['vf1'], result['vp'], result['vr'], result['vce'] = res_valid['F1'], res_valid['precision'], res_valid['recall'], valid_ce
result['tf1'], result['tp'], result['tr'], result['tce'] = res_test['F1'], res_test['precision'], res_test['recall'], test_ce
result['iter'] = iter
end
io.stdout:flush()
end
local elapsed_time = torch.toc(start_time) / 60
print('Training finished, elapsed time = ' .. string.format('%.1f', elapsed_time))
print('BEST RESULT: epoch ' .. result['iter'] .. ' best valid F1 ' .. result['vf1'] .. ' test F1 ' .. result['tf1'])
else
-- print("Please input the text:")
os.execute("python CSPPT/test/convert2list.py 2> CSPPT/test/error_test.txt")
if options.test_only == 1 and options.read_model ~= '' then
local test_len, test_ce, res_test = slu:evaluate(options.test, options.test .. '.result')
-- print("Your result")
end
end