From d95b1c33e730f14c962fd3d45963dd997f183e44 Mon Sep 17 00:00:00 2001 From: Phillip Chlap Date: Mon, 4 Dec 2023 12:00:57 +1100 Subject: [PATCH] Support context map for augmentation --- platipy/imaging/generation/augment.py | 135 ++++++++++++++++++-------- 1 file changed, 94 insertions(+), 41 deletions(-) diff --git a/platipy/imaging/generation/augment.py b/platipy/imaging/generation/augment.py index 46332692..bb4b4484 100644 --- a/platipy/imaging/generation/augment.py +++ b/platipy/imaging/generation/augment.py @@ -44,8 +44,8 @@ logger = logging.getLogger(__name__) -def apply_augmentation(image, augmentation, masks=[]): +def apply_augmentation(image, augmentation, context_map=None, masks=[]): if not isinstance(image, sitk.Image): raise AttributeError("image should be a SimpleITK.Image") @@ -62,7 +62,6 @@ def apply_augmentation(image, augmentation, masks=[]): transform = None dvf = None for aug in augmentation: - if not isinstance(aug, DeformableAugment): raise AttributeError("Each augmentation must be of type DeformableAugment") @@ -90,21 +89,32 @@ def apply_augmentation(image, augmentation, masks=[]): masks_deformed = [] for mask in masks: def_mask = apply_transform( - mask, transform=transform, default_value=0, interpolator=sitk.sitkNearestNeighbor + mask, + transform=transform, + default_value=0, + interpolator=sitk.sitkNearestNeighbor, ) def_mask = sitk.BinaryMorphologicalClosing(def_mask, [3, 3, 3]) masks_deformed.append(def_mask) + cmap_deformed = None + if context_map is not None: + cmap_deformed = apply_transform( + context_map, + transform=transform, + default_value=0, + interpolator=sitk.sitkNearestNeighbor, + ) + if masks: - return image_deformed, masks_deformed, dvf + return image_deformed, cmap_deformed, masks_deformed, dvf - return image_deformed, dvf + return image_deformed, cmap_deformed, dvf def generate_random_augmentation(ct_image, masks, augmentation_types): - augmentation = [] probabilities = [a["probability"] for a in augmentation_types] @@ -114,7 +124,9 @@ def generate_random_augmentation(ct_image, masks, augmentation_types): prob_none = 0 for mask in masks: - aug = random.choices(augmentation_types + [None], weights=probabilities + [prob_none])[0] + aug = random.choices( + augmentation_types + [None], weights=probabilities + [prob_none] + )[0] if aug is None: continue @@ -122,10 +134,8 @@ def generate_random_augmentation(ct_image, masks, augmentation_types): aug_class = aug["class"] aug_args = {} for arg in aug["args"]: - value = aug["args"][arg] if isinstance(value, list): - # Randomly sample for each dim result = [] for rng in value: @@ -146,20 +156,17 @@ def generate_random_augmentation(ct_image, masks, augmentation_types): class DeformableAugment(ABC): @abstractmethod def augment(self): - # return deformation pass class ShiftAugment(DeformableAugment): def __init__(self, mask, vector_shift=(10, 10, 10), gaussian_smooth=5): - self.mask = mask self.vector_shift = vector_shift self.gaussian_smooth = gaussian_smooth def augment(self): - _, transform, dvf = generate_field_shift( self.mask, self.vector_shift, @@ -172,15 +179,15 @@ def __str__(self): class ExpandAugment(DeformableAugment): - def __init__(self, mask, vector_expand=(10, 10, 10), gaussian_smooth=5, bone_mask=False): - + def __init__( + self, mask, vector_expand=(10, 10, 10), gaussian_smooth=5, bone_mask=False + ): self.mask = mask self.vector_expand = vector_expand self.gaussian_smooth = gaussian_smooth self.bone_mask = bone_mask def augment(self): - _, transform, dvf = generate_field_expand( self.mask, bone_mask=self.bone_mask, @@ -191,19 +198,23 @@ def augment(self): return transform, dvf def __str__(self): - return f"Expand with vector: {self.vector_expand}, smooth: {self.gaussian_smooth}" + return ( + f"Expand with vector: {self.vector_expand}, smooth: {self.gaussian_smooth}" + ) class ContractAugment(DeformableAugment): - def __init__(self, mask, vector_contract=(10, 10, 10), gaussian_smooth=5, bone_mask=False): - + def __init__( + self, mask, vector_contract=(10, 10, 10), gaussian_smooth=5, bone_mask=False + ): self.mask = mask - self.vector_contract = [int(-x / s) for x, s in zip(vector_contract, mask.GetSpacing())] + self.vector_contract = [ + int(-x / s) for x, s in zip(vector_contract, mask.GetSpacing()) + ] self.gaussian_smooth = gaussian_smooth self.bone_mask = bone_mask def augment(self): - _, transform, dvf = generate_field_expand( self.mask, bone_mask=self.bone_mask, @@ -217,7 +228,6 @@ def __str__(self): def augment_data(args): - random.seed(args.seed) augmentation_types = [] @@ -285,23 +295,32 @@ def augment_data(args): data = { case: { "image": data_dir.joinpath(args.image_glob.format(case=case)), - "label": [i for sl in [list(data_dir.glob(lg.format(case=case))) for lg in args.label_glob] for i in sl], + "context_map": data_dir.joinpath(args.context_map_glob.format(case=case)), + "label": [ + i + for sl in [ + list(data_dir.glob(lg.format(case=case))) for lg in args.label_glob + ] + for i in sl + ], } for case in cases } for case in cases: - logger.info(f"Augmenting for case: {case}") ct_image_original = sitk.ReadImage(str(data[case]["image"])) + cmap_original = None + if data[case]["context_map"]: + cmap_original = sitk.ReadImage(str(data[case]["context_map"])) + # Get list of structures to generate augmentations off logger.debug("Collecting structures") all_masks = [] all_names = [] for structure_path in data[case]["label"]: - mask = sitk.ReadImage(str(structure_path)) all_masks.append(mask) @@ -316,24 +335,25 @@ def augment_data(args): all_masks[m] = crop_to_roi(mask, size, index) if args.enable_fill_holes: - logger.debug("Finding holes") label_image, labels = detect_holes(ct_image) # Generate x random augmentations per case for i in range(args.augmentations_per_case): - logger.debug(f"Generating augmentation {i}") ct_image = sitk.ReadImage(str(data[case]["image"])) ct_image = crop_to_roi(ct_image, size, index) - if args.enable_fill_holes: + cmap = None + if data[case]["context_map"]: + cmap = sitk.ReadImage(str(data[case]["context_map"])) + cmap = crop_to_roi(cmap, size, index) + if args.enable_fill_holes: logger.debug("Filling holes") for label in labels[1:]: # Skip first hole since likely air around body - if random.random() > args.fill_probability: continue @@ -357,7 +377,9 @@ def augment_data(args): augmented_case_path.mkdir(exist_ok=True, parents=True) logger.debug("Generating random augmentations") - augmentation = generate_random_augmentation(ct_image, all_masks, augmentation_types) + augmentation = generate_random_augmentation( + ct_image, all_masks, augmentation_types + ) dvf = None @@ -369,12 +391,17 @@ def augment_data(args): augmented_image = ct_image augmented_masks = all_masks else: - logger.debug("Applying augmentation") - augmented_image, augmented_masks, dvf = apply_augmentation( - ct_image, augmentation, masks=all_masks + ( + augmented_image, + augmented_cmap, + augmented_masks, + dvf, + ) = apply_augmentation( + ct_image, augmentation, context_map=cmap, masks=all_masks ) + # Save off image augmented_image_path = augmented_case_path.joinpath("CT.nii.gz") ct_image_original[ index[0] : index[0] + size[0], @@ -383,17 +410,35 @@ def augment_data(args): ] = augmented_image sitk.WriteImage(ct_image_original, str(augmented_image_path)) + # Save off context map if we have one + augmented_cmap_path = augmented_case_path.joinpath("context_map.nii.gz") + cmap_original[ + index[0] : index[0] + size[0], + index[1] : index[1] + size[1], + index[2] : index[2] + size[2], + ] = augmented_cmap + sitk.WriteImage(cmap_original, str(augmented_cmap_path)) + vis = ImageVisualiser(image=ct_image, figure_size_in=6) vis.add_comparison_overlay(augmented_image) if dvf is not None: vis.add_vector_overlay(dvf, arrow_scale=1, subsample=(4, 12, 12)) - for mask_name, mask, augmented_mask in zip(all_names, all_masks, augmented_masks): - vis.add_contour({f"{mask_name}": mask, f"{mask_name} (augmented)": augmented_mask}) + for mask_name, mask, augmented_mask in zip( + all_names, all_masks, augmented_masks + ): + vis.add_contour( + {f"{mask_name}": mask, f"{mask_name} (augmented)": augmented_mask} + ) logger.debug(f"Applying augmentation to mask: {mask_name}") - augmented_mask_path = augmented_case_path.joinpath(f"{mask_name}.nii.gz") + augmented_mask_path = augmented_case_path.joinpath( + f"{mask_name}.nii.gz" + ) augmented_mask = sitk.Resample( - augmented_mask, ct_image_original, sitk.Transform(), sitk.sitkNearestNeighbor + augmented_mask, + ct_image_original, + sitk.Transform(), + sitk.sitkNearestNeighbor, ) sitk.WriteImage(augmented_mask, str(augmented_mask_path)) @@ -405,14 +450,18 @@ def augment_data(args): if __name__ == "__main__": - arg_parser = ArgumentParser() - arg_parser.add_argument("--seed", type=int, default=42, help="an integer to use as seed") + arg_parser.add_argument( + "--seed", type=int, default=42, help="an integer to use as seed" + ) arg_parser.add_argument("--data_dir", type=str, default="./data") arg_parser.add_argument("--output_dir", type=str, default="./augment") arg_parser.add_argument("--case_glob", type=str, default="images/*.nii.gz") arg_parser.add_argument("--image_glob", type=str, default="images/{case}.nii.gz") - arg_parser.add_argument("--label_glob", nargs="+", type=str, default="labels/{case}_*.nii.gz") + arg_parser.add_argument( + "--label_glob", nargs="+", type=str, default="labels/{case}_*.nii.gz" + ) + arg_parser.add_argument("--context_map_glob", type=str, default=None) arg_parser.add_argument( "--augmentations_per_case", type=int, @@ -431,7 +480,9 @@ def augment_data(args): arg_parser.add_argument("--expand_x_range", nargs="+", type=int, default=[0, 10]) arg_parser.add_argument("--expand_y_range", nargs="+", type=int, default=[0, 10]) arg_parser.add_argument("--expand_z_range", nargs="+", type=int, default=[0, 10]) - arg_parser.add_argument("--expand_smooth_range", nargs="+", type=int, default=[3, 5]) + arg_parser.add_argument( + "--expand_smooth_range", nargs="+", type=int, default=[3, 5] + ) arg_parser.add_argument("--expand_bone_mask", type=bool, default=True) arg_parser.add_argument("--expand_probability", type=float, default=0.5) @@ -439,7 +490,9 @@ def augment_data(args): arg_parser.add_argument("--contract_x_range", nargs="+", type=int, default=[0, 10]) arg_parser.add_argument("--contract_y_range", nargs="+", type=int, default=[0, 10]) arg_parser.add_argument("--contract_z_range", nargs="+", type=int, default=[0, 10]) - arg_parser.add_argument("--contract_smooth_range", nargs="+", type=int, default=[3, 5]) + arg_parser.add_argument( + "--contract_smooth_range", nargs="+", type=int, default=[3, 5] + ) arg_parser.add_argument("--contract_bone_mask", type=bool, default=True) arg_parser.add_argument("--contract_probability", type=float, default=0.5)