Skip to content

Commit

Permalink
Merge pull request #112 from pyplati/visualisation-updates
Browse files Browse the repository at this point in the history
Visualisation updates
  • Loading branch information
pchlap authored Feb 1, 2022
2 parents 6b358d9 + a32080b commit 945b67b
Show file tree
Hide file tree
Showing 9 changed files with 211 additions and 136 deletions.
7 changes: 5 additions & 2 deletions platipy/imaging/dose/dvh.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ def calculate_dvh(dose_grid, label, bins=1001):

# Calculate the actual DVH values
values = np.cumsum(counts[::-1])[::-1]
values = values / values.max()
if np.all(values == 0):
return bins, values
else:
values = values / values.max()

return bins, values

Expand Down Expand Up @@ -142,7 +145,7 @@ def calculate_d_x(dvh, x, label=None):

def calculate_v_x(dvh, x, label=None):
"""Get the volume (in cc) which receives x dose
Args:
dvh (pandas.DataFrame): DVH DataFrame as produced by calculate_dvh_for_labels
x (float): The dose to get the volume for.
Expand Down
12 changes: 12 additions & 0 deletions platipy/imaging/label/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from platipy.imaging.utils.crop import label_to_roi, crop_to_roi


def compute_volume(label):
"""Computes the volume in cubic centimetres
Expand Down Expand Up @@ -232,6 +233,12 @@ def compute_metric_masd(label_a, label_b, auto_crop=True):
Returns:
float: The mean absolute surface distance
"""
if (
sitk.GetArrayViewFromImage(label_a).sum() == 0
or sitk.GetArrayViewFromImage(label_b).sum() == 0
):
return np.nan

if auto_crop:
largest_region = (label_a + label_b) > 0
crop_box_size, crop_box_index = label_to_roi(largest_region)
Expand Down Expand Up @@ -267,6 +274,11 @@ def compute_metric_hd(label_a, label_b, auto_crop=True):
Returns:
float: The maximum Hausdorff distance
"""
if (
sitk.GetArrayViewFromImage(label_a).sum() == 0
or sitk.GetArrayViewFromImage(label_b).sum() == 0
):
return np.nan
if auto_crop:
largest_region = (label_a + label_b) > 0
crop_box_size, crop_box_index = label_to_roi(largest_region)
Expand Down
23 changes: 13 additions & 10 deletions platipy/imaging/label/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from platipy.imaging.utils.math import gen_primes


