You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I find a code error in self.q_v_pred_one_timestep(log_vt, t, batch) of q_v_posterior function, which is for calculating the q(v_t| v_t-1) ??
source code in gitub : # atom type generative process def q_v_posterior(self, log_v0, log_vt, t, batch): # q(vt-1 | vt, v0) = q(vt | vt-1, x0) * q(vt-1 | x0) / q(vt | x0) t_minus_1 = t - 1 # Remove negative values, will not be used anyway for final decoder t_minus_1 = torch.where(t_minus_1 < 0, torch.zeros_like(t_minus_1), t_minus_1) log_qvt1_v0 = self.q_v_pred(log_v0, t_minus_1, batch) unnormed_logprobs = log_qvt1_v0 + self.q_v_pred_one_timestep(log_vt, t, batch) log_vt1_given_vt_v0 = unnormed_logprobs - torch.logsumexp(unnormed_logprobs, dim=-1, keepdim=True) return log_vt1_given_vt_v0
Is there error in "unnormed_logprobs = log_qvt1_v0 + self.q_v_pred_one_timestep(log_vt, t, batch)" ?
and, should be change to "unnormed_logprobs = log_qvt1_v0 + self.q_v_pred_one_timestep(log_qvt1_v0 , t, batch)" ?
Best,
The text was updated successfully, but these errors were encountered:
Hi,guanjq,
I find a code error in self.q_v_pred_one_timestep(log_vt, t, batch) of q_v_posterior function, which is for calculating the q(v_t| v_t-1) ??
source code in gitub :
# atom type generative process def q_v_posterior(self, log_v0, log_vt, t, batch): # q(vt-1 | vt, v0) = q(vt | vt-1, x0) * q(vt-1 | x0) / q(vt | x0) t_minus_1 = t - 1 # Remove negative values, will not be used anyway for final decoder t_minus_1 = torch.where(t_minus_1 < 0, torch.zeros_like(t_minus_1), t_minus_1) log_qvt1_v0 = self.q_v_pred(log_v0, t_minus_1, batch) unnormed_logprobs = log_qvt1_v0 + self.q_v_pred_one_timestep(log_vt, t, batch) log_vt1_given_vt_v0 = unnormed_logprobs - torch.logsumexp(unnormed_logprobs, dim=-1, keepdim=True) return log_vt1_given_vt_v0
Is there error in "unnormed_logprobs = log_qvt1_v0 + self.q_v_pred_one_timestep(log_vt, t, batch)" ?
and, should be change to "unnormed_logprobs = log_qvt1_v0 + self.q_v_pred_one_timestep(log_qvt1_v0 , t, batch)" ?
Best,
The text was updated successfully, but these errors were encountered: