Skip to content

Commit

Permalink
Merge branch 'lib-updates' into img-vis-clean
Browse files Browse the repository at this point in the history
  • Loading branch information
pchlap committed Jun 8, 2021
2 parents f518345 + 218babf commit 89cfddf
Show file tree
Hide file tree
Showing 16 changed files with 1,245 additions and 193 deletions.
595 changes: 595 additions & 0 deletions .pylintrc

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion platipy/imaging/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .visualisation.visualiser import ImageVisualiser
from .visualisation.visualiser import ImageVisualiser
50 changes: 50 additions & 0 deletions platipy/imaging/dose/dvh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright 2020 University of New South Wales, University of Sydney, Ingham Institute

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import SimpleITK as sitk


def calculate_dvh(dose_grid, label, bins=1001):
"""Calculates a dose-volume histogram
Args:
dose_grid (SimpleITK.Image): The dose grid.
label (SimpleITK.Image): The (binary) label defining a structure.
bins (int | list | np.ndarray, optional): Passed to np.histogram,
can be an int (number of bins), or a list (specifying bin edges). Defaults to 1001.
Returns:
SimpleITK.Image: [description]
"""

if dose_grid.GetSize() != label.GetSize():
print("Dose grid size does not match label, automatically resampling.")
dose_grid = sitk.Resample(dose_grid, label)

dose_arr = sitk.GetArrayViewFromImage(dose_grid)
label_arr = sitk.GetArrayViewFromImage(label)

dose_vals = dose_arr[np.where(label_arr)]

counts, bin_edges = np.histogram(dose_vals, bins=bins)

# Get mid-points of bins
dose_points = (bin_edges[1:] + bin_edges[:-1]) / 2.0

# Calculate the actual DVH values
counts = np.cumsum(counts[::-1])[::-1]
counts = counts / counts.max()

return dose_points, counts
69 changes: 65 additions & 4 deletions platipy/imaging/generation/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


def insert_sphere(arr, sp_radius=4, sp_centre=(0, 0, 0)):
"""Insert a sphere into the give array
"""Insert a sphere into the given array
Args:
arr (np.array): Array in which to insert sphere
Expand Down Expand Up @@ -48,14 +48,46 @@ def insert_sphere(arr, sp_radius=4, sp_centre=(0, 0, 0)):
return arr_copy


def insert_cylinder(arr, cyl_radius=4, cyl_height=2, cyl_centre=(0, 0, 0)):
"""
Insert a cylinder into the given array.
The cylinder vertical extent is +/- 0.5 * height
Args:
arr (np.ndarray): The array into which the cylinder is inserted
cyl_radius (int, optional): Cylinder radius. Defaults to 4.
cyl_height (int, optional): Cylinder height. Defaults to 2.
cyl_centre (tuple, optional): Cylinder centre. Defaults to (0, 0, 0).
Returns:
np.ndarray: The original array with a cylinder (value 1)
"""
arr_copy = arr[:]

x, y, z = np.indices(arr.shape)

if not hasattr(cyl_radius, "__iter__"):
cyl_radius = [cyl_radius] * 2

condition_radial = (
((z - cyl_centre[0]) / cyl_radius[0]) ** 2 + ((y - cyl_centre[1]) / cyl_radius[1]) ** 2
) <= 1
condition_height = np.abs((x - cyl_centre[2]) / (0.5 * cyl_height)) <= 1

arr_copy[condition_radial & condition_height] = 1

return arr_copy


def insert_sphere_image(image, sp_radius, sp_centre):
"""Insert a sphere into a blank image with the same size as image
"""Insert a sphere into an image
Args:
image (sitk.Image): Image in which to insert sphere
sp_radius (int | list, optional): The radius of the sphere. Can also be defined as a vector. Defaults to 4.
sp_radius (int | list, optional): The radius of the sphere.
Can also be defined as a vector. Defaults to 4.
sp_centre (tuple, optional): The position at which the sphere should be inserted. Defaults
to (0, 0, 0).
to (0, 0, 0).
Returns:
np.array: An array with the sphere inserted
Expand All @@ -74,3 +106,32 @@ def insert_sphere_image(image, sp_radius, sp_centre):
image_sphere.CopyInformation(image)

return image_sphere


def insert_cylinder_image(image, cyl_radius=(5, 5), cyl_height=10, cyl_centre=(0, 0, 0)):
"""Insert a cylinder into an image
Args:
image (SimpleITK.Image):
cyl_radius (tuple, optional): Cylinder radius, can be defined as a single value
or a tuple (will generate an ellipsoid). Defaults to (5,5).
cyl_height (int, optional): Cylinder height. Defaults to 10.
cyl_centre (tuple, optional): Cylinder centre. Defaults to (0,0,0).
Returns:
SimpleITK.Image: Image with cylinder inserted
"""
if not hasattr(cyl_radius, "__iter__"):
cyl_radius = [cyl_radius] * 2

