Skip to content

Commit

Permalink
First commit
Browse files Browse the repository at this point in the history
  • Loading branch information
indiejoseph committed Apr 21, 2017
0 parents commit bc739f9
Show file tree
Hide file tree
Showing 10 changed files with 731 additions and 0 deletions.
21 changes: 21 additions & 0 deletions .editorconfig
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# EditorConfig helps developers define and maintain consistent
# coding styles between different editors and IDEs
# editorconfig.org

root = true


[*]

# Change these settings to your own preference
indent_style = space
indent_size = 2

# We recommend you to keep these unchanged
end_of_line = lf
charset = utf-8
trim_trailing_whitespace = true
insert_final_newline = true

[*.md]
trim_trailing_whitespace = false
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
*.pyc
.DS_Store
checkpoint
data
!data/wiki/vocab.pkl
log
6 changes: 6 additions & 0 deletions .ipynb_checkpoints/chatbot-checkpoint.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"cells": [],
"metadata": {},
"nbformat": 4,
"nbformat_minor": 2
}
18 changes: 18 additions & 0 deletions .pylintrc
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#My pylintrc for use with atom.io's linter-pylint
[MESSAGES CONTROL]
disable=W0311,W1201,W0702,W0611,W0621,E1101,C0111,C0103,R0902

# checks for :
# * unauthorized constructions
# * strict indentation
# * line length
# * use of <> instead of !=
#
[FORMAT]
# Maximum number of characters on a single line.
max-line-length=128
# Maximum number of lines in a module
max-module-lines=1000
# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
# tab). In repo it is 2 spaces.
indent-string=' '
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# TF Seq2Seq Chatbot

### Requirements
* tensorflow r1.1

### References
- [Learning Phrase Representations using RNN Encoder–Decoder for Statistical Machine Translation](https://arxiv.org/pdf/1406.1078.pdf)
- [A Neural Conversational Model](https://arxiv.org/pdf/1506.05869.pdf)
- [A Hierarchical Recurrent Encoder-Decoder for Generative Context-Aware Query Suggestion](https://arxiv.org/pdf/1507.02221.pdf)
- [Attention with Intention for a Neural Network Conversation Model](https://arxiv.org/pdf/1510.08565.pdf)
65 changes: 65 additions & 0 deletions grucell_cond.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# -*- coding: utf-8 -*-
import tensorflow as tf
from tensorflow.python.framework import ops
from tensorflow.python.ops.rnn_cell_impl import _RNNCell as RNNCell
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops.math_ops import tanh, sigmoid
from tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl import _linear, _checked_scope

class CondWrapper(RNNCell):
def __init__(self, cell, context):
self._context = context
self._cell = cell
self._output_size = self._cell.output_size

@property
def state_size(self):
return self._cell.state_size

@property
def output_size(self):
return self._output_size

def zero_state(self, batch_size, dtype):
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
return self._cell.zero_state(batch_size, dtype)

def __call__(self, inputs, state, scope=None):
output, res_state = self._cell(inputs, state, self._context)
return output, res_state

class GRUCellCond(RNNCell):
"""Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078)."""

def __init__(self, num_units, input_size=None, activation=tanh, reuse=None):
if input_size is not None:
logging.warn("%s: The input_size parameter is deprecated.", self)
self._num_units = num_units
self._activation = activation
self._reuse = reuse

@property
def state_size(self):
return self._num_units

@property
def output_size(self):
return self._num_units

def __call__(self, inputs, state, context, scope=None):
"""Gated recurrent unit (GRU) with nunits cells."""
with _checked_scope(self, scope or "gru_cell", reuse=self._reuse):
with vs.variable_scope("gates"): # Reset gate and update gate.
# We start with bias of 1.0 to not reset and not update.
value = sigmoid(_linear(
[inputs, state, context], 2 * self._num_units, True, 1.0))
r, u = array_ops.split(
value=value,
num_or_size_splits=2,
axis=1)
with vs.variable_scope("candidate"):
c = self._activation(_linear([inputs, r * state],
self._num_units, True))
new_h = u * state + (1 - u) * c
return new_h, new_h
82 changes: 82 additions & 0 deletions infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# -*- coding: utf-8 -*-

import tensorflow as tf
import numpy as np
import pprint
import _pickle as cPickle
from model import DialogueModel
from utils import TextLoader, UNK_ID, PAD_ID
from glob import glob

checkpoint = "/tmp/model.ckpt"

pp = pprint.PrettyPrinter()

flags = tf.app.flags
flags.DEFINE_string("checkpoint", "checkpoint", "Directory name to save the checkpoints [checkpoint]")
flags.DEFINE_string("logdir", "log", "Log directory [log]")
FLAGS = flags.FLAGS

def main(_):
config = cPickle.load(open(FLAGS.logdir + "/hyperparams.pkl", 'rb'))
pp.pprint(config)

try:
# pre-trained chars embedding
emb = np.load("./data/emb.npy")
chars = cPickle.load(open("./data/vocab.pkl", 'rb'))
vocab_size, emb_size = np.shape(emb)
data_loader = TextLoader('./data', 1, chars)
except Exception:
data_loader = TextLoader('./data', 1)
emb_size = config["emb_size"]
vocab_size = data_loader.vocab_size

checkpoint = FLAGS.checkpoint + '/model.ckpt'

model = DialogueModel(batch_size=1, max_seq_length=data_loader.seq_length,
vocab_size=vocab_size, pad_token_id=0, unk_token_id=UNK_ID,
emb_size=emb_size, memory_size=config["memory_size"],
keep_prob=config["keep_prob"], learning_rate=config["learning_rate"],
grad_clip=config["grad_clip"], infer=True)

init = tf.global_variables_initializer()
saver = tf.train.Saver()

with tf.Session() as sess:
sess.run(init)

if len(glob(checkpoint + "*")) > 0:
saver.restore(sess, checkpoint)
else:
print("No model found!")
return

## -- debug --
#np.set_printoptions(threshold=np.inf)
#for v in tf.trainable_variables():
# print(v.name)
# print(sess.run(v))
# print()
#return

while True:
try:
input_ = input('in> ')
except EOFError:
print("\nBye!")
break

input_ids, input_len = data_loader.parse_input(input_)

feed = {
model.input_data: np.expand_dims(input_ids, 0),
model.input_lengths: [input_len]
}

output_ids, state = sess.run([model.output_ids, model.final_state], feed_dict=feed)

print(data_loader.compose_output(output_ids[0]))

if __name__ == "__main__":
tf.app.run()
Loading

0 comments on commit bc739f9

Please sign in to comment.