Skip to content

Commit

Permalink
Merge pull request #36 from carusyte/carusyte_fix
Browse files Browse the repository at this point in the history
fix shape issues, improve performance, make it pip installable, etc
  • Loading branch information
dm-jrae authored Aug 6, 2018
2 parents a4debae + 06db1b1 commit 22deec1
Show file tree
Hide file tree
Showing 13 changed files with 68 additions and 25 deletions.
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -96,4 +96,8 @@ ENV/
/site

# mypy
.mypy_cache/
.mypy_cache/

# vscode and its extensions
.vscode/*
.history/*
Empty file added dnc/__init__.py
Empty file.
6 changes: 3 additions & 3 deletions access.py → dnc/access.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
import sonnet as snt
import tensorflow as tf

import addressing
import util
from dnc import addressing
from dnc import util

AccessState = collections.namedtuple('AccessState', (
'memory', 'read_weights', 'write_weights', 'linkage', 'usage'))
Expand Down Expand Up @@ -53,7 +53,7 @@ def _erase_and_write(memory, address, reset_weights, values):
expand_address = tf.expand_dims(address, 3)
reset_weights = tf.expand_dims(reset_weights, 2)
weighted_resets = expand_address * reset_weights
reset_gate = tf.reduce_prod(1 - weighted_resets, [1])
reset_gate = util.reduce_prod(1 - weighted_resets, 1)
memory *= reset_gate

with tf.name_scope('additive_write', values=[memory, address, values]):
Expand Down
4 changes: 2 additions & 2 deletions access_test.py → dnc/access_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
import tensorflow as tf
from tensorflow.python.ops import rnn

import access
import util
from dnc import access
from dnc import util

BATCH_SIZE = 2
MEMORY_SIZE = 20
Expand Down
12 changes: 6 additions & 6 deletions addressing.py → dnc/addressing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import sonnet as snt
import tensorflow as tf

import util
from dnc import util

# Ensure values are greater than epsilon to avoid numerical instability.
_EPSILON = 1e-6
Expand All @@ -32,7 +32,7 @@


def _vector_norms(m):
squared_norms = tf.reduce_sum(m * m, axis=2, keep_dims=True)
squared_norms = tf.reduce_sum(m * m, axis=2, keepdims=True)
return tf.sqrt(squared_norms + _EPSILON)


Expand Down Expand Up @@ -202,7 +202,7 @@ def _link(self, prev_link, prev_precedence_weights, write_weights):
containing the new link graphs for each write head.
"""
with tf.name_scope('link'):
batch_size = prev_link.get_shape()[0].value
batch_size = tf.shape(prev_link)[0]
write_weights_i = tf.expand_dims(write_weights, 3)
write_weights_j = tf.expand_dims(write_weights, 2)
prev_precedence_weights_j = tf.expand_dims(prev_precedence_weights, 2)
Expand Down Expand Up @@ -236,7 +236,7 @@ def _precedence_weights(self, prev_precedence_weights, write_weights):
new precedence weights.
"""
with tf.name_scope('precedence_weights'):
write_sum = tf.reduce_sum(write_weights, 2, keep_dims=True)
write_sum = tf.reduce_sum(write_weights, 2, keepdims=True)
return (1 - write_sum) * prev_precedence_weights + write_weights

@property
Expand Down Expand Up @@ -351,7 +351,7 @@ def _usage_after_write(self, prev_usage, write_weights):
"""
with tf.name_scope('usage_after_write'):
# Calculate the aggregated effect of all write heads
write_weights = 1 - tf.reduce_prod(1 - write_weights, [1])
write_weights = 1 - util.reduce_prod(1 - write_weights, 1)
return prev_usage + (1 - prev_usage) * write_weights

def _usage_after_read(self, prev_usage, free_gate, read_weights):
Expand All @@ -370,7 +370,7 @@ def _usage_after_read(self, prev_usage, free_gate, read_weights):
with tf.name_scope('usage_after_read'):
free_gate = tf.expand_dims(free_gate, -1)
free_read_weights = free_gate * read_weights
phi = tf.reduce_prod(1 - free_read_weights, [1], name='phi')
phi = util.reduce_prod(1 - free_read_weights, 1, name='phi')
return prev_usage * phi

def _allocation(self, usage):
Expand Down
4 changes: 2 additions & 2 deletions addressing_test.py → dnc/addressing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
import sonnet as snt
import tensorflow as tf

import addressing
import util
from dnc import addressing
from dnc import util


class WeightedSoftmaxTest(tf.test.TestCase):
Expand Down
4 changes: 2 additions & 2 deletions dnc.py → dnc/dnc.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import sonnet as snt
import tensorflow as tf

import access
from dnc import access

DNCState = collections.namedtuple('DNCState', ('access_output', 'access_state',
'controller_state'))
Expand Down Expand Up @@ -110,7 +110,7 @@ def _build(self, inputs, prev_state):
controller_input, prev_controller_state)

controller_output = self._clip_if_enabled(controller_output)
controller_state = snt.nest.map(self._clip_if_enabled, controller_state)
controller_state = tf.contrib.framework.nest.map_structure(self._clip_if_enabled, controller_state)

access_output, access_state = self._access(controller_output,
prev_access_state)
Expand Down
File renamed without changes
File renamed without changes.
39 changes: 33 additions & 6 deletions util.py → dnc/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,48 @@
def batch_invert_permutation(permutations):
"""Returns batched `tf.invert_permutation` for every row in `permutations`."""
with tf.name_scope('batch_invert_permutation', values=[permutations]):
unpacked = tf.unstack(permutations)
inverses = [tf.invert_permutation(permutation) for permutation in unpacked]
return tf.stack(inverses)
perm = tf.cast(permutations, tf.float32)
dim = int(perm.get_shape()[-1])
size = tf.cast(tf.shape(perm)[0], tf.float32)
delta = tf.cast(tf.shape(perm)[-1], tf.float32)
rg = tf.range(0, size * delta, delta, dtype=tf.float32)
rg = tf.expand_dims(rg, 1)
rg = tf.tile(rg, [1, dim])
perm = tf.add(perm, rg)
flat = tf.reshape(perm, [-1])
perm = tf.invert_permutation(tf.cast(flat, tf.int32))
perm = tf.reshape(perm, [-1, dim])
return tf.subtract(perm, tf.cast(rg, tf.int32))


def batch_gather(values, indices):
"""Returns batched `tf.gather` for every row in the input."""
with tf.name_scope('batch_gather', values=[values, indices]):
unpacked = zip(tf.unstack(values), tf.unstack(indices))
result = [tf.gather(value, index) for value, index in unpacked]
return tf.stack(result)
idx = tf.expand_dims(indices, -1)
size = tf.shape(indices)[0]
rg = tf.range(size, dtype=tf.int32)
rg = tf.expand_dims(rg, -1)
rg = tf.tile(rg, [1, int(indices.get_shape()[-1])])
rg = tf.expand_dims(rg, -1)
gidx = tf.concat([rg, idx], -1)
return tf.gather_nd(values, gidx)


def one_hot(length, index):
"""Return an nd array of given `length` filled with 0s and a 1 at `index`."""
result = np.zeros(length)
result[index] = 1
return result

def reduce_prod(x, axis, name=None):
"""Efficient reduce product over axis.
Uses tf.cumprod and tf.gather_nd as a workaround to the poor performance of calculating tf.reduce_prod's gradient on CPU.
"""
with tf.name_scope(name, 'util_reduce_prod', values=[x]):
cp = tf.cumprod(x, axis, reverse=True)
size = tf.shape(cp)[0]
idx1 = tf.range(tf.cast(size, tf.float32), dtype=tf.float32)
idx2 = tf.zeros([size], tf.float32)
indices = tf.stack([idx1, idx2], 1)
return tf.gather_nd(cp, tf.cast(indices, tf.int32))
2 changes: 1 addition & 1 deletion util_test.py → dnc/util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import numpy as np
import tensorflow as tf

import util
from dnc import util


class BatchInvertPermutation(tf.test.TestCase):
Expand Down
12 changes: 12 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from setuptools import setup

setup(
name='dnc',
version='0.0.2',
description='This package provides an implementation of the Differentiable Neural Computer, as published in Nature.',
license='Apache Software License 2.0',
packages=['dnc'],
author='DeepMind',
keywords=['tensorflow', 'differentiable neural computer', 'dnc', 'deepmind', 'deep mind', 'sonnet', 'dm-sonnet', 'machine learning'],
url='https://github.com/deepmind/dnc'
)
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
import tensorflow as tf
import sonnet as snt

import dnc
import repeat_copy
from dnc import dnc
from dnc import repeat_copy

FLAGS = tf.flags.FLAGS

Expand Down

0 comments on commit 22deec1

Please sign in to comment.