Skip to content

Commit

Permalink
Support context map for augmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
pchlap committed Dec 4, 2023
1 parent 5bd7b4d commit d95b1c3
Showing 1 changed file with 94 additions and 41 deletions.
135 changes: 94 additions & 41 deletions platipy/imaging/generation/augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@

logger = logging.getLogger(__name__)

def apply_augmentation(image, augmentation, masks=[]):

def apply_augmentation(image, augmentation, context_map=None, masks=[]):
if not isinstance(image, sitk.Image):
raise AttributeError("image should be a SimpleITK.Image")

Expand All @@ -62,7 +62,6 @@ def apply_augmentation(image, augmentation, masks=[]):
transform = None
dvf = None
for aug in augmentation:

if not isinstance(aug, DeformableAugment):
raise AttributeError("Each augmentation must be of type DeformableAugment")

Expand Down Expand Up @@ -90,21 +89,32 @@ def apply_augmentation(image, augmentation, masks=[]):
masks_deformed = []
for mask in masks:
def_mask = apply_transform(
mask, transform=transform, default_value=0, interpolator=sitk.sitkNearestNeighbor
mask,
transform=transform,
default_value=0,
interpolator=sitk.sitkNearestNeighbor,
)

def_mask = sitk.BinaryMorphologicalClosing(def_mask, [3, 3, 3])

masks_deformed.append(def_mask)

cmap_deformed = None
if context_map is not None:
cmap_deformed = apply_transform(
context_map,
transform=transform,
default_value=0,
interpolator=sitk.sitkNearestNeighbor,
)

if masks:
return image_deformed, masks_deformed, dvf
return image_deformed, cmap_deformed, masks_deformed, dvf

return image_deformed, dvf
return image_deformed, cmap_deformed, dvf


def generate_random_augmentation(ct_image, masks, augmentation_types):

augmentation = []

probabilities = [a["probability"] for a in augmentation_types]
Expand All @@ -114,18 +124,18 @@ def generate_random_augmentation(ct_image, masks, augmentation_types):
prob_none = 0

for mask in masks:
aug = random.choices(augmentation_types + [None], weights=probabilities + [prob_none])[0]
aug = random.choices(
augmentation_types + [None], weights=probabilities + [prob_none]
)[0]

if aug is None:
continue

aug_class = aug["class"]
aug_args = {}
for arg in aug["args"]:

value = aug["args"][arg]
if isinstance(value, list):

# Randomly sample for each dim
result = []
for rng in value:
Expand All @@ -146,20 +156,17 @@ def generate_random_augmentation(ct_image, masks, augmentation_types):
class DeformableAugment(ABC):
@abstractmethod
def augment(self):

# return deformation
pass


class ShiftAugment(DeformableAugment):
def __init__(self, mask, vector_shift=(10, 10, 10), gaussian_smooth=5):

self.mask = mask
self.vector_shift = vector_shift
self.gaussian_smooth = gaussian_smooth

def augment(self):

