diff --git a/platipy/imaging/cnn/dataset.py b/platipy/imaging/cnn/dataset.py index ad5c118..ccfd211 100644 --- a/platipy/imaging/cnn/dataset.py +++ b/platipy/imaging/cnn/dataset.py @@ -235,7 +235,7 @@ def apply(self, img, context_map, masks=[]): def crop_img_using_localise_model( - img, localise_model, spacing=[1, 1, 1], crop_to_grid_size=[100, 100, 100] + img, localise_model, spacing=[1, 1, 1], crop_to_grid_size=[100, 100, 100], context_seg=None ): """Crops an image using a LocaliseUNet @@ -246,6 +246,8 @@ def crop_img_using_localise_model( spacing (list, optional): The image spacing (mm) to resample to. Defaults to [1,1,1]. crop_to_grid_size (list, optional): The size of the grid to crop to. Defaults to [100,100,100]. + context_seg (sitk.Image, optional): Use this segmentation instead of localise model if + provided. Defaults to None. Returns: SimpleITK.Image: The cropped image. @@ -254,15 +256,18 @@ def crop_img_using_localise_model( if isinstance(localise_model, str): localise_model = Path(localise_model) - if isinstance(localise_model, Path): - if localise_model.is_dir(): - # Find the first actual model checkpoint in this directory - localise_model = next(localise_model.glob("*.ckpt")) + if context_seg is not None: + localise_pred = context_seg + else: + if isinstance(localise_model, Path): + if localise_model.is_dir(): + # Find the first actual model checkpoint in this directory + localise_model = next(localise_model.glob("*.ckpt")) - localise_model = LocaliseUNet.load_from_checkpoint(localise_model) + localise_model = LocaliseUNet.load_from_checkpoint(localise_model) - localise_model.eval() - localise_pred = localise_model.infer(img) + localise_model.eval() + localise_pred = localise_model.infer(img) img = preprocess_image(img, spacing=spacing, crop_to_grid_size_xy=None) localise_pred = resample_mask_to_image(img, localise_pred) diff --git a/platipy/imaging/cnn/train.py b/platipy/imaging/cnn/train.py index b30b793..7d86529 100644 --- a/platipy/imaging/cnn/train.py +++ b/platipy/imaging/cnn/train.py @@ -283,6 +283,7 @@ def infer( latent_dim=True, spaced_range=[-1.5, 1.5], preprocess=True, + return_latent_space=False ): # sample strategy in "mean", "random", "spaced" @@ -340,6 +341,7 @@ def infer( localise_path, spacing=self.hparams.spacing, crop_to_grid_size=self.hparams.localise_voxel_grid_size, + context_seg=seg ) else: img = preprocess_image( @@ -406,6 +408,9 @@ def infer( else: self.prob_unet.forward(x) + if return_latent_space: + return self.prob_unet.prior_latent_space + for sample in samples: if self.hparams.prob_type == "prob": if sample["name"] == "mean":