From 9124951681239e899619cf68edf65d266d8eadf1 Mon Sep 17 00:00:00 2001 From: rnfinnegan Date: Wed, 2 Jun 2021 17:11:37 +1000 Subject: [PATCH 1/2] updates --- platipy/imaging/generation/mask.py | 39 ++++++++++++++ platipy/imaging/label/comparison.py | 42 +++++++++++++-- platipy/imaging/projects/cardiac/run.py | 60 ++++++++++++++++++--- platipy/imaging/radiotherapy/dose.py | 39 ++++++++++++++ platipy/imaging/registration/linear.py | 45 +++++++++++----- platipy/imaging/registration/utils.py | 7 +++ platipy/imaging/visualisation/visualiser.py | 20 +++---- 7 files changed, 216 insertions(+), 36 deletions(-) create mode 100644 platipy/imaging/radiotherapy/dose.py diff --git a/platipy/imaging/generation/mask.py b/platipy/imaging/generation/mask.py index 4f3a5f29..68c68a5c 100644 --- a/platipy/imaging/generation/mask.py +++ b/platipy/imaging/generation/mask.py @@ -102,3 +102,42 @@ def get_external_mask( body_mask_hull.CopyInformation(body_mask) return body_mask_hull + + +def extend_mask(mask, direction=("ax", "sup"), extension_mm=10, interior_mm_shape=10): + + arr = sitk.GetArrayViewFromImage(mask) + vals = np.unique(arr[arr > 0]) + if len(vals) > 2: + # There is more than one value! We need to threshold (at the median) + cutoff = np.median(vals) + mask_binary = sitk.BinaryThreshold(mask, cutoff, np.max(vals).astype(float)) + else: + mask_binary = mask + + arr = sitk.GetArrayFromImage(mask_binary) + + if direction[0] == "ax": + inferior_slice = np.where(arr)[0].min() + superior_slice = np.where(arr)[0].max() + + n_slices_ext = int(extension_mm / mask.GetSpacing()[2]) + n_slices_est = int(interior_mm_shape / mask.GetSpacing()[2]) + + if direction[1] == "sup": + max_index = min([arr.shape[0], superior_slice + 1 + n_slices_ext]) + for s_in in range(superior_slice + 1 - n_slices_est, max_index): + arr[s_in, :, :] = np.max( + arr[superior_slice - n_slices_est : superior_slice, :, :], axis=0 + ) + if direction[1] == "inf": + min_index = max([arr.shape[0], inferior_slice - n_slices_ext + n_slices_est]) + for s_in in range(min_index, inferior_slice): + arr[s_in, :, :] = np.max( + arr[inferior_slice + n_slices_est : inferior_slice, :, :], axis=0 + ) + + mask_ext = sitk.GetImageFromArray(arr) + mask_ext.CopyInformation(mask) + + return mask_ext diff --git a/platipy/imaging/label/comparison.py b/platipy/imaging/label/comparison.py index 68c80186..70c0b5c2 100644 --- a/platipy/imaging/label/comparison.py +++ b/platipy/imaging/label/comparison.py @@ -28,6 +28,8 @@ import numpy as np import SimpleITK as sitk +from platipy.imaging.utils.crop import label_to_roi, crop_to_roi + def compute_surface_metrics(label_a, label_b, verbose=False): """Compute surface distance metrics between two labels. Surface metrics computed are: @@ -141,7 +143,7 @@ def compute_volume_metrics(label_a, label_b): return result -def compute_metric_dsc(label_a, label_b): +def compute_metric_dsc(label_a, label_b, auto_crop=True): """Compute the Dice Similarity Coefficient between two labels Args: @@ -151,13 +153,19 @@ def compute_metric_dsc(label_a, label_b): Returns: float: The Dice Similarity Coefficient """ + if auto_crop: + largest_region = (label_a + label_b) > 0 + crop_box_size, crop_box_index = label_to_roi(largest_region) + + label_a = crop_to_roi(label_a, size=crop_box_size, index=crop_box_index) + label_b = crop_to_roi(label_b, size=crop_box_size, index=crop_box_index) arr_a = sitk.GetArrayFromImage(label_a).astype(bool) arr_b = sitk.GetArrayFromImage(label_b).astype(bool) return 2 * ((arr_a & arr_b).sum()) / (arr_a.sum() + arr_b.sum()) -def compute_metric_specificity(label_a, label_b): +def compute_metric_specificity(label_a, label_b, auto_crop=True): """Compute the specificity between two labels Args: @@ -167,6 +175,12 @@ def compute_metric_specificity(label_a, label_b): Returns: float: The specificity between the two labels """ + if auto_crop: + largest_region = (label_a + label_b) > 0 + crop_box_size, crop_box_index = label_to_roi(largest_region) + + label_a = crop_to_roi(label_a, size=crop_box_size, index=crop_box_index) + label_b = crop_to_roi(label_b, size=crop_box_size, index=crop_box_index) arr_a = sitk.GetArrayFromImage(label_a).astype(bool) arr_b = sitk.GetArrayFromImage(label_b).astype(bool) @@ -180,7 +194,7 @@ def compute_metric_specificity(label_a, label_b): return float((1.0 * true_neg) / (true_neg + false_pos)) -def compute_metric_sensitivity(label_a, label_b): +def compute_metric_sensitivity(label_a, label_b, auto_crop=True): """Compute the sensitivity between two labels Args: @@ -190,6 +204,12 @@ def compute_metric_sensitivity(label_a, label_b): Returns: float: The sensitivity between the two labels """ + if auto_crop: + largest_region = (label_a + label_b) > 0 + crop_box_size, crop_box_index = label_to_roi(largest_region) + + label_a = crop_to_roi(label_a, size=crop_box_size, index=crop_box_index) + label_b = crop_to_roi(label_b, size=crop_box_size, index=crop_box_index) arr_a = sitk.GetArrayFromImage(label_a).astype(bool) arr_b = sitk.GetArrayFromImage(label_b).astype(bool) @@ -202,7 +222,7 @@ def compute_metric_sensitivity(label_a, label_b): return float((1.0 * true_pos) / (true_pos + false_neg)) -def compute_metric_masd(label_a, label_b): +def compute_metric_masd(label_a, label_b, auto_crop=True): """Compute the mean absolute distance between two labels Args: @@ -212,6 +232,12 @@ def compute_metric_masd(label_a, label_b): Returns: float: The mean absolute surface distance """ + if auto_crop: + largest_region = (label_a + label_b) > 0 + crop_box_size, crop_box_index = label_to_roi(largest_region) + + label_a = crop_to_roi(label_a, size=crop_box_size, index=crop_box_index) + label_b = crop_to_roi(label_b, size=crop_box_size, index=crop_box_index) mean_sd_list = [] num_points = [] @@ -231,7 +257,7 @@ def compute_metric_masd(label_a, label_b): return float(mean_surf_dist) -def compute_metric_hd(label_a, label_b): +def compute_metric_hd(label_a, label_b, auto_crop=True): """Compute the Hausdorff distance between two labels Args: @@ -241,6 +267,12 @@ def compute_metric_hd(label_a, label_b): Returns: float: The maximum Hausdorff distance """ + if auto_crop: + largest_region = (label_a + label_b) > 0 + crop_box_size, crop_box_index = label_to_roi(largest_region) + + label_a = crop_to_roi(label_a, size=crop_box_size, index=crop_box_index) + label_b = crop_to_roi(label_b, size=crop_box_size, index=crop_box_index) hausdorff_distance = sitk.HausdorffDistanceImageFilter() hausdorff_distance.Execute(label_a, label_b) diff --git a/platipy/imaging/projects/cardiac/run.py b/platipy/imaging/projects/cardiac/run.py index e6c93653..da254f06 100644 --- a/platipy/imaging/projects/cardiac/run.py +++ b/platipy/imaging/projects/cardiac/run.py @@ -40,6 +40,8 @@ from platipy.imaging.utils.crop import label_to_roi, crop_to_roi +from platipy.imaging.generation.mask import extend_mask + ATLAS_PATH = "/atlas" if "ATLAS_PATH" in os.environ: ATLAS_PATH = os.environ["ATLAS_PATH"] @@ -54,6 +56,7 @@ "crop_atlas_to_structures": False, "crop_atlas_expansion_mm": (10, 10, 10), "guide_structure_name": "WHOLEHEART", + "superior_extension": 30, }, "auto_crop_target_image_settings": { "expansion_mm": [2, 2, 2], @@ -231,7 +234,7 @@ def run_cardiac_segmentation(img, guide_structure=None, settings=CARDIAC_SETTING img_crop = crop_to_roi(img, crop_box_size, crop_box_index) guide_structure = crop_to_roi(guide_structure, crop_box_size, crop_box_index) - target_reg_structure = convert_mask_to_reg_structure(guide_structure) + target_reg_structure = convert_mask_to_reg_structure(guide_structure, expansion=0) else: quick_reg_settings = { @@ -303,7 +306,7 @@ def run_cardiac_segmentation(img, guide_structure=None, settings=CARDIAC_SETTING guide_structure_name = settings["atlas_settings"]["guide_structure_name"] target_reg_image = target_reg_structure atlas_reg_image = convert_mask_to_reg_structure( - atlas_set[atlas_id]["Original"][guide_structure_name] + atlas_set[atlas_id]["Original"][guide_structure_name], expansion=0 ) else: @@ -328,6 +331,21 @@ def run_cardiac_segmentation(img, guide_structure=None, settings=CARDIAC_SETTING interpolator=sitk.sitkNearestNeighbor, ) + expanded_atlas_guide_structure = extend_mask( + atlas_set[atlas_id]["Original"][guide_structure_name], + direction=("ax", "sup"), + extension_mm=settings["atlas_settings"]["superior_extension"], + interior_mm_shape=settings["atlas_settings"]["superior_extension"] / 2, + ) + + atlas_set[atlas_id]["RIR"][guide_structure_name + "EXPANDED"] = apply_transform( + input_image=expanded_atlas_guide_structure, + reference_image=img_crop, + transform=initial_tfm, + default_value=0, + interpolator=sitk.sitkNearestNeighbor, + ) + atlas_set[atlas_id]["RIR"]["CT Image"] = apply_transform( input_image=atlas_set[atlas_id]["Original"]["CT Image"], reference_image=img_crop, @@ -383,6 +401,14 @@ def run_cardiac_segmentation(img, guide_structure=None, settings=CARDIAC_SETTING interpolator=sitk.sitkLinear, ) + atlas_set[atlas_id]["DIR_STRUCT"][guide_structure_name + "EXPANDED"] = apply_transform( + input_image=atlas_set[atlas_id]["RIR"][guide_structure_name + "EXPANDED"], + reference_image=img_crop, + transform=struct_guided_tfm, + default_value=0, + interpolator=sitk.sitkNearestNeighbor, + ) + # sitk.WriteImage(deform_image, f"./DIR_STRUCT_{atlas_id}.nii.gz") for struct in atlas_structure_list: @@ -399,7 +425,7 @@ def run_cardiac_segmentation(img, guide_structure=None, settings=CARDIAC_SETTING # Settings deformable_registration_settings = settings["deformable_registration_settings"] - logger.info("Running DIR to register atlas images") + logger.info("Running DIR to refine atlas image registration") for atlas_id in atlas_id_list: @@ -408,13 +434,31 @@ def run_cardiac_segmentation(img, guide_structure=None, settings=CARDIAC_SETTING # Register the atlases atlas_set[atlas_id]["DIR"] = {} + atlas_reg_image = atlas_set[atlas_id]["DIR_STRUCT"]["CT Image"] + target_reg_image = img_crop + if guide_structure: - target_reg_image = target_reg_structure - atlas_reg_image = atlas_set[atlas_id]["DIR_STRUCT"]["Reg Mask"] + expanded_atlas_mask = atlas_set[atlas_id]["DIR_STRUCT"][ + guide_structure_name + "EXPANDED" + ] + expanded_target_mask = extend_mask( + guide_structure, + direction=("ax", "sup"), + extension_mm=settings["atlas_settings"]["superior_extension"], + interior_mm_shape=settings["atlas_settings"]["superior_extension"] / 2, + ) - else: - atlas_reg_image = atlas_set[atlas_id]["RIR"]["CT Image"] - target_reg_image = sitk.Mask(img_crop, atlas_reg_image > -1000, outsideValue=-1000) + combined_mask = sitk.Maximum(expanded_atlas_mask, expanded_target_mask) + + atlas_reg_image = sitk.Mask(atlas_reg_image, combined_mask, outsideValue=-1000) + atlas_reg_image = sitk.Mask( + atlas_reg_image, atlas_reg_image > -400, outsideValue=-1000 + ) + + target_reg_image = sitk.Mask(target_reg_image, combined_mask, outsideValue=-1000) + target_reg_image = sitk.Mask( + target_reg_image, atlas_reg_image > -400, outsideValue=-1000 + ) deform_image, dir_tfm, _ = fast_symmetric_forces_demons_registration( target_reg_image, diff --git a/platipy/imaging/radiotherapy/dose.py b/platipy/imaging/radiotherapy/dose.py new file mode 100644 index 00000000..d75dff45 --- /dev/null +++ b/platipy/imaging/radiotherapy/dose.py @@ -0,0 +1,39 @@ +# Copyright 2020 University of New South Wales, University of Sydney, Ingham Institute + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import SimpleITK as sitk + + +def calculate_dvh(dose_grid, label, bins=1001): + + if dose_grid.GetSize() != label.GetSize(): + print("Dose grid size does not match label, automatically resampling.") + dose_grid = sitk.Resample(dose_grid, label) + + dose_arr = sitk.GetArrayViewFromImage(dose_grid) + label_arr = sitk.GetArrayViewFromImage(label) + + dose_vals = dose_arr[np.where(label_arr)] + + counts, bin_edges = np.histogram(dose_vals, bins=bins) + + # Get mid-points of bins + dose_points = (bin_edges[1:] + bin_edges[:-1]) / 2.0 + + # Calculate the actual DVH values + counts = np.cumsum(counts[::-1])[::-1] + counts = counts / counts.max() + + return dose_points, counts \ No newline at end of file diff --git a/platipy/imaging/registration/linear.py b/platipy/imaging/registration/linear.py index 707be88d..2dcdb15b 100644 --- a/platipy/imaging/registration/linear.py +++ b/platipy/imaging/registration/linear.py @@ -162,22 +162,39 @@ def linear_registration( if fixed_structure: registration.SetMetricFixedMask(fixed_structure) - if reg_method.lower() == "translation": - registration.SetInitialTransform(sitk.TranslationTransform(3)) - elif reg_method.lower() == "similarity": - registration.SetInitialTransform(sitk.Similarity3DTransform()) - elif reg_method.lower() == "affine": - registration.SetInitialTransform(sitk.AffineTransform(3)) - elif reg_method.lower() == "rigid": - registration.SetInitialTransform(sitk.VersorRigid3DTransform()) - elif reg_method.lower() == "scaleversor": - registration.SetInitialTransform(sitk.ScaleVersor3DTransform()) - elif reg_method.lower() == "scaleskewversor": - registration.SetInitialTransform(sitk.ScaleSkewVersor3DTransform()) + if isinstance(reg_method, str): + if reg_method.lower() == "translation": + registration.SetInitialTransform(sitk.TranslationTransform(3)) + elif reg_method.lower() == "similarity": + registration.SetInitialTransform(sitk.Similarity3DTransform()) + elif reg_method.lower() == "affine": + registration.SetInitialTransform(sitk.AffineTransform(3)) + elif reg_method.lower() == "rigid": + registration.SetInitialTransform(sitk.VersorRigid3DTransform()) + elif reg_method.lower() == "scaleversor": + registration.SetInitialTransform(sitk.ScaleVersor3DTransform()) + elif reg_method.lower() == "scaleskewversor": + registration.SetInitialTransform(sitk.ScaleSkewVersor3DTransform()) + else: + raise ValueError( + "You have selected a registration method that does not exist.\n Please select from " + "Translation, Similarity, Affine, Rigid, ScaleVersor, ScaleSkewVersor" + ) + elif ( + isinstance(reg_method, sitk.CompositeTransform) + or isinstance(reg_method, sitk.Transform) + or isinstance(reg_method, sitk.TranslationTransform) + or isinstance(reg_method, sitk.Similarity3DTransform) + or isinstance(reg_method, sitk.AffineTransform) + or isinstance(reg_method, sitk.VersorRigid3DTransform) + or isinstance(reg_method, sitk.ScaleVersor3DTransform) + or isinstance(reg_method, sitk.ScaleSkewVersor3DTransform) + ): + registration.SetInitialTransform(reg_method) else: raise ValueError( - "You have selected a registration method that does not exist.\n Please select from " - "Translation, Similarity, Affine, Rigid, ScaleVersor, ScaleSkewVersor" + "'reg_method' must be either a string (see docs for acceptable registration names), " + "or a custom sitk.CompositeTransform." ) if optimiser.lower() == "lbfgsb": diff --git a/platipy/imaging/registration/utils.py b/platipy/imaging/registration/utils.py index dee6d6cb..58ef45d4 100644 --- a/platipy/imaging/registration/utils.py +++ b/platipy/imaging/registration/utils.py @@ -278,6 +278,13 @@ def convert_mask_to_distance_map(mask, squared_distance=False, normalise=False): Returns: [SimpleITK.Image]: The distance map as an image. """ + arr = sitk.GetArrayViewFromImage(mask) + vals = np.unique(arr[arr > 0]) + if len(vals) > 2: + # There is more than one value! We need to threshold at the median + cutoff = np.median(vals) + mask = sitk.BinaryThreshold(mask, cutoff, np.max(vals).astype(float)) + raw_map = sitk.SignedMaurerDistanceMap( mask, insideIsPositive=True, diff --git a/platipy/imaging/visualisation/visualiser.py b/platipy/imaging/visualisation/visualiser.py index a5dfdc55..78ae5817 100644 --- a/platipy/imaging/visualisation/visualiser.py +++ b/platipy/imaging/visualisation/visualiser.py @@ -280,10 +280,12 @@ def add_contour( name = "input" self.__show_legend = False - else: - self.__show_legend = True - - visualise_contour = VisualiseContour(contour, name, color=color, linewidth=linewidth) + visualise_contour = VisualiseContour( + contour, + name, + color=color, + linewidth=linewidth, + ) self.__contours.append(visualise_contour) else: @@ -1359,9 +1361,9 @@ def overlay_scalar_field(self): cax = self.__figure.add_axes( ( ax_box.x1 + 0.02 + (cbar_width + 0.1) * scalar_index, - ax_box.y0, + ax_box.y0 * 1.025, cbar_width, - ax_box.height, + ax_box.height - ax_box.y0 * 0.05, ) ) @@ -1548,14 +1550,14 @@ def overlay_vector_field(self): x_pos_legend = max_xpos + 0.025 else: - x_pos_legend = ax_box.x0 + 0.025 + x_pos_legend = ax_box.x1 + 0.025 cax = self.__figure.add_axes( ( x_pos_legend, - ax_box.y0, + ax_box.y0 * 1.025, cbar_width, - ax_box.height, + ax_box.height - ax_box.y0 * 0.05, ) ) From 218babfce332af0d413cf95694d099e19c063477 Mon Sep 17 00:00:00 2001 From: rnfinnegan Date: Tue, 8 Jun 2021 11:10:37 +1000 Subject: [PATCH 2/2] update --- .pylintrc | 595 ++++++++++++++++++ platipy/imaging/__init__.py | 2 +- .../{radiotherapy/dose.py => dose/dvh.py} | 13 +- platipy/imaging/generation/image.py | 69 +- platipy/imaging/generation/mask.py | 16 + platipy/imaging/label/fusion.py | 6 +- platipy/imaging/label/utils.py | 7 +- platipy/imaging/registration/linear.py | 25 +- platipy/imaging/utils/geometry.py | 75 +++ platipy/imaging/utils/valve.py | 235 ++++--- platipy/imaging/utils/vessel.py | 136 ++-- platipy/imaging/visualisation/animation.py | 58 +- 12 files changed, 1066 insertions(+), 171 deletions(-) create mode 100644 .pylintrc rename platipy/imaging/{radiotherapy/dose.py => dose/dvh.py} (74%) create mode 100644 platipy/imaging/utils/geometry.py diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 00000000..7f7aa386 --- /dev/null +++ b/.pylintrc @@ -0,0 +1,595 @@ +[MASTER] + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. +extension-pkg-whitelist= + +# Specify a score threshold to be exceeded before program exits with error. +fail-under=10.0 + +# Add files or directories to the blacklist. They should be base names, not +# paths. +ignore=CVS + +# Add files or directories matching the regex patterns to the blacklist. The +# regex matches against base names, not paths. +ignore-patterns= + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the +# number of processors available to use. +jobs=1 + +# Control the amount of potential inferred values when inferring a single +# object. This can help the performance when dealing with large functions or +# complex, nested conditions. +limit-inference-results=100 + +# List of plugins (as comma separated values of python module names) to load, +# usually to register additional checkers. +load-plugins=pylint_flask_sqlalchemy + +# Pickle collected data for later comparisons. +persistent=yes + +# When enabled, pylint would attempt to guess common misconfiguration and emit +# user-friendly hints instead of false-positive error messages. +suggestion-mode=yes + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED. +confidence= + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once). You can also use "--disable=all" to +# disable everything first and then reenable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use "--disable=all --enable=classes +# --disable=W". +disable=print-statement, + parameter-unpacking, + unpacking-in-except, + old-raise-syntax, + backtick, + long-suffix, + old-ne-operator, + old-octal-literal, + import-star-module-level, + non-ascii-bytes-literal, + raw-checker-failed, + bad-inline-option, + locally-disabled, + file-ignored, + suppressed-message, + useless-suppression, + deprecated-pragma, + use-symbolic-message-instead, + apply-builtin, + basestring-builtin, + buffer-builtin, + cmp-builtin, + coerce-builtin, + execfile-builtin, + file-builtin, + long-builtin, + raw_input-builtin, + reduce-builtin, + standarderror-builtin, + unicode-builtin, + xrange-builtin, + coerce-method, + delslice-method, + getslice-method, + setslice-method, + no-absolute-import, + old-division, + dict-iter-method, + dict-view-method, + next-method-called, + metaclass-assignment, + indexing-exception, + raising-string, + reload-builtin, + oct-method, + hex-method, + nonzero-method, + cmp-method, + input-builtin, + round-builtin, + intern-builtin, + unichr-builtin, + map-builtin-not-iterating, + zip-builtin-not-iterating, + range-builtin-not-iterating, + filter-builtin-not-iterating, + using-cmp-argument, + eq-without-hash, + div-method, + idiv-method, + rdiv-method, + exception-message-attribute, + invalid-str-codec, + sys-max-int, + bad-python3-import, + deprecated-string-function, + deprecated-str-translate-call, + deprecated-itertools-function, + deprecated-types-field, + next-method-defined, + dict-items-not-iterating, + dict-keys-not-iterating, + dict-values-not-iterating, + deprecated-operator-function, + deprecated-urllib-function, + xreadlines-attribute, + deprecated-sys-function, + exception-escape, + comprehension-escape, + C0330, + C0114, + W0102, + W0105 + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +enable=c-extension-no-member + + +[REPORTS] + +# Python expression which should return a score less than or equal to 10. You +# have access to the variables 'error', 'warning', 'refactor', and 'convention' +# which contain the number of messages in each category, as well as 'statement' +# which is the total number of statements analyzed. This score is used by the +# global evaluation report (RP0004). +evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details. +#msg-template= + +# Set the output format. Available formats are text, parseable, colorized, json +# and msvs (visual studio). You can also give a reporter class, e.g. +# mypackage.mymodule.MyReporterClass. +output-format=text + +# Tells whether to display a full report or only the messages. +reports=no + +# Activate the evaluation score. +score=yes + + +[REFACTORING] + +# Maximum number of nested blocks for function / method body +max-nested-blocks=5 + +# Complete name of functions that never returns. When checking for +# inconsistent-return-statements if a never returning function is called then +# it will be considered as an explicit return statement and no message will be +# printed. +never-returning-functions=sys.exit + + +[VARIABLES] + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid defining new builtins when possible. +additional-builtins= + +# Tells whether unused global variables should be treated as a violation. +allow-global-unused-variables=yes + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_, + _cb + +# A regular expression matching the name of dummy variables (i.e. expected to +# not be used). +dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ + +# Argument names that match this expression will be ignored. Default to name +# with leading underscore. +ignored-argument-names=_.*|^ignored_|^unused_ + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME, + XXX, + TODO + +# Regular expression of note tags to take in consideration. +#notes-rgx= + + +[SIMILARITIES] + +# Ignore comments when computing similarities. +ignore-comments=yes + +# Ignore docstrings when computing similarities. +ignore-docstrings=yes + +# Ignore imports when computing similarities. +ignore-imports=no + +# Minimum lines number of a similarity. +min-similarity-lines=4 + + +[STRING] + +# This flag controls whether inconsistent-quotes generates a warning when the +# character used as a quote delimiter is used inconsistently within a module. +check-quote-consistency=no + +# This flag controls whether the implicit-str-concat should generate a warning +# on implicit string concatenation in sequences defined over several lines. +check-str-concat-over-line-jumps=no + + +[SPELLING] + +# Limits count of emitted suggestions for spelling mistakes. +max-spelling-suggestions=4 + +# Spelling dictionary name. Available dictionaries: none. To make it work, +# install the python-enchant package. +spelling-dict= + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains the private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to the private dictionary (see the +# --spelling-private-dict-file option) instead of raising a message. +spelling-store-unknown-words=no + + +[FORMAT] + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^\s*(# )??$ + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + +# Maximum number of characters on a single line. +max-line-length=99 + +# Maximum number of lines in a module. +max-module-lines=1000 + +# Allow the body of a class to be on the same line as the declaration if body +# contains single statement. +single-line-class-stmt=no + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=no + + +[LOGGING] + +# The type of string formatting that logging methods do. `old` means using % +# formatting, `new` is for `{}` formatting. +logging-format-style=old + +# Logging modules to check that the string format arguments are in logging +# function parameter format. +logging-modules=logging + + +[BASIC] + +# Naming style matching correct argument names. +argument-naming-style=snake_case + +# Regular expression matching correct argument names. Overrides argument- +# naming-style. +argument-rgx=[a-z_][a-z0-9_]{0,40}$ + +# Naming style matching correct attribute names. +attr-naming-style=snake_case + +# Regular expression matching correct attribute names. Overrides attr-naming- +# style. +#attr-rgx= + +# Bad variable names which should always be refused, separated by a comma. +bad-names=foo, + bar, + baz, + toto, + tutu, + tata + +# Bad variable names regexes, separated by a comma. If names match any regex, +# they will always be refused +bad-names-rgxs= + +# Naming style matching correct class attribute names. +class-attribute-naming-style=any + +# Regular expression matching correct class attribute names. Overrides class- +# attribute-naming-style. +#class-attribute-rgx= + +# Naming style matching correct class names. +class-naming-style=PascalCase + +# Regular expression matching correct class names. Overrides class-naming- +# style. +#class-rgx= + +# Naming style matching correct constant names. +const-naming-style=UPPER_CASE + +# Regular expression matching correct constant names. Overrides const-naming- +# style. +#const-rgx= + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=-1 + +# Naming style matching correct function names. +function-naming-style=snake_case + +# Regular expression matching correct function names. Overrides function- +# naming-style. +#function-rgx= + +# Good variable names which should always be accepted, separated by a comma. +good-names=i, + j, + k, + ex, + Run, + _ + +# Good variable names regexes, separated by a comma. If names match any regex, +# they will always be accepted +good-names-rgxs= + +# Include a hint for the correct naming format with invalid-name. +include-naming-hint=no + +# Naming style matching correct inline iteration names. +inlinevar-naming-style=any + +# Regular expression matching correct inline iteration names. Overrides +# inlinevar-naming-style. +#inlinevar-rgx= + +# Naming style matching correct method names. +method-naming-style=snake_case + +# Regular expression matching correct method names. Overrides method-naming- +# style. +#method-rgx= + +# Naming style matching correct module names. +module-naming-style=snake_case + +# Regular expression matching correct module names. Overrides module-naming- +# style. +#module-rgx= + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=^_ + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +# These decorators are taken in consideration only for invalid-name. +property-classes=abc.abstractproperty + +# Naming style matching correct variable names. +variable-naming-style=snake_case + +# Regular expression matching correct variable names. Overrides variable- +# naming-style. +variable-rgx=[a-z_][a-z0-9_]{0,40}$ + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + +# Tells whether missing members accessed in mixin class should be ignored. A +# mixin class is detected if its name ends with "mixin" (case insensitive). +ignore-mixin-members=yes + +# Tells whether to warn about missing members when the owner of the attribute +# is inferred to be None. +ignore-none=yes + +# This flag controls whether pylint should warn about no-member and similar +# checks whenever an opaque object is returned when inferring. The inference +# can return multiple potential results while evaluating a Python object, but +# some branches might not be evaluated, which results in partial inference. In +# that case, it might be useful to still emit no-member and other checks for +# the rest of the inferred objects. +ignore-on-opaque-inference=yes + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis). It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules=vtk + +# Show a hint with possible names when a member name was not found. The aspect +# of finding the hint is based on edit distance. +missing-member-hint=yes + +# The minimum edit distance a name should have in order to be considered a +# similar match for a missing member name. +missing-member-hint-distance=1 + +# The total number of similar names that should be taken in consideration when +# showing a hint for a missing member. +missing-member-max-choices=1 + +# List of decorators that change the signature of a decorated function. +signature-mutators= + + +[DESIGN] + +# Maximum number of arguments for function / method. +max-args=5 + +# Maximum number of attributes for a class (see R0902). +max-attributes=7 + +# Maximum number of boolean expressions in an if statement (see R0916). +max-bool-expr=5 + +# Maximum number of branch for function / method body. +max-branches=12 + +# Maximum number of locals for function / method body. +max-locals=15 + +# Maximum number of parents for a class (see R0901). +max-parents=7 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + +# Maximum number of return / yield for function / method body. +max-returns=6 + +# Maximum number of statements in function / method body. +max-statements=50 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=2 + + +[CLASSES] + +# Warn about protected attribute access inside special methods +check-protected-access-in-special-methods=no + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp, + __post_init__ + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict, + _fields, + _replace, + _source, + _make + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=cls + + +[IMPORTS] + +# List of modules that can be imported at any level, not just the top level +# one. +allow-any-import-level= + +# Allow wildcard imports from modules that define __all__. +allow-wildcard-with-all=no + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + +# Deprecated modules which should not be used, separated by a comma. +deprecated-modules=optparse,tkinter.tix + +# Create a graph of external dependencies in the given file (report RP0402 must +# not be disabled). +ext-import-graph= + +# Create a graph of every (i.e. internal and external) dependencies in the +# given file (report RP0402 must not be disabled). +import-graph= + +# Create a graph of internal dependencies in the given file (report RP0402 must +# not be disabled). +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant + +# Couples of modules and preferred modules, separated by a comma. +preferred-modules= + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when being caught. Defaults to +# "BaseException, Exception". +overgeneral-exceptions=BaseException, + Exception diff --git a/platipy/imaging/__init__.py b/platipy/imaging/__init__.py index 320039ce..f1968b81 100644 --- a/platipy/imaging/__init__.py +++ b/platipy/imaging/__init__.py @@ -1 +1 @@ -from .visualisation.visualiser import ImageVisualiser \ No newline at end of file +from .visualisation.visualiser import ImageVisualiser diff --git a/platipy/imaging/radiotherapy/dose.py b/platipy/imaging/dose/dvh.py similarity index 74% rename from platipy/imaging/radiotherapy/dose.py rename to platipy/imaging/dose/dvh.py index d75dff45..14106b99 100644 --- a/platipy/imaging/radiotherapy/dose.py +++ b/platipy/imaging/dose/dvh.py @@ -17,6 +17,17 @@ def calculate_dvh(dose_grid, label, bins=1001): + """Calculates a dose-volume histogram + + Args: + dose_grid (SimpleITK.Image): The dose grid. + label (SimpleITK.Image): The (binary) label defining a structure. + bins (int | list | np.ndarray, optional): Passed to np.histogram, + can be an int (number of bins), or a list (specifying bin edges). Defaults to 1001. + + Returns: + SimpleITK.Image: [description] + """ if dose_grid.GetSize() != label.GetSize(): print("Dose grid size does not match label, automatically resampling.") @@ -36,4 +47,4 @@ def calculate_dvh(dose_grid, label, bins=1001): counts = np.cumsum(counts[::-1])[::-1] counts = counts / counts.max() - return dose_points, counts \ No newline at end of file + return dose_points, counts diff --git a/platipy/imaging/generation/image.py b/platipy/imaging/generation/image.py index 975729d1..6e0c98db 100644 --- a/platipy/imaging/generation/image.py +++ b/platipy/imaging/generation/image.py @@ -17,7 +17,7 @@ def insert_sphere(arr, sp_radius=4, sp_centre=(0, 0, 0)): - """Insert a sphere into the give array + """Insert a sphere into the given array Args: arr (np.array): Array in which to insert sphere @@ -48,14 +48,46 @@ def insert_sphere(arr, sp_radius=4, sp_centre=(0, 0, 0)): return arr_copy +def insert_cylinder(arr, cyl_radius=4, cyl_height=2, cyl_centre=(0, 0, 0)): + """ + Insert a cylinder into the given array. + The cylinder vertical extent is +/- 0.5 * height + + Args: + arr (np.ndarray): The array into which the cylinder is inserted + cyl_radius (int, optional): Cylinder radius. Defaults to 4. + cyl_height (int, optional): Cylinder height. Defaults to 2. + cyl_centre (tuple, optional): Cylinder centre. Defaults to (0, 0, 0). + + Returns: + np.ndarray: The original array with a cylinder (value 1) + """ + arr_copy = arr[:] + + x, y, z = np.indices(arr.shape) + + if not hasattr(cyl_radius, "__iter__"): + cyl_radius = [cyl_radius] * 2 + + condition_radial = ( + ((z - cyl_centre[0]) / cyl_radius[0]) ** 2 + ((y - cyl_centre[1]) / cyl_radius[1]) ** 2 + ) <= 1 + condition_height = np.abs((x - cyl_centre[2]) / (0.5 * cyl_height)) <= 1 + + arr_copy[condition_radial & condition_height] = 1 + + return arr_copy + + def insert_sphere_image(image, sp_radius, sp_centre): - """Insert a sphere into a blank image with the same size as image + """Insert a sphere into an image Args: image (sitk.Image): Image in which to insert sphere - sp_radius (int | list, optional): The radius of the sphere. Can also be defined as a vector. Defaults to 4. + sp_radius (int | list, optional): The radius of the sphere. + Can also be defined as a vector. Defaults to 4. sp_centre (tuple, optional): The position at which the sphere should be inserted. Defaults - to (0, 0, 0). + to (0, 0, 0). Returns: np.array: An array with the sphere inserted @@ -74,3 +106,32 @@ def insert_sphere_image(image, sp_radius, sp_centre): image_sphere.CopyInformation(image) return image_sphere + + +def insert_cylinder_image(image, cyl_radius=(5, 5), cyl_height=10, cyl_centre=(0, 0, 0)): + """Insert a cylinder into an image + + Args: + image (SimpleITK.Image): + cyl_radius (tuple, optional): Cylinder radius, can be defined as a single value + or a tuple (will generate an ellipsoid). Defaults to (5,5). + cyl_height (int, optional): Cylinder height. Defaults to 10. + cyl_centre (tuple, optional): Cylinder centre. Defaults to (0,0,0). + + Returns: + SimpleITK.Image: Image with cylinder inserted + """ + if not hasattr(cyl_radius, "__iter__"): + cyl_radius = [cyl_radius] * 2 + + cyl_radius_image = [i / j for i, j in zip(cyl_radius, image.GetSpacing()[1::-1])] + cyl_height_image = cyl_height / image.GetSpacing()[2] + + arr = sitk.GetArrayFromImage(image) + + arr = insert_cylinder(arr, cyl_radius_image, cyl_height_image, cyl_centre) + + image_cylinder = sitk.GetImageFromArray(arr) + image_cylinder.CopyInformation(image) + + return image_cylinder diff --git a/platipy/imaging/generation/mask.py b/platipy/imaging/generation/mask.py index 68c68a5c..ff4a4038 100644 --- a/platipy/imaging/generation/mask.py +++ b/platipy/imaging/generation/mask.py @@ -105,7 +105,23 @@ def get_external_mask( def extend_mask(mask, direction=("ax", "sup"), extension_mm=10, interior_mm_shape=10): + """ + Extends a binary label (mask) a number of slices. + PROTOTYPE! + Currently can only extend in axial directions (superior of inferior). + The shape of the extended part is based on some number of interior slices, as defined. + + Args: + mask (SimpleITK.Image): The input binary label (mask). + direction (tuple, optional): The direction as a tuple. First element is axis, second + element is direction. Defaults to ("ax", "sup"). + extension_mm (int, optional): The extension in millimeters. Defaults to 10. + interior_mm_shape (int, optional): The length on which to base the extension shape. + Defaults to 10. + Returns: + SimpleITK.Image: The output (extended mask). + """ arr = sitk.GetArrayViewFromImage(mask) vals = np.unique(arr[arr > 0]) if len(vals) > 2: diff --git a/platipy/imaging/label/fusion.py b/platipy/imaging/label/fusion.py index 355c2673..46f1be69 100644 --- a/platipy/imaging/label/fusion.py +++ b/platipy/imaging/label/fusion.py @@ -130,11 +130,9 @@ def combine_labels(atlas_set, structure_name, label="DIR", threshold=1e-4, smoot combined_label_dict = {} - for structure_name in structure_name_list: + for s_name in structure_name_list: # Find the cases which have the strucure (in case some cases do not) - valid_case_id_list = [ - i for i in case_id_list if structure_name in atlas_set[i][label].keys() - ] + valid_case_id_list = [i for i in case_id_list if s_name in atlas_set[i][label].keys()] # Get valid weight images weight_image_list = [ diff --git a/platipy/imaging/label/utils.py b/platipy/imaging/label/utils.py index 8eb6b413..676b9acb 100644 --- a/platipy/imaging/label/utils.py +++ b/platipy/imaging/label/utils.py @@ -3,7 +3,7 @@ from scipy.ndimage.measurements import center_of_mass -def get_com(label, real_coords=False): +def get_com(label, as_int=True, real_coords=False): """ Get centre of mass of a SimpleITK.Image """ @@ -11,10 +11,11 @@ def get_com(label, real_coords=False): com = center_of_mass(arr) if real_coords: - com = label.TransformContinuousIndexToPhysicalPoint(com) + com = label.TransformContinuousIndexToPhysicalPoint(com[::-1]) else: - com = [int(i) for i in com] + if as_int: + com = [int(i) for i in com] return com diff --git a/platipy/imaging/registration/linear.py b/platipy/imaging/registration/linear.py index 2dcdb15b..27857be9 100644 --- a/platipy/imaging/registration/linear.py +++ b/platipy/imaging/registration/linear.py @@ -177,18 +177,21 @@ def linear_registration( registration.SetInitialTransform(sitk.ScaleSkewVersor3DTransform()) else: raise ValueError( - "You have selected a registration method that does not exist.\n Please select from " - "Translation, Similarity, Affine, Rigid, ScaleVersor, ScaleSkewVersor" + "You have selected a registration method that does not exist.\n Please select from" + " Translation, Similarity, Affine, Rigid, ScaleVersor, ScaleSkewVersor" ) - elif ( - isinstance(reg_method, sitk.CompositeTransform) - or isinstance(reg_method, sitk.Transform) - or isinstance(reg_method, sitk.TranslationTransform) - or isinstance(reg_method, sitk.Similarity3DTransform) - or isinstance(reg_method, sitk.AffineTransform) - or isinstance(reg_method, sitk.VersorRigid3DTransform) - or isinstance(reg_method, sitk.ScaleVersor3DTransform) - or isinstance(reg_method, sitk.ScaleSkewVersor3DTransform) + elif isinstance( + reg_method, + ( + sitk.CompositeTransform, + sitk.Transform, + sitk.TranslationTransform, + sitk.Similarity3DTransform, + sitk.AffineTransform, + sitk.VersorRigid3DTransform, + sitk.ScaleVersor3DTransform, + sitk.ScaleSkewVersor3DTransform, + ), ): registration.SetInitialTransform(reg_method) else: diff --git a/platipy/imaging/utils/geometry.py b/platipy/imaging/utils/geometry.py new file mode 100644 index 00000000..93602421 --- /dev/null +++ b/platipy/imaging/utils/geometry.py @@ -0,0 +1,75 @@ +# Copyright 2020 University of New South Wales, University of Sydney, Ingham Institute + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import SimpleITK as sitk + + +def vector_angle(v1, v2): + """Return the angle between two vectors + + Args: + v1 ([np.array]): A three-dimensional vector + v2 ([np.array]): A three-dimensional vector + + Returns: + [float]: The angle in radians + """ + v1 = np.array(v1) + v2 = np.array(v2) + v1_norm = v1 / np.linalg.norm(v1) + v2_norm = v2 / np.linalg.norm(v2) + dot_product = np.dot(v1_norm, v2_norm) + angle = np.arccos(dot_product) + return angle + + +def rotate_image( + img, + rotation_centre=(0, 0, 0), + rotation_axis=(1, 0, 0), + rotation_angle_radians=0, + interpolation=sitk.sitkNearestNeighbor, + default_value=0, +): + """Rotates an image + + Args: + img (SimpleITK.Image): The image to rotate + rotation_centre (tuple, optional): The centre of rotation (in image coordinates). + Defaults to (0, 0, 0). + rotation_axis (tuple, optional): The axis of rotation. Defaults to (1, 0, 0). + rotation_angle_radians (float, optional): The angle of rotation. Defaults to 0. + interpolation (int, optional): Final interpolation. Defaults to sitk.sitkNearestNeighbor. + default_value (int, optional): Default value. Defaults to 0. + + Returns: + SimpleITK.Image: + """ + + # Define the transform, using predefined centre of rotation and given angle + rotation_transform = sitk.VersorRigid3DTransform() + rotation_transform.SetCenter(rotation_centre) + rotation_transform.SetRotation(rotation_axis, rotation_angle_radians) + + # Resample the image using the rotation transform + resampled_image = sitk.Resample( + img, + rotation_transform, + interpolation, + default_value, + img.GetPixelID(), + ) + + return resampled_image \ No newline at end of file diff --git a/platipy/imaging/utils/valve.py b/platipy/imaging/utils/valve.py index b6b713c7..9f9c01a6 100644 --- a/platipy/imaging/utils/valve.py +++ b/platipy/imaging/utils/valve.py @@ -18,105 +18,170 @@ from platipy.imaging.label.utils import get_com -from platipy.imaging.generation.image import insert_sphere_image - -""" -Generate valves -First example - meeting point of LA/LV (mitral valve) -""" - - -def generate_valves_from_chambers(img_chamber_1, img_chamber_2, radius_mm=10): - - # Find the mid-point of the chambers - com_1 = np.array(get_com(img_chamber_1)) - com_2 = np.array(get_com(img_chamber_2)) - - midpoint = com_1 + 0.5 * (com_2 - com_1) +from platipy.imaging.generation.image import insert_cylinder_image + +from platipy.imaging.utils.geometry import vector_angle, rotate_image + + +def generate_valve_from_great_vessel( + label_great_vessel, label_ventricle, initial_dilation=(4, 4, 5), final_erosion=(3, 3, 3) +): + """ + Generates a geometrically-defined valve. + This function is suitable for the pulmonic and aortic valves. + + Args: + label_great_vessel (SimpleITK.Image): The binary mask for the great vessel + (pulmonary artery or ascending aorta) + label_ventricle (SimpleITK.Image): The binary mask for the ventricle (left or right) + initial_dilation (tuple, optional): Initial dilation, larger values increase the valve + size. Defaults to (4, 4, 5). + final_erosion (tuple, optional): Final erosion, larger values decrease the valve size. + Defaults to (3, 3, 3). + + Returns: + SimpleITK.Image: The geometric valve, as a binary mask. + """ + # Dilate the great vessel and ventricle + label_great_vessel_dilate = sitk.BinaryDilate(label_great_vessel, initial_dilation) + label_ventricle_dilate = sitk.BinaryDilate(label_ventricle, initial_dilation) + + # Find the overlap (of these dilated volumes) + overlap = label_great_vessel_dilate & label_ventricle_dilate + + # Create a mask, first we calculate the union + dilation = 1 + union_vol = 0 + while union_vol <= 2000: + union = sitk.BinaryDilate(label_great_vessel, (dilation,) * 3) | sitk.BinaryDilate( + label_ventricle, (dilation,) * 3 + ) + union_vol = np.sum(sitk.GetArrayFromImage(union) * np.product(union.GetSpacing())) + dilation += 1 + mask = sitk.Mask(overlap, union) + + label_valve = sitk.BinaryMorphologicalClosing(mask) + label_valve = sitk.BinaryErode(label_valve, final_erosion) + + return label_valve + + +def generate_valve_using_cylinder( + label_atrium, + label_ventricle, + label_wh, + radius_mm=15, + height_mm=10, + shift_parameters=[ + [7.63383999e-01, -1.15883572e00, 2.12311297e00], + [4.21062525e-03, -3.95014189e-04, 1.13108043e-03], + ], +): + """ + Generates a geometrically-defined valve. + This function is suitable for the tricuspid and mitral valves. + + Note: the shift parameters have been determined empirically. + For the mitral valve, use the defaults. + For the tricuspid valve, use np.zeros((2,3)) + + Args: + label_atrium (SimpleITK.Image): The binary mask for the (left or right) atrium. + label_ventricle (SimpleITK.Image): The binary mask for the (left or right) ventricle. + label_wh (SimpleITK.Image): The binary mask for the whole heart. Used to scale the shift. + radius_mm (int, optional): The valve radius, in mm. Defaults to 15. + height_mm (int, optional): The valve height (i.e. perpendicular extend), in mm. + Defaults to 10. + shift_parameters (list, optional): + Shift parameters, which are the intercept (first row) and gradient (second row) + of a linear function that maps whole heart volume to 3D shift + (axial, coronal, sagittal). Set to zero to not use. + Defaults to + [ [7.63383999e-01, -1.15883572e00, 2.12311297e00], + [4.21062525e-03, -3.95014189e-04, 1.13108043e-03], ]. + + Returns: + SimpleITK.Image: The geometrically defined valve + """ # Define the overlap region (using binary dilation) # Increment overlap to make sure we have enough voxels dilation = 1 overlap_vol = 0 - while overlap_vol <= 2000: - overlap = sitk.BinaryDilate(img_chamber_1, (dilation,) * 3) & sitk.BinaryDilate( - img_chamber_2, (dilation,) * 3 + while overlap_vol <= 10000: + overlap = sitk.BinaryDilate(label_atrium, (dilation,) * 3) & sitk.BinaryDilate( + label_ventricle, (dilation,) * 3 ) overlap_vol = np.sum(sitk.GetArrayFromImage(overlap) * np.product(overlap.GetSpacing())) dilation += 1 - print(f"Sufficient overlap found at dilation = {dilation} [V = {overlap_vol/1000:.2f} cm^3]") + com_overlap = get_com(overlap, as_int=False) + + # Use empirical model to shift + wh_vol = sitk.GetArrayFromImage(label_wh).sum() * np.product(label_wh.GetSpacing()) / 1000 + shift = np.dot([1, wh_vol], shift_parameters) + com_overlap_shifted = np.array(com_overlap) - shift + + # Create a small expanded overlap region + overlap = sitk.BinaryDilate(label_atrium, (1,) * 3) & sitk.BinaryDilate( + label_ventricle, (1,) * 3 + ) - # Find the point in the overlap region closest to the mid-point + # Find the point in this small overlap region closest to the shifted location separation_vector_pixels = ( - np.stack(np.where(sitk.GetArrayFromImage(overlap))) - midpoint[:, None] + np.stack(np.where(sitk.GetArrayFromImage(overlap))) - com_overlap_shifted[:, None] ) ** 2 - spacing = np.array(img_chamber_1.GetSpacing()) + spacing = np.array(label_atrium.GetSpacing()) separation_vector_mm = separation_vector_pixels / spacing[:, None] separation_mm = np.sum(separation_vector_mm, axis=0) closest_overlap_point = np.argmin(separation_mm) - com_valve = np.stack(np.where(sitk.GetArrayFromImage(overlap)))[:, closest_overlap_point] - - # Define the valve as a sphere - auto_valve = insert_sphere_image(0 * overlap, sp_radius=radius_mm, sp_centre=com_valve) - - return auto_valve - - -def generate_valves_from_vessels(vessel, thickness_mm=4, erosion_mm=2): - - # Thickness can be defined by (inferior_thickness, superior_thickness), - # or just total_thickness - if hasattr(thickness_mm, "__iter__"): - thickness_inferior, thickness_superior = thickness_mm - else: - thickness_inferior, thickness_superior = thickness_mm * 0.5, thickness_mm * 0.5 - - # Get the most inferior slice - arr_vessel = sitk.GetArrayFromImage(vessel) - inferior_slice = np.where(arr_vessel)[0].min() - - # Get interior and superior limits - filled_superior_slice = np.ceil( - np.where(arr_vessel)[0].min() + (thickness_superior / vessel.GetSpacing()[2]) - ).astype(int) - filled_inferior_slice = np.floor( - np.where(arr_vessel)[0].min() - (thickness_inferior / vessel.GetSpacing()[2]) - ).astype(int) - - # Create vessel interior - vessel_interior = sitk.BinaryErode(vessel, (erosion_mm, erosion_mm, 0)) - arr = sitk.GetArrayFromImage(vessel_interior) - - # Erase upper slices - arr[filled_superior_slice:, :, :] = 0 - - # Copy down (using mirroring about inferior slice) - for s_in, s_out in zip( - range(filled_inferior_slice, inferior_slice), - range(inferior_slice, filled_superior_slice)[::-1], - ): - arr[s_in, :, :] = arr[s_out, :, :] - - # Define the valve - auto_valve = sitk.GetImageFromArray(arr) - auto_valve.CopyInformation(vessel) - - # Post-processing - # 1. Extend the actual vessel downwards (continued) - arr_vessel[:inferior_slice, :, :] = arr_vessel[inferior_slice, :, :] - continued_vessel = sitk.GetImageFromArray(arr_vessel) - continued_vessel.CopyInformation(vessel) - - # 2. Erode this continued vessel - continued_vessel = sitk.BinaryErode(continued_vessel, (erosion_mm, erosion_mm, 0)) - - # 3. Mask - auto_valve = sitk.Mask(auto_valve, continued_vessel) - - # 4. Fill small holes - auto_valve = sitk.BinaryMorphologicalClosing(auto_valve, (1, 1, 1)) - - return auto_valve \ No newline at end of file + # Now we can calculate the location of the valve + valve_loc = np.stack(np.where(sitk.GetArrayFromImage(overlap)))[:, closest_overlap_point] + valve_loc_real = label_ventricle.TransformContinuousIndexToPhysicalPoint( + valve_loc.tolist()[::-1] + ) + + # Now we create a cylinder with the user_defined parameters + cylinder = insert_cylinder_image(0 * label_ventricle, radius_mm, height_mm, valve_loc[::-1]) + + # Now we compute the first principal moment (long axis) of the larger chamber (2) + # f = sitk.LabelShapeStatisticsImageFilter() + # f.Execute(label_ventricle) + # orientation_vector = f.GetPrincipalAxes(1) + + # A more robust method is to use the COM offset from the chambers + # as a proxy for the long axis of the LV/RV + # orientation_vector = np.array(get_com(label_ventricle, real_coords=True)) - np.array( + # get_com(label_atrium, real_coords=True) + # ) + + # Another method is to compute the third principal moment of the overlap region + f = sitk.LabelShapeStatisticsImageFilter() + f.Execute(overlap) + orientation_vector = f.GetPrincipalAxes(1)[:3] + + # Get the rotation parameters + rotation_angle = vector_angle(orientation_vector, (0, 0, 1)) + rotation_axis = np.cross(orientation_vector, (0, 0, 1)) + + # Rotate the cylinder to define the valve + label_valve = rotate_image( + cylinder, + rotation_centre=valve_loc_real, + rotation_axis=rotation_axis, + rotation_angle_radians=rotation_angle, + interpolation=sitk.sitkNearestNeighbor, + default_value=0, + ) + + # Now we want to trim any parts of the valve too close to the edge of the chambers + # combined_chambers = sitk.BinaryDilate(label_atrium, (3,) * 3) | sitk.BinaryDilate( + # label_ventricle, (3,) * 3 + # ) + # combined_chambers = sitk.BinaryErode(combined_chambers, (6, 6, 6)) + + # label_valve = sitk.Mask(label_valve, combined_chambers) + + return label_valve diff --git a/platipy/imaging/utils/vessel.py b/platipy/imaging/utils/vessel.py index 9d271e5e..aa6c880b 100644 --- a/platipy/imaging/utils/vessel.py +++ b/platipy/imaging/utils/vessel.py @@ -176,14 +176,14 @@ def tube_from_com_list(com_list, radius): spline = vtk.vtkParametricSpline() spline.SetPoints(points) - functionSource = vtk.vtkParametricFunctionSource() - functionSource.SetParametricFunction(spline) - functionSource.SetUResolution(10 * points.GetNumberOfPoints()) - functionSource.Update() + function_source = vtk.vtkParametricFunctionSource() + function_source.SetParametricFunction(spline) + function_source.SetUResolution(10 * points.GetNumberOfPoints()) + function_source.Update() # Generate the radius scalars tube_radius = vtk.vtkDoubleArray() - n = functionSource.GetOutput().GetNumberOfPoints() + n = function_source.GetOutput().GetNumberOfPoints() tube_radius.SetNumberOfTuples(n) tube_radius.SetName("TubeRadius") for i in range(n): @@ -193,14 +193,14 @@ def tube_from_com_list(com_list, radius): tube_radius.SetTuple1(i, radius) # Add the scalars to the polydata - tubePolyData = vtk.vtkPolyData() - tubePolyData = functionSource.GetOutput() - tubePolyData.GetPointData().AddArray(tube_radius) - tubePolyData.GetPointData().SetActiveScalars("TubeRadius") + tube_poly_data = vtk.vtkPolyData() + tube_poly_data = function_source.GetOutput() + tube_poly_data.GetPointData().AddArray(tube_radius) + tube_poly_data.GetPointData().SetActiveScalars("TubeRadius") # Create the tubes tuber = vtk.vtkTubeFilter() - tuber.SetInputData(tubePolyData) + tuber.SetInputData(tube_poly_data) tuber.SetNumberOfSides(50) tuber.SetVaryRadiusToVaryRadiusByAbsoluteScalar() tuber.Update() @@ -280,18 +280,25 @@ def simpleitk_image_from_vtk_tube(tube, sitk_reference_image): imgstenc.Update() logger.debug("Generating SimpleITK image.") - finalImage = imgstenc.GetOutput() - finalArray = finalImage.GetPointData().GetScalars() - finalArray = vtk_to_numpy(finalArray).reshape(sitk_reference_image.GetSize()[::-1]) - logger.debug(f"Volume = {finalArray.sum()*sum(spacing):.3f} mm^3") - finalImageSITK = sitk.GetImageFromArray(finalArray) - finalImageSITK.CopyInformation(sitk_reference_image) + final_image = imgstenc.GetOutput() + final_array = final_image.GetPointData().GetScalars() + final_array = vtk_to_numpy(final_array).reshape(sitk_reference_image.GetSize()[::-1]) + logger.debug(f"Volume = {final_array.sum()*sum(spacing):.3f} mm^3") + final_image_sitk = sitk.GetImageFromArray(final_array) + final_image_sitk.CopyInformation(sitk_reference_image) - return finalImageSITK + return final_image_sitk def convert_simpleitk_to_vtk(img): - """""" + """Converts from SimpleITK to VTK representation + + Args: + img (SimpleITK.Image): The input image. + + Returns: + The VTK image. + """ size = list(img.GetSize()) origin = list(img.GetOrigin()) spacing = list(img.GetSpacing()) @@ -302,21 +309,21 @@ def convert_simpleitk_to_vtk(img): arr_string = arr.tostring() # send the numpy array to VTK with a vtkImageImport object - dataImporter = vtk.vtkImageImport() + data_importer = vtk.vtkImageImport() - dataImporter.CopyImportVoidPointer(arr_string, len(arr_string)) - dataImporter.SetDataScalarTypeToUnsignedChar() - dataImporter.SetNumberOfScalarComponents(ncomp) + data_importer.CopyImportVoidPointer(arr_string, len(arr_string)) + data_importer.SetDataScalarTypeToUnsignedChar() + data_importer.SetNumberOfScalarComponents(ncomp) # Set the new VTK image's parameters - dataImporter.SetDataExtent(0, size[0] - 1, 0, size[1] - 1, 0, size[2] - 1) - dataImporter.SetWholeExtent(0, size[0] - 1, 0, size[1] - 1, 0, size[2] - 1) - dataImporter.SetDataOrigin(origin) - dataImporter.SetDataSpacing(spacing) + data_importer.SetDataExtent(0, size[0] - 1, 0, size[1] - 1, 0, size[2] - 1) + data_importer.SetWholeExtent(0, size[0] - 1, 0, size[1] - 1, 0, size[2] - 1) + data_importer.SetDataOrigin(origin) + data_importer.SetDataSpacing(spacing) - dataImporter.Update() + data_importer.Update() - vtk_image = dataImporter.GetOutput() + vtk_image = data_importer.GetOutput() return vtk_image @@ -330,13 +337,48 @@ def vessel_spline_generation( scan_direction_dict, atlas_label="DIR", ): - """""" - splinedVessels = {} + """Generates a splined vessel from a list of binary masks. + + + Args: + reference_image (SimpleITK.Image): The reference image to copy information from. + atlas_set (dict): A dictionary conforming to the following format: + { atlas_id: + { + LABEL: + { + structure_a: SimpleITK.Image, + structure_b: SimpleITK.Image, + structure_c: SimpleITK.Image, + ... + } + } + ... + } + where LABEL should be passed as the atlas_label argument, and atlas_id is some label + for each atlas. + vessel_name_list (list): The list of vessels to generate splines for. + vessel_radius_mm_dict (list): A dictionary specifying the radius for each vessel. + stop_condition_type_dict (dict): A dictionary specifying the stopping condition for each + vessel. Available options are "count" - stopping at a particular number of atlases, or + "area" - stopping at a particular total area (in square pixels). Recommendation is to + use "count" with a stop_condition_value (see below) of ~1/3 the number of atlases. + stop_condition_value_dict (dict): A dictionary specifying the stopping value for each + vessel. + scan_direction_dict (dict): A dictionary specifying the direction to spline each vessel. + "x": sagittal, "y": coronal, "z": axial. + atlas_label (str, optional): The atlas label. Defaults to "DIR". + + Returns: + dict: The output dictionary, with keys as the vessel names and values as the binary labels + defining the splined vessels + """ + splined_vessels = {} if isinstance(vessel_name_list, str): vessel_name_list = [vessel_name_list] - for vesselName in vessel_name_list: + for vessel_name in vessel_name_list: # We must set the image direction to identity # This is because it is not possible to modify VTK Image directions @@ -344,35 +386,35 @@ def vessel_spline_generation( initial_image_direction = reference_image.GetDirection() - imageList = [atlas_set[i][atlas_label][vesselName] for i in atlas_set.keys()] - for im in imageList: + image_list = [atlas_set[i][atlas_label][vessel_name] for i in atlas_set.keys()] + for im in image_list: im.SetDirection((1, 0, 0, 0, 1, 0, 0, 0, 1)) - vesselRadius = vessel_radius_mm_dict[vesselName] - stopcondition_type = stop_condition_type_dict[vesselName] - stopcondition_value = stop_condition_value_dict[vesselName] - scan_direction = scan_direction_dict[vesselName] + vessel_radius = vessel_radius_mm_dict[vessel_name] + stop_condition_type = stop_condition_type_dict[vessel_name] + stop_condition_value = stop_condition_value_dict[vessel_name] + scan_direction = scan_direction_dict[vessel_name] - pointArray = com_from_image_list( - imageList, - condition_type=stopcondition_type, - condition_value=stopcondition_value, + point_array = com_from_image_list( + image_list, + condition_type=stop_condition_type, + condition_value=stop_condition_value, scan_direction=scan_direction, ) - tube = tube_from_com_list(pointArray, radius=vesselRadius) + tube = tube_from_com_list(point_array, radius=vessel_radius) - SITKReferenceImage = imageList[0] + reference_image = image_list[0] - vessel_delineation = simpleitk_image_from_vtk_tube(tube, SITKReferenceImage) + vessel_delineation = simpleitk_image_from_vtk_tube(tube, reference_image) vessel_delineation.SetDirection(initial_image_direction) - splinedVessels[vesselName] = vessel_delineation + splined_vessels[vessel_name] = vessel_delineation # We also have to reset the direction to whatever it was # This is because SimpleITK doesn't use deep copying # And it isn't necessary here as we can save some sweet, sweet memory - for im in imageList: + for im in image_list: im.SetDirection(initial_image_direction) - return splinedVessels \ No newline at end of file + return splined_vessels diff --git a/platipy/imaging/visualisation/animation.py b/platipy/imaging/visualisation/animation.py index f7e28ff7..4e7ebeca 100644 --- a/platipy/imaging/visualisation/animation.py +++ b/platipy/imaging/visualisation/animation.py @@ -12,18 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from matplotlib import rcParams - import pathlib import numpy as np import SimpleITK as sitk import matplotlib.pyplot as plt -import matplotlib.gridspec as gridspec import matplotlib.animation as animation -from platipy.imaging.visualisation.utils import project_onto_arbitrary_plane +from matplotlib import rcParams def generate_animation_from_image_sequence( @@ -33,15 +30,46 @@ def generate_animation_from_image_sequence( contour_list=False, scalar_list=False, figure_size_in=6, - image_cmap=plt.cm.Greys_r, - contour_cmap=plt.cm.jet, - scalar_cmap=plt.cm.magma, + image_cmap=plt.cm.get_cmap("Greys_r"), + contour_cmap=plt.cm.get_cmap("jet"), + scalar_cmap=plt.cm.get_cmap("magma"), image_window=[-1000, 800], scalar_min=False, scalar_max=False, scalar_alpha=0.5, image_origin="lower", ): + """Generates an animation from a list of images, with optional scalar overlay and contours. + + Args: + image_list (list (SimpleITK.Image)): A list of SimpleITK (2D) images. + output_file (str, optional): The name of the output file. Defaults to "animation.gif". + fps (int, optional): Frames per second. Defaults to 10. + contour_list (list (SimpleITK.Image), optional): A list of SimpleITK (2D) images + (overlay as scalar field). Defaults to False. + scalar_list (list (SimpleITK.Image), optional): A list of SimpleITK (2D) images + (overlay as contours). Defaults to False. + figure_size_in (int, optional): Size of the figure. Defaults to 6. + image_cmap (matplotlib.colors.ListedColormap, optional): Colormap to use for the image. + Defaults to plt.cm.get_cmap("Greys_r"). + contour_cmap (matplotlib.colors.ListedColormap, optional): Colormap to use for contours. + Defaults to plt.cm.get_cmap("jet"). + scalar_cmap (matplotlib.colors.ListedColormap, optional): Colormap to use for scalar field. + Defaults to plt.cm.get_cmap("magma"). + image_window (list, optional): Image intensity window (mininmum, range). + Defaults to [-1000, 800]. + scalar_min (bool, optional): Minimum scalar value to show. Defaults to False. + scalar_max (bool, optional): Maximum scalar value to show. Defaults to False. + scalar_alpha (float, optional): Alpha (transparency) for scalar field. Defaults to 0.5. + image_origin (str, optional): Image origin. Defaults to "lower". + + Raises: + RuntimeError: If ImageMagick isn't installed you cannot use this function! + ValueError: The list of images must be of type SimpleITK.Image + + Returns: + matplotlib.animation: The animation. + """ # We need to check for ImageMagick # There may be other tools that can be used @@ -51,8 +79,8 @@ def generate_animation_from_image_sequence( if not convert_path.exists(): raise RuntimeError("To use this function you need ImageMagick.") - if type(image_list[0]) is not sitk.Image: - raise ValueError("Each image must be a SimplITK image (sitk.Image).") + if not all(isinstance(i, sitk.Image) for i in image_list): + raise ValueError("Each image must be a SimpleITK image (sitk.Image).") # Get the image information x_size, y_size = image_list[0].GetSize() @@ -79,7 +107,7 @@ def generate_animation_from_image_sequence( # These can be given as a list of sitk.Image objects or a list of dicts {"name":sitk.Image} if contour_list is not False: - if type(contour_list[0]) is not dict: + if not isinstance(contour_list[0], dict): plot_dict = {"_": contour_list[0]} contour_labels = False else: @@ -143,7 +171,7 @@ def animate(i): except ValueError: pass - if type(contour_list[i]) is not dict: + if not isinstance(contour_list[i], dict): plot_dict = {"_": contour_list[i]} else: plot_dict = contour_list[i] @@ -152,7 +180,7 @@ def animate(i): for index, contour in enumerate(plot_dict.values()): - display_contours = ax.contour( + ax.contour( sitk.GetArrayFromImage(contour), colors=[color_map[index]], levels=[0], @@ -166,7 +194,7 @@ def animate(i): return (display_image,) # create animation using the animate() function with no repeat - myAnimation = animation.FuncAnimation( + my_animation = animation.FuncAnimation( fig, animate, frames=np.arange(0, len(image_list), 1), @@ -176,6 +204,6 @@ def animate(i): ) # save animation at 30 frames per second - myAnimation.save(output_file, writer="imagemagick", fps=fps) + my_animation.save(output_file, writer="imagemagick", fps=fps) - return myAnimation \ No newline at end of file + return my_animation