_, transform, dvf = generate_field_shift(
self.mask,
self.vector_shift,
Expand All @@ -172,15 +179,15 @@ def __str__(self):


class ExpandAugment(DeformableAugment):
def __init__(self, mask, vector_expand=(10, 10, 10), gaussian_smooth=5, bone_mask=False):

def __init__(
self, mask, vector_expand=(10, 10, 10), gaussian_smooth=5, bone_mask=False
):
self.mask = mask
self.vector_expand = vector_expand
self.gaussian_smooth = gaussian_smooth
self.bone_mask = bone_mask

def augment(self):

_, transform, dvf = generate_field_expand(
self.mask,
bone_mask=self.bone_mask,
Expand All @@ -191,19 +198,23 @@ def augment(self):
return transform, dvf

def __str__(self):
return f"Expand with vector: {self.vector_expand}, smooth: {self.gaussian_smooth}"
return (
f"Expand with vector: {self.vector_expand}, smooth: {self.gaussian_smooth}"
)


class ContractAugment(DeformableAugment):
def __init__(self, mask, vector_contract=(10, 10, 10), gaussian_smooth=5, bone_mask=False):

def __init__(
self, mask, vector_contract=(10, 10, 10), gaussian_smooth=5, bone_mask=False
):
self.mask = mask
self.vector_contract = [int(-x / s) for x, s in zip(vector_contract, mask.GetSpacing())]
self.vector_contract = [
int(-x / s) for x, s in zip(vector_contract, mask.GetSpacing())
]
self.gaussian_smooth = gaussian_smooth
self.bone_mask = bone_mask

def augment(self):

_, transform, dvf = generate_field_expand(
self.mask,
bone_mask=self.bone_mask,
Expand All @@ -217,7 +228,6 @@ def __str__(self):


def augment_data(args):

random.seed(args.seed)

augmentation_types = []
Expand Down Expand Up @@ -285,23 +295,32 @@ def augment_data(args):
data = {
case: {
"image": data_dir.joinpath(args.image_glob.format(case=case)),
"label": [i for sl in [list(data_dir.glob(lg.format(case=case))) for lg in args.label_glob] for i in sl],
"context_map": data_dir.joinpath(args.context_map_glob.format(case=case)),
"label": [
i
for sl in [
list(data_dir.glob(lg.format(case=case))) for lg in args.label_glob
]
for i in sl
],
}
for case in cases
}

for case in cases:

logger.info(f"Augmenting for case: {case}")

ct_image_original = sitk.ReadImage(str(data[case]["image"]))

cmap_original = None
if data[case]["context_map"]:
cmap_original = sitk.ReadImage(str(data[case]["context_map"]))

# Get list of structures to generate augmentations off
logger.debug("Collecting structures")
all_masks = []
all_names = []
for structure_path in data[case]["label"]:

mask = sitk.ReadImage(str(structure_path))

all_masks.append(mask)
Expand All @@ -316,24 +335,25 @@ def augment_data(args):
all_masks[m] = crop_to_roi(mask, size, index)

if args.enable_fill_holes:

logger.debug("Finding holes")
label_image, labels = detect_holes(ct_image)

# Generate x random augmentations per case
for i in range(args.augmentations_per_case):

logger.debug(f"Generating augmentation {i}")

ct_image = sitk.ReadImage(str(data[case]["image"]))
ct_image = crop_to_roi(ct_image, size, index)

if args.enable_fill_holes:
cmap = None
if data[case]["context_map"]:
cmap = sitk.ReadImage(str(data[case]["context_map"]))
cmap = crop_to_roi(cmap, size, index)

if args.enable_fill_holes:
logger.debug("Filling holes")

for label in labels[1:]: # Skip first hole since likely air around body

if random.random() > args.fill_probability:
continue

Expand All @@ -357,7 +377,9 @@ def augment_data(args):
augmented_case_path.mkdir(exist_ok=True, parents=True)

logger.debug("Generating random augmentations")
augmentation = generate_random_augmentation(ct_image, all_masks, augmentation_types)
augmentation = generate_random_augmentation(
ct_image, all_masks, augmentation_types
)

dvf = None

Expand All @@ -369,12 +391,17 @@ def augment_data(args):
augmented_image = ct_image
augmented_masks = all_masks
else:

logger.debug("Applying augmentation")
augmented_image, augmented_masks, dvf = apply_augmentation(
ct_image, augmentation, masks=all_masks
(
augmented_image,
augmented_cmap,
augmented_masks,
dvf,
) = apply_augmentation(
ct_image, augmentation, context_map=cmap, masks=all_masks
)

# Save off image
augmented_image_path = augmented_case_path.joinpath("CT.nii.gz")
ct_image_original[
index[0] : index[0] + size[0],
Expand All @@ -383,17 +410,35 @@ def augment_data(args):
] = augmented_image
sitk.WriteImage(ct_image_original, str(augmented_image_path))

# Save off context map if we have one
augmented_cmap_path = augmented_case_path.joinpath("context_map.nii.gz")
cmap_original[
index[0] : index[0] + size[0],
index[1] : index[1] + size[1],
index[2] : index[2] + size[2],
] = augmented_cmap
sitk.WriteImage(cmap_original, str(augmented_cmap_path))

vis = ImageVisualiser(image=ct_image, figure_size_in=6)
vis.add_comparison_overlay(augmented_image)
if dvf is not None:
vis.add_vector_overlay(dvf, arrow_scale=1, subsample=(4, 12, 12))
for mask_name, mask, augmented_mask in zip(all_names, all_masks, augmented_masks):
vis.add_contour({f"{mask_name}": mask, f"{mask_name} (augmented)": augmented_mask})
for mask_name, mask, augmented_mask in zip(
all_names, all_masks, augmented_masks
):
vis.add_contour(
{f"{mask_name}": mask, f"{mask_name} (augmented)": augmented_mask}
)

logger.debug(f"Applying augmentation to mask: {mask_name}")
augmented_mask_path = augmented_case_path.joinpath(f"{mask_name}.nii.gz")
augmented_mask_path = augmented_case_path.joinpath(
f"{mask_name}.nii.gz"
)
augmented_mask = sitk.Resample(
augmented_mask, ct_image_original, sitk.Transform(), sitk.sitkNearestNeighbor
augmented_mask,
ct_image_original,
sitk.Transform(),
sitk.sitkNearestNeighbor,
)
sitk.WriteImage(augmented_mask, str(augmented_mask_path))

Expand All @@ -405,14 +450,18 @@ def augment_data(args):


if __name__ == "__main__":

arg_parser = ArgumentParser()
arg_parser.add_argument("--seed", type=int, default=42, help="an integer to use as seed")
arg_parser.add_argument(
"--seed", type=int, default=42, help="an integer to use as seed"
)
arg_parser.add_argument("--data_dir", type=str, default="./data")
arg_parser.add_argument("--output_dir", type=str, default="./augment")
arg_parser.add_argument("--case_glob", type=str, default="images/*.nii.gz")
arg_parser.add_argument("--image_glob", type=str, default="images/{case}.nii.gz")
arg_parser.add_argument("--label_glob", nargs="+", type=str, default="labels/{case}_*.nii.gz")
arg_parser.add_argument(
"--label_glob", nargs="+", type=str, default="labels/{case}_*.nii.gz"
)
arg_parser.add_argument("--context_map_glob", type=str, default=None)
arg_parser.add_argument(
"--augmentations_per_case",
type=int,
Expand All @@ -431,15 +480,19 @@ def augment_data(args):
arg_parser.add_argument("--expand_x_range", nargs="+", type=int, default=[0, 10])
arg_parser.add_argument("--expand_y_range", nargs="+", type=int, default=[0, 10])
arg_parser.add_argument("--expand_z_range", nargs="+", type=int, default=[0, 10])
arg_parser.add_argument("--expand_smooth_range", nargs="+", type=int, default=[3, 5])
arg_parser.add_argument(
"--expand_smooth_range", nargs="+", type=int, default=[3, 5]
)
arg_parser.add_argument("--expand_bone_mask", type=bool, default=True)
arg_parser.add_argument("--expand_probability", type=float, default=0.5)

arg_parser.add_argument("--enable_contract", type=bool, default=True)
arg_parser.add_argument("--contract_x_range", nargs="+", type=int, default=[0, 10])
arg_parser.add_argument("--contract_y_range", nargs="+", type=int, default=[0, 10])
arg_parser.add_argument("--contract_z_range", nargs="+", type=int, default=[0, 10])
arg_parser.add_argument("--contract_smooth_range", nargs="+", type=int, default=[3, 5])
arg_parser.add_argument(
"--contract_smooth_range", nargs="+", type=int, default=[3, 5]
)
arg_parser.add_argument("--contract_bone_mask", type=bool, default=True)
arg_parser.add_argument("--contract_probability", type=float, default=0.5)

Expand Down

0 comments on commit d95b1c3

Please sign in to comment.