diff --git a/platipy/imaging/cnn/prob_unet.py b/platipy/imaging/cnn/prob_unet.py index 94107c9..31db04f 100644 --- a/platipy/imaging/cnn/prob_unet.py +++ b/platipy/imaging/cnn/prob_unet.py @@ -222,7 +222,7 @@ def __init__( self.no_convs_fcomb = no_convs_fcomb self.initializers = {"w": "he_normal", "b": "normal"} self.z_prior_sample = 0 - self.dist = Independent(Normal(loc=torch.zeros(latent_dim), scale=torch.ones(latent_dim)), 1) + self.latent_dim = latent_dim self.use_structure_context = use_structure_context unet_input_channels = input_channels @@ -339,8 +339,10 @@ def kl_divergence(self): """ if self.prior_latent_space is None: - self.dist.to(self.posterior_latent_space.base_dist.stddev.device) - kl_div = kl.kl_divergence(self.posterior_latent_space, self.dist) + + device = self.posterior_latent_space.base_dist.stddev.device + dist = Independent(Normal(loc=torch.zeros(self.latent_dim).to(device), scale=torch.ones(self.latent_dim)).to(device), 1) + kl_div = kl.kl_divergence(self.posterior_latent_space, dist) else: kl_div = kl.kl_divergence(self.posterior_latent_space, self.prior_latent_space)