Skip to content

Commit

Permalink
Fix move to device
Browse files Browse the repository at this point in the history
  • Loading branch information
pchlap committed Apr 15, 2024
1 parent 0a76cb9 commit 926825b
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions platipy/imaging/cnn/prob_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def __init__(
self.no_convs_fcomb = no_convs_fcomb
self.initializers = {"w": "he_normal", "b": "normal"}
self.z_prior_sample = 0
self.dist = Independent(Normal(loc=torch.zeros(latent_dim), scale=torch.ones(latent_dim)), 1)
self.latent_dim = latent_dim
self.use_structure_context = use_structure_context

unet_input_channels = input_channels
Expand Down Expand Up @@ -339,8 +339,10 @@ def kl_divergence(self):
"""

if self.prior_latent_space is None:
self.dist.to(self.posterior_latent_space.base_dist.stddev.device)
kl_div = kl.kl_divergence(self.posterior_latent_space, self.dist)

device = self.posterior_latent_space.base_dist.stddev.device
dist = Independent(Normal(loc=torch.zeros(self.latent_dim).to(device), scale=torch.ones(self.latent_dim)).to(device), 1)
kl_div = kl.kl_divergence(self.posterior_latent_space, dist)
else:
kl_div = kl.kl_divergence(self.posterior_latent_space, self.prior_latent_space)

Expand Down

0 comments on commit 926825b

Please sign in to comment.