Skip to content

Commit

Permalink
Merged in rf-working (pull request #27)
Browse files Browse the repository at this point in the history
update devel
  • Loading branch information
rnfinnegan authored and pchlap committed Sep 22, 2020
2 parents c43163d + bbda969 commit a427e54
Show file tree
Hide file tree
Showing 5 changed files with 219 additions and 44 deletions.
6 changes: 4 additions & 2 deletions platipy/imaging/atlas/iterative_atlas_removal.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def run_iar(
iteration=0,
single_step=False,
project_on_sphere=False,
label='DIR',
):
"""
Perform iterative atlas removal on the atlas_set
Expand All @@ -174,7 +175,7 @@ def run_iar(

# Generate the surface projections
# 1. Set the consensus surface using the reference volume
probability_label = combine_labels(atlas_set, structure_name)[structure_name]
probability_label = combine_labels(atlas_set, structure_name, label=label)[structure_name]

# Modify resolution for better statistics
if project_on_sphere:
Expand Down Expand Up @@ -202,7 +203,7 @@ def run_iar(
logger.info(" {0}".format(test_id))
# 2. Calculate the distance from the surface to the consensus surface

test_volume = atlas_set[test_id]["DIR"][structure_name]
test_volume = atlas_set[test_id][label][structure_name]

# This next step ensures non-binary labels are treated properly
# We use 0.1 to capture the outer edge of the test delineation, if it is probabilistic
Expand Down Expand Up @@ -394,6 +395,7 @@ def run_iar(
n_factor=n_factor,
iteration=iteration,
project_on_sphere=project_on_sphere,
label=label
)

logger.info(" End point reached. Keeping:\n {0}".format(keep_id_list))
Expand Down
10 changes: 5 additions & 5 deletions platipy/imaging/atlas/label_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def combine_labels_staple(label_list_dict, threshold=1e-4):
return combined_label_dict


def combine_labels(atlas_set, structure_name, threshold=1e-4, smooth_sigma=1.0):
def combine_labels(atlas_set, structure_name, label='DIR', threshold=1e-4, smooth_sigma=1.0):
"""
Combine labels using weight maps
"""
Expand All @@ -123,12 +123,12 @@ def combine_labels(atlas_set, structure_name, threshold=1e-4, smooth_sigma=1.0):
for structure_name in structure_name_list:
# Find the cases which have the strucure (in case some cases do not)
valid_case_id_list = [
i for i in case_id_list if structure_name in atlas_set[i]["DIR"].keys()
i for i in case_id_list if structure_name in atlas_set[i][label].keys()
]

# Get valid weight images
weight_image_list = [
atlas_set[caseId]["DIR"]["Weight Map"] for caseId in valid_case_id_list
atlas_set[caseId][label]["Weight Map"] for caseId in valid_case_id_list
]

# Sum the weight images
Expand All @@ -139,8 +139,8 @@ def combine_labels(atlas_set, structure_name, threshold=1e-4, smooth_sigma=1.0):

# Combine weight map with each label
weighted_labels = [
atlas_set[caseId]["DIR"]["Weight Map"]
* sitk.Cast(atlas_set[caseId]["DIR"][structure_name], sitk.sitkFloat32)
atlas_set[caseId][label]["Weight Map"]
* sitk.Cast(atlas_set[caseId][label][structure_name], sitk.sitkFloat32)
for caseId in valid_case_id_list
]

Expand Down
78 changes: 42 additions & 36 deletions platipy/imaging/projects/cardiac/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@

from platipy.imaging.projects.cardiac.utils import vesselSplineGeneration

from platipy.imaging.utils.tools import label_to_roi, crop_to_roi

ATLAS_PATH = "/atlas"
if "ATLAS_PATH" in os.environ:
ATLAS_PATH = os.environ["ATLAS_PATH"]
Expand Down Expand Up @@ -91,6 +93,7 @@
"stopCondition": {"LANTDESCARTERY_SPLINE": "count"},
"stopConditionValue": {"LANTDESCARTERY_SPLINE": 1},
},
"returnAsCropped": False
}


Expand All @@ -107,6 +110,7 @@ def run_cardiac_segmentation(img, settings=CARDIAC_SETTINGS_DEFAULTS):
"""

results = {}
return_as_cropped = settings["returnAsCropped"]

"""
Initialisation - Read in atlases
Expand Down Expand Up @@ -191,7 +195,7 @@ def run_cardiac_segmentation(img, settings=CARDIAC_SETTINGS_DEFAULTS):
"""
# Settings
quick_reg_settings = {
"shrinkFactors": [16],
"shrinkFactors": [8],
"smoothSigmas": [0],
"samplingRate": 0.75,
"defaultValue": -1024,
Expand Down Expand Up @@ -235,35 +239,24 @@ def run_cardiac_segmentation(img, settings=CARDIAC_SETTINGS_DEFAULTS):
shape_filter.Execute(combined_image_extent)
bounding_box = np.array(shape_filter.GetBoundingBox(1))

expansion = settings["autoCropSettings"]["expansion"]
expansion_array = expansion * np.array(img.GetSpacing())

# Avoid starting outside the image
crop_box_index = np.max(
[bounding_box[:3] - expansion_array, np.array([0, 0, 0])], axis=0
)

# Avoid ending outside the image
crop_box_size = np.min(
[
np.array(img.GetSize()) - crop_box_index,
bounding_box[3:] + 2 * expansion_array,
],
axis=0,
)
"""
Crop image to region of interest (ROI)
--> Defined by images
"""

crop_box_size = [int(i) for i in crop_box_size]
crop_box_index = [int(i) for i in crop_box_index]
expansion = settings["autoCropSettings"]["expansion"]
expansion_array = expansion * np.array(img.GetSpacing())

crop_box_size, crop_box_index = label_to_roi(img, combined_image_extent, expansion = expansion_array)
img_crop = crop_to_roi(img, crop_box_size, crop_box_index)
logger.info(
f"Calculated crop box\n\
{crop_box_index}\n\
{crop_box_size}\n\n\
Volume reduced by factor {np.product(img.GetSize())/np.product(crop_box_size)}"
)

img_crop = sitk.RegionOfInterest(img, size=crop_box_size, index=crop_box_index)

"""
Step 2 - Rigid registration of target images
- Individual atlas images are registered to the target
Expand Down Expand Up @@ -463,26 +456,39 @@ def run_cardiac_segmentation(img, settings=CARDIAC_SETTINGS_DEFAULTS):

binary_struct = process_probability_image(probability_map, optimal_threshold)

paste_binary_img = sitk.Paste(
template_img_binary,
binary_struct,
binary_struct.GetSize(),
(0, 0, 0),
crop_box_index,
)
if return_as_cropped:
results[structure_name] = binary_struct

else:
paste_binary_img = sitk.Paste(
template_img_binary,
binary_struct,
binary_struct.GetSize(),
(0, 0, 0),
crop_box_index,
)

results[structure_name] = paste_binary_img
results[structure_name] = paste_binary_img

for structure_name in vessel_name_list:
binary_struct = segmented_vessel_dict[structure_name]
paste_img_binary = sitk.Paste(
template_img_binary,
binary_struct,
binary_struct.GetSize(),
(0, 0, 0),
crop_box_index,
)

results[structure_name] = paste_img_binary
if return_as_cropped:
results[structure_name] = binary_struct

else:
paste_img_binary = sitk.Paste(
template_img_binary,
binary_struct,
binary_struct.GetSize(),
(0, 0, 0),
crop_box_index,
)

results[structure_name] = paste_img_binary

if return_as_cropped:
results['CROP_IMAGE'] = img_crop


return results
42 changes: 41 additions & 1 deletion platipy/imaging/utils/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,4 +224,44 @@ def get_crop_bounding_box(img, mask):
bounding_box[3 + i], img.GetSize()[i] - bounding_box[i]
)

return bounding_box
return bounding_box


def label_to_roi(image, label_list, expansion = [0,0,0]):

label_stats_image_filter = sitk.LabelStatisticsImageFilter()
if type(label_list)==list:
label_stats_image_filter.Execute(image, sum(label_list) > 0)
elif type(label_list)==sitk.Image:
label_stats_image_filter.Execute(image, label_list)
else:
raise ValueError('Second argument must be a SITK image, or list thereof.')

bounding_box = np.array(label_stats_image_filter.GetBoundingBox(1))

index = [bounding_box[x * 2] for x in range(3)]
size = [bounding_box[(x * 2) + 1] - bounding_box[x * 2] for x in range(3)]
expansion = np.array(expansion)

# Avoid starting outside the image
crop_box_index = np.max(
[index - expansion, np.array([0, 0, 0])], axis=0
)

# Avoid ending outside the image
crop_box_size = np.min(
[
np.array(image.GetSize()) - crop_box_index,
np.array(size) + 2*expansion,
],
axis=0,
)

crop_box_size = [int(i) for i in crop_box_size]
crop_box_index = [int(i) for i in crop_box_index]

return crop_box_size, crop_box_index

def crop_to_roi(image, size, index):
return sitk.RegionOfInterest(image, size=size, index=index)

127 changes: 127 additions & 0 deletions platipy/imaging/visualisation/crosshairs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
"""
Code to display orthogonal image slices with crosshairs displaying the slice position
"""

import numpy as np
import SimpleITK as sitk
import matplotlib.pyplot as plt

"""
Settings
Fill out this section
"""

# Define the image
im = sitk.ReadImage('/home/robbie/Work/4_Software/platipy/platipy/imaging/tests/data/dynamic/lung/LCTSC-Train-S1-001/CT.nii.gz')

# Define the cut location - this is where the crosshairs will appear
# Given as [axial, coronal, sagittal] location
cut = [30, 220, 330]

# Define the figure size (inches), colormap, intensity windowing ([min, range])
figSize=6
cmap=plt.cm.Greys_r
window=[-250, 500]

# Output file name
fig_name = './test.png'

"""
Code - shouldn't need to edit this
"""

def returnSlice(axis, index):
if axis == "x":
s = (slice(None), slice(None), index)
if axis == "y":
s = (slice(None), index, slice(None))
if axis == "z":
s = (index, slice(None), slice(None))

return s

nda = sitk.GetArrayFromImage(im)

# Get data for correct visualisation
(AxSize, CorSize, SagSize) = nda.shape
spPlane, _, spSlice = im.GetSpacing()
asp = (1.0 * spSlice) / spPlane


# Set up figure
fSize = (
figSize,
figSize * (asp * AxSize + CorSize) / (1.0 * SagSize + CorSize),
)

fig, ((axAx, blank), (axCor, axSag)) = plt.subplots(
2,
2,
figsize=fSize,
gridspec_kw={"height_ratios": [(CorSize) / (asp * AxSize), 1], "width_ratios": [SagSize, CorSize]},
);
blank.axis("off")

# Get slices
sAx = returnSlice("z", cut[0])
sCor = returnSlice("y", cut[1])
sSag = returnSlice("x", cut[2])

# Display image data
imAx = axAx.imshow(
nda.__getitem__(sAx),
aspect=1.0,
interpolation=None,
cmap=cmap,
clim=(window[0], window[0] + window[1]),
)
imCor = axCor.imshow(
nda.__getitem__(sCor),
origin="lower",
aspect=asp,
interpolation=None,
cmap=cmap,
clim=(window[0], window[0] + window[1]),
)
imSag = axSag.imshow(
nda.__getitem__(sSag),
origin="lower",
aspect=asp,
interpolation=None,
cmap=cmap,
clim=(window[0], window[0] + window[1]),
)

# Display crosshairs
# Axial image, FAKE cut (just for label)
axAx.plot([0, 0], [0, 0], c='yellow', label=f'Axial slice: {cut[0]}')
# Axial image, coronal cut
axAx.plot([0, SagSize], [cut[1], cut[1]], c='r', label=f'Coronal slice: {cut[1]}')
# Axial image, sagittal cut
axAx.plot([cut[2], cut[2]], [0, CorSize], c='orange', label=f'Sagittal slice: {cut[2]}')

axAx.legend(loc='center left', bbox_to_anchor=(1.05,0.5))

# Sag image, ax cut
axSag.plot([0, CorSize], [cut[0], cut[0]], c='yellow')
# Sag image, cor cut
axSag.plot([cut[1], cut[1]], [0, AxSize], c='r')


# Cor image, ax cut
axCor.plot([0, SagSize], [cut[0], cut[0]], c='yellow')
# Cor image, sagittal cut
axCor.plot([cut[2], cut[2]], [0, AxSize], c='orange')


# Turn off axes
axAx.axis("off")
axCor.axis("off")
axSag.axis("off")

# Adjust spacing
fig.subplots_adjust(left=0, right=1, wspace=0.01, hspace=0.01, top=1, bottom=0)

# Save image
#fig.savefig(f'{fig_name}', dpi=300, transparent=True)

0 comments on commit a427e54

Please sign in to comment.