cyl_radius_image = [i / j for i, j in zip(cyl_radius, image.GetSpacing()[1::-1])]
cyl_height_image = cyl_height / image.GetSpacing()[2]

arr = sitk.GetArrayFromImage(image)

arr = insert_cylinder(arr, cyl_radius_image, cyl_height_image, cyl_centre)

image_cylinder = sitk.GetImageFromArray(arr)
image_cylinder.CopyInformation(image)

return image_cylinder
55 changes: 55 additions & 0 deletions platipy/imaging/generation/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,58 @@ def get_external_mask(
body_mask_hull.CopyInformation(body_mask)

return body_mask_hull


def extend_mask(mask, direction=("ax", "sup"), extension_mm=10, interior_mm_shape=10):
"""
Extends a binary label (mask) a number of slices.
PROTOTYPE!
Currently can only extend in axial directions (superior of inferior).
The shape of the extended part is based on some number of interior slices, as defined.
Args:
mask (SimpleITK.Image): The input binary label (mask).
direction (tuple, optional): The direction as a tuple. First element is axis, second
element is direction. Defaults to ("ax", "sup").
extension_mm (int, optional): The extension in millimeters. Defaults to 10.
interior_mm_shape (int, optional): The length on which to base the extension shape.
Defaults to 10.
Returns:
SimpleITK.Image: The output (extended mask).
"""
arr = sitk.GetArrayViewFromImage(mask)
vals = np.unique(arr[arr > 0])
if len(vals) > 2:
# There is more than one value! We need to threshold (at the median)
cutoff = np.median(vals)
mask_binary = sitk.BinaryThreshold(mask, cutoff, np.max(vals).astype(float))
else:
mask_binary = mask

arr = sitk.GetArrayFromImage(mask_binary)

if direction[0] == "ax":
inferior_slice = np.where(arr)[0].min()
superior_slice = np.where(arr)[0].max()

n_slices_ext = int(extension_mm / mask.GetSpacing()[2])
n_slices_est = int(interior_mm_shape / mask.GetSpacing()[2])

if direction[1] == "sup":
max_index = min([arr.shape[0], superior_slice + 1 + n_slices_ext])
for s_in in range(superior_slice + 1 - n_slices_est, max_index):
arr[s_in, :, :] = np.max(
arr[superior_slice - n_slices_est : superior_slice, :, :], axis=0
)
if direction[1] == "inf":
min_index = max([arr.shape[0], inferior_slice - n_slices_ext + n_slices_est])
for s_in in range(min_index, inferior_slice):
arr[s_in, :, :] = np.max(
arr[inferior_slice + n_slices_est : inferior_slice, :, :], axis=0
)

mask_ext = sitk.GetImageFromArray(arr)
mask_ext.CopyInformation(mask)

return mask_ext
42 changes: 37 additions & 5 deletions platipy/imaging/label/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
import numpy as np
import SimpleITK as sitk

from platipy.imaging.utils.crop import label_to_roi, crop_to_roi


def compute_surface_metrics(label_a, label_b, verbose=False):
"""Compute surface distance metrics between two labels. Surface metrics computed are:
Expand Down Expand Up @@ -141,7 +143,7 @@ def compute_volume_metrics(label_a, label_b):
return result


def compute_metric_dsc(label_a, label_b):
def compute_metric_dsc(label_a, label_b, auto_crop=True):
"""Compute the Dice Similarity Coefficient between two labels
Args:
Expand All @@ -151,13 +153,19 @@ def compute_metric_dsc(label_a, label_b):
Returns:
float: The Dice Similarity Coefficient
"""
if auto_crop:
largest_region = (label_a + label_b) > 0
crop_box_size, crop_box_index = label_to_roi(largest_region)

label_a = crop_to_roi(label_a, size=crop_box_size, index=crop_box_index)
label_b = crop_to_roi(label_b, size=crop_box_size, index=crop_box_index)

arr_a = sitk.GetArrayFromImage(label_a).astype(bool)
arr_b = sitk.GetArrayFromImage(label_b).astype(bool)
return 2 * ((arr_a & arr_b).sum()) / (arr_a.sum() + arr_b.sum())


def compute_metric_specificity(label_a, label_b):
def compute_metric_specificity(label_a, label_b, auto_crop=True):
"""Compute the specificity between two labels
Args:
Expand All @@ -167,6 +175,12 @@ def compute_metric_specificity(label_a, label_b):
Returns:
float: The specificity between the two labels
"""
if auto_crop:
largest_region = (label_a + label_b) > 0
crop_box_size, crop_box_index = label_to_roi(largest_region)

label_a = crop_to_roi(label_a, size=crop_box_size, index=crop_box_index)
label_b = crop_to_roi(label_b, size=crop_box_size, index=crop_box_index)

arr_a = sitk.GetArrayFromImage(label_a).astype(bool)
arr_b = sitk.GetArrayFromImage(label_b).astype(bool)
Expand All @@ -180,7 +194,7 @@ def compute_metric_specificity(label_a, label_b):
return float((1.0 * true_neg) / (true_neg + false_pos))


def compute_metric_sensitivity(label_a, label_b):
def compute_metric_sensitivity(label_a, label_b, auto_crop=True):
"""Compute the sensitivity between two labels
Args:
Expand All @@ -190,6 +204,12 @@ def compute_metric_sensitivity(label_a, label_b):
Returns:
float: The sensitivity between the two labels
"""
if auto_crop:
largest_region = (label_a + label_b) > 0
crop_box_size, crop_box_index = label_to_roi(largest_region)

label_a = crop_to_roi(label_a, size=crop_box_size, index=crop_box_index)
label_b = crop_to_roi(label_b, size=crop_box_size, index=crop_box_index)

arr_a = sitk.GetArrayFromImage(label_a).astype(bool)
arr_b = sitk.GetArrayFromImage(label_b).astype(bool)
Expand All @@ -202,7 +222,7 @@ def compute_metric_sensitivity(label_a, label_b):
return float((1.0 * true_pos) / (true_pos + false_neg))


def compute_metric_masd(label_a, label_b):
def compute_metric_masd(label_a, label_b, auto_crop=True):
"""Compute the mean absolute distance between two labels
Args:
Expand All @@ -212,6 +232,12 @@ def compute_metric_masd(label_a, label_b):
Returns:
float: The mean absolute surface distance
"""
if auto_crop:
largest_region = (label_a + label_b) > 0
crop_box_size, crop_box_index = label_to_roi(largest_region)

label_a = crop_to_roi(label_a, size=crop_box_size, index=crop_box_index)
label_b = crop_to_roi(label_b, size=crop_box_size, index=crop_box_index)

mean_sd_list = []
num_points = []
Expand All @@ -231,7 +257,7 @@ def compute_metric_masd(label_a, label_b):
return float(mean_surf_dist)


def compute_metric_hd(label_a, label_b):
def compute_metric_hd(label_a, label_b, auto_crop=True):
"""Compute the Hausdorff distance between two labels
Args:
Expand All @@ -241,6 +267,12 @@ def compute_metric_hd(label_a, label_b):
Returns:
float: The maximum Hausdorff distance
"""
if auto_crop:
largest_region = (label_a + label_b) > 0
crop_box_size, crop_box_index = label_to_roi(largest_region)

label_a = crop_to_roi(label_a, size=crop_box_size, index=crop_box_index)
label_b = crop_to_roi(label_b, size=crop_box_size, index=crop_box_index)

hausdorff_distance = sitk.HausdorffDistanceImageFilter()
hausdorff_distance.Execute(label_a, label_b)
Expand Down
6 changes: 2 additions & 4 deletions platipy/imaging/label/fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,9 @@ def combine_labels(atlas_set, structure_name, label="DIR", threshold=1e-4, smoot

combined_label_dict = {}

for structure_name in structure_name_list:
for s_name in structure_name_list:
# Find the cases which have the strucure (in case some cases do not)
valid_case_id_list = [
i for i in case_id_list if structure_name in atlas_set[i][label].keys()
]
valid_case_id_list = [i for i in case_id_list if s_name in atlas_set[i][label].keys()]

# Get valid weight images
weight_image_list = [
Expand Down
7 changes: 4 additions & 3 deletions platipy/imaging/label/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,19 @@
from scipy.ndimage.measurements import center_of_mass


def get_com(label, real_coords=False):
def get_com(label, as_int=True, real_coords=False):
"""
Get centre of mass of a SimpleITK.Image
"""
arr = sitk.GetArrayFromImage(label)
com = center_of_mass(arr)

if real_coords:
com = label.TransformContinuousIndexToPhysicalPoint(com)
com = label.TransformContinuousIndexToPhysicalPoint(com[::-1])

else:
com = [int(i) for i in com]
if as_int:
com = [int(i) for i in com]

return com

Expand Down
Loading

0 comments on commit 89cfddf

Please sign in to comment.