Skip to content

Commit

Permalink
Update infer
Browse files Browse the repository at this point in the history
  • Loading branch information
indiejoseph committed Apr 27, 2017
1 parent cb77b8d commit 9890070
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 19 deletions.
3 changes: 2 additions & 1 deletion infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
flags = tf.app.flags
flags.DEFINE_string("checkpoint", "checkpoint", "Directory name to save the checkpoints [checkpoint]")
flags.DEFINE_string("logdir", "log", "Log directory [log]")
flags.DEFINE_float("temperature", 0.5, "temperature")
FLAGS = flags.FLAGS

def main(_):
Expand All @@ -38,7 +39,7 @@ def main(_):
vocab_size=vocab_size, pad_token_id=0, unk_token_id=UNK_ID,
emb_size=emb_size, memory_size=config["memory_size"],
keep_prob=config["keep_prob"], learning_rate=config["learning_rate"],
grad_clip=config["grad_clip"], infer=True)
grad_clip=config["grad_clip"], temperature=config["temperature"], infer=True)

init = tf.global_variables_initializer()
saver = tf.train.Saver()
Expand Down
33 changes: 19 additions & 14 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def _count_param_size(tvars):
class DialogueModel(object):
def __init__(self, batch_size, max_seq_length, vocab_size,
start_token_id=1, end_token_id=2, pad_token_id=0, unk_token_id=3,
emb_size=100, memory_size=100, keep_prob=0.5, temperature=0.9,
emb_size=100, memory_size=100, keep_prob=0.5, temperature=0.5, antilm=0.55,
learning_rate=0.001, grad_clip=5.0, infer=False):

self._batch_size = batch_size
Expand All @@ -38,6 +38,7 @@ def __init__(self, batch_size, max_seq_length, vocab_size,
self._end_token_id = end_token_id
self._pad_token_id = pad_token_id
self._infer = infer
self._antilm = antilm

self.input_data = tf.placeholder(tf.int32, [batch_size, max_seq_length], name="input_data")
self.input_lengths = tf.placeholder(tf.int32, shape=[batch_size], name="input_lengths")
Expand Down Expand Up @@ -99,10 +100,10 @@ def seq2seq(self, inputs, fw_cell, bw_cell, ctx_cell, dec_cell, reuse=False):
if reuse:
scope.reuse_variables()
enc_outputs, enc_state = self.encode(fw_cell, bw_cell, inputs)
ctx_output, ctx_state = self.contextual(ctx_cell, enc_state)
dec_outputs, dec_state = self.decode(dec_cell, enc_outputs, ctx_output)
outputs = dec_outputs.rnn_output
output_ids = dec_outputs.sample_id
ctx_outputs, ctx_state = self.contextual(ctx_cell, enc_state)
dec_outputs, dec_sample_id, dec_state = self.decode(dec_cell, enc_outputs, ctx_outputs)
outputs = dec_outputs
output_ids = dec_sample_id
output_state = dec_state.cell_state

return outputs, output_ids, output_state, ctx_state
Expand All @@ -128,34 +129,38 @@ def contextual(self, ctx_cell, enc_state):
with tf.variable_scope("context"):
_, ctx_state = ctx_cell(enc_state, self.initial_state)
# Sec 3.2.3 in https://arxiv.org/pdf/1507.02221.pdf
ctx_output = tf.tanh(tf.matmul(ctx_state, self.ctx_w) + self.ctx_b)
return ctx_output, ctx_state
ctx_outputs = tf.tanh(tf.matmul(ctx_state, self.ctx_w) + self.ctx_b)
return ctx_outputs, ctx_state

def decode(self, dec_cell, enc_outputs, ctx_outputs):
with tf.variable_scope("decode"):
batch_size = self._batch_size

attn_mech = seq2seq.BahdanauAttention(self._memory_size, enc_outputs, self.input_lengths)
dec_cell = CondWrapper(dec_cell, ctx_outputs)
dec_cell = seq2seq.AttentionWrapper(dec_cell, attn_mech, self._memory_size)
dec_initial_state = dec_cell.zero_state(batch_size=self._batch_size, dtype=tf.float32)
dec_initial_state = dec_cell.zero_state(batch_size=batch_size, dtype=tf.float32)
helper_build_fn = self._infer_helper if self._infer else self._train_helper

output_layer = layers_core.Dense(self._vocab_size, use_bias=True)
output_layer = layers_core.Dense(self._vocab_size, use_bias=True, activation=None)
decoder = seq2seq.BasicDecoder(cell=dec_cell,
helper=helper_build_fn(),
initial_state=dec_initial_state,
output_layer=output_layer)
dec_outputs, dec_state = seq2seq.dynamic_decode(decoder,
impute_finished=True,
maximum_iterations=self._max_seq_length)
return dec_outputs, dec_state
dec_output, dec_state = seq2seq.dynamic_decode(decoder,
impute_finished=True,
maximum_iterations=self._max_seq_length)
rnn_output = dec_output.rnn_output
sample_id = dec_output.sample_id
return rnn_output, sample_id, dec_state

def _infer_helper(self):
helper = seq2seq.GreedyEmbeddingHelper(self.embedding,
start_tokens=tf.fill([self._batch_size], self._start_token_id),
end_token=tf.constant(self._end_token_id, dtype=tf.int32))
def sample_fn(time, outputs, state):
sample_ids = tf.multinomial(tf.exp(outputs / self._temperature), 1)
sample_ids = tf.cast(tf.reshape(sample_ids, [self._batch_size]), dtype=tf.int32)
sample_ids = tf.cast(tf.reshape(sample_ids, [self._batch_size * 1]), dtype=tf.int32)
return sample_ids

return seq2seq.CustomHelper(initialize_fn=helper.initialize,
Expand Down
9 changes: 5 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,12 @@ def main(_):
res = model.step(sess, x, y, input_lengths, output_lengths, state, summaries)
summary_writer.add_summary(res["summary_out"], current_step)
loss = res["loss"]
perplexity = np.exp(loss)
count += 1
print("{0}/{1}({2}), loss {3}".format(current_step + 1,
FLAGS.num_epochs * data_loader.num_batches,
e,
loss))
print("{0}/{1}({2}), perplexity {3}".format(current_step + 1,
FLAGS.num_epochs * data_loader.num_batches,
e,
perplexity))
state = res["final_state"]

if (current_step + 1) % 2000 == 0:
Expand Down

0 comments on commit 9890070

Please sign in to comment.