Skip to content

Commit

Permalink
fix grud bug
Browse files Browse the repository at this point in the history
  • Loading branch information
bnestor authored Oct 1, 2021
1 parent 420bfc4 commit 0e83501
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions notebooks/mmd_grud_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ def step(self, x, x_last_obsv, x_mean, h, mask, delta):
combined = torch.cat((x, h, mask), 1)
z = torch.sigmoid(self.zl(combined)) #sigmoid(W_z*x_t + U_z*h_{t-1} + V_z*m_t + bz)
r = torch.sigmoid(self.rl(combined)) #sigmoid(W_r*x_t + U_r*h_{t-1} + V_r*m_t + br)
h_tilde = torch.tanh(self.hl(combined)) #tanh(W*x_t +U(r_t*h_{t-1}) + V*m_t) + b
combined_new = torch.cat((x, r*h, mask), 1)
h_tilde = torch.tanh(self.hl(combined_new)) #tanh(W*x_t +U(r_t*h_{t-1}) + V*m_t) + b
h = (1 - z) * h + z * h_tilde

return h
Expand Down Expand Up @@ -437,4 +438,4 @@ def predict_proba(model, dataloader):
probabilities.append(prob.detach().cpu().data.numpy())
labels.append(label.detach().cpu().data.numpy())

return probabilities, labels
return probabilities, labels

0 comments on commit 0e83501

Please sign in to comment.