Skip to content

Commit

Permalink
Context based inference
Browse files Browse the repository at this point in the history
  • Loading branch information
pchlap committed Nov 16, 2024
1 parent de6628e commit 2b9b5e0
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 8 deletions.
21 changes: 13 additions & 8 deletions platipy/imaging/cnn/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def apply(self, img, context_map, masks=[]):


def crop_img_using_localise_model(
img, localise_model, spacing=[1, 1, 1], crop_to_grid_size=[100, 100, 100]
img, localise_model, spacing=[1, 1, 1], crop_to_grid_size=[100, 100, 100], context_seg=None
):
"""Crops an image using a LocaliseUNet
Expand All @@ -246,6 +246,8 @@ def crop_img_using_localise_model(
spacing (list, optional): The image spacing (mm) to resample to. Defaults to [1,1,1].
crop_to_grid_size (list, optional): The size of the grid to crop to. Defaults to
[100,100,100].
context_seg (sitk.Image, optional): Use this segmentation instead of localise model if
provided. Defaults to None.
Returns:
SimpleITK.Image: The cropped image.
Expand All @@ -254,15 +256,18 @@ def crop_img_using_localise_model(
if isinstance(localise_model, str):
localise_model = Path(localise_model)

if isinstance(localise_model, Path):
if localise_model.is_dir():
# Find the first actual model checkpoint in this directory
localise_model = next(localise_model.glob("*.ckpt"))
if context_seg is not None:
localise_pred = context_seg
else:
if isinstance(localise_model, Path):
if localise_model.is_dir():
# Find the first actual model checkpoint in this directory
localise_model = next(localise_model.glob("*.ckpt"))

localise_model = LocaliseUNet.load_from_checkpoint(localise_model)
localise_model = LocaliseUNet.load_from_checkpoint(localise_model)

localise_model.eval()
localise_pred = localise_model.infer(img)
localise_model.eval()
localise_pred = localise_model.infer(img)

img = preprocess_image(img, spacing=spacing, crop_to_grid_size_xy=None)
localise_pred = resample_mask_to_image(img, localise_pred)
Expand Down
5 changes: 5 additions & 0 deletions platipy/imaging/cnn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ def infer(
latent_dim=True,
spaced_range=[-1.5, 1.5],
preprocess=True,
return_latent_space=False
):
# sample strategy in "mean", "random", "spaced"

Expand Down Expand Up @@ -340,6 +341,7 @@ def infer(
localise_path,
spacing=self.hparams.spacing,
crop_to_grid_size=self.hparams.localise_voxel_grid_size,
context_seg=seg
)
else:
img = preprocess_image(
Expand Down Expand Up @@ -406,6 +408,9 @@ def infer(
else:
self.prob_unet.forward(x)

if return_latent_space:
return self.prob_unet.prior_latent_space

for sample in samples:
if self.hparams.prob_type == "prob":
if sample["name"] == "mean":
Expand Down

0 comments on commit 2b9b5e0

Please sign in to comment.