forked from tensorlayer/seq2seq-chatbot
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
141 lines (103 loc) · 4.84 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
#! /usr/bin/python
# -*- coding: utf-8 -*-
import tensorflow as tf
import tensorlayer as tl
import numpy as np
from tensorlayer.cost import cross_entropy_seq, cross_entropy_seq_with_mask
from tqdm import tqdm
from sklearn.utils import shuffle
from data.twitter import data
from tensorlayer.models.seq2seq import Seq2seq
import os
def initial_setup(data_corpus):
metadata, idx_q, idx_a = data.load_data(PATH='data/{}/'.format(data_corpus))
(trainX, trainY), (testX, testY), (validX, validY) = data.split_dataset(idx_q, idx_a)
trainX = tl.prepro.remove_pad_sequences(trainX.tolist())
trainY = tl.prepro.remove_pad_sequences(trainY.tolist())
testX = tl.prepro.remove_pad_sequences(testX.tolist())
testY = tl.prepro.remove_pad_sequences(testY.tolist())
validX = tl.prepro.remove_pad_sequences(validX.tolist())
validY = tl.prepro.remove_pad_sequences(validY.tolist())
return metadata, trainX, trainY, testX, testY, validX, validY
if __name__ == "__main__":
data_corpus = "twitter"
#data preprocessing
metadata, trainX, trainY, testX, testY, validX, validY = initial_setup(data_corpus)
# Parameters
src_len = len(trainX)
tgt_len = len(trainY)
assert src_len == tgt_len
batch_size = 32
n_step = src_len // batch_size
src_vocab_size = len(metadata['idx2w']) # 8002 (0~8001)
emb_dim = 1024
word2idx = metadata['w2idx'] # dict word 2 index
idx2word = metadata['idx2w'] # list index 2 word
unk_id = word2idx['unk'] # 1
pad_id = word2idx['_'] # 0
start_id = src_vocab_size # 8002
end_id = src_vocab_size + 1 # 8003
word2idx.update({'start_id': start_id})
word2idx.update({'end_id': end_id})
idx2word = idx2word + ['start_id', 'end_id']
src_vocab_size = tgt_vocab_size = src_vocab_size + 2
num_epochs = 50
vocabulary_size = src_vocab_size
def inference(seed, top_n):
model_.eval()
seed_id = [word2idx.get(w, unk_id) for w in seed.split(" ")]
sentence_id = model_(inputs=[seed_id], seq_length=20, start_token=start_id, top_n = top_n)
sentence = []
for w_id in sentence_id[0]:
w = idx2word[w_id]
if w == 'end_id':
break
sentence = sentence + [w]
return sentence
decoder_seq_length = 20
model_ = Seq2seq(
decoder_seq_length = decoder_seq_length,
cell_enc=tf.keras.layers.GRUCell,
cell_dec=tf.keras.layers.GRUCell,
n_layer=3,
n_units=256,
embedding_layer=tl.layers.Embedding(vocabulary_size=vocabulary_size, embedding_size=emb_dim),
)
# Uncomment below statements if you have already saved the model
# load_weights = tl.files.load_npz(name='model.npz')
# tl.files.assign_weights(load_weights, model_)
optimizer = tf.optimizers.Adam(learning_rate=0.001)
model_.train()
seeds = ["happy birthday have a nice day",
"donald trump won last nights presidential debate according to snap online polls"]
for epoch in range(num_epochs):
model_.train()
trainX, trainY = shuffle(trainX, trainY, random_state=0)
total_loss, n_iter = 0, 0
for X, Y in tqdm(tl.iterate.minibatches(inputs=trainX, targets=trainY, batch_size=batch_size, shuffle=False),
total=n_step, desc='Epoch[{}/{}]'.format(epoch + 1, num_epochs), leave=False):
X = tl.prepro.pad_sequences(X)
_target_seqs = tl.prepro.sequences_add_end_id(Y, end_id=end_id)
_target_seqs = tl.prepro.pad_sequences(_target_seqs, maxlen=decoder_seq_length)
_decode_seqs = tl.prepro.sequences_add_start_id(Y, start_id=start_id, remove_last=False)
_decode_seqs = tl.prepro.pad_sequences(_decode_seqs, maxlen=decoder_seq_length)
_target_mask = tl.prepro.sequences_get_mask(_target_seqs)
with tf.GradientTape() as tape:
## compute outputs
output = model_(inputs = [X, _decode_seqs])
output = tf.reshape(output, [-1, vocabulary_size])
## compute loss and update model
loss = cross_entropy_seq_with_mask(logits=output, target_seqs=_target_seqs, input_mask=_target_mask)
grad = tape.gradient(loss, model_.all_weights)
optimizer.apply_gradients(zip(grad, model_.all_weights))
total_loss += loss
n_iter += 1
# printing average loss after every epoch
print('Epoch [{}/{}]: loss {:.4f}'.format(epoch + 1, num_epochs, total_loss / n_iter))
for seed in seeds:
print("Query >", seed)
top_n = 3
for i in range(top_n):
sentence = inference(seed, top_n)
print(" >", ' '.join(sentence))
tl.files.save_npz(model_.all_weights, name='model.npz')