-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit bc739f9
Showing
10 changed files
with
731 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# EditorConfig helps developers define and maintain consistent | ||
# coding styles between different editors and IDEs | ||
# editorconfig.org | ||
|
||
root = true | ||
|
||
|
||
[*] | ||
|
||
# Change these settings to your own preference | ||
indent_style = space | ||
indent_size = 2 | ||
|
||
# We recommend you to keep these unchanged | ||
end_of_line = lf | ||
charset = utf-8 | ||
trim_trailing_whitespace = true | ||
insert_final_newline = true | ||
|
||
[*.md] | ||
trim_trailing_whitespace = false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
*.pyc | ||
.DS_Store | ||
checkpoint | ||
data | ||
!data/wiki/vocab.pkl | ||
log |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
{ | ||
"cells": [], | ||
"metadata": {}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
#My pylintrc for use with atom.io's linter-pylint | ||
[MESSAGES CONTROL] | ||
disable=W0311,W1201,W0702,W0611,W0621,E1101,C0111,C0103,R0902 | ||
|
||
# checks for : | ||
# * unauthorized constructions | ||
# * strict indentation | ||
# * line length | ||
# * use of <> instead of != | ||
# | ||
[FORMAT] | ||
# Maximum number of characters on a single line. | ||
max-line-length=128 | ||
# Maximum number of lines in a module | ||
max-module-lines=1000 | ||
# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 | ||
# tab). In repo it is 2 spaces. | ||
indent-string=' ' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# TF Seq2Seq Chatbot | ||
|
||
### Requirements | ||
* tensorflow r1.1 | ||
|
||
### References | ||
- [Learning Phrase Representations using RNN Encoder–Decoder for Statistical Machine Translation](https://arxiv.org/pdf/1406.1078.pdf) | ||
- [A Neural Conversational Model](https://arxiv.org/pdf/1506.05869.pdf) | ||
- [A Hierarchical Recurrent Encoder-Decoder for Generative Context-Aware Query Suggestion](https://arxiv.org/pdf/1507.02221.pdf) | ||
- [Attention with Intention for a Neural Network Conversation Model](https://arxiv.org/pdf/1510.08565.pdf) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# -*- coding: utf-8 -*- | ||
import tensorflow as tf | ||
from tensorflow.python.framework import ops | ||
from tensorflow.python.ops.rnn_cell_impl import _RNNCell as RNNCell | ||
from tensorflow.python.ops import array_ops | ||
from tensorflow.python.ops import variable_scope as vs | ||
from tensorflow.python.ops.math_ops import tanh, sigmoid | ||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl import _linear, _checked_scope | ||
|
||
class CondWrapper(RNNCell): | ||
def __init__(self, cell, context): | ||
self._context = context | ||
self._cell = cell | ||
self._output_size = self._cell.output_size | ||
|
||
@property | ||
def state_size(self): | ||
return self._cell.state_size | ||
|
||
@property | ||
def output_size(self): | ||
return self._output_size | ||
|
||
def zero_state(self, batch_size, dtype): | ||
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): | ||
return self._cell.zero_state(batch_size, dtype) | ||
|
||
def __call__(self, inputs, state, scope=None): | ||
output, res_state = self._cell(inputs, state, self._context) | ||
return output, res_state | ||
|
||
class GRUCellCond(RNNCell): | ||
"""Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078).""" | ||
|
||
def __init__(self, num_units, input_size=None, activation=tanh, reuse=None): | ||
if input_size is not None: | ||
logging.warn("%s: The input_size parameter is deprecated.", self) | ||
self._num_units = num_units | ||
self._activation = activation | ||
self._reuse = reuse | ||
|
||
@property | ||
def state_size(self): | ||
return self._num_units | ||
|
||
@property | ||
def output_size(self): | ||
return self._num_units | ||
|
||
def __call__(self, inputs, state, context, scope=None): | ||
"""Gated recurrent unit (GRU) with nunits cells.""" | ||
with _checked_scope(self, scope or "gru_cell", reuse=self._reuse): | ||
with vs.variable_scope("gates"): # Reset gate and update gate. | ||
# We start with bias of 1.0 to not reset and not update. | ||
value = sigmoid(_linear( | ||
[inputs, state, context], 2 * self._num_units, True, 1.0)) | ||
r, u = array_ops.split( | ||
value=value, | ||
num_or_size_splits=2, | ||
axis=1) | ||
with vs.variable_scope("candidate"): | ||
c = self._activation(_linear([inputs, r * state], | ||
self._num_units, True)) | ||
new_h = u * state + (1 - u) * c | ||
return new_h, new_h |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
import tensorflow as tf | ||
import numpy as np | ||
import pprint | ||
import _pickle as cPickle | ||
from model import DialogueModel | ||
from utils import TextLoader, UNK_ID, PAD_ID | ||
from glob import glob | ||
|
||
checkpoint = "/tmp/model.ckpt" | ||
|
||
pp = pprint.PrettyPrinter() | ||
|
||
flags = tf.app.flags | ||
flags.DEFINE_string("checkpoint", "checkpoint", "Directory name to save the checkpoints [checkpoint]") | ||
flags.DEFINE_string("logdir", "log", "Log directory [log]") | ||
FLAGS = flags.FLAGS | ||
|
||
def main(_): | ||
config = cPickle.load(open(FLAGS.logdir + "/hyperparams.pkl", 'rb')) | ||
pp.pprint(config) | ||
|
||
try: | ||
# pre-trained chars embedding | ||
emb = np.load("./data/emb.npy") | ||
chars = cPickle.load(open("./data/vocab.pkl", 'rb')) | ||
vocab_size, emb_size = np.shape(emb) | ||
data_loader = TextLoader('./data', 1, chars) | ||
except Exception: | ||
data_loader = TextLoader('./data', 1) | ||
emb_size = config["emb_size"] | ||
vocab_size = data_loader.vocab_size | ||
|
||
checkpoint = FLAGS.checkpoint + '/model.ckpt' | ||
|
||
model = DialogueModel(batch_size=1, max_seq_length=data_loader.seq_length, | ||
vocab_size=vocab_size, pad_token_id=0, unk_token_id=UNK_ID, | ||
emb_size=emb_size, memory_size=config["memory_size"], | ||
keep_prob=config["keep_prob"], learning_rate=config["learning_rate"], | ||
grad_clip=config["grad_clip"], infer=True) | ||
|
||
init = tf.global_variables_initializer() | ||
saver = tf.train.Saver() | ||
|
||
with tf.Session() as sess: | ||
sess.run(init) | ||
|
||
if len(glob(checkpoint + "*")) > 0: | ||
saver.restore(sess, checkpoint) | ||
else: | ||
print("No model found!") | ||
return | ||
|
||
## -- debug -- | ||
#np.set_printoptions(threshold=np.inf) | ||
#for v in tf.trainable_variables(): | ||
# print(v.name) | ||
# print(sess.run(v)) | ||
# print() | ||
#return | ||
|
||
while True: | ||
try: | ||
input_ = input('in> ') | ||
except EOFError: | ||
print("\nBye!") | ||
break | ||
|
||
input_ids, input_len = data_loader.parse_input(input_) | ||
|
||
feed = { | ||
model.input_data: np.expand_dims(input_ids, 0), | ||
model.input_lengths: [input_len] | ||
} | ||
|
||
output_ids, state = sess.run([model.output_ids, model.final_state], feed_dict=feed) | ||
|
||
print(data_loader.compose_output(output_ids[0])) | ||
|
||
if __name__ == "__main__": | ||
tf.app.run() |
Oops, something went wrong.