Skip to content
This repository has been archived by the owner on Jun 2, 2023. It is now read-only.

Commit

Permalink
[#31] just updating cell state, not hidden state
Browse files Browse the repository at this point in the history
  • Loading branch information
jsadler2 committed Aug 17, 2020
1 parent c1ba47a commit 7150217
Showing 1 changed file with 8 additions and 31 deletions.
39 changes: 8 additions & 31 deletions river_dl/RGCN.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,6 @@ def __init__(self, hidden_size, A, flow_in_temp=False, rand_seed=None):
w_initializer = tf.random_normal_initializer(stddev=0.02,
seed=rand_seed)

# was Wg1
self.W_graph_h = self.add_weight(shape=[hidden_size, hidden_size],
initializer=w_initializer,
name='W_graph_h')
# was bg1
self.b_graph_h = self.add_weight(shape=[hidden_size],
initializer='zeros', name='b_graph_h')
# was Wg2
self.W_graph_c = self.add_weight(shape=[hidden_size, hidden_size],
initializer=w_initializer,
Expand All @@ -50,17 +43,6 @@ def __init__(self, hidden_size, A, flow_in_temp=False, rand_seed=None):
self.b_graph_c = self.add_weight(shape=[hidden_size],
initializer='zeros', name='b_graph_c')

# was Wa1
self.W_h_cur = self.add_weight(shape=[hidden_size, hidden_size],
initializer=w_initializer,
name='W_h_cur')
# was Wa2
self.W_h_prev = self.add_weight(shape=[hidden_size, hidden_size],
initializer=w_initializer,
name='W_h_prev')
# was ba
self.b_h = self.add_weight(shape=[hidden_size], initializer='zeros',
name='b_h')

# was Wc1
self.W_c_cur = self.add_weight(shape=[hidden_size, hidden_size],
Expand Down Expand Up @@ -101,42 +83,37 @@ def __init__(self, hidden_size, A, flow_in_temp=False, rand_seed=None):
@tf.function
def call(self, inputs, **kwargs):
graph_size = self.A.shape[0]
hidden_state_prev, cell_state_prev = (tf.zeros([graph_size,
hidden_state, cell_state_prev = (tf.zeros([graph_size,
self.hidden_size]),
tf.zeros([graph_size,
self.hidden_size]))
out = []
n_steps = inputs.shape[1]
for t in range(n_steps):
h_graph = tf.nn.tanh(tf.matmul(self.A, tf.matmul(hidden_state_prev,
self.W_graph_h)
+ self.b_graph_h))
c_graph = tf.nn.tanh(tf.matmul(self.A, tf.matmul(cell_state_prev,
self.W_graph_c)
+ self.b_graph_c))

seq, state = self.lstm(inputs[:, t, :], states=[hidden_state_prev,
seq, state = self.lstm(inputs[:, t, :], states=[hidden_state,
cell_state_prev])
hidden_state_cur, cell_state_cur = state
hidden_state, cell_state_cur = state

h_update = tf.nn.sigmoid(tf.matmul(hidden_state_cur, self.W_h_cur)
+ tf.matmul(h_graph, self.W_h_prev)
+ self.b_h)
c_update = tf.nn.sigmoid(tf.matmul(cell_state_cur, self.W_c_cur)
+ tf.matmul(c_graph, self.W_c_prev)
+ self.b_c)

if self.flow_in_temp:
out_pred_q = tf.matmul(h_update, self.W_out_flow) + self.b_out_flow
out_pred_t = tf.matmul(tf.concat([h_update, out_pred_q], axis=1),
out_pred_q = tf.matmul(hidden_state, self.W_out_flow) +\
self.b_out_flow
out_pred_t = tf.matmul(tf.concat([hidden_state, out_pred_q],
axis=1),
self.W_out_temp) + self.b_out_temp
out_pred = tf.concat([out_pred_t, out_pred_q], axis=1)
else:
out_pred = tf.matmul(h_update, self.W_out) + self.b_out
out_pred = tf.matmul(hidden_state, self.W_out) + self.b_out

out.append(out_pred)

hidden_state_prev = h_update
cell_state_prev = c_update
out = tf.stack(out)
out = tf.transpose(out, [1, 0, 2])
Expand Down

0 comments on commit 7150217

Please sign in to comment.