-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmodel.py
311 lines (251 loc) · 13.4 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Concatenate, Dense, LSTM
from typing import Union
class DNC(tf.keras.Model):
def __init__(self,
output_dim: int,
memory_shape: tuple = (100, 20),
n_read: int = 3,
name: str = 'dnc'
) -> None:
"""
Initialize DNC object.
Parameters
----------
output_dim
Size of output vector.
memory_shape
Shape of memory matrix (rows, cols).
n_read
Number of read heads.
name
Name of DNC.
"""
super(DNC, self).__init__(name=name)
# define output data size
self.output_dim = output_dim # Y
# define size of memory matrix
self.N, self.W = memory_shape # N, W
# define number of read heads
self.R = n_read # R
# size of output vector from controller that defines interactions with memory matrix:
# R read keys + R read strengths + write key + write strength + erase vector +
# write vector + R free gates + allocation gate + write gate + R read modes
self.interface_dim = self.R * self.W + 3 * self.W + 5 * self.R + 3 # I
# neural net output = output of controller + interface vector with memory
self.controller_dim = self.output_dim + self.interface_dim # Y+I
# initialize controller output and interface vector with gaussian normal
self.output_v = tf.random.truncated_normal([1, self.output_dim], stddev=0.1) # [1,Y]
self.interface = tf.random.truncated_normal([1, self.interface_dim], stddev=0.1) # [1,I]
# initialize memory matrix with zeros
self.M = tf.zeros(memory_shape) # [N,W]
# usage vector records which locations in the memory are used and which are free
self.usage = tf.fill([self.N, 1], 1e-6) # [N,1]
# temporal link matrix L[i,j] records to which degree location i was written to after j
self.L = tf.zeros([self.N, self.N]) # [N,N]
# precedence vector determines degree to which a memory row was written to at t-1
self.W_precedence = tf.zeros([self.N, 1]) # [N,1]
# initialize R read weights and vectors and write weights
self.W_read = tf.fill([self.N, self.R], 1e-6) # [N,R]
self.W_write = tf.fill([self.N, 1], 1e-6) # [N,1]
self.read_v = tf.fill([self.R, self.W], 1e-6) # [R,W]
# controller variables
# initialize controller hidden state
self.h = tf.Variable(tf.random.truncated_normal([1, self.controller_dim], stddev=0.1), name='dnc_h') # [1,Y+I]
self.c = tf.Variable(tf.random.truncated_normal([1, self.controller_dim], stddev=0.1), name='dnc_c') # [1,Y+I]
# initialise Dense and LSTM layers of the controller
self.dense = Dense(self.W, activation=None)
self.lstm = LSTM(
self.controller_dim,
return_sequences=False,
return_state=True,
name='dnc_controller'
)
# define and initialize weights for controller output and interface vectors
self.W_output = tf.Variable( # [Y+I,Y]
tf.random.truncated_normal([self.controller_dim, self.output_dim], stddev=0.1),
name='dnc_net_output_weights'
)
self.W_interface = tf.Variable( # [Y+I,I]
tf.random.truncated_normal([self.controller_dim, self.interface_dim], stddev=0.1),
name='dnc_interface_weights'
)
# output y = v + W_read_out[r(1), ..., r(R)]
self.W_read_out = tf.Variable( # [R*W,Y]
tf.random.truncated_normal([self.R * self.W, self.output_dim], stddev=0.1),
name='dnc_read_vector_weights'
)
def content_lookup(self, key: tf.Tensor, strength: tf.Tensor) -> tf.Tensor:
"""
Attention mechanism: content based addressing to read from and write to the memory.
Params
------
key
Key vector emitted by the controller and used to calculate row-by-row
cosine similarity with the memory matrix.
strength
Strength scalar attached to each key vector (1x1 or 1xR).
Returns
-------
Similarity measure for each row in the memory used by the read heads for associative
recall or by the write head to modify a vector in memory.
"""
# The l2 norm applied to each key and each row in the memory matrix
norm_mem = tf.nn.l2_normalize(self.M, 1) # [N,W]
norm_key = tf.nn.l2_normalize(key, 1) # [1,W] for write or [R,W] for read
# get similarity measure between both vectors, transpose before multiplication
# write: [N*W]*[W*1] -> [N*1]
# read: [N*W]*[W*R] -> [N,R]
sim = tf.matmul(norm_mem, norm_key, transpose_b=True)
return tf.nn.softmax(sim * strength, 0) # [N,1] or [N,R]
def allocation_weighting(self) -> tf.Tensor:
"""
Memory needs to be freed up and allocated in a differentiable way.
The usage vector shows how much each memory row is used.
Unused rows can be written to. Usage of a row increases if
we write to it and can decrease if we read from it, depending on the free gates.
Allocation weights are then derived from the usage vector.
Returns
-------
Allocation weights for each row in the memory.
"""
# sort usage vector in ascending order and keep original indices of sorted usage vector
sorted_usage, free_list = tf.nn.top_k(-1 * tf.transpose(self.usage), k=self.N)
sorted_usage *= -1
cumprod = tf.math.cumprod(sorted_usage, axis=1, exclusive=True)
unorder = (1 - sorted_usage) * cumprod
W_alloc = tf.zeros([self.N])
I = tf.constant(np.identity(self.N, dtype=np.float32))
# for each usage vec
for pos, idx in enumerate(tf.unstack(free_list[0])):
# flatten
m = tf.squeeze(tf.slice(I, [idx, 0], [1, -1]))
# add to allocation weight matrix
W_alloc += m * unorder[0, pos]
# return the allocation weighting for each row in memory
return tf.reshape(W_alloc, [self.N, 1])
def controller(self, x: tf.Tensor) -> None:
""" Update the hidden state of the LSTM controller. """
# flatten input and pass through dense layer to avoid shape mismatch
x = tf.reshape(x, [1, -1])
x = self.dense(x) # [1,W]
# concatenate input with read vectors
x_in = tf.expand_dims(Concatenate(axis=0)([x, self.read_v]), axis=0) # [1,R+1,W]
# LSTM controller
initial_state = [self.h, self.c]
_, self.h, self.c = self.lstm(x_in, initial_state=initial_state)
def partition_interface(self):
"""
Partition the interface vector in the read and write keys and strengths,
the free, allocation and write gates, read modes and erase and write vectors.
"""
# convert interface vector into a set of read write vectors
partition = tf.constant([[0] * (self.R * self.W) + [1] * self.R +
[2] * self.W + [3] + [4] * self.W + [5] * self.W +
[6] * self.R + [7] + [8] + [9] * (self.R * 3)],
dtype=tf.int32)
(k_read, b_read, k_write, b_write, erase, write_v, free_gates, alloc_gate,
write_gate, read_modes) = tf.dynamic_partition(self.interface, partition, 10)
# R read keys and strengths
k_read = tf.reshape(k_read, [self.R, self.W]) # [R,W]
b_read = 1 + tf.nn.softplus(tf.expand_dims(b_read, 0)) # [1,R]
# write key, strength, erase and write vectors
k_write = tf.expand_dims(k_write, 0) # [1,W]
b_write = 1 + tf.nn.softplus(tf.expand_dims(b_write, 0)) # [1,1]
erase = tf.nn.sigmoid(tf.expand_dims(erase, 0)) # [1,W]
write_v = tf.expand_dims(write_v, 0) # [1,W]
# the degree to which locations at read heads will be freed
free_gates = tf.nn.sigmoid(tf.expand_dims(free_gates, 0)) # [1,R]
# the fraction of writing that is being allocated in a new location
alloc_gate = tf.reshape(tf.nn.sigmoid(alloc_gate), [1]) # 1
# the amount of information to be written to memory
write_gate = tf.reshape(tf.nn.sigmoid(write_gate), [1]) # 1
# softmax distribution over the 3 read modes (forward, content lookup, backward)
read_modes = tf.reshape(read_modes, [3, self.R]) # [3,R]
read_modes = tf.nn.softmax(read_modes, axis=0)
return (k_read, b_read, k_write, b_write, erase, write_v,
free_gates, alloc_gate, write_gate, read_modes)
def write(self,
free_gates: tf.Tensor,
alloc_gate: tf.Tensor,
write_gate: tf.Tensor,
k_write: tf.Tensor,
b_write: tf.Tensor,
erase: tf.Tensor,
write_v: tf.Tensor
):
""" Write to the memory matrix. """
# memory retention vector represents by how much each location will not be freed by the free gates
retention = tf.reduce_prod(1 - free_gates * self.W_read, axis=1)
retention = tf.reshape(retention, [self.N, 1]) # [N,1]
# update usage vector which is used to dynamically allocate memory
self.usage = (self.usage + self.W_write - self.usage * self.W_write) * retention
# compute allocation weights using dynamic memory allocation
W_alloc = self.allocation_weighting() # [N,1]
# apply content lookup for the write vector to figure out where to write to
W_lookup = self.content_lookup(k_write, b_write)
W_lookup = tf.reshape(W_lookup, [self.N, 1]) # [N,1]
# define our write weights now that we know how much space to allocate for them and where to write to
self.W_write = write_gate * (alloc_gate * W_alloc + (1 - alloc_gate) * W_lookup)
# update memory matrix: erase memory and write using the write weights and vector
self.M = (self.M * (1 - tf.matmul(self.W_write, erase)) + tf.matmul(self.W_write, write_v))
def read(self,
k_read: tf.Tensor,
b_read: tf.Tensor,
read_modes: tf.Tensor
):
""" Read from the memory matrix. """
# update memory link matrix used later for the forward and backward read modes
W_write_cast = tf.matmul(self.W_write, tf.ones([1, self.N])) # [N,N]
self.L = ((1 - W_write_cast - tf.transpose(W_write_cast)) * self.L +
tf.matmul(self.W_write, self.W_precedence, transpose_b=True)) # [N,N]
self.L *= (tf.ones([self.N, self.N]) - tf.constant(np.identity(self.N, dtype=np.float32)))
# update precedence vector which determines degree to which a memory row was written to at t-1
self.W_precedence = ((1 - tf.reduce_sum(self.W_write, axis=0)) * self.W_precedence + self.W_write)
# apply content lookup for the read vector(s) to figure out where to read from
W_lookup = self.content_lookup(k_read, b_read)
W_lookup = tf.reshape(W_lookup, [self.N, self.R]) # [N,R]
# compute forward and backward read weights using the link matrix
# forward weights recall information written in sequence and backward weights in reverse
W_fwd = tf.matmul(self.L, self.W_read) # [N,N]*[N,R] -> [N,R]
W_bwd = tf.matmul(self.L, self.W_read, transpose_a=True) # [N,R]
# 3 modes: forward, backward and content lookup
fwd_mode = read_modes[2] * W_fwd
lookup_mode = read_modes[1] * W_lookup
bwd_mode = read_modes[0] * W_bwd
# read weights = backward + content lookup + forward mode weights
self.W_read = bwd_mode + lookup_mode + fwd_mode # [N,R]
# create read vectors by applying read weights to memory matrix
self.read_v = tf.transpose(tf.matmul(self.M, self.W_read, transpose_a=True)) # ([W,N]*[N,R])^T -> [R,W]
def step(self, x: tf.Tensor) -> tf.Tensor:
"""
Update the controller, compute the output and interface vectors,
write to and read from memory and compute the output.
"""
# update controller
self.controller(x)
# compute output and interface vectors
self.output_v = tf.matmul(self.h, self.W_output) # [1,Y+I] * [Y+I,Y] -> [1,Y]
self.interface = tf.matmul(self.h, self.W_interface) # [1,Y+I] * [Y+I,I] -> [1,I]
# partition the interface vector
(k_read, b_read, k_write, b_write, erase, write_v,
free_gates, alloc_gate, write_gate, read_modes) = self.partition_interface()
# write to memory
self.write(free_gates, alloc_gate, write_gate, k_write, b_write, erase, write_v)
# read from memory
self.read(k_read, b_read, read_modes)
# flatten read vectors and multiply them with W matrix before adding to controller output
read_v_out = tf.matmul(tf.reshape(self.read_v, [1, self.R * self.W]),
self.W_read_out) # [1,RW]*[RW,Y] -> [1,Y]
# compute output
y = self.output_v + read_v_out
return y
def call(self, x: Union[np.ndarray, tf.Tensor]) -> tf.Tensor:
""" Unstack the input, run through the DNC and return the stacked output. """
y = []
for x_seq in tf.unstack(x, axis=0):
x_seq = tf.expand_dims(x_seq, axis=0)
y_seq = self.step(x_seq)
y.append(y_seq)
return tf.squeeze(tf.stack(y, axis=0))