diff --git a/notebooks/mmd_grud_utils.py b/notebooks/mmd_grud_utils.py index 1109aa6..9b8f533 100644 --- a/notebooks/mmd_grud_utils.py +++ b/notebooks/mmd_grud_utils.py @@ -154,7 +154,8 @@ def step(self, x, x_last_obsv, x_mean, h, mask, delta): combined = torch.cat((x, h, mask), 1) z = torch.sigmoid(self.zl(combined)) #sigmoid(W_z*x_t + U_z*h_{t-1} + V_z*m_t + bz) r = torch.sigmoid(self.rl(combined)) #sigmoid(W_r*x_t + U_r*h_{t-1} + V_r*m_t + br) - h_tilde = torch.tanh(self.hl(combined)) #tanh(W*x_t +U(r_t*h_{t-1}) + V*m_t) + b + combined_new = torch.cat((x, r*h, mask), 1) + h_tilde = torch.tanh(self.hl(combined_new)) #tanh(W*x_t +U(r_t*h_{t-1}) + V*m_t) + b h = (1 - z) * h + z * h_tilde return h @@ -437,4 +438,4 @@ def predict_proba(model, dataloader): probabilities.append(prob.detach().cpu().data.numpy()) labels.append(label.detach().cpu().data.numpy()) - return probabilities, labels \ No newline at end of file + return probabilities, labels