Skip to content

Commit

Permalink
Tweak probunet input
Browse files Browse the repository at this point in the history
  • Loading branch information
pchlap committed Apr 15, 2024
1 parent 06b7ec6 commit 8e0491c
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
2 changes: 1 addition & 1 deletion platipy/imaging/cnn/prob_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def __init__(
self.no_convs_fcomb = no_convs_fcomb
self.initializers = {"w": "he_normal", "b": "normal"}
self.z_prior_sample = 0

unet_input_channels = input_channels
if use_structure_context:
unet_input_channels += 1
Expand Down
15 changes: 11 additions & 4 deletions platipy/imaging/cnn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,18 +579,25 @@ def validate(
def training_step(self, batch, _):
x, c, y, m, _ = batch

# Concat input mask if we are using the structure as context
if self.use_structure_context:
x = torch.cat((x, y), dim=1)

print(f"y.shape1 {y.shape}")
# Add background layer for one-hot encoding
not_y = 1 - y.max(axis=1).values
not_y = torch.unsqueeze(not_y, dim=1)
y = torch.cat((not_y, y), dim=1).float()
print(f"y.shape2 {y.shape}")

print(f"x.shape0 {x.shape}")

print(f"c.shape {c.shape}")

# Concat context map to image if we have one
if c.numel() > 0:
x = torch.cat((x, c), dim=1)

# Concat input mask if we are using the structure as context
if self.use_structure_context:
x = torch.cat((x, y), dim=1)
print(f"x.shape 1 {x.shape}")

# self.prob_unet.forward(x, y, training=True)
if self.hparams.prob_type == "prob":
Expand Down

0 comments on commit 8e0491c

Please sign in to comment.