Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
rnfinnegan committed Jun 2, 2021
1 parent efb80bb commit 9124951
Show file tree
Hide file tree
Showing 7 changed files with 216 additions and 36 deletions.
39 changes: 39 additions & 0 deletions platipy/imaging/generation/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,42 @@ def get_external_mask(
body_mask_hull.CopyInformation(body_mask)

return body_mask_hull


def extend_mask(mask, direction=("ax", "sup"), extension_mm=10, interior_mm_shape=10):

arr = sitk.GetArrayViewFromImage(mask)
vals = np.unique(arr[arr > 0])
if len(vals) > 2:
# There is more than one value! We need to threshold (at the median)
cutoff = np.median(vals)
mask_binary = sitk.BinaryThreshold(mask, cutoff, np.max(vals).astype(float))
else:
mask_binary = mask

arr = sitk.GetArrayFromImage(mask_binary)

if direction[0] == "ax":
inferior_slice = np.where(arr)[0].min()
superior_slice = np.where(arr)[0].max()

n_slices_ext = int(extension_mm / mask.GetSpacing()[2])
n_slices_est = int(interior_mm_shape / mask.GetSpacing()[2])

if direction[1] == "sup":
max_index = min([arr.shape[0], superior_slice + 1 + n_slices_ext])
for s_in in range(superior_slice + 1 - n_slices_est, max_index):
arr[s_in, :, :] = np.max(
arr[superior_slice - n_slices_est : superior_slice, :, :], axis=0
)
if direction[1] == "inf":
min_index = max([arr.shape[0], inferior_slice - n_slices_ext + n_slices_est])
for s_in in range(min_index, inferior_slice):
arr[s_in, :, :] = np.max(
arr[inferior_slice + n_slices_est : inferior_slice, :, :], axis=0
)

mask_ext = sitk.GetImageFromArray(arr)
mask_ext.CopyInformation(mask)

return mask_ext
42 changes: 37 additions & 5 deletions platipy/imaging/label/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
import numpy as np
import SimpleITK as sitk

from platipy.imaging.utils.crop import label_to_roi, crop_to_roi


def compute_surface_metrics(label_a, label_b, verbose=False):
"""Compute surface distance metrics between two labels. Surface metrics computed are:
Expand Down Expand Up @@ -141,7 +143,7 @@ def compute_volume_metrics(label_a, label_b):
return result


def compute_metric_dsc(label_a, label_b):
def compute_metric_dsc(label_a, label_b, auto_crop=True):
"""Compute the Dice Similarity Coefficient between two labels
Args:
Expand All @@ -151,13 +153,19 @@ def compute_metric_dsc(label_a, label_b):
Returns:
float: The Dice Similarity Coefficient
"""
if auto_crop:
largest_region = (label_a + label_b) > 0
crop_box_size, crop_box_index = label_to_roi(largest_region)

label_a = crop_to_roi(label_a, size=crop_box_size, index=crop_box_index)
label_b = crop_to_roi(label_b, size=crop_box_size, index=crop_box_index)

arr_a = sitk.GetArrayFromImage(label_a).astype(bool)
arr_b = sitk.GetArrayFromImage(label_b).astype(bool)
return 2 * ((arr_a & arr_b).sum()) / (arr_a.sum() + arr_b.sum())


def compute_metric_specificity(label_a, label_b):
def compute_metric_specificity(label_a, label_b, auto_crop=True):
"""Compute the specificity between two labels
Args:
Expand All @@ -167,6 +175,12 @@ def compute_metric_specificity(label_a, label_b):
Returns:
float: The specificity between the two labels
"""
if auto_crop:
largest_region = (label_a + label_b) > 0
crop_box_size, crop_box_index = label_to_roi(largest_region)

label_a = crop_to_roi(label_a, size=crop_box_size, index=crop_box_index)
label_b = crop_to_roi(label_b, size=crop_box_size, index=crop_box_index)

arr_a = sitk.GetArrayFromImage(label_a).astype(bool)
arr_b = sitk.GetArrayFromImage(label_b).astype(bool)
Expand All @@ -180,7 +194,7 @@ def compute_metric_specificity(label_a, label_b):
return float((1.0 * true_neg) / (true_neg + false_pos))


def compute_metric_sensitivity(label_a, label_b):
def compute_metric_sensitivity(label_a, label_b, auto_crop=True):
"""Compute the sensitivity between two labels
Args:
Expand All @@ -190,6 +204,12 @@ def compute_metric_sensitivity(label_a, label_b):
Returns:
float: The sensitivity between the two labels
"""
if auto_crop:
largest_region = (label_a + label_b) > 0
crop_box_size, crop_box_index = label_to_roi(largest_region)

label_a = crop_to_roi(label_a, size=crop_box_size, index=crop_box_index)
label_b = crop_to_roi(label_b, size=crop_box_size, index=crop_box_index)

arr_a = sitk.GetArrayFromImage(label_a).astype(bool)
arr_b = sitk.GetArrayFromImage(label_b).astype(bool)
Expand All @@ -202,7 +222,7 @@ def compute_metric_sensitivity(label_a, label_b):
return float((1.0 * true_pos) / (true_pos + false_neg))


def compute_metric_masd(label_a, label_b):
def compute_metric_masd(label_a, label_b, auto_crop=True):
"""Compute the mean absolute distance between two labels
Args:
Expand All @@ -212,6 +232,12 @@ def compute_metric_masd(label_a, label_b):
Returns:
float: The mean absolute surface distance
"""
if auto_crop:
largest_region = (label_a + label_b) > 0
crop_box_size, crop_box_index = label_to_roi(largest_region)

label_a = crop_to_roi(label_a, size=crop_box_size, index=crop_box_index)
label_b = crop_to_roi(label_b, size=crop_box_size, index=crop_box_index)

mean_sd_list = []
num_points = []
Expand All @@ -231,7 +257,7 @@ def compute_metric_masd(label_a, label_b):
return float(mean_surf_dist)


def compute_metric_hd(label_a, label_b):
def compute_metric_hd(label_a, label_b, auto_crop=True):
"""Compute the Hausdorff distance between two labels
Args:
Expand All @@ -241,6 +267,12 @@ def compute_metric_hd(label_a, label_b):
Returns:
float: The maximum Hausdorff distance
"""
if auto_crop:
largest_region = (label_a + label_b) > 0
crop_box_size, crop_box_index = label_to_roi(largest_region)

label_a = crop_to_roi(label_a, size=crop_box_size, index=crop_box_index)
label_b = crop_to_roi(label_b, size=crop_box_size, index=crop_box_index)

hausdorff_distance = sitk.HausdorffDistanceImageFilter()
hausdorff_distance.Execute(label_a, label_b)
Expand Down
60 changes: 52 additions & 8 deletions platipy/imaging/projects/cardiac/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@

from platipy.imaging.utils.crop import label_to_roi, crop_to_roi

from platipy.imaging.generation.mask import extend_mask

ATLAS_PATH = "/atlas"
if "ATLAS_PATH" in os.environ:
ATLAS_PATH = os.environ["ATLAS_PATH"]
Expand All @@ -54,6 +56,7 @@
"crop_atlas_to_structures": False,
"crop_atlas_expansion_mm": (10, 10, 10),
"guide_structure_name": "WHOLEHEART",
"superior_extension": 30,
},
"auto_crop_target_image_settings": {
"expansion_mm": [2, 2, 2],
Expand Down Expand Up @@ -231,7 +234,7 @@ def run_cardiac_segmentation(img, guide_structure=None, settings=CARDIAC_SETTING
img_crop = crop_to_roi(img, crop_box_size, crop_box_index)

guide_structure = crop_to_roi(guide_structure, crop_box_size, crop_box_index)
target_reg_structure = convert_mask_to_reg_structure(guide_structure)
target_reg_structure = convert_mask_to_reg_structure(guide_structure, expansion=0)

else:
quick_reg_settings = {
Expand Down Expand Up @@ -303,7 +306,7 @@ def run_cardiac_segmentation(img, guide_structure=None, settings=CARDIAC_SETTING
guide_structure_name = settings["atlas_settings"]["guide_structure_name"]
target_reg_image = target_reg_structure
atlas_reg_image = convert_mask_to_reg_structure(
atlas_set[atlas_id]["Original"][guide_structure_name]
atlas_set[atlas_id]["Original"][guide_structure_name], expansion=0
)

else:
Expand All @@ -328,6 +331,21 @@ def run_cardiac_segmentation(img, guide_structure=None, settings=CARDIAC_SETTING
interpolator=sitk.sitkNearestNeighbor,
)

expanded_atlas_guide_structure = extend_mask(
atlas_set[atlas_id]["Original"][guide_structure_name],
direction=("ax", "sup"),
extension_mm=settings["atlas_settings"]["superior_extension"],
interior_mm_shape=settings["atlas_settings"]["superior_extension"] / 2,
)

atlas_set[atlas_id]["RIR"][guide_structure_name + "EXPANDED"] = apply_transform(
input_image=expanded_atlas_guide_structure,
reference_image=img_crop,
transform=initial_tfm,
default_value=0,
interpolator=sitk.sitkNearestNeighbor,
)

atlas_set[atlas_id]["RIR"]["CT Image"] = apply_transform(
input_image=atlas_set[atlas_id]["Original"]["CT Image"],
reference_image=img_crop,
Expand Down Expand Up @@ -383,6 +401,14 @@ def run_cardiac_segmentation(img, guide_structure=None, settings=CARDIAC_SETTING
interpolator=sitk.sitkLinear,
)

atlas_set[atlas_id]["DIR_STRUCT"][guide_structure_name + "EXPANDED"] = apply_transform(
input_image=atlas_set[atlas_id]["RIR"][guide_structure_name + "EXPANDED"],
reference_image=img_crop,
transform=struct_guided_tfm,
default_value=0,
interpolator=sitk.sitkNearestNeighbor,
)

# sitk.WriteImage(deform_image, f"./DIR_STRUCT_{atlas_id}.nii.gz")

for struct in atlas_structure_list:
Expand All @@ -399,7 +425,7 @@ def run_cardiac_segmentation(img, guide_structure=None, settings=CARDIAC_SETTING
# Settings
deformable_registration_settings = settings["deformable_registration_settings"]

logger.info("Running DIR to register atlas images")
logger.info("Running DIR to refine atlas image registration")

for atlas_id in atlas_id_list:

Expand All @@ -408,13 +434,31 @@ def run_cardiac_segmentation(img, guide_structure=None, settings=CARDIAC_SETTING
# Register the atlases
atlas_set[atlas_id]["DIR"] = {}

atlas_reg_image = atlas_set[atlas_id]["DIR_STRUCT"]["CT Image"]
target_reg_image = img_crop

if guide_structure:
target_reg_image = target_reg_structure
atlas_reg_image = atlas_set[atlas_id]["DIR_STRUCT"]["Reg Mask"]
expanded_atlas_mask = atlas_set[atlas_id]["DIR_STRUCT"][
guide_structure_name + "EXPANDED"
]
expanded_target_mask = extend_mask(
guide_structure,
direction=("ax", "sup"),
extension_mm=settings["atlas_settings"]["superior_extension"],
interior_mm_shape=settings["atlas_settings"]["superior_extension"] / 2,
)

else:
atlas_reg_image = atlas_set[atlas_id]["RIR"]["CT Image"]
target_reg_image = sitk.Mask(img_crop, atlas_reg_image > -1000, outsideValue=-1000)
combined_mask = sitk.Maximum(expanded_atlas_mask, expanded_target_mask)

atlas_reg_image = sitk.Mask(atlas_reg_image, combined_mask, outsideValue=-1000)
atlas_reg_image = sitk.Mask(
atlas_reg_image, atlas_reg_image > -400, outsideValue=-1000
)

target_reg_image = sitk.Mask(target_reg_image, combined_mask, outsideValue=-1000)
target_reg_image = sitk.Mask(
target_reg_image, atlas_reg_image > -400, outsideValue=-1000
)

deform_image, dir_tfm, _ = fast_symmetric_forces_demons_registration(
target_reg_image,
Expand Down
39 changes: 39 additions & 0 deletions platipy/imaging/radiotherapy/dose.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright 2020 University of New South Wales, University of Sydney, Ingham Institute

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import SimpleITK as sitk


def calculate_dvh(dose_grid, label, bins=1001):

if dose_grid.GetSize() != label.GetSize():
print("Dose grid size does not match label, automatically resampling.")
dose_grid = sitk.Resample(dose_grid, label)

dose_arr = sitk.GetArrayViewFromImage(dose_grid)
label_arr = sitk.GetArrayViewFromImage(label)

dose_vals = dose_arr[np.where(label_arr)]

counts, bin_edges = np.histogram(dose_vals, bins=bins)

# Get mid-points of bins
dose_points = (bin_edges[1:] + bin_edges[:-1]) / 2.0

# Calculate the actual DVH values
counts = np.cumsum(counts[::-1])[::-1]
counts = counts / counts.max()

return dose_points, counts
45 changes: 31 additions & 14 deletions platipy/imaging/registration/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,22 +162,39 @@ def linear_registration(
if fixed_structure:
registration.SetMetricFixedMask(fixed_structure)

if reg_method.lower() == "translation":
registration.SetInitialTransform(sitk.TranslationTransform(3))
elif reg_method.lower() == "similarity":
registration.SetInitialTransform(sitk.Similarity3DTransform())
elif reg_method.lower() == "affine":
registration.SetInitialTransform(sitk.AffineTransform(3))
elif reg_method.lower() == "rigid":
registration.SetInitialTransform(sitk.VersorRigid3DTransform())
elif reg_method.lower() == "scaleversor":
registration.SetInitialTransform(sitk.ScaleVersor3DTransform())
elif reg_method.lower() == "scaleskewversor":
registration.SetInitialTransform(sitk.ScaleSkewVersor3DTransform())
if isinstance(reg_method, str):
if reg_method.lower() == "translation":
registration.SetInitialTransform(sitk.TranslationTransform(3))
elif reg_method.lower() == "similarity":
registration.SetInitialTransform(sitk.Similarity3DTransform())
elif reg_method.lower() == "affine":
registration.SetInitialTransform(sitk.AffineTransform(3))
elif reg_method.lower() == "rigid":
registration.SetInitialTransform(sitk.VersorRigid3DTransform())
elif reg_method.lower() == "scaleversor":
registration.SetInitialTransform(sitk.ScaleVersor3DTransform())
elif reg_method.lower() == "scaleskewversor":
registration.SetInitialTransform(sitk.ScaleSkewVersor3DTransform())
else:
raise ValueError(
"You have selected a registration method that does not exist.\n Please select from "
"Translation, Similarity, Affine, Rigid, ScaleVersor, ScaleSkewVersor"
)
elif (
isinstance(reg_method, sitk.CompositeTransform)
or isinstance(reg_method, sitk.Transform)
or isinstance(reg_method, sitk.TranslationTransform)
or isinstance(reg_method, sitk.Similarity3DTransform)
or isinstance(reg_method, sitk.AffineTransform)
or isinstance(reg_method, sitk.VersorRigid3DTransform)
or isinstance(reg_method, sitk.ScaleVersor3DTransform)
or isinstance(reg_method, sitk.ScaleSkewVersor3DTransform)
):
registration.SetInitialTransform(reg_method)
else:
raise ValueError(
"You have selected a registration method that does not exist.\n Please select from "
"Translation, Similarity, Affine, Rigid, ScaleVersor, ScaleSkewVersor"
"'reg_method' must be either a string (see docs for acceptable registration names), "
"or a custom sitk.CompositeTransform."
)

if optimiser.lower() == "lbfgsb":
Expand Down
7 changes: 7 additions & 0 deletions platipy/imaging/registration/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,13 @@ def convert_mask_to_distance_map(mask, squared_distance=False, normalise=False):
Returns:
[SimpleITK.Image]: The distance map as an image.
"""
arr = sitk.GetArrayViewFromImage(mask)
vals = np.unique(arr[arr > 0])
if len(vals) > 2:
# There is more than one value! We need to threshold at the median
cutoff = np.median(vals)
mask = sitk.BinaryThreshold(mask, cutoff, np.max(vals).astype(float))

raw_map = sitk.SignedMaurerDistanceMap(
mask,
insideIsPositive=True,
Expand Down
Loading

0 comments on commit 9124951

Please sign in to comment.