def correct_volume_overlap(binary_label_dict):
def correct_volume_overlap(binary_label_dict, assign_overlap_to_largest=True):
"""
Label structures by primes
Smallest prime = largest volume
Expand All @@ -31,26 +31,29 @@ def correct_volume_overlap(binary_label_dict):
volume_dict = {i: f_vol(binary_label_dict[i]) for i in binary_label_dict.keys()}

keys, vals = zip(*volume_dict.items())
volume_rank = np.argsort(vals)[::-1]
if assign_overlap_to_largest:
volume_rank = np.argsort(vals)[::-1]
else:
volume_rank = np.argsort(vals)

# print(keys, volume_rank)

ranked_names = np.array(keys)[volume_rank]

# Get overlap using prime factors
prime_labelled_image = sum(binary_label_dict.values()) > 0
# Get overlap (this is used to reconstruct labels)
combined_label = sum(binary_label_dict.values()) > 0

for p, label in zip(gen_primes(), ranked_names):
prime_labelled_image = prime_labelled_image * (
(p - 1) * binary_label_dict[label] + combined_label
)
# Prime encode each binary label
prime_labelled_image = prime_encode_structure_list(
[binary_label_dict[i] for i in ranked_names]
)

# Remove overlap (by assigning to binary volume)
output_label_dict = {}
for p, label in zip(gen_primes(), ranked_names):
output_label_dict[label] = combined_label * (sitk.Modulus(prime_labelled_image, p) == 0)

combined_label = sitk.Mask(combined_label, output_label_dict[label] == 0)
combined_label = sitk.MaskNegated(combined_label, output_label_dict[label])

return output_label_dict

Expand Down Expand Up @@ -155,14 +158,14 @@ def prime_encode_structure_list(structure_list):
img_size = structure_list[0].GetSize()
prime_encoded_image = sitk.GetImageFromArray(np.ones(img_size[::-1]))
prime_encoded_image = sitk.Cast(prime_encoded_image, sitk.sitkUInt64)
prime_encoded_image.CopyInformation(structure_list[0])

prime_generator = generate_primes()

for s_img, prime in zip(structure_list, prime_generator):
# Cast to int
s_img_int = sitk.Cast(s_img > 0, sitk.sitkUInt64)

print(prime)
# Multiply with the encoded image
prime_encoded_image = (
sitk.MaskNegated(prime_encoded_image, s_img_int)
Expand Down
9 changes: 4 additions & 5 deletions platipy/imaging/projects/cardiac/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,11 @@
"isotropic_resample": True,
"resolution_staging": [
6,
4,
2,
1,
3,
1.5,
], # specify voxel size (mm) since isotropic_resample is set
"iteration_staging": [200, 150, 125, 100],
"smoothing_sigmas": [0, 0, 0, 0],
"iteration_staging": [200, 150, 100],
"smoothing_sigmas": [0, 0, 0],
"ncores": 8,
"default_value": 0,
"verbose": False,
Expand Down
23 changes: 13 additions & 10 deletions platipy/imaging/projects/multiatlas/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from loguru import logger

from platipy.imaging.registration.utils import apply_transform, convert_mask_to_reg_structure
from platipy.imaging.registration.utils import apply_transform

from platipy.imaging.registration.linear import (
linear_registration,
Expand All @@ -37,7 +37,7 @@

from platipy.imaging.utils.crop import label_to_roi, crop_to_roi

from platipy.imaging.label.utils import binary_encode_structure_list, correct_volume_overlap
from platipy.imaging.label.utils import correct_volume_overlap

ATLAS_PATH = "/atlas"
if "ATLAS_PATH" in os.environ:
Expand Down Expand Up @@ -65,7 +65,7 @@
"shrink_factors": [16, 8, 4],
"smooth_sigmas": [0, 0, 0],
"sampling_rate": 0.75,
"default_value": -1000,
"default_value": None,
"number_of_iterations": 50,
"metric": "mean_squares",
"optimiser": "gradient_descent_line_search",
Expand All @@ -75,14 +75,17 @@
"isotropic_resample": True,
"resolution_staging": [
6,
4,
2,
1,
3,
1.5,
], # specify voxel size (mm) since isotropic_resample is set
"iteration_staging": [200, 150, 125, 100],
"smoothing_sigmas": [0, 0, 0, 0],
"iteration_staging": [150, 125, 100],
"smoothing_sigmas": [
0,
0,
0,
],
"ncores": 8,
"default_value": 0,
"default_value": None,
"verbose": False,
},
"label_fusion_settings": {
Expand Down Expand Up @@ -315,7 +318,7 @@ def run_segmentation(img, settings=MUTLIATLAS_SETTINGS_DEFAULTS):
atlas_reg_image = atlas_set[atlas_id]["RIR"]["CT Image"]
target_reg_image = img_crop

deform_image, dir_tfm, _ = fast_symmetric_forces_demons_registration(
_, dir_tfm, _ = fast_symmetric_forces_demons_registration(
target_reg_image,
atlas_reg_image,
**deformable_registration_settings,
Expand Down
13 changes: 12 additions & 1 deletion platipy/imaging/registration/deformable.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def fast_symmetric_forces_demons_registration(
initial_displacement_field=None,
smoothing_sigma_factor=1,
smoothing_sigmas=False,
default_value=-1000,
default_value=None,
ncores=1,
interp_order=2,
verbose=False,
Expand All @@ -176,6 +176,8 @@ def fast_symmetric_forces_demons_registration(
2 = Bi-linear splines
3 = B-Spline (cubic)
default_value (float) : Default voxel value. Defaults to 0 unless image is CT-like.
Returns
registered_image (sitk.Image) : the registered moving image
output_transform : the displacement field transform
Expand Down Expand Up @@ -224,6 +226,15 @@ def fast_symmetric_forces_demons_registration(
resampler = sitk.ResampleImageFilter()
resampler.SetReferenceImage(fixed_image)
resampler.SetInterpolator(interp_order)

# Try to find default value
if default_value is None:
default_value = 0

# Test if image is CT-like
if sitk.GetArrayViewFromImage(moving_image).min() <= -1000:
default_value = -1000

resampler.SetDefaultPixelValue(default_value)

resampler.SetTransform(output_transform)
Expand Down
13 changes: 11 additions & 2 deletions platipy/imaging/registration/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def linear_registration(
sampling_rate=0.25,
final_interp=2,
number_of_iterations=50,
default_value=-1000,
default_value=None,
verbose=False,
):
"""
Expand Down Expand Up @@ -110,7 +110,7 @@ def linear_registration(
final_interp (int, optional): The final interpolation order. Defaults to 2 (linear).
number_of_iterations (int, optional): Number of iterations in each multi-resolution step.
Defaults to 50.
default_value (int, optional): Default voxel value. Defaults to -1000.
default_value (int, optional): Default voxel value. Defaults to 0 unless image is CT-like.
verbose (bool, optional): Print image registration process information. Defaults to False.
Returns:
Expand Down Expand Up @@ -236,6 +236,15 @@ def linear_registration(
# Combine initial and optimised transform
combined_transform = sitk.CompositeTransform([initial_transform, output_transform])


# Try to find default value
if default_value is None:
default_value = 0

# Test if image is CT-like
if sitk.GetArrayViewFromImage(moving_image).min() <= -1000:
default_value = -1000

registered_image = apply_transform(
input_image=moving_image,
reference_image=fixed_image,
Expand Down
2 changes: 2 additions & 0 deletions platipy/imaging/visualisation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(
mid_ticks=False,
show_colorbar=True,
norm=None,
projection=False,
):
self.image = image
self.name = name
Expand All @@ -71,6 +72,7 @@ def __init__(
self.mid_ticks = mid_ticks
self.show_colorbar = show_colorbar
self.norm = norm
self.projection = projection


class VisualiseVectorOverlay:
Expand Down
Loading

0 comments on commit 945b67b

Please sign in to comment.