-
Notifications
You must be signed in to change notification settings - Fork 117
/
text_train.py
123 lines (103 loc) · 4.63 KB
/
text_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#encoding:utf-8
from __future__ import print_function
from text_model import *
from loader import *
from sklearn import metrics
import sys
import os
import time
from datetime import timedelta
def evaluate(sess, x_, y_):
data_len = len(x_)
batch_eval = batch_iter(x_, y_, 128)
total_loss = 0.0
total_acc = 0.0
for x_batch, y_batch in batch_eval:
batch_len = len(x_batch)
feed_dict = feed_data(x_batch, y_batch, 1.0)
loss, acc = sess.run([model.loss, model.acc], feed_dict=feed_dict)
total_loss += loss * batch_len
total_acc += acc * batch_len
return total_loss / data_len, total_acc / data_len
def feed_data(x_batch, y_batch, keep_prob):
feed_dict = {
model.input_x: x_batch,
model.input_y: y_batch,
model.keep_prob:keep_prob
}
return feed_dict
def train():
print("Configuring TensorBoard and Saver...")
tensorboard_dir = './tensorboard/textcnn'
save_dir = './checkpoints/textcnn'
if not os.path.exists(tensorboard_dir):
os.makedirs(tensorboard_dir)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
save_path = os.path.join(save_dir, 'best_validation')
print("Loading training and validation data...")
start_time = time.time()
x_train, y_train = process_file(config.train_filename, word_to_id, cat_to_id, config.seq_length)
x_val, y_val = process_file(config.val_filename, word_to_id, cat_to_id, config.seq_length)
print("Time cost: %.3f seconds...\n" % (time.time() - start_time))
tf.summary.scalar("loss", model.loss)
tf.summary.scalar("accuracy", model.acc)
merged_summary = tf.summary.merge_all()
writer = tf.summary.FileWriter(tensorboard_dir)
saver = tf.train.Saver()
session = tf.Session()
session.run(tf.global_variables_initializer())
writer.add_graph(session.graph)
print('Training and evaluating...')
best_val_accuracy = 0
last_improved = 0 # record global_step at best_val_accuracy
require_improvement = 1000 # break training if not having improvement over 1000 iter
flag=False
for epoch in range(config.num_epochs):
batch_train = batch_iter(x_train, y_train, config.batch_size)
start = time.time()
print('Epoch:', epoch + 1)
for x_batch, y_batch in batch_train:
feed_dict = feed_data(x_batch, y_batch, config.keep_prob)
_, global_step, train_summaries, train_loss, train_accuracy = session.run([model.optim, model.global_step,
merged_summary, model.loss,
model.acc], feed_dict=feed_dict)
if global_step % config.print_per_batch == 0:
end = time.time()
val_loss, val_accuracy = evaluate(session, x_val, y_val)
writer.add_summary(train_summaries, global_step)
# If improved, save the model
if val_accuracy > best_val_accuracy:
saver.save(session, save_path)
best_val_accuracy = val_accuracy
last_improved=global_step
improved_str = '*'
else:
improved_str = ''
print("step: {},train loss: {:.3f}, train accuracy: {:.3f}, val loss: {:.3f}, val accuracy: {:.3f},training speed: {:.3f}sec/batch {}\n".format(
global_step, train_loss, train_accuracy, val_loss, val_accuracy,
(end - start) / config.print_per_batch,improved_str))
start = time.time()
if global_step - last_improved > require_improvement:
print("No optimization over 1000 steps, stop training")
flag = True
break
if flag:
break
config.lr *= config.lr_decay
if __name__ == '__main__':
print('Configuring CNN model...')
config = TextConfig()
filenames = [config.train_filename, config.test_filename, config.val_filename]
if not os.path.exists(config.vocab_filename):
build_vocab(filenames, config.vocab_filename, config.vocab_size)
#read vocab and categories
categories,cat_to_id = read_category()
words,word_to_id = read_vocab(config.vocab_filename)
config.vocab_size = len(words)
# trans vector file to numpy file
if not os.path.exists(config.vector_word_npz):
export_word2vec_vectors(word_to_id, config.vector_word_filename, config.vector_word_npz)
config.pre_trianing = get_training_word2vec_vectors(config.vector_word_npz)
model = TextCNN(config)
train()