From 9fa4222d89168c24b4c8262c2361c53a4bf7953d Mon Sep 17 00:00:00 2001 From: Richard Christie Date: Mon, 5 Aug 2024 17:35:03 +1200 Subject: [PATCH 01/13] Initial code and test --- .gitignore | 2 +- pyproject.toml | 27 + src/segmentationstitcher/__init__.py | 0 src/segmentationstitcher/annotation.py | 165 +++++ src/segmentationstitcher/segment.py | 86 +++ src/segmentationstitcher/stitcher.py | 118 ++++ tests/resources/vagus-segment1.exf | 810 +++++++++++++++++++++++++ tests/resources/vagus-segment2.exf | 714 ++++++++++++++++++++++ tests/resources/vagus-segment3.exf | 679 +++++++++++++++++++++ tests/test_vagus.py | 70 +++ tests/testutils.py | 8 + 11 files changed, 2678 insertions(+), 1 deletion(-) create mode 100644 pyproject.toml create mode 100644 src/segmentationstitcher/__init__.py create mode 100644 src/segmentationstitcher/annotation.py create mode 100644 src/segmentationstitcher/segment.py create mode 100644 src/segmentationstitcher/stitcher.py create mode 100644 tests/resources/vagus-segment1.exf create mode 100644 tests/resources/vagus-segment2.exf create mode 100644 tests/resources/vagus-segment3.exf create mode 100644 tests/test_vagus.py create mode 100644 tests/testutils.py diff --git a/.gitignore b/.gitignore index 82f9275..7b6caf3 100644 --- a/.gitignore +++ b/.gitignore @@ -159,4 +159,4 @@ cython_debug/ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ +.idea/ diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..46209b4 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,27 @@ +[build-system] +requires = ["setuptools>=61.0", "setuptools_scm>=8.0"] +build-backend = "setuptools.build_meta" +[tool.setuptools-git-versioning] +enabled = true +[project] +name = "segmentation_stitcher" +dynamic = ["version"] +keywords = ["Medical", "Image", "Segmentation", "Merge", "SPARC"] +readme = "README.md" +license = {file = "LICENSE"} +authors = [ + { name="Richard Christie", email="r.christie@auckland.ac.nz" }, +] +dependencies = [ + "cmlibs.maths>=0.6.2", + "cmlibs.utils>=0.9", + "cmlibs.zinc>=4.1" +] +description = "Utility for stitching segmentations of networks and other features from multiple adjacent blocks" +requires-python = ">=3.7" +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache License", + "Operating System :: OS Independent", +] +[tool.setuptools_scm] diff --git a/src/segmentationstitcher/__init__.py b/src/segmentationstitcher/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/segmentationstitcher/annotation.py b/src/segmentationstitcher/annotation.py new file mode 100644 index 0000000..e62d78f --- /dev/null +++ b/src/segmentationstitcher/annotation.py @@ -0,0 +1,165 @@ +""" +Utility functions and classes for annotations and how they are used by segmentation stitcher. +""" +from enum import Enum +from cmlibs.utils.zinc.field import get_group_list +from cmlibs.utils.zinc.group import group_get_highest_dimension, groups_have_same_local_contents +from cmlibs.zinc.field import Field + + +class AnnotationCategory(Enum): + """ + How to process segmentations with this annotation. + """ + CONNECTED_SIMPLE_NETWORK = 1 # a simple connected network graph + CONNECTED_COMPLEX_NETWORK = 2 # a complex network of connected parallel sections e.g. fascicles. + UNCONNECTED_GENERAL = 3 # contours and other segmentations which are not connected but are included in output + EXCLUDE = 4 # segmentations to exclude from the output + + +class Annotation: + """ + A record of an annotation name/term and how it is used by the stitcher. + """ + + def __init__(self, name: str, term, dimension, category: AnnotationCategory): + """ + :param name: Unique name of annotation for feature. + :param term: Unique string term (e.g. URL) identifying feature in standard term set, or None if unknown. + :param dimension: Dimension of annotation from 0 to 3, but realistically only 0 or 1. + :param category: How to process segmentations with this annotation. + """ + assert 0 <= dimension <= 3 + self._name = name + self._term = term + self._dimension = dimension + self._category = category + + def decode_settings(self, settings_in: dict): + """ + Update segment settings from JSON dict containing serialised settings. + :param settings_in: Dictionary of settings as produced by encode_settings(). + """ + assert (settings_in.get("name") == self._name) and (settings_in.get("term") == self._term) + settings_dimension = settings_in.get("dimension") + if settings_dimension != self._dimension: + print("WARNING: Segmentation Stitcher. Annotation with name", self._name, "term", self._term, + "was dimension ", settings_dimension, "in settings, is now ", self._dimension, + ". Have input files changed?") + settings_in["dimension"] = self._dimension + # update current settings to gain new ones and override old ones + settings = self.encode_settings() + settings.update(settings_in) + self._category = AnnotationCategory[settings["category"]] + + def encode_settings(self) -> dict: + """ + Encode segment data in a dictionary to serialize. + :return: Settings in a dict ready for passing to json.dump. + """ + settings = { + "category": self._category.name, + "dimension": self._dimension, + "name": self._name, + "term": self._term + } + return settings + + def get_category(self): + return self._category + + def set_category(self, category): + self._category = category + + def get_dimension(self): + return self._dimension + + def get_name(self): + return self._name + + def get_term(self): + return self._term + + def set_term(self, term): + """ + Set the term for this annotation; must currently be None. + :param term: New term string e.g. URL + """ + assert self._term is None + self._term = term + + +def region_get_annotations(region, simple_network_keywords, complex_network_keywords, term_keyword="http"): + """ + Get annotation group names and terms from region's non-empty groups. + Groups with names consisting only of numbers are ignored as we're needlessly getting these for part contours. + After sorting for simple and complex networks and terms, remaining annotations are marked as general unconnected. + :param region: Zinc region to analyse groups in. + :param simple_network_keywords: Annotation names containing any of these keywords are marked as simple networks. + Must use lower case-folded. Comparison is case-insensitive. + :param complex_network_keywords: Annotation names containing any of these keywords are marked as complex networks. + Must use lower case-folded. Comparison is case-insensitive. + :param term_keyword: Groups with names containing this keyword are matched to other groups with the same content, + and supply the term name for them instead of making another Annotation. If no matching group is supplied these + are used as names and terms. + :return: list of Annotation. + """ + fieldmodule = region.getFieldmodule() + groups = get_group_list(fieldmodule) + annotations = [] + term_annotations = [] + datapoints = fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_DATAPOINTS) + for group in groups: + # clean up name to remove case and leading/trailing whitespace + name = group.getName().strip() + lower_name = name.casefold() + dimension = group_get_highest_dimension(group) + if dimension < 0: + data_group = group.getNodesetGroup(datapoints) + if data_group.isValid() and (data_group.getSize() > 0): + dimension = 0 + else: + continue # empty group + if lower_name.isdigit(): + continue # ignore as these can never be valid annotation names + category = AnnotationCategory.UNCONNECTED_GENERAL + for keyword in simple_network_keywords: + if keyword in lower_name: + category = AnnotationCategory.CONNECTED_SIMPLE_NETWORK + break + else: + for keyword in complex_network_keywords: + if keyword in lower_name: + category = AnnotationCategory.CONNECTED_COMPLEX_NETWORK + break + annotation = Annotation(name, None, dimension, category) + if term_keyword in lower_name: + term_annotations.append(annotation) + else: + annotations.append(annotation) + for term_annotation in term_annotations: + term = term_annotation.get_name() + term_group = fieldmodule.findFieldByName(term).castGroup() + dimension = term_annotation.get_dimension() + for annotation in annotations: + if annotation.get_term() is not None: + continue + if annotation.get_dimension() != dimension: + continue + name = annotation.get_name() + name_group = fieldmodule.findFieldByName(name).castGroup() + if groups_have_same_local_contents(name_group, term_group): + annotation.set_term(term) + break + else: + print("WARNING: Segmentation Stitcher. Did not find matching annotation name for term", term, + ". Adding separate annotation.") + term_annotation.set_term(term) + index = 0 + for annotation in annotations: + name = annotation.get_name() + if term < name: + break + index += 1 + annotations.insert(index, term_annotation) + return annotations diff --git a/src/segmentationstitcher/segment.py b/src/segmentationstitcher/segment.py new file mode 100644 index 0000000..86c74e0 --- /dev/null +++ b/src/segmentationstitcher/segment.py @@ -0,0 +1,86 @@ +""" +A segment of the segmentation data, generally from a separate image block. +""" +from cmlibs.zinc.result import RESULT_OK + + +class Segment: + """ + A segment of the segmentation data, generally from a separate image block. + """ + + def __init__(self, name, segmentation_file_name, root_region): + """ + :param name: Unique name of segment, usually derived from the file name. + :param segmentation_file_name: Path and file name of raw segmentation file, in Zinc format. + :param root_region: Zinc root region to create segment region under. + """ + self._name = name + self._segmentationFileName = segmentation_file_name + # print("Create segment", self._name) + self._base_region = root_region.createChild(self._name) + assert self._base_region.isValid(), \ + "Cannot create segment region " + self._name + ". Name may already be in use?" + # the raw region contains the original segment data which is not modified apart from building + # groups to categorise data for stitching a visualisation, including selecting for display. + self._raw_region = self._base_region.createChild("raw") + result = self._raw_region.readFile(segmentation_file_name) + assert result == RESULT_OK, \ + "Could not read segmentation file " + segmentation_file_name + self._rotation = [0.0, 0.0, 0.0] + self._translation = [0.0, 0.0, 0.0] + + def decode_settings(self, settings_in: dict): + """ + Update segment settings from JSON dict containing serialised settings. + :param settings_in: Dictionary of settings as produced by encode_settings(). + """ + assert settings_in.get("name") == self._name + # update current settings to gain new ones and override old ones + settings = self.encode_settings() + settings.update(settings_in) + self._rotation = settings["rotation"] + self._translation = settings["translation"] + + def encode_settings(self) -> dict: + """ + Encode segment data in a dictionary to serialize. + :return: Settings in a dict ready for passing to json.dump. + """ + settings = { + "name": self._name, + "rotation": self._rotation, + "translation": self._translation + } + return settings + + def get_base_region(self): + """ + Get the base region for all segmentation and auxiliary data for this segment. + :return: Zinc Region. + """ + return self._base_region + + def get_name(self): + return self._name + + def get_raw_region(self): + """ + Get the raw region, a child of base region, into which the raw segmentation was loaded. + :return: Zinc Region. + """ + return self._raw_region + + def get_rotation(self): + return self._rotation + + def set_rotation(self, rotation): + assert len(rotation) == 3 + self._rotation = rotation + + def get_translation(self): + return self._translation + + def set_translation(self, translation): + assert len(translation) == 3 + self._translation = translation diff --git a/src/segmentationstitcher/stitcher.py b/src/segmentationstitcher/stitcher.py new file mode 100644 index 0000000..89458d2 --- /dev/null +++ b/src/segmentationstitcher/stitcher.py @@ -0,0 +1,118 @@ +""" +Interface for stitching segmentation data from and calculating transformations between adjacent image blocks. +""" +from cmlibs.zinc.context import Context +from segmentationstitcher.segment import Segment +from segmentationstitcher.annotation import region_get_annotations +from pathlib import Path + + +class Stitcher: + """ + Interface for stitching segmentation data from and calculating transformations between adjacent image blocks. + """ + + def __init__(self, segmentation_file_names: list): + """ + :param segmentation_file_names: List of filenames containing raw segmentations in Zinc format. + """ + self._context = Context("Scaffoldfitter") + self._root_region = self._context.getDefaultRegion() + self._annotations = [] + self._segments = [] + self._version = 1 # increment when new settings added to migrate older serialised settings + for segmentation_file_name in segmentation_file_names: + name = Path(segmentation_file_name).stem + segment = Segment(name, segmentation_file_name, self._root_region) + self._segments.append(segment) + segment_annotations = region_get_annotations( + segment.get_raw_region(), simple_network_keywords=["vagus", "nerve", "trunk", "branch"], + complex_network_keywords=["fascicle"], term_keyword="http") + for segment_annotation in segment_annotations: + name = segment_annotation.get_name() + term = segment_annotation.get_term() + index = 0 + for annotation in self._annotations: + if (annotation.get_name() == name) and (annotation.get_term() == term): + # print("Found annotation name", name, "term", term) + break # exists already + if name > annotation.get_name(): + index += 1 + else: + # print("Add annoation name", name, "term", term, "dim", segment_annotation.get_dimension(), + # "category", segment_annotation.get_category()) + self._annotations.insert(index, segment_annotation) + + def decode_settings(self, settings_in: dict): + """ + Update stitcher settings from dictionary of serialised settings. + :param settings_in: Dictionary of settings as produced by encode_settings(). + """ + assert settings_in.get("annotations") and settings_in.get("segments") and settings_in.get("version"), \ + "Stitcher.decode_settings: Invalid settings dictionary" + # settings_version = settings_in["version"] + + # update annotations and warn about differences + processed_count = 0 + for annotation_settings in settings_in["annotations"]: + name = annotation_settings["name"] + term = annotation_settings["term"] + for annotation in self._annotations: + if (annotation.get_name() == name) and (annotation.get_term() == term): + annotation.decode_settings(annotation_settings) + processed_count += 1 + break + else: + print("WARNING: Segmentation Stitcher. Annotation with name", name, "term", term, + "in settings not found; ignoring. Have input files changed?") + if processed_count != len(self._annotations): + for annotation in self._annotations: + name = annotation.get_name() + term = annotation.get_term() + for annotation_settings in settings_in["annotations"]: + if (annotation_settings["name"] == name) and (annotation_settings["term"] == term): + break + else: + print("WARNING: Segmentation Stitcher. Annotation with name", name, "term", term, + "not found in settings; using defaults. Have input files changed?") + + # update segment settings and warn about differences + processed_count = 0 + for segment_settings in settings_in["segments"]: + name = segment_settings["name"] + for segment in self._segments: + if segment.get_name() == name: + segment.decode_settings(segment_settings) + processed_count += 1 + break + else: + print("WARNING: Segmentation Stitcher. Segment with name", name, + "in settings not found; ignoring. Have input files changed?") + if processed_count != len(self._segments): + for segment in self._segments: + name = segment.get_name() + for segment_settings in settings_in["segments"]: + if segment_settings["name"] == name: + break + else: + print("WARNING: Segmentation Stitcher. Segment with name", name, + "not found in settings; using defaults. Have input files changed?") + + def encode_settings(self) -> dict: + """ + :return: Dictionary of Stitcher settings ready to serialise to JSON. + """ + settings = { + "annotations": [annotation.encode_settings() for annotation in self._annotations], + "segments": [segment.encode_settings() for segment in self._segments], + "version": self._version + } + return settings + + def get_annotations(self): + return self._annotations + def get_segments(self): + return self._segments + + def get_version(self): + return self._version diff --git a/tests/resources/vagus-segment1.exf b/tests/resources/vagus-segment1.exf new file mode 100644 index 0000000..ec685f3 --- /dev/null +++ b/tests/resources/vagus-segment1.exf @@ -0,0 +1,810 @@ +EX Version: 3 +Region: / +!#nodeset nodes +Define node template: node1 +Shape. Dimension=0 +#Fields=2 +1) coordinates, coordinate, rectangular cartesian, real, #Components=3 + x. #Values=1 (value) + y. #Values=1 (value) + z. #Values=1 (value) +2) radius, field, rectangular cartesian, real, #Components=1 + 1. #Values=1 (value) +Node template: node1 +Node: 1 + 4.674543239403558e-02 + 4.994332584975342e-03 + -5.142379780295361e-02 + 5.000000000000000e-01 +Node: 2 + 5.088931281898371e-01 + 2.394788272830661e-04 + -1.897722788108855e-02 + 5.000000000000000e-01 +Node: 3 + 1.003531283528280e+00 + -4.884390570605898e-03 + -1.262024506493469e-02 + 5.000000000000000e-01 +Node: 4 + 1.503628414897787e+00 + 2.497113185983284e-02 + -4.156146251267568e-03 + 5.000000000000000e-01 +Node: 5 + 2.003179127455374e+00 + 1.606427200643680e-02 + -5.039788157846552e-03 + 5.000000000000000e-01 +Node: 6 + 2.499311138160670e+00 + -7.125707585743861e-03 + -1.116235969760458e-03 + 5.000000000000000e-01 +Node: 7 + 3.000002538469982e+00 + -2.978453387709680e-02 + -2.446427371756252e-03 + 5.000000000000000e-01 +Node: 8 + 3.499002669458522e+00 + -2.439031312243860e-02 + -1.014279179715817e-02 + 5.000000000000000e-01 +Node: 9 + 3.999419901743434e+00 + -2.705334490213033e-02 + -1.369497030490745e-02 + 5.000000000000000e-01 +Node: 10 + 4.497416430576471e+00 + -2.427646842993988e-02 + -2.703030189629844e-03 + 5.000000000000000e-01 +Node: 11 + 4.940077726837480e+00 + -5.876504996014503e-02 + -6.088281310333100e-03 + 5.000000000000000e-01 +Node: 12 + 9.418003533634166e-02 + -8.578875146990687e-02 + -3.185893809248219e-01 + 1.500000000000000e-01 +Node: 13 + 4.938666237286751e-01 + -3.054728676869334e-02 + -3.130429177043437e-01 + 1.500000000000000e-01 +Node: 14 + 1.003533501758645e+00 + -4.710575200639753e-02 + -3.106450048505804e-01 + 1.500000000000000e-01 +Node: 15 + 1.512394533728358e+00 + -5.584159532468388e-02 + -2.889348904933468e-01 + 1.500000000000000e-01 +Node: 16 + 2.032698077351535e+00 + -7.955598481357226e-02 + -3.151666726418236e-01 + 1.500000000000000e-01 +Node: 17 + 2.508613712658257e+00 + -8.963864883910685e-02 + -2.595148799197792e-01 + 1.500000000000000e-01 +Node: 18 + 3.012595554470824e+00 + -1.048489300589250e-01 + -2.473566078657696e-01 + 1.500000000000000e-01 +Node: 19 + 3.520700494134624e+00 + -1.225098957367806e-01 + -2.455813513836385e-01 + 1.500000000000000e-01 +Node: 20 + 4.011530096692854e+00 + -1.469321187202401e-01 + -2.411968874977056e-01 + 1.500000000000000e-01 +Node: 21 + 4.509767635959474e+00 + -1.643518508242892e-01 + -2.280779222266621e-01 + 1.500000000000000e-01 +Node: 22 + 4.924764596510804e+00 + -2.312486286387353e-01 + -2.986356575759058e-01 + 1.500000000000000e-01 +Node: 23 + 6.246044805234810e-02 + 1.164964549647818e-01 + 1.407224543876318e-01 + 2.000000000000000e-01 +Node: 24 + 5.006627992283177e-01 + 2.053495724385865e-01 + 1.756791035944522e-01 + 2.000000000000000e-01 +Node: 25 + 1.000473561320181e+00 + 1.909670673937150e-01 + 1.569127764129110e-01 + 2.000000000000000e-01 +Node: 26 + 1.501449834402207e+00 + 1.894206225157190e-01 + 1.486693471137240e-01 + 2.000000000000000e-01 +Node: 27 + 2.000000000000000e+00 + 1.600000000000000e-01 + 1.200000000000000e-01 + 2.000000000000000e-01 +Node: 28 + 2.500000000000000e+00 + 1.500000000000000e-01 + 1.000000000000000e-01 + 2.000000000000000e-01 +Node: 29 + 1.303413400633734e+00 + 2.071697993638444e-02 + -1.342074560598966e-01 + 1.200000000000000e-01 +Node: 30 + 1.690907307994362e+00 + 1.001223235571003e-01 + -2.752660238162596e-02 + 1.200000000000000e-01 +Node: 31 + 3.001313976384459e+00 + 2.173809005675096e-01 + 1.705556344738394e-01 + 1.300000000000000e-01 +Node: 32 + 3.504672666453772e+00 + 2.401065072256187e-01 + 1.884807226825300e-01 + 1.300000000000000e-01 +Node: 33 + 4.005616634652601e+00 + 2.339260325218280e-01 + 1.951022482711223e-01 + 1.300000000000000e-01 +Node: 34 + 4.508361305126138e+00 + 2.025510174075223e-01 + 1.800263313495004e-01 + 1.300000000000000e-01 +Node: 35 + 4.937256268871565e+00 + 1.444495277965882e-01 + 1.370343813914974e-01 + 1.300000000000000e-01 +Node: 36 + 2.981007759099746e+00 + -2.915973609955198e-02 + 1.696026255753690e-01 + 1.300000000000000e-01 +Node: 37 + 3.478622097359870e+00 + -2.105533820451407e-02 + 2.052041759309985e-01 + 1.300000000000000e-01 +Node: 38 + 3.979086863024901e+00 + -3.753036612250779e-02 + 2.341611182511475e-01 + 1.300000000000000e-01 +Node: 39 + 4.505253519098527e+00 + -3.939330532523967e-02 + 2.685470596760576e-01 + 1.300000000000000e-01 +Node: 40 + 4.884986427639625e+00 + -1.494304760120796e-01 + 2.739756414649969e-01 + 1.300000000000000e-01 +Node: 41 + 3.320345754623057e+00 + 2.236359081045942e-01 + -2.765344840527691e-01 + 3.000000000000000e-01 +Node: 42 + 3.547055958954830e+00 + 4.793765104993831e-01 + -3.053184316961680e-01 + 3.000000000000000e-01 +Node: 43 + 3.730776105401425e+00 + 7.398628684136312e-01 + -3.161880636510943e-01 + 3.000000000000000e-01 +Node: 44 + 3.971784269855830e+00 + 1.068212107446483e+00 + -3.838705026223965e-01 + 3.000000000000000e-01 +Node: 45 + 3.177352323560384e+00 + 1.876311088000044e-01 + -3.266256049263134e-01 + 1.000000000000000e-01 +Node: 46 + 3.403063662822699e+00 + 4.586915642732161e-01 + -3.541539599454015e-01 + 1.000000000000000e-01 +Node: 47 + 3.638193679758521e+00 + 7.525194313557660e-01 + -3.598551185329084e-01 + 1.000000000000000e-01 +Node: 48 + 3.945774908711690e+00 + 1.112381242034030e+00 + -4.211568462718694e-01 + 1.000000000000000e-01 +Node: 49 + 8.415699375444691e-01 + 1.012672339245308e-01 + -5.776458172597428e-01 + 1.000000000000000e-02 +Node: 50 + 8.528188644777025e-01 + 3.060241695784205e-01 + -2.748122116871037e-01 + 1.000000000000000e-02 +Node: 51 + 8.574824978093748e-01 + 5.234259874285337e-01 + -6.103507563611862e-02 + 1.000000000000000e-02 +Node: 52 + 8.685645662873697e-01 + 6.004633455582893e-01 + 1.185916459881602e-01 + 1.000000000000000e-02 +Node: 53 + 8.655958320683292e-01 + 5.346496547101207e-01 + 3.216121485719327e-01 + 1.000000000000000e-02 +Node: 54 + 8.727874206977626e-01 + 3.732385153257773e-01 + 5.034147536628709e-01 + 1.000000000000000e-02 +Node: 55 + 8.569222898933115e-01 + 1.915504464050390e-02 + 5.451111212353961e-01 + 1.000000000000000e-02 +Node: 56 + 8.590675538808702e-01 + -1.680484106071109e-01 + 4.385032007931779e-01 + 1.000000000000000e-02 +Node: 57 + 8.589388426004948e-01 + -3.292116520378644e-01 + 2.525312848168873e-01 + 1.000000000000000e-02 +Node: 58 + 8.684183291357422e-01 + -4.690445000108720e-01 + 1.026119292615201e-01 + 1.000000000000000e-02 +Node: 59 + 8.599881996115876e-01 + -4.660585167033815e-01 + -1.997834208577165e-01 + 1.000000000000000e-02 +Node: 60 + 8.499416869456555e-01 + -2.658662909358659e-01 + -5.084427936917201e-01 + 1.000000000000000e-02 +Node: 61 + 2.457513856594047e+00 + -1.502188557287938e-01 + -5.804178855490847e-01 + 1.000000000000000e-02 +Node: 62 + 2.489150952086654e+00 + 1.610573758996429e-01 + -4.620512750782058e-01 + 1.000000000000000e-02 +Node: 63 + 2.508153932088234e+00 + 4.018394812867033e-01 + -2.364653362134370e-01 + 1.000000000000000e-02 +Node: 64 + 2.520610561255106e+00 + 5.171822073975094e-01 + 6.792919191687449e-03 + 1.000000000000000e-02 +Node: 65 + 2.503222803977086e+00 + 5.134986298250340e-01 + 2.229434922135527e-01 + 1.000000000000000e-02 +Node: 66 + 2.502419626291860e+00 + 3.806731487485062e-01 + 4.601559769249017e-01 + 1.000000000000000e-02 +Node: 67 + 2.532597986005163e+00 + 1.278584702857989e-01 + 5.305650283313478e-01 + 1.000000000000000e-02 +Node: 68 + 2.516413971505339e+00 + -1.067135651549795e-01 + 4.413763515603156e-01 + 1.000000000000000e-02 +Node: 69 + 2.494263591346698e+00 + -2.642653512355263e-01 + 2.808921748460516e-01 + 1.000000000000000e-02 +Node: 70 + 2.482204423448402e+00 + -4.254886911057046e-01 + 1.162961614831712e-01 + 1.000000000000000e-02 +Node: 71 + 2.491291266907887e+00 + -4.720534110663925e-01 + -1.089923444954435e-01 + 1.000000000000000e-02 +Node: 72 + 2.469114083684929e+00 + -4.511912354844180e-01 + -3.617004683277068e-01 + 1.000000000000000e-02 +Node: 73 + 4.211192594904671e+00 + -2.180913561460371e-01 + -5.483220203120308e-01 + 1.000000000000000e-02 +Node: 74 + 4.282609500750963e+00 + 5.896451220566845e-02 + -5.012385144346297e-01 + 1.000000000000000e-02 +Node: 75 + 4.315960696363735e+00 + 2.446136447629781e-01 + -2.887745037424596e-01 + 1.000000000000000e-02 +Node: 76 + 4.336797796847812e+00 + 3.762921093051989e-01 + -8.317071594138040e-02 + 1.000000000000000e-02 +Node: 77 + 4.326586384799966e+00 + 4.305286314213965e-01 + 1.423146769653955e-01 + 1.000000000000000e-02 +Node: 78 + 4.320381379719021e+00 + 3.931307519604826e-01 + 3.573654979527858e-01 + 1.000000000000000e-02 +Node: 79 + 4.269355282012874e+00 + 1.390923726008407e-01 + 4.945942763192685e-01 + 1.000000000000000e-02 +Node: 80 + 4.202883790665199e+00 + -1.298142100394249e-01 + 4.880356442686678e-01 + 1.000000000000000e-02 +Node: 81 + 4.154726138354027e+00 + -3.287885872249272e-01 + 3.723818654387616e-01 + 1.000000000000000e-02 +Node: 82 + 4.137871351709665e+00 + -4.305890161123914e-01 + 1.698889966601328e-01 + 1.000000000000000e-02 +Node: 83 + 4.117083324460348e+00 + -5.143014291368941e-01 + -9.170540869805206e-02 + 1.000000000000000e-02 +Node: 84 + 4.139147896433698e+00 + -4.364982277527066e-01 + -3.626326906335586e-01 + 1.000000000000000e-02 +Define node template: node2 +Shape. Dimension=0 +#Fields=3 +1) marker coordinates, coordinate, rectangular cartesian, real, #Components=3 + x. #Values=1 (value) + y. #Values=1 (value) + z. #Values=1 (value) +2) marker_location, field, element_xi, #Components=1, host mesh=mesh1d, host mesh dimension=1 + 1. #Values=1 (value) +3) marker_name, field, string, #Components=1 + 1. #Values=1 (value) +Node template: node2 +Node: 85 + 9.014003757666196e-01 + -3.826434512807377e-03 + -1.393280944888287e-02 + 2 7.935239999999998e-01 + "landmark 1" +Node: 86 + 2.532597965468619e+00 + 1.278581726279602e-01 + 5.305649151562682e-01 + 64 1.000000000000000e+00 + orientation +!#mesh mesh1d, dimension=1, nodeset=nodes +Define element template: element1 +Shape. Dimension=1, line +#Scale factor sets=0 +#Nodes=2 +#Fields=2 +1) coordinates, coordinate, rectangular cartesian, real, #Components=3 + x. l.Lagrange, no modify, standard node based. + #Nodes=2 + 1. #Values=1 + Value labels: value + 2. #Values=1 + Value labels: value + y. l.Lagrange, no modify, standard node based. + #Nodes=2 + 1. #Values=1 + Value labels: value + 2. #Values=1 + Value labels: value + z. l.Lagrange, no modify, standard node based. + #Nodes=2 + 1. #Values=1 + Value labels: value + 2. #Values=1 + Value labels: value +2) radius, field, rectangular cartesian, real, #Components=1 + 1. l.Lagrange, no modify, standard node based. + #Nodes=2 + 1. #Values=1 + Value labels: value + 2. #Values=1 + Value labels: value +Element template: element1 +Element: 1 + Nodes: + 1 2 +Element: 2 + Nodes: + 2 3 +Element: 3 + Nodes: + 3 4 +Element: 4 + Nodes: + 4 5 +Element: 5 + Nodes: + 5 6 +Element: 6 + Nodes: + 6 7 +Element: 7 + Nodes: + 7 8 +Element: 8 + Nodes: + 8 9 +Element: 9 + Nodes: + 9 10 +Element: 10 + Nodes: + 10 11 +Element: 11 + Nodes: + 12 13 +Element: 12 + Nodes: + 13 14 +Element: 13 + Nodes: + 14 15 +Element: 14 + Nodes: + 15 16 +Element: 15 + Nodes: + 16 17 +Element: 16 + Nodes: + 17 18 +Element: 17 + Nodes: + 18 19 +Element: 18 + Nodes: + 19 20 +Element: 19 + Nodes: + 20 21 +Element: 20 + Nodes: + 21 22 +Element: 21 + Nodes: + 23 24 +Element: 22 + Nodes: + 24 25 +Element: 23 + Nodes: + 25 26 +Element: 24 + Nodes: + 26 27 +Element: 25 + Nodes: + 27 28 +Element: 26 + Nodes: + 14 29 +Element: 27 + Nodes: + 29 30 +Element: 28 + Nodes: + 30 27 +Element: 29 + Nodes: + 28 31 +Element: 30 + Nodes: + 31 32 +Element: 31 + Nodes: + 32 33 +Element: 32 + Nodes: + 33 34 +Element: 33 + Nodes: + 34 35 +Element: 34 + Nodes: + 28 36 +Element: 35 + Nodes: + 36 37 +Element: 36 + Nodes: + 37 38 +Element: 37 + Nodes: + 38 39 +Element: 38 + Nodes: + 39 40 +Element: 39 + Nodes: + 7 41 +Element: 40 + Nodes: + 41 42 +Element: 41 + Nodes: + 42 43 +Element: 42 + Nodes: + 43 44 +Element: 43 + Nodes: + 18 45 +Element: 44 + Nodes: + 45 46 +Element: 45 + Nodes: + 46 47 +Element: 46 + Nodes: + 47 48 +Element: 47 + Nodes: + 49 50 +Element: 48 + Nodes: + 50 51 +Element: 49 + Nodes: + 51 52 +Element: 50 + Nodes: + 52 53 +Element: 51 + Nodes: + 53 54 +Element: 52 + Nodes: + 54 55 +Element: 53 + Nodes: + 55 56 +Element: 54 + Nodes: + 56 57 +Element: 55 + Nodes: + 57 58 +Element: 56 + Nodes: + 58 59 +Element: 57 + Nodes: + 59 60 +Element: 58 + Nodes: + 60 49 +Element: 59 + Nodes: + 61 62 +Element: 60 + Nodes: + 62 63 +Element: 61 + Nodes: + 63 64 +Element: 62 + Nodes: + 64 65 +Element: 63 + Nodes: + 65 66 +Element: 64 + Nodes: + 66 67 +Element: 65 + Nodes: + 67 68 +Element: 66 + Nodes: + 68 69 +Element: 67 + Nodes: + 69 70 +Element: 68 + Nodes: + 70 71 +Element: 69 + Nodes: + 71 72 +Element: 70 + Nodes: + 72 61 +Element: 71 + Nodes: + 73 74 +Element: 72 + Nodes: + 74 75 +Element: 73 + Nodes: + 75 76 +Element: 74 + Nodes: + 76 77 +Element: 75 + Nodes: + 77 78 +Element: 76 + Nodes: + 78 79 +Element: 77 + Nodes: + 79 80 +Element: 78 + Nodes: + 80 81 +Element: 79 + Nodes: + 81 82 +Element: 80 + Nodes: + 82 83 +Element: 81 + Nodes: + 83 84 +Element: 82 + Nodes: + 84 73 +Group name: 00001 +!#nodeset nodes +Node group: +49..60 +!#mesh mesh1d, dimension=1, nodeset=nodes +Element group: +47..58 +Group name: 00002 +!#nodeset nodes +Node group: +61..72 +!#mesh mesh1d, dimension=1, nodeset=nodes +Element group: +59..70 +Group name: 00003 +!#nodeset nodes +Node group: +73..84 +!#mesh mesh1d, dimension=1, nodeset=nodes +Element group: +71..82 +Group name: Epineurium +!#nodeset nodes +Node group: +49..84 +!#mesh mesh1d, dimension=1, nodeset=nodes +Element group: +47..82 +Group name: Fascicle +!#nodeset nodes +Node group: +12..40,45..48 +!#mesh mesh1d, dimension=1, nodeset=nodes +Element group: +11..38,43..46 +Group name: http://purl.obolibrary.org/obo/UBERON_0000124 +!#nodeset nodes +Node group: +49..84 +!#mesh mesh1d, dimension=1, nodeset=nodes +Element group: +47..82 +Group name: http://purl.obolibrary.org/obo/UBERON_0035020 +!#nodeset nodes +Node group: +1..11 +!#mesh mesh1d, dimension=1, nodeset=nodes +Element group: +1..10 +Group name: http://uri.interlex.org/base/ilx_0738426 +!#nodeset nodes +Node group: +12..40,45..48 +!#mesh mesh1d, dimension=1, nodeset=nodes +Element group: +11..38,43..46 +Group name: left A branch 1 +!#nodeset nodes +Node group: +7,41..44 +!#mesh mesh1d, dimension=1, nodeset=nodes +Element group: +39..42 +Group name: left vagus X nerve trunk +!#nodeset nodes +Node group: +1..11 +!#mesh mesh1d, dimension=1, nodeset=nodes +Element group: +1..10 +Group name: marker +!#nodeset nodes +Node group: +85..86 diff --git a/tests/resources/vagus-segment2.exf b/tests/resources/vagus-segment2.exf new file mode 100644 index 0000000..c400f23 --- /dev/null +++ b/tests/resources/vagus-segment2.exf @@ -0,0 +1,714 @@ +EX Version: 3 +Region: / +!#nodeset nodes +Define node template: node1 +Shape. Dimension=0 +#Fields=2 +1) coordinates, coordinate, rectangular cartesian, real, #Components=3 + x. #Values=1 (value) + y. #Values=1 (value) + z. #Values=1 (value) +2) radius, field, rectangular cartesian, real, #Components=1 + 1. #Values=1 (value) +Node template: node1 +Node: 1 + 5.964812921168261e-02 + -6.273546711369847e-02 + -3.355441365022512e-02 + 5.000000000000000e-01 +Node: 2 + 4.693738101370620e-01 + 4.653915664744340e-02 + -3.323098349765055e-02 + 5.000000000000000e-01 +Node: 3 + 9.387281371900745e-01 + 9.202060653844510e-02 + -5.910871710620801e-02 + 5.000000000000000e-01 +Node: 4 + 1.408491810117268e+00 + 1.430819220874628e-01 + -8.461987977035484e-02 + 5.000000000000000e-01 +Node: 5 + 1.877315226911354e+00 + 1.760407665782467e-01 + -1.226895086938640e-01 + 5.000000000000000e-01 +Node: 6 + 2.350388798778136e+00 + 2.081325257864106e-01 + -1.444901311314077e-01 + 5.000000000000000e-01 +Node: 7 + 2.820408784119046e+00 + 2.485505602798801e-01 + -1.742069085137291e-01 + 5.000000000000000e-01 +Node: 8 + 3.303552238495561e+00 + 2.685259377980394e-01 + -2.094495380833261e-01 + 5.000000000000000e-01 +Node: 9 + 3.765226728491446e+00 + 2.887717624801313e-01 + -2.347640522107229e-01 + 5.000000000000000e-01 +Node: 10 + 4.234094626034139e+00 + 3.355357572985979e-01 + -2.658981515557162e-01 + 5.000000000000000e-01 +Node: 11 + 4.649351322583406e+00 + 3.147843160681483e-01 + -1.929821893185177e-01 + 5.000000000000000e-01 +Node: 12 + 5.895366780392978e-02 + -5.816489437685837e-02 + 2.435963071931103e-01 + 1.300000000000000e-01 +Node: 13 + 4.884047722778075e-01 + -7.755825741335237e-03 + 2.658037975431622e-01 + 1.300000000000000e-01 +Node: 14 + 9.897686143993319e-01 + 3.464610678847231e-02 + 2.657611634243279e-01 + 1.300000000000000e-01 +Node: 15 + 1.476570214872810e+00 + 6.462684193008554e-02 + 2.402417678394919e-01 + 1.300000000000000e-01 +Node: 16 + 1.972157045498206e+00 + 8.972186822651815e-02 + 1.960768677096780e-01 + 1.300000000000000e-01 +Node: 17 + 2.469090578395699e+00 + 6.196203185938773e-02 + 5.621931904079824e-02 + 2.000000000000000e-01 +Node: 18 + 2.946119268478841e+00 + 1.250867213541862e-01 + 2.421606995890878e-02 + 2.000000000000000e-01 +Node: 19 + 3.411879032072762e+00 + 1.660838268607837e-01 + 2.270313000119678e-03 + 2.000000000000000e-01 +Node: 20 + 3.873878170842627e+00 + 1.880720159726681e-01 + -1.858452896759641e-03 + 2.000000000000000e-01 +Node: 21 + 4.330973948784339e+00 + 1.951875037554594e-01 + -2.418877327551888e-02 + 2.000000000000000e-01 +Node: 22 + 4.612390816368989e+00 + 1.512323694124026e-01 + 6.507547243332894e-02 + 2.000000000000000e-01 +Node: 23 + 1.293214970581236e-01 + -2.792048089888567e-01 + -2.164207378381781e-02 + 1.300000000000000e-01 +Node: 24 + 5.179042415278017e-01 + -2.368974901378895e-01 + 2.737480637211859e-02 + 1.300000000000000e-01 +Node: 25 + 1.013317207281856e+00 + -1.990475117211184e-01 + 2.681614581373870e-02 + 1.300000000000000e-01 +Node: 26 + 1.523401555538024e+00 + -1.584205064651123e-01 + -1.104594924310610e-02 + 1.300000000000000e-01 +Node: 27 + 2.008720599171234e+00 + -8.720364285079608e-02 + -7.738027594649231e-03 + 1.300000000000000e-01 +Node: 28 + 4.000186023704735e-02 + 1.270394815034144e-01 + -3.229041442172280e-01 + 1.800000000000000e-01 +Node: 29 + 4.874842948174925e-01 + 2.328230516006755e-01 + -3.393155075654865e-01 + 1.800000000000000e-01 +Node: 30 + 9.471295133813437e-01 + 2.605608519522547e-01 + -3.744744052101734e-01 + 1.800000000000000e-01 +Node: 31 + 1.367332690390529e+00 + 2.725873846507746e-01 + -3.875862185781842e-01 + 1.800000000000000e-01 +Node: 32 + 1.853063543367501e+00 + 2.907778923785880e-01 + -4.208015582070269e-01 + 1.800000000000000e-01 +Node: 33 + 2.296636818582645e+00 + 3.148018303434528e-01 + -4.482494386961102e-01 + 1.800000000000000e-01 +Node: 34 + 2.761957158713819e+00 + 3.211219675945598e-01 + -4.815534443128778e-01 + 1.800000000000000e-01 +Node: 35 + 3.238534610426289e+00 + 3.389379078341463e-01 + -5.149021854645067e-01 + 1.800000000000000e-01 +Node: 36 + 3.710408419405519e+00 + 3.493719415058864e-01 + -5.480836257694687e-01 + 1.800000000000000e-01 +Node: 37 + 4.195747547643650e+00 + 3.561711439874546e-01 + -5.705309854833883e-01 + 1.800000000000000e-01 +Node: 38 + 4.595780357914053e+00 + 3.228584203970787e-01 + -5.152335214220667e-01 + 1.800000000000000e-01 +Node: 39 + 9.032725229067927e-01 + 2.373369856800536e-02 + -6.329727305413341e-01 + 1.000000000000000e-02 +Node: 40 + 8.934709527822703e-01 + 3.818521770357496e-01 + -5.763586753448999e-01 + 1.000000000000000e-02 +Node: 41 + 8.982013888321906e-01 + 5.573132534971721e-01 + -3.302803918872626e-01 + 1.000000000000000e-02 +Node: 42 + 9.062130888568782e-01 + 5.192446029281673e-01 + -6.745393769308183e-02 + 1.000000000000000e-02 +Node: 43 + 9.196787103749551e-01 + 4.397918155707508e-01 + 1.835224672799013e-01 + 1.000000000000000e-02 +Node: 44 + 9.321163536061950e-01 + 2.865431688518749e-01 + 3.653899708284058e-01 + 1.000000000000000e-02 +Node: 45 + 9.455493586398056e-01 + 1.054362036582266e-01 + 5.037059572690106e-01 + 1.000000000000000e-02 +Node: 46 + 9.427083450147042e-01 + -1.457725512763095e-01 + 4.558632234155680e-01 + 1.000000000000000e-02 +Node: 47 + 9.254246049017438e-01 + -3.135757298325889e-01 + 2.975261857604956e-01 + 1.000000000000000e-02 +Node: 48 + 9.167741913094090e-01 + -4.391613818448387e-01 + 1.690338431207718e-01 + 1.000000000000000e-02 +Node: 49 + 8.965003448840203e-01 + -4.671027392838462e-01 + -1.481534733918043e-01 + 1.000000000000000e-02 +Node: 50 + 8.774840798545825e-01 + -2.634915141496693e-01 + -4.426480885463952e-01 + 1.000000000000000e-02 +Node: 51 + 2.412235683349789e+00 + 1.622447999947594e-01 + -7.021304974367505e-01 + 1.000000000000000e-02 +Node: 52 + 2.412116792635590e+00 + 5.462818502596256e-01 + -5.964880994276321e-01 + 1.000000000000000e-02 +Node: 53 + 2.393714590765646e+00 + 6.901448699217729e-01 + -3.339595971939595e-01 + 1.000000000000000e-02 +Node: 54 + 2.396854196436117e+00 + 6.896918057232820e-01 + -5.736187286787107e-02 + 1.000000000000000e-02 +Node: 55 + 2.390833273167956e+00 + 6.135883435827709e-01 + 1.971469871357453e-01 + 1.000000000000000e-02 +Node: 56 + 2.412386856736946e+00 + 4.130600557330022e-01 + 3.735844483404094e-01 + 1.000000000000000e-02 +Node: 57 + 2.430431092495623e+00 + 1.559440909551448e-01 + 4.036389705613910e-01 + 1.000000000000000e-02 +Node: 58 + 2.387281869207619e+00 + -4.291479073807819e-02 + 3.036119830579160e-01 + 1.000000000000000e-02 +Node: 59 + 2.406676886976866e+00 + -2.101563304250731e-01 + 1.711513008649808e-01 + 1.000000000000000e-02 +Node: 60 + 2.430135116146151e+00 + -3.078082198143320e-01 + -8.708152277615835e-03 + 1.000000000000000e-02 +Node: 61 + 2.434815646214388e+00 + -2.874337639404328e-01 + -2.600342278679089e-01 + 1.000000000000000e-02 +Node: 62 + 2.394285361203448e+00 + -1.612091923256298e-01 + -5.496729791957915e-01 + 1.000000000000000e-02 +Node: 63 + 3.824335255298526e+00 + 2.848674373659748e-01 + -7.240511790590889e-01 + 1.000000000000000e-02 +Node: 64 + 3.831011589586607e+00 + 5.725477943597055e-01 + -6.873093235124992e-01 + 1.000000000000000e-02 +Node: 65 + 3.856549867407969e+00 + 6.895302363065451e-01 + -4.489141768624906e-01 + 1.000000000000000e-02 +Node: 66 + 3.868435859467971e+00 + 7.070791892732076e-01 + -1.966254275659377e-01 + 1.000000000000000e-02 +Node: 67 + 3.857922019419580e+00 + 6.385669119119456e-01 + -4.364182697730948e-02 + 1.000000000000000e-02 +Node: 68 + 3.872754482525333e+00 + 4.558434875716680e-01 + 1.321302056224572e-01 + 1.000000000000000e-02 +Node: 69 + 3.868822828963671e+00 + 3.028158184701839e-01 + 2.516984531530910e-01 + 1.000000000000000e-02 +Node: 70 + 3.862867182901765e+00 + 7.339524259965696e-02 + 2.545190781977691e-01 + 1.000000000000000e-02 +Node: 71 + 3.861007770953574e+00 + -7.493155681502191e-02 + 1.114648068722638e-01 + 1.000000000000000e-02 +Node: 72 + 3.838304007828248e+00 + -1.391361458053640e-01 + -1.052604668664126e-01 + 1.000000000000000e-02 +Node: 73 + 3.810823727026837e+00 + -8.733211417605728e-02 + -3.772518501788384e-01 + 1.000000000000000e-02 +Node: 74 + 3.785514504951673e+00 + 3.870986066902640e-02 + -5.955710515963688e-01 + 1.000000000000000e-02 +Define node template: node2 +Shape. Dimension=0 +#Fields=3 +1) marker coordinates, coordinate, rectangular cartesian, real, #Components=3 + x. #Values=1 (value) + y. #Values=1 (value) + z. #Values=1 (value) +2) marker_location, field, element_xi, #Components=1, host mesh=mesh1d, host mesh dimension=1 + 1. #Values=1 (value) +3) marker_name, field, string, #Components=1 + 1. #Values=1 (value) +Node template: node2 +Node: 75 + 2.398918879869167e+00 + -1.432597145502751e-01 + 2.241355737421549e-01 + 55 6.000000000000000e-01 + orientation +!#mesh mesh1d, dimension=1, nodeset=nodes +Define element template: element1 +Shape. Dimension=1, line +#Scale factor sets=0 +#Nodes=2 +#Fields=2 +1) coordinates, coordinate, rectangular cartesian, real, #Components=3 + x. l.Lagrange, no modify, standard node based. + #Nodes=2 + 1. #Values=1 + Value labels: value + 2. #Values=1 + Value labels: value + y. l.Lagrange, no modify, standard node based. + #Nodes=2 + 1. #Values=1 + Value labels: value + 2. #Values=1 + Value labels: value + z. l.Lagrange, no modify, standard node based. + #Nodes=2 + 1. #Values=1 + Value labels: value + 2. #Values=1 + Value labels: value +2) radius, field, rectangular cartesian, real, #Components=1 + 1. l.Lagrange, no modify, standard node based. + #Nodes=2 + 1. #Values=1 + Value labels: value + 2. #Values=1 + Value labels: value +Element template: element1 +Element: 1 + Nodes: + 1 2 +Element: 2 + Nodes: + 2 3 +Element: 3 + Nodes: + 3 4 +Element: 4 + Nodes: + 4 5 +Element: 5 + Nodes: + 5 6 +Element: 6 + Nodes: + 6 7 +Element: 7 + Nodes: + 7 8 +Element: 8 + Nodes: + 8 9 +Element: 9 + Nodes: + 9 10 +Element: 10 + Nodes: + 10 11 +Element: 11 + Nodes: + 12 13 +Element: 12 + Nodes: + 13 14 +Element: 13 + Nodes: + 14 15 +Element: 14 + Nodes: + 15 16 +Element: 15 + Nodes: + 16 17 +Element: 16 + Nodes: + 17 18 +Element: 17 + Nodes: + 18 19 +Element: 18 + Nodes: + 19 20 +Element: 19 + Nodes: + 20 21 +Element: 20 + Nodes: + 21 22 +Element: 21 + Nodes: + 23 24 +Element: 22 + Nodes: + 24 25 +Element: 23 + Nodes: + 25 26 +Element: 24 + Nodes: + 26 27 +Element: 25 + Nodes: + 27 17 +Element: 26 + Nodes: + 28 29 +Element: 27 + Nodes: + 29 30 +Element: 28 + Nodes: + 30 31 +Element: 29 + Nodes: + 31 32 +Element: 30 + Nodes: + 32 33 +Element: 31 + Nodes: + 33 34 +Element: 32 + Nodes: + 34 35 +Element: 33 + Nodes: + 35 36 +Element: 34 + Nodes: + 36 37 +Element: 35 + Nodes: + 37 38 +Element: 36 + Nodes: + 39 40 +Element: 37 + Nodes: + 40 41 +Element: 38 + Nodes: + 41 42 +Element: 39 + Nodes: + 42 43 +Element: 40 + Nodes: + 43 44 +Element: 41 + Nodes: + 44 45 +Element: 42 + Nodes: + 45 46 +Element: 43 + Nodes: + 46 47 +Element: 44 + Nodes: + 47 48 +Element: 45 + Nodes: + 48 49 +Element: 46 + Nodes: + 49 50 +Element: 47 + Nodes: + 50 39 +Element: 48 + Nodes: + 51 52 +Element: 49 + Nodes: + 52 53 +Element: 50 + Nodes: + 53 54 +Element: 51 + Nodes: + 54 55 +Element: 52 + Nodes: + 55 56 +Element: 53 + Nodes: + 56 57 +Element: 54 + Nodes: + 57 58 +Element: 55 + Nodes: + 58 59 +Element: 56 + Nodes: + 59 60 +Element: 57 + Nodes: + 60 61 +Element: 58 + Nodes: + 61 62 +Element: 59 + Nodes: + 62 51 +Element: 60 + Nodes: + 63 64 +Element: 61 + Nodes: + 64 65 +Element: 62 + Nodes: + 65 66 +Element: 63 + Nodes: + 66 67 +Element: 64 + Nodes: + 67 68 +Element: 65 + Nodes: + 68 69 +Element: 66 + Nodes: + 69 70 +Element: 67 + Nodes: + 70 71 +Element: 68 + Nodes: + 71 72 +Element: 69 + Nodes: + 72 73 +Element: 70 + Nodes: + 73 74 +Element: 71 + Nodes: + 74 63 +Group name: 00001 +!#nodeset nodes +Node group: +39..50 +!#mesh mesh1d, dimension=1, nodeset=nodes +Element group: +36..47 +Group name: 00002 +!#nodeset nodes +Node group: +51..62 +!#mesh mesh1d, dimension=1, nodeset=nodes +Element group: +48..59 +Group name: 00003 +!#nodeset nodes +Node group: +63..74 +!#mesh mesh1d, dimension=1, nodeset=nodes +Element group: +60..71 +Group name: Epineurium +!#nodeset nodes +Node group: +39..74 +!#mesh mesh1d, dimension=1, nodeset=nodes +Element group: +36..71 +Group name: Fascicle +!#nodeset nodes +Node group: +12..38 +!#mesh mesh1d, dimension=1, nodeset=nodes +Element group: +11..35 +Group name: http://purl.obolibrary.org/obo/UBERON_0000124 +!#nodeset nodes +Node group: +39..74 +!#mesh mesh1d, dimension=1, nodeset=nodes +Element group: +36..71 +Group name: http://purl.obolibrary.org/obo/UBERON_0035020 +!#nodeset nodes +Node group: +1..11 +!#mesh mesh1d, dimension=1, nodeset=nodes +Element group: +1..10 +Group name: http://uri.interlex.org/base/ilx_0738426 +!#nodeset nodes +Node group: +12..38 +!#mesh mesh1d, dimension=1, nodeset=nodes +Element group: +11..35 +Group name: left vagus X nerve trunk +!#nodeset nodes +Node group: +1..11 +!#mesh mesh1d, dimension=1, nodeset=nodes +Element group: +1..10 +Group name: marker +!#nodeset nodes +Node group: +75 diff --git a/tests/resources/vagus-segment3.exf b/tests/resources/vagus-segment3.exf new file mode 100644 index 0000000..786c820 --- /dev/null +++ b/tests/resources/vagus-segment3.exf @@ -0,0 +1,679 @@ +EX Version: 3 +Region: / +!#nodeset nodes +Define node template: node1 +Shape. Dimension=0 +#Fields=2 +1) coordinates, coordinate, rectangular cartesian, real, #Components=3 + x. #Values=1 (value) + y. #Values=1 (value) + z. #Values=1 (value) +2) radius, field, rectangular cartesian, real, #Components=1 + 1. #Values=1 (value) +Node template: node1 +Node: 1 + 1.942913940852241e-03 + -6.174096706026299e-03 + -4.184117075191224e-02 + 5.000000000000000e-01 +Node: 2 + 4.448264093299951e-01 + 8.950283861785819e-04 + 8.513537594775918e-03 + 5.000000000000000e-01 +Node: 3 + 9.144594663866782e-01 + -7.856525010683454e-03 + 5.230329122004440e-03 + 5.000000000000000e-01 +Node: 4 + 1.341529169159712e+00 + -3.379731197072129e-03 + 4.432846503988402e-03 + 5.000000000000000e-01 +Node: 5 + 1.761745770384641e+00 + 8.287396815189014e-03 + 4.417953503525773e-03 + 5.000000000000000e-01 +Node: 6 + 2.213753789766584e+00 + 3.492142347275788e-03 + -4.580284671442248e-03 + 5.000000000000000e-01 +Node: 7 + 2.662473125636201e+00 + 2.017812954482604e-03 + -1.525185733770927e-05 + 5.000000000000000e-01 +Node: 8 + 3.105347103809529e+00 + 8.648228296777416e-03 + -4.870345987352635e-04 + 5.000000000000000e-01 +Node: 9 + 3.541403559171137e+00 + 2.293381052435084e-02 + 4.003380068629582e-03 + 5.000000000000000e-01 +Node: 10 + 3.982089419158738e+00 + 1.827226733944906e-02 + -8.564516602682279e-02 + 5.000000000000000e-01 +Node: 11 + 2.447643130510998e+00 + -2.570845179719078e-01 + -1.342412383818372e-01 + 3.000000000000000e-01 +Node: 12 + 2.711081252270924e+00 + -4.501526348728000e-01 + -2.798828374798183e-01 + 3.000000000000000e-01 +Node: 13 + 2.974352619308820e+00 + -6.419436940272879e-01 + -4.174463020710795e-01 + 3.000000000000000e-01 +Node: 14 + 3.218437907900458e+00 + -8.074905606301290e-01 + -5.115771921019351e-01 + 3.000000000000000e-01 +Node: 15 + 3.017897526101733e-03 + 6.726043232147080e-02 + 1.747255010683442e-01 + 1.800000000000000e-01 +Node: 16 + 4.133180274415295e-01 + 1.020102945909715e-01 + 2.319094597468164e-01 + 1.800000000000000e-01 +Node: 17 + 8.795172184976472e-01 + 1.149448459481597e-01 + 2.274915119711656e-01 + 1.800000000000000e-01 +Node: 18 + 1.324185845280540e+00 + 1.128968096159903e-01 + 2.275938907989292e-01 + 1.800000000000000e-01 +Node: 19 + 1.776094753602661e+00 + 1.162543894279857e-01 + 2.159050795679053e-01 + 1.800000000000000e-01 +Node: 20 + 2.220458350027672e+00 + 1.128685896957036e-01 + 2.197708812407930e-01 + 1.800000000000000e-01 +Node: 21 + 2.665377659401191e+00 + 1.139402738887889e-01 + 2.110713760079689e-01 + 1.800000000000000e-01 +Node: 22 + 3.109037245293125e+00 + 1.134024246980764e-01 + 2.245317429568045e-01 + 1.800000000000000e-01 +Node: 23 + 3.554032957525765e+00 + 1.111083384484419e-01 + 2.170664886704722e-01 + 1.800000000000000e-01 +Node: 24 + 4.007873722959625e+00 + 9.455184174585338e-02 + 1.562665357800655e-01 + 1.800000000000000e-01 +Node: 25 + 3.650365824150211e-02 + -2.443646517760476e-01 + -2.484677767585793e-01 + 1.600000000000000e-01 +Node: 26 + 4.444444444444444e-01 + -2.000000000000000e-01 + -2.000000000000000e-01 + 1.600000000000000e-01 +Node: 27 + 8.888888888888888e-01 + -2.000000000000000e-01 + -2.000000000000000e-01 + 1.600000000000000e-01 +Node: 28 + 1.333333333333333e+00 + -2.000000000000000e-01 + -2.000000000000000e-01 + 1.600000000000000e-01 +Node: 29 + 1.777777777777778e+00 + -2.000000000000000e-01 + -2.000000000000000e-01 + 1.600000000000000e-01 +Node: 30 + 2.222222222222222e+00 + -2.000000000000000e-01 + -2.000000000000000e-01 + 1.600000000000000e-01 +Node: 31 + 2.666666666666667e+00 + -2.000000000000000e-01 + -2.000000000000000e-01 + 1.600000000000000e-01 +Node: 32 + 3.119773532806408e+00 + -1.739822778795015e-01 + -1.418776640205003e-01 + 1.600000000000000e-01 +Node: 33 + 3.559610278792930e+00 + -1.809736986882864e-01 + -1.544218854266600e-01 + 1.600000000000000e-01 +Node: 34 + 4.005046501495423e+00 + -1.726020482287798e-01 + -2.188070821305510e-01 + 1.600000000000000e-01 +Node: 35 + 2.593065320266171e+00 + -3.636453436630986e-01 + -2.782545705182799e-01 + 1.200000000000000e-01 +Node: 36 + 2.864977251430427e+00 + -5.788448713763134e-01 + -4.093331980088444e-01 + 1.200000000000000e-01 +Node: 37 + 3.162539960993468e+00 + -7.731751141114022e-01 + -5.689115558349389e-01 + 1.200000000000000e-01 +Node: 38 + 2.109848842979773e+00 + 3.719079207098927e-02 + 1.119288096759290e-01 + 1.200000000000000e-01 +Node: 39 + 2.443602932356886e+00 + -4.187280528600715e-02 + 7.952539783952647e-03 + 1.200000000000000e-01 +Node: 40 + 2.777357021733999e+00 + -1.209364026430036e-01 + -9.602373010802370e-02 + 1.200000000000000e-01 +Node: 41 + 1.188774765516805e+00 + -1.317557143317290e-01 + 2.713028727140580e-01 + 1.000000000000000e-01 +Node: 42 + 1.551693950482933e+00 + -1.544385131235169e-01 + 3.150434946628239e-01 + 1.000000000000000e-01 +Node: 43 + 1.149899874923449e+00 + -9.822446159056160e-02 + -4.870183811724769e-01 + 1.000000000000000e-02 +Node: 44 + 1.142023040483053e+00 + 1.038332976044075e-01 + -3.725593905614107e-01 + 1.000000000000000e-02 +Node: 45 + 1.133758005562625e+00 + 2.852475311665115e-01 + -1.565198362655622e-01 + 1.000000000000000e-02 +Node: 46 + 1.139733192538984e+00 + 3.526427305482840e-01 + 6.783031899931936e-02 + 1.000000000000000e-02 +Node: 47 + 1.135575496303638e+00 + 3.729293473070320e-01 + 3.129020572927896e-01 + 1.000000000000000e-02 +Node: 48 + 1.145027536689751e+00 + 2.476288707625080e-01 + 4.783048734333570e-01 + 1.000000000000000e-02 +Node: 49 + 1.155765499702990e+00 + -8.940822592192652e-03 + 5.466605040604726e-01 + 1.000000000000000e-02 +Node: 50 + 1.151397872884108e+00 + -2.386179401262818e-01 + 4.991814735148105e-01 + 1.000000000000000e-02 +Node: 51 + 1.170126401711106e+00 + -3.978633230645265e-01 + 3.547048653571901e-01 + 1.000000000000000e-02 +Node: 52 + 1.198946099488208e+00 + -5.042422191675737e-01 + 1.075292090735259e-01 + 1.000000000000000e-02 +Node: 53 + 1.183093133189083e+00 + -5.097951359581457e-01 + -1.799070828179922e-01 + 1.000000000000000e-02 +Node: 54 + 1.188338855899369e+00 + -3.442974454082276e-01 + -4.464940574422623e-01 + 1.000000000000000e-02 +Node: 55 + 2.606895397793692e+00 + -1.950598675435381e-01 + -5.361193420133061e-01 + 1.000000000000000e-02 +Node: 56 + 2.651902252839656e+00 + 2.267963645231352e-02 + -4.467891552329343e-01 + 1.000000000000000e-02 +Node: 57 + 2.708886412730180e+00 + 1.964262699177105e-01 + -2.629825160572929e-01 + 1.000000000000000e-02 +Node: 58 + 2.724087185465105e+00 + 4.002766604185694e-01 + -1.698150657838519e-01 + 1.000000000000000e-02 +Node: 59 + 2.762649172380893e+00 + 4.734943912484497e-01 + 1.950103855185695e-02 + 1.000000000000000e-02 +Node: 60 + 2.761204225191120e+00 + 4.531602622637320e-01 + 2.407676855854297e-01 + 1.000000000000000e-02 +Node: 61 + 2.732505873421819e+00 + 2.909626972902289e-01 + 4.507066538115742e-01 + 1.000000000000000e-02 +Node: 62 + 2.656477291205181e+00 + -9.045587239960982e-03 + 5.146054159844166e-01 + 1.000000000000000e-02 +Node: 63 + 2.617549174048286e+00 + -2.682754820764324e-01 + 3.758702852189720e-01 + 1.000000000000000e-02 +Node: 64 + 2.564687839624030e+00 + -4.493650306903538e-01 + 1.135176760617449e-01 + 1.000000000000000e-02 +Node: 65 + 2.539825917415548e+00 + -6.284575422920684e-01 + -8.927549417402995e-02 + 1.000000000000000e-02 +Node: 66 + 2.520444076789123e+00 + -6.730375017164026e-01 + -3.042201528516072e-01 + 1.000000000000000e-02 +Node: 67 + 2.532117137196638e+00 + -6.244798860718002e-01 + -4.486257837097265e-01 + 1.000000000000000e-02 +Node: 68 + 2.555646992814593e+00 + -4.359756419261278e-01 + -5.458579284260732e-01 + 1.000000000000000e-02 +Define node template: node2 +Shape. Dimension=0 +#Fields=3 +1) marker coordinates, coordinate, rectangular cartesian, real, #Components=3 + x. #Values=1 (value) + y. #Values=1 (value) + z. #Values=1 (value) +2) marker_location, field, element_xi, #Components=1, host mesh=mesh1d, host mesh dimension=1 + 1. #Values=1 (value) +3) marker_name, field, string, #Components=1 + 1. #Values=1 (value) +Node template: node2 +Node: 69 + 3.000000000000000e+00 + 0.000000000000000e+00 + 7.000000000000000e-01 + 58 8.299280188212664e-01 + orientation +Node: 70 + 1.599724956533351e+00 + 3.788960603141545e-03 + 4.423695723249146e-03 + 4 6.144349999999998e-01 + "landmark 2" +!#mesh mesh1d, dimension=1, nodeset=nodes +Define element template: element1 +Shape. Dimension=1, line +#Scale factor sets=0 +#Nodes=2 +#Fields=2 +1) coordinates, coordinate, rectangular cartesian, real, #Components=3 + x. l.Lagrange, no modify, standard node based. + #Nodes=2 + 1. #Values=1 + Value labels: value + 2. #Values=1 + Value labels: value + y. l.Lagrange, no modify, standard node based. + #Nodes=2 + 1. #Values=1 + Value labels: value + 2. #Values=1 + Value labels: value + z. l.Lagrange, no modify, standard node based. + #Nodes=2 + 1. #Values=1 + Value labels: value + 2. #Values=1 + Value labels: value +2) radius, field, rectangular cartesian, real, #Components=1 + 1. l.Lagrange, no modify, standard node based. + #Nodes=2 + 1. #Values=1 + Value labels: value + 2. #Values=1 + Value labels: value +Element template: element1 +Element: 1 + Nodes: + 1 2 +Element: 2 + Nodes: + 2 3 +Element: 3 + Nodes: + 3 4 +Element: 4 + Nodes: + 4 5 +Element: 5 + Nodes: + 5 6 +Element: 6 + Nodes: + 6 7 +Element: 7 + Nodes: + 7 8 +Element: 8 + Nodes: + 8 9 +Element: 9 + Nodes: + 9 10 +Element: 10 + Nodes: + 6 11 +Element: 11 + Nodes: + 11 12 +Element: 12 + Nodes: + 12 13 +Element: 13 + Nodes: + 13 14 +Element: 14 + Nodes: + 15 16 +Element: 15 + Nodes: + 16 17 +Element: 16 + Nodes: + 17 18 +Element: 17 + Nodes: + 18 19 +Element: 18 + Nodes: + 19 20 +Element: 19 + Nodes: + 20 21 +Element: 20 + Nodes: + 21 22 +Element: 21 + Nodes: + 22 23 +Element: 22 + Nodes: + 23 24 +Element: 23 + Nodes: + 25 26 +Element: 24 + Nodes: + 26 27 +Element: 25 + Nodes: + 27 28 +Element: 26 + Nodes: + 28 29 +Element: 27 + Nodes: + 29 30 +Element: 28 + Nodes: + 30 31 +Element: 29 + Nodes: + 31 32 +Element: 30 + Nodes: + 32 33 +Element: 31 + Nodes: + 33 34 +Element: 32 + Nodes: + 30 35 +Element: 33 + Nodes: + 35 36 +Element: 34 + Nodes: + 36 37 +Element: 35 + Nodes: + 19 38 +Element: 36 + Nodes: + 38 39 +Element: 37 + Nodes: + 39 40 +Element: 38 + Nodes: + 40 32 +Element: 39 + Nodes: + 41 42 +Element: 40 + Nodes: + 43 44 +Element: 41 + Nodes: + 44 45 +Element: 42 + Nodes: + 45 46 +Element: 43 + Nodes: + 46 47 +Element: 44 + Nodes: + 47 48 +Element: 45 + Nodes: + 48 49 +Element: 46 + Nodes: + 49 50 +Element: 47 + Nodes: + 50 51 +Element: 48 + Nodes: + 51 52 +Element: 49 + Nodes: + 52 53 +Element: 50 + Nodes: + 53 54 +Element: 51 + Nodes: + 54 43 +Element: 52 + Nodes: + 55 56 +Element: 53 + Nodes: + 56 57 +Element: 54 + Nodes: + 57 58 +Element: 55 + Nodes: + 58 59 +Element: 56 + Nodes: + 59 60 +Element: 57 + Nodes: + 60 61 +Element: 58 + Nodes: + 61 62 +Element: 59 + Nodes: + 62 63 +Element: 60 + Nodes: + 63 64 +Element: 61 + Nodes: + 64 65 +Element: 62 + Nodes: + 65 66 +Element: 63 + Nodes: + 66 67 +Element: 64 + Nodes: + 67 68 +Element: 65 + Nodes: + 68 55 +Group name: 00001 +!#nodeset nodes +Node group: +43..54 +!#mesh mesh1d, dimension=1, nodeset=nodes +Element group: +40..51 +Group name: 00002 +!#nodeset nodes +Node group: +55..68 +!#mesh mesh1d, dimension=1, nodeset=nodes +Element group: +52..65 +Group name: Epineurium +!#nodeset nodes +Node group: +43..68 +!#mesh mesh1d, dimension=1, nodeset=nodes +Element group: +40..65 +Group name: Fascicle +!#nodeset nodes +Node group: +15..40 +!#mesh mesh1d, dimension=1, nodeset=nodes +Element group: +14..38 +Group name: http://purl.obolibrary.org/obo/UBERON_0000124 +!#nodeset nodes +Node group: +43..68 +!#mesh mesh1d, dimension=1, nodeset=nodes +Element group: +40..65 +Group name: http://purl.obolibrary.org/obo/UBERON_0035020 +!#nodeset nodes +Node group: +1..10 +!#mesh mesh1d, dimension=1, nodeset=nodes +Element group: +1..9 +Group name: http://uri.interlex.org/base/ilx_0738426 +!#nodeset nodes +Node group: +15..40 +!#mesh mesh1d, dimension=1, nodeset=nodes +Element group: +14..38 +Group name: left A branch 2 +!#nodeset nodes +Node group: +6,11..14 +!#mesh mesh1d, dimension=1, nodeset=nodes +Element group: +10..13 +Group name: left vagus X nerve trunk +!#nodeset nodes +Node group: +1..10 +!#mesh mesh1d, dimension=1, nodeset=nodes +Element group: +1..9 +Group name: marker +!#nodeset nodes +Node group: +69..70 +Group name: unknown +!#nodeset nodes +Node group: +41..42 +!#mesh mesh1d, dimension=1, nodeset=nodes +Element group: +39 diff --git a/tests/test_vagus.py b/tests/test_vagus.py new file mode 100644 index 0000000..b7f9c4b --- /dev/null +++ b/tests/test_vagus.py @@ -0,0 +1,70 @@ +import os +import unittest +from segmentationstitcher.annotation import AnnotationCategory +from segmentationstitcher.stitcher import Stitcher +from tests.testutils import assertAlmostEqualList + +here = os.path.abspath(os.path.dirname(__file__)) + + +class StitchVagusTestCase(unittest.TestCase): + + def test_io_vagus1(self): + """ + Test loading, modifying and serialising synthetic vagus nerve/fascicle segmentations. + """ + resource_names = [ + "vagus-segment1.exf", + "vagus-segment2.exf", + "vagus-segment3.exf", + ] + TOL = 1.0E-7 + zero = [0.0, 0.0, 0.0] + new_translation = [5.0, 0.5, 0.1] + segmentation_file_names = [os.path.join(here, "resources", resource_name) for resource_name in resource_names] + stitcher1 = Stitcher(segmentation_file_names) + segments1 = stitcher1.get_segments() + self.assertEqual(3, len(segments1)) + segment12 = segments1[1] + self.assertEqual("vagus-segment2", segment12.get_name()) + assertAlmostEqualList(self, zero, segment12.get_translation(), delta=TOL) + segment12.set_translation(new_translation) + annotations1 = stitcher1.get_annotations() + self.assertEqual(7, len(annotations1)) + self.assertEqual(1, stitcher1.get_version()) + annotation11 = annotations1[0] + self.assertEqual("Epineurium", annotation11.get_name()) + self.assertEqual("http://purl.obolibrary.org/obo/UBERON_0000124", annotation11.get_term()) + self.assertEqual(AnnotationCategory.UNCONNECTED_GENERAL, annotation11.get_category()) + annotation12 = annotations1[1] + self.assertEqual("Fascicle", annotation12.get_name()) + self.assertEqual("http://uri.interlex.org/base/ilx_0738426", annotation12.get_term()) + self.assertEqual(AnnotationCategory.CONNECTED_COMPLEX_NETWORK, annotation12.get_category()) + annotation15 = annotations1[4] + self.assertEqual("left vagus X nerve trunk", annotation15.get_name()) + self.assertEqual('http://purl.obolibrary.org/obo/UBERON_0035020', annotation15.get_term()) + self.assertEqual(AnnotationCategory.CONNECTED_SIMPLE_NETWORK, annotation15.get_category()) + annotation17 = annotations1[6] + self.assertEqual("unknown", annotation17.get_name()) + self.assertEqual(AnnotationCategory.UNCONNECTED_GENERAL, annotation17.get_category()) + annotation17.set_category(AnnotationCategory.EXCLUDE) + + settings = stitcher1.encode_settings() + self.assertEqual(3, len(settings["segments"])) + self.assertEqual(7, len(settings["annotations"])) + self.assertEqual(1, settings["version"]) + assertAlmostEqualList(self, new_translation, settings["segments"][1]["translation"], delta=TOL) + self.assertEqual(AnnotationCategory.EXCLUDE.name, settings["annotations"][6]["category"]) + + stitcher2 = Stitcher(segmentation_file_names) + stitcher2.decode_settings(settings) + segments2 = stitcher2.get_segments() + segment22 = segments2[1] + assertAlmostEqualList(self, new_translation, segment22.get_translation(), delta=TOL) + annotations2 = stitcher2.get_annotations() + annotation27 = annotations2[6] + self.assertEqual(AnnotationCategory.EXCLUDE, annotation27.get_category()) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/testutils.py b/tests/testutils.py new file mode 100644 index 0000000..4eb9d27 --- /dev/null +++ b/tests/testutils.py @@ -0,0 +1,8 @@ +""" +Utility function for tests. +""""" + +def assertAlmostEqualList(testcase, actualList, expectedList, delta): + assert len(actualList) == len(expectedList) + for actual, expected in zip(actualList, expectedList): + testcase.assertAlmostEqual(actual, expected, delta=delta) From eef3eb6d4bb8f220d45d21df589053b7f3dcb66d Mon Sep 17 00:00:00 2001 From: Richard Christie Date: Wed, 14 Aug 2024 10:04:47 +1200 Subject: [PATCH 02/13] Improve category names and group management --- src/segmentationstitcher/annotation.py | 70 +++++++++++++++++------- src/segmentationstitcher/segment.py | 74 ++++++++++++++++++++++++++ src/segmentationstitcher/stitcher.py | 48 +++++++++++++++-- tests/resources/vagus-segment1.exf | 3 +- tests/resources/vagus-segment2.exf | 3 +- tests/resources/vagus-segment3.exf | 3 +- tests/test_vagus.py | 30 ++++++++--- 7 files changed, 198 insertions(+), 33 deletions(-) diff --git a/src/segmentationstitcher/annotation.py b/src/segmentationstitcher/annotation.py index e62d78f..a858e61 100644 --- a/src/segmentationstitcher/annotation.py +++ b/src/segmentationstitcher/annotation.py @@ -11,10 +11,18 @@ class AnnotationCategory(Enum): """ How to process segmentations with this annotation. """ - CONNECTED_SIMPLE_NETWORK = 1 # a simple connected network graph - CONNECTED_COMPLEX_NETWORK = 2 # a complex network of connected parallel sections e.g. fascicles. - UNCONNECTED_GENERAL = 3 # contours and other segmentations which are not connected but are included in output - EXCLUDE = 4 # segmentations to exclude from the output + EXCLUDE = 0 # segmentations to exclude from the output + GENERAL = 1 # for segmentations which are not connected but are included in output + INDEPENDENT_NETWORK = 2 # networks which only connect with the same annotation + NETWORK_GROUP_1 = 3 # network group 1, any segmentations with this category may connect + NETWORK_GROUP_2 = 4 # network group 2, any segmentations with this category may connect + + def get_group_name(self): + """ + Get name of Zinc group to put all segmentations with this category. + :return: String name. + """ + return '.' + self.name class Annotation: @@ -34,6 +42,7 @@ def __init__(self, name: str, term, dimension, category: AnnotationCategory): self._term = term self._dimension = dimension self._category = category + self._category_change_callback = None def decode_settings(self, settings_in: dict): """ @@ -69,7 +78,22 @@ def get_category(self): return self._category def set_category(self, category): - self._category = category + old_category = self._category + if category != old_category: + self._category = category + if self._category_change_callback: + self._category_change_callback(self, old_category) + + def set_category_change_callback(self, category_change_callback): + """ + Set up client to be informed when annotation category is changed. + Typically used to update category groups for user interface. + :param category_change_callback: Callable with signature (annotation, old_category) + """ + self._category_change_callback = category_change_callback + + def set_category_by_name(self, category_name): + self.set_category(AnnotationCategory[category_name]) def get_dimension(self): return self._dimension @@ -89,19 +113,19 @@ def set_term(self, term): self._term = term -def region_get_annotations(region, simple_network_keywords, complex_network_keywords, term_keyword="http"): +def region_get_annotations(region, network_group1_keywords, network_group2_keywords, term_keywords): """ Get annotation group names and terms from region's non-empty groups. Groups with names consisting only of numbers are ignored as we're needlessly getting these for part contours. - After sorting for simple and complex networks and terms, remaining annotations are marked as general unconnected. + After sorting for network groups and terms, remaining annotations are marked as general unconnected. :param region: Zinc region to analyse groups in. - :param simple_network_keywords: Annotation names containing any of these keywords are marked as simple networks. - Must use lower case-folded. Comparison is case-insensitive. - :param complex_network_keywords: Annotation names containing any of these keywords are marked as complex networks. - Must use lower case-folded. Comparison is case-insensitive. - :param term_keyword: Groups with names containing this keyword are matched to other groups with the same content, - and supply the term name for them instead of making another Annotation. If no matching group is supplied these - are used as names and terms. + :param network_group1_keywords: Annotation names with any of these keywords are put in network group 1 category. + Must be lower case for comparison. + :param network_group2_keywords: Annotation names with any of these keywords are put in network group 2 category. + Must use lower case for comparison. + :param term_keywords: Annotation names containing any of these keywords are considered ontological term ids. These + are matched to other groups with the same content, and supply the term name for them instead of making another + Annotation. If no matching group is supplied these are used as names and terms. :return: list of Annotation. """ fieldmodule = region.getFieldmodule() @@ -122,18 +146,24 @@ def region_get_annotations(region, simple_network_keywords, complex_network_keyw continue # empty group if lower_name.isdigit(): continue # ignore as these can never be valid annotation names - category = AnnotationCategory.UNCONNECTED_GENERAL - for keyword in simple_network_keywords: + category = AnnotationCategory.GENERAL + for keyword in network_group1_keywords: if keyword in lower_name: - category = AnnotationCategory.CONNECTED_SIMPLE_NETWORK + category = AnnotationCategory.NETWORK_GROUP_1 break else: - for keyword in complex_network_keywords: + for keyword in network_group2_keywords: if keyword in lower_name: - category = AnnotationCategory.CONNECTED_COMPLEX_NETWORK + category = AnnotationCategory.NETWORK_GROUP_2 break annotation = Annotation(name, None, dimension, category) - if term_keyword in lower_name: + is_term = False + if category == AnnotationCategory.GENERAL: + for keyword in term_keywords: + if keyword in lower_name: + is_term = True + break + if is_term: term_annotations.append(annotation) else: annotations.append(annotation) diff --git a/src/segmentationstitcher/segment.py b/src/segmentationstitcher/segment.py index 86c74e0..bc20d50 100644 --- a/src/segmentationstitcher/segment.py +++ b/src/segmentationstitcher/segment.py @@ -1,7 +1,11 @@ """ A segment of the segmentation data, generally from a separate image block. """ +from cmlibs.utils.zinc.group import group_add_group_local_contents, group_remove_group_local_contents +from cmlibs.utils.zinc.general import ChangeManager +from cmlibs.zinc.field import Field from cmlibs.zinc.result import RESULT_OK +from segmentationstitcher.annotation import AnnotationCategory class Segment: @@ -27,6 +31,14 @@ def __init__(self, name, segmentation_file_name, root_region): result = self._raw_region.readFile(segmentation_file_name) assert result == RESULT_OK, \ "Could not read segmentation file " + segmentation_file_name + # ensure category groups exist: + fieldmodule = self._raw_region.getFieldmodule() + with ChangeManager(fieldmodule): + for category in AnnotationCategory: + group_name = category.get_group_name() + group = fieldmodule.createFieldGroup() + group.setName(group_name) + group.setManaged(True) self._rotation = [0.0, 0.0, 0.0] self._translation = [0.0, 0.0, 0.0] @@ -61,6 +73,30 @@ def get_base_region(self): """ return self._base_region + def get_annotation_group(self, annotation): + """ + Get Zinc group containing segmentations for the supplied annotation + :param annotation: An Annotation object. + :return: Zinc FieldGroup in the segment's raw region, or None if not present in segment. + """ + fieldmodule = self._raw_region.getFieldmodule() + annotation_group = fieldmodule.findFieldByName(annotation.get_name()).castGroup() + if annotation_group.isValid(): + return annotation_group + return None + + def get_category_group(self, category): + """ + Get Zinc group in which segmentations with the supplied annotation category are maintained + for visualisation. + :param category: The AnnotationCategory to query. + :return: Zinc FieldGroup in the segment's raw region. + """ + fieldmodule = self._raw_region.getFieldmodule() + group_name = category.get_group_name() + group = fieldmodule.findFieldByName(group_name).castGroup() + return group + def get_name(self): return self._name @@ -84,3 +120,41 @@ def get_translation(self): def set_translation(self, translation): assert len(translation) == 3 self._translation = translation + + def update_annotation_category(self, annotation, old_category=AnnotationCategory.EXCLUDE): + """ + Ensures special groups representing annotion categories contain via addition or removal the + correct contents for this annotation. + :param annotation: The annotation to update category group for. Ensures its local contents + (elements, nodes, datapoints) are in its category group. + :param old_category: The old category for this annotation, i.e. category group to remove from. + """ + new_category = annotation.get_category() + if new_category == old_category: + return + annotation_group = self.get_annotation_group(annotation) + if not annotation_group: + return # not present in this segment + fieldmodule = self._raw_region.getFieldmodule() + with ChangeManager(fieldmodule): + old_category_group = self.get_category_group(old_category) + group_remove_group_local_contents(old_category_group, annotation_group) + new_category_group = self.get_category_group(new_category) + group_add_group_local_contents(new_category_group, annotation_group) + + def reset_annotation_category_groups(self, annotations): + """ + Rebuild all annotation category groups e.g. after loading settings. + :param annotations: List of all annotations from stitcher. + """ + fieldmodule = self._raw_region.getFieldmodule() + with ChangeManager(fieldmodule): + # clear all category groups + for category in AnnotationCategory: + category_group = self.get_category_group(category) + category_group.clear() + for annotation in annotations: + annotation_group = self.get_annotation_group(annotation) + if annotation_group: + category_group = self.get_category_group(annotation.get_category()) + group_add_group_local_contents(category_group, annotation_group) diff --git a/src/segmentationstitcher/stitcher.py b/src/segmentationstitcher/stitcher.py index 89458d2..e69d1b6 100644 --- a/src/segmentationstitcher/stitcher.py +++ b/src/segmentationstitcher/stitcher.py @@ -1,9 +1,12 @@ """ Interface for stitching segmentation data from and calculating transformations between adjacent image blocks. """ +from cmlibs.utils.zinc.general import HierarchicalChangeManager from cmlibs.zinc.context import Context from segmentationstitcher.segment import Segment from segmentationstitcher.annotation import region_get_annotations + +import copy from pathlib import Path @@ -12,13 +15,20 @@ class Stitcher: Interface for stitching segmentation data from and calculating transformations between adjacent image blocks. """ - def __init__(self, segmentation_file_names: list): + def __init__(self, segmentation_file_names: list, network_group1_keywords, network_group2_keywords): """ :param segmentation_file_names: List of filenames containing raw segmentations in Zinc format. + :param network_group1_keywords: List of keywords. Segmented networks annotated with any of these keywords are + initially assigned to network group 1, allowing them to be stitched together. + :param network_group2_keywords: List of keywords. Segmented networks annotated with any of these keywords are + initially assigned to network group 2, allowing them to be stitched together. """ - self._context = Context("Scaffoldfitter") + self._context = Context("Segmentation Stitcher") self._root_region = self._context.getDefaultRegion() self._annotations = [] + self._network_group1_keywords = copy.deepcopy(network_group1_keywords) + self._network_group2_keywords = copy.deepcopy(network_group2_keywords) + self._term_keywords = ['fma:', 'fma_', 'ilx:', 'ilx_', 'uberon:', 'uberon_'] self._segments = [] self._version = 1 # increment when new settings added to migrate older serialised settings for segmentation_file_name in segmentation_file_names: @@ -26,8 +36,8 @@ def __init__(self, segmentation_file_names: list): segment = Segment(name, segmentation_file_name, self._root_region) self._segments.append(segment) segment_annotations = region_get_annotations( - segment.get_raw_region(), simple_network_keywords=["vagus", "nerve", "trunk", "branch"], - complex_network_keywords=["fascicle"], term_keyword="http") + segment.get_raw_region(), self._network_group1_keywords, self._network_group2_keywords, + self._term_keywords) for segment_annotation in segment_annotations: name = segment_annotation.get_name() term = segment_annotation.get_term() @@ -42,6 +52,11 @@ def __init__(self, segmentation_file_names: list): # print("Add annoation name", name, "term", term, "dim", segment_annotation.get_dimension(), # "category", segment_annotation.get_category()) self._annotations.insert(index, segment_annotation) + with HierarchicalChangeManager(self._root_region): + for segment in self._segments: + segment.reset_annotation_category_groups(self._annotations) + for annotation in self._annotations: + annotation.set_category_change_callback(self._annotation_change) def decode_settings(self, settings_in: dict): """ @@ -97,6 +112,10 @@ def decode_settings(self, settings_in: dict): else: print("WARNING: Segmentation Stitcher. Segment with name", name, "not found in settings; using defaults. Have input files changed?") + with HierarchicalChangeManager(self._root_region): + for segment in self._segments: + segment.reset_annotation_category_groups(self._annotations) + def encode_settings(self) -> dict: """ @@ -109,10 +128,31 @@ def encode_settings(self) -> dict: } return settings + def _annotation_change(self, annotation, old_category): + """ + Callback from annotation that its category has changed. + Update segment category groups. + :param annotation: Annotation that has changed category. + :param old_category: The old category to remove segmentations with annotation from. + """ + with HierarchicalChangeManager(self._root_region): + for segment in self._segments: + segment.update_annotation_category(annotation, old_category) + def get_annotations(self): return self._annotations + + def get_context(self): + return self._context + + def get_root_region(self): + return self._root_region + def get_segments(self): return self._segments def get_version(self): return self._version + + def write_output_segmentation_file(self, file_name): + pass diff --git a/tests/resources/vagus-segment1.exf b/tests/resources/vagus-segment1.exf index ec685f3..e24407c 100644 --- a/tests/resources/vagus-segment1.exf +++ b/tests/resources/vagus-segment1.exf @@ -431,6 +431,7 @@ Node: 84 -4.364982277527066e-01 -3.626326906335586e-01 1.000000000000000e-02 +!#nodeset datapoints Define node template: node2 Shape. Dimension=0 #Fields=3 @@ -805,6 +806,6 @@ Node group: Element group: 1..10 Group name: marker -!#nodeset nodes +!#nodeset datapoints Node group: 85..86 diff --git a/tests/resources/vagus-segment2.exf b/tests/resources/vagus-segment2.exf index c400f23..80254c9 100644 --- a/tests/resources/vagus-segment2.exf +++ b/tests/resources/vagus-segment2.exf @@ -381,6 +381,7 @@ Node: 74 3.870986066902640e-02 -5.955710515963688e-01 1.000000000000000e-02 +!#nodeset datapoints Define node template: node2 Shape. Dimension=0 #Fields=3 @@ -709,6 +710,6 @@ Node group: Element group: 1..10 Group name: marker -!#nodeset nodes +!#nodeset datapoints Node group: 75 diff --git a/tests/resources/vagus-segment3.exf b/tests/resources/vagus-segment3.exf index 786c820..025b988 100644 --- a/tests/resources/vagus-segment3.exf +++ b/tests/resources/vagus-segment3.exf @@ -351,6 +351,7 @@ Node: 68 -4.359756419261278e-01 -5.458579284260732e-01 1.000000000000000e-02 +!#nodeset datapoints Define node template: node2 Shape. Dimension=0 #Fields=3 @@ -667,7 +668,7 @@ Node group: Element group: 1..9 Group name: marker -!#nodeset nodes +!#nodeset datapoints Node group: 69..70 Group name: unknown diff --git a/tests/test_vagus.py b/tests/test_vagus.py index b7f9c4b..31cf5e6 100644 --- a/tests/test_vagus.py +++ b/tests/test_vagus.py @@ -22,7 +22,9 @@ def test_io_vagus1(self): zero = [0.0, 0.0, 0.0] new_translation = [5.0, 0.5, 0.1] segmentation_file_names = [os.path.join(here, "resources", resource_name) for resource_name in resource_names] - stitcher1 = Stitcher(segmentation_file_names) + network_group1_keywords = ["vagus", "nerve", "trunk", "branch"] + network_group2_keywords = ["fascicle"] + stitcher1 = Stitcher(segmentation_file_names, network_group1_keywords, network_group2_keywords) segments1 = stitcher1.get_segments() self.assertEqual(3, len(segments1)) segment12 = segments1[1] @@ -35,19 +37,35 @@ def test_io_vagus1(self): annotation11 = annotations1[0] self.assertEqual("Epineurium", annotation11.get_name()) self.assertEqual("http://purl.obolibrary.org/obo/UBERON_0000124", annotation11.get_term()) - self.assertEqual(AnnotationCategory.UNCONNECTED_GENERAL, annotation11.get_category()) + self.assertEqual(AnnotationCategory.GENERAL, annotation11.get_category()) annotation12 = annotations1[1] self.assertEqual("Fascicle", annotation12.get_name()) self.assertEqual("http://uri.interlex.org/base/ilx_0738426", annotation12.get_term()) - self.assertEqual(AnnotationCategory.CONNECTED_COMPLEX_NETWORK, annotation12.get_category()) + self.assertEqual(AnnotationCategory.NETWORK_GROUP_2, annotation12.get_category()) annotation15 = annotations1[4] self.assertEqual("left vagus X nerve trunk", annotation15.get_name()) self.assertEqual('http://purl.obolibrary.org/obo/UBERON_0035020', annotation15.get_term()) - self.assertEqual(AnnotationCategory.CONNECTED_SIMPLE_NETWORK, annotation15.get_category()) + self.assertEqual(AnnotationCategory.NETWORK_GROUP_1, annotation15.get_category()) annotation17 = annotations1[6] self.assertEqual("unknown", annotation17.get_name()) - self.assertEqual(AnnotationCategory.UNCONNECTED_GENERAL, annotation17.get_category()) + self.assertEqual(AnnotationCategory.GENERAL, annotation17.get_category()) + + # test changing category and that category groups are updated + segment13 = segments1[2] + mesh1d = segment13.get_raw_region().getFieldmodule().findMeshByDimension(1) + exclude13_group = segment13.get_category_group(AnnotationCategory.EXCLUDE) + exclude13_mesh_group = exclude13_group.getMeshGroup(mesh1d) + general13_group = segment13.get_category_group(AnnotationCategory.GENERAL) + general13_mesh_group = general13_group.getMeshGroup(mesh1d) + self.assertFalse(exclude13_mesh_group.isValid()) + self.assertEqual(27, general13_mesh_group.getSize()) + annotation17_group = segment13.get_annotation_group(annotation17) + annotation17_mesh_group = annotation17_group.getMeshGroup(mesh1d) + self.assertEqual(1, annotation17_mesh_group.getSize()) annotation17.set_category(AnnotationCategory.EXCLUDE) + exclude13_mesh_group = exclude13_group.getMeshGroup(mesh1d) + self.assertEqual(1, exclude13_mesh_group.getSize()) + self.assertEqual(26, general13_mesh_group.getSize()) settings = stitcher1.encode_settings() self.assertEqual(3, len(settings["segments"])) @@ -56,7 +74,7 @@ def test_io_vagus1(self): assertAlmostEqualList(self, new_translation, settings["segments"][1]["translation"], delta=TOL) self.assertEqual(AnnotationCategory.EXCLUDE.name, settings["annotations"][6]["category"]) - stitcher2 = Stitcher(segmentation_file_names) + stitcher2 = Stitcher(segmentation_file_names, network_group1_keywords, network_group2_keywords) stitcher2.decode_settings(settings) segments2 = stitcher2.get_segments() segment22 = segments2[1] From 72561714b82001ae3ce536f28cb827c7964135be Mon Sep 17 00:00:00 2001 From: Richard Christie Date: Thu, 19 Sep 2024 13:03:50 +1200 Subject: [PATCH 03/13] Add connection class --- src/segmentationstitcher/connection.py | 73 ++++++++++++++++++++++++++ src/segmentationstitcher/segment.py | 2 +- src/segmentationstitcher/stitcher.py | 41 +++++++++++++++ tests/test_vagus.py | 4 ++ 4 files changed, 119 insertions(+), 1 deletion(-) create mode 100644 src/segmentationstitcher/connection.py diff --git a/src/segmentationstitcher/connection.py b/src/segmentationstitcher/connection.py new file mode 100644 index 0000000..102d71c --- /dev/null +++ b/src/segmentationstitcher/connection.py @@ -0,0 +1,73 @@ +""" +A connection between segments in the segmentation data. +""" +from cmlibs.utils.zinc.general import ChangeManager +from segmentationstitcher.annotation import AnnotationCategory + + +class Connection: + """ + A connection between segments in the segmentation data. + """ + _separator = " - " + + def __init__(self, segments, root_region): + """ + :param segments: List of 2 Stitcher Segment objects. + :param root_region: Zinc root region to create segment region under. + """ + assert len(segments) == 2, "Only supports connections between 2 segments" + self._name = self._separator.join(segment.get_name() for segment in segments) + self._segments = segments + self._region = root_region.createChild(self._name) + assert self._region.isValid(), \ + "Cannot create connection region " + self._name + ". Name may already be in use?" + # ensure category groups exist: + fieldmodule = self._region.getFieldmodule() + with ChangeManager(fieldmodule): + for category in AnnotationCategory: + group_name = category.get_group_name() + group = fieldmodule.createFieldGroup() + group.setName(group_name) + group.setManaged(True) + self._linked_nodes = [] # (segment0_node_identifier, segment1_node_identifier) + + def decode_settings(self, settings_in: dict): + """ + Update segment settings from JSON dict containing serialised settings. + :param settings_in: Dictionary of settings as produced by encode_settings(). + :param all_segments: List of all segments in Stitcher. + """ + settings_name = self._separator.join(settings_in["segments"]) + assert settings_name == self._name + # update current settings to gain new ones and override old ones + settings = self.encode_settings() + settings.update(settings_in) + self._linked_nodes = settings["linked nodes"] + + def encode_settings(self) -> dict: + """ + Encode segment data in a dictionary to serialize. + :return: Settings in a dict ready for passing to json.dump. + """ + settings = { + "segments": [segment.get_name() for segment in self._segments], + "linked nodes": self._linked_nodes + } + return settings + + def get_name(self): + return self._name + + def get_region(self): + """ + Get the region containing any UI visualisation data for connection. + :return: Zinc Region. + """ + return self._region + + def get_segments(self): + """ + :return: List of segments joined by this connection. + """ + return self._segments diff --git a/src/segmentationstitcher/segment.py b/src/segmentationstitcher/segment.py index bc20d50..e421874 100644 --- a/src/segmentationstitcher/segment.py +++ b/src/segmentationstitcher/segment.py @@ -20,7 +20,7 @@ def __init__(self, name, segmentation_file_name, root_region): :param root_region: Zinc root region to create segment region under. """ self._name = name - self._segmentationFileName = segmentation_file_name + self._segmentation_file_name = segmentation_file_name # print("Create segment", self._name) self._base_region = root_region.createChild(self._name) assert self._base_region.isValid(), \ diff --git a/src/segmentationstitcher/stitcher.py b/src/segmentationstitcher/stitcher.py index e69d1b6..fac03a8 100644 --- a/src/segmentationstitcher/stitcher.py +++ b/src/segmentationstitcher/stitcher.py @@ -3,6 +3,7 @@ """ from cmlibs.utils.zinc.general import HierarchicalChangeManager from cmlibs.zinc.context import Context +from segmentationstitcher.connection import Connection from segmentationstitcher.segment import Segment from segmentationstitcher.annotation import region_get_annotations @@ -30,6 +31,7 @@ def __init__(self, segmentation_file_names: list, network_group1_keywords, netwo self._network_group2_keywords = copy.deepcopy(network_group2_keywords) self._term_keywords = ['fma:', 'fma_', 'ilx:', 'ilx_', 'uberon:', 'uberon_'] self._segments = [] + self._connections = [] self._version = 1 # increment when new settings added to migrate older serialised settings for segmentation_file_name in segmentation_file_names: name = Path(segmentation_file_name).stem @@ -116,6 +118,22 @@ def decode_settings(self, settings_in: dict): for segment in self._segments: segment.reset_annotation_category_groups(self._annotations) + # create connections from stitcher settings' connection serialisations + assert len(self._connections) == 0, "Cannot decode connections after any exist" + for connection_settings in settings_in["connections"]: + connection_segments = [] + for segment_name in connection_settings["segments"]: + for segment in self._segments: + if segment.get_name() == segment_name: + connection_segments.append(segment) + break + else: + print("WARNING: Segmentation Stitcher. Segment with name", segment_name, + "in connection settings not found; ignoring. Have input files changed?") + if len(connection_segments) >= 2: + connection = self.create_connection(connection_segments) + connection.decode_settings(connection_settings) + def encode_settings(self) -> dict: """ @@ -123,6 +141,7 @@ def encode_settings(self) -> dict: """ settings = { "annotations": [annotation.encode_settings() for annotation in self._annotations], + "connections": [connection.encode_settings() for connection in self._connections], "segments": [segment.encode_settings() for segment in self._segments], "version": self._version } @@ -142,6 +161,28 @@ def _annotation_change(self, annotation, old_category): def get_annotations(self): return self._annotations + def create_connection(self, segments): + """ + :param segments: List of 2 Stitcher Segment objects to connect. + :return: Connection object or None if invalid segments or connection between segments already exists + """ + if len(segments) != 2: + print("Only supports connections between 2 segments") + return None + for connection in self._connections: + if all(segment in connection.get_segments() for segment in segments): + print("Stitcher.create_connection: Already have a connection between segments") + return None + connection = Connection(segments, self._root_region) + self._connections.append(connection) + return connection + + def get_connections(self): + return self._connections + + def remove_connection(self, connection): + self._connections.remove(connection) + def get_context(self): return self._context diff --git a/tests/test_vagus.py b/tests/test_vagus.py index 31cf5e6..84eefe9 100644 --- a/tests/test_vagus.py +++ b/tests/test_vagus.py @@ -50,6 +50,10 @@ def test_io_vagus1(self): self.assertEqual("unknown", annotation17.get_name()) self.assertEqual(AnnotationCategory.GENERAL, annotation17.get_category()) + connection = stitcher1.create_connection([segments1[0], segments1[1]]) + connections = stitcher1.get_connections() + self.assertEqual(1, len(connections)) + # test changing category and that category groups are updated segment13 = segments1[2] mesh1d = segment13.get_raw_region().getFieldmodule().findMeshByDimension(1) From 7d1c0519d18a4319af4288bd9452087ef771ef7d Mon Sep 17 00:00:00 2001 From: Richard Christie Date: Fri, 8 Nov 2024 14:45:37 +1300 Subject: [PATCH 04/13] Find end point directions --- src/segmentationstitcher/segment.py | 339 ++++++++++++++++++++++++++- src/segmentationstitcher/stitcher.py | 19 +- 2 files changed, 349 insertions(+), 9 deletions(-) diff --git a/src/segmentationstitcher/segment.py b/src/segmentationstitcher/segment.py index e421874..319ddd0 100644 --- a/src/segmentationstitcher/segment.py +++ b/src/segmentationstitcher/segment.py @@ -1,11 +1,19 @@ """ A segment of the segmentation data, generally from a separate image block. """ +from builtins import enumerate + +from cmlibs.maths.vectorops import cross, dot, magnitude, matrix_mult, mult, normalize, set_magnitude, sub +from cmlibs.utils.zinc.field import ( + get_group_list, find_or_create_field_coordinates, find_or_create_field_finite_element) +from cmlibs.utils.zinc.finiteelement import evaluate_field_nodeset_range from cmlibs.utils.zinc.group import group_add_group_local_contents, group_remove_group_local_contents from cmlibs.utils.zinc.general import ChangeManager from cmlibs.zinc.field import Field +from cmlibs.zinc.node import Node from cmlibs.zinc.result import RESULT_OK from segmentationstitcher.annotation import AnnotationCategory +import math class Segment: @@ -32,15 +40,35 @@ def __init__(self, name, segmentation_file_name, root_region): assert result == RESULT_OK, \ "Could not read segmentation file " + segmentation_file_name # ensure category groups exist: - fieldmodule = self._raw_region.getFieldmodule() - with ChangeManager(fieldmodule): + self._raw_fieldmodule = self._raw_region.getFieldmodule() + with ChangeManager(self._raw_fieldmodule): for category in AnnotationCategory: group_name = category.get_group_name() - group = fieldmodule.createFieldGroup() + group = self._raw_fieldmodule.createFieldGroup() group.setName(group_name) group.setManaged(True) self._rotation = [0.0, 0.0, 0.0] self._translation = [0.0, 0.0, 0.0] + self._group_element_node_ids = {} + self._group_node_element_ids = {} + self._raw_fieldcache = self._raw_fieldmodule.createFieldcache() + self._raw_coordinates = self._raw_fieldmodule.findFieldByName("coordinates").castFiniteElement() + self._raw_radius = self._raw_fieldmodule.findFieldByName("radius").castFiniteElement() + self._raw_mesh1d = self._raw_fieldmodule.findMeshByDimension(1) + self._raw_nodes = self._raw_fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_NODES) + self._raw_minimums, self._raw_maximums = evaluate_field_nodeset_range(self._raw_coordinates, self._raw_nodes) + self._working_region = self._base_region.createChild("working") + self._working_fieldmodule = self._working_region.getFieldmodule() + self._working_datapoints = self._working_fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_DATAPOINTS) + self._working_coordinates = find_or_create_field_coordinates(self._working_fieldmodule) + self._working_radius_direction = find_or_create_field_finite_element( + self._working_fieldmodule, "radius_direction", 3) + self._working_best_fit_line_orientation = find_or_create_field_finite_element( + self._working_fieldmodule, "best_fit_line_orientation", 9) + self._element_node_ids, self._node_element_ids = self._get_element_node_maps() + self._raw_groups = get_group_list(self._raw_fieldmodule) + self._raw_mesh_groups = [group.getMeshGroup(self._raw_mesh1d) for group in self._raw_groups] + self._end_node_ids = self._get_end_node_ids() def decode_settings(self, settings_in: dict): """ @@ -66,9 +94,208 @@ def encode_settings(self) -> dict: } return settings + def _get_element_node_maps(self): + """ + Get maps from 1-D elements to nodes and nodes to elements for the raw data. + All elements are assumed to have the same linear interpolation. + :return: dict elementid -> list(nodeids), dict nodeid -> list(elementid) + """ + element_node_ids = {} + node_element_ids = {} + elem_iter = self._raw_mesh1d.createElementiterator() + element = elem_iter.next() + eft = element.getElementfieldtemplate(self._raw_coordinates, -1) # all elements assumed to use this + while element.isValid(): + element_id = element.getIdentifier() + node_ids = [] + for ln in range(1, 3): + node = element.getNode(eft, ln) + node_id = node.getIdentifier() + node_ids.append(node_id) + element_ids = node_element_ids.get(node_id) + if not element_ids: + element_ids = node_element_ids[node_id] = [] + element_ids.append(element_id) + element_node_ids[element_id] = node_ids + element = elem_iter.next() + return element_node_ids, node_element_ids + + def _get_end_node_ids(self): + """ + :return: List of identifiers of nodes at end points i.e. in only 1 element. + """ + end_node_ids = [] + for node_id, element_ids in self._node_element_ids.items(): + if len(element_ids) == 1: + end_node_ids.append(node_id) + return end_node_ids + + def _element_id_to_group(self, element_id): + """ + Get the first (should be only) Zinc Group containing raw element of supplied identifier. + :param node_id: Identifier of [end] node to query. + :return: Zinc Group, MeshGroup or None, None if not found. + """ + element = self._raw_mesh1d.findElementByIdentifier(element_id) + for i, mesh_group in enumerate(self._raw_mesh_groups): + if mesh_group.containsElement(element): + return self._raw_groups[i], mesh_group + return None, None + + def _track_segment(self, start_node_id, start_element_id, max_distance=None): + """ + Get coordinates and radii along segment from start_node_id in start_element_id, proceeding + first to other local node in element, until junction, end point or max_distance is tracked. + :param start_node_id: First node in path. + :param start_element_id: Element containing start_node_id and another node to be added. + :param max_distance: Maximum distance to track to from first node coordinates, or None for no limit. + :return: coordinates list, radius list, node id list, endElementId + """ + self._element_node_ids, self._node_element_ids + node_id = start_node_id + element_id = start_element_id + path_coordinates = [] + path_radii = [] + path_node_ids = [] + lastNode = False + while True: + node = self._raw_nodes.findNodeByIdentifier(node_id) + self._raw_fieldcache.setNode(node) + result, x = self._raw_coordinates.evaluateReal(self._raw_fieldcache, 3) + if result != RESULT_OK: + continue + path_node_ids.append(node_id) + path_coordinates.append(x) + result, r = self._raw_radius.evaluateReal(self._raw_fieldcache, 1) + if result != RESULT_OK: + r = 1.0 + path_radii.append(r) + if lastNode: + break + if (len(path_coordinates) > 1) and (max_distance is not None): + distance = magnitude(sub(x, path_coordinates[0])) + if distance > max_distance: + break + node_ids = self._element_node_ids[element_id] + node_id = node_ids[1] if (node_ids[0] == node_id) else node_ids[0] + element_ids = self._node_element_ids[node_id] + if len(element_ids) != 2: + lastNode = True + continue + element_id = element_ids[1] if (element_ids[0] == element_id) else element_ids[0] + return path_coordinates, path_radii, path_node_ids, element_id + + def _track_path(self, end_node_id, max_distance=None): + """ + Get coordinates and radii along path from end_node_id, continuing along + branches if in similar direction. + :param group_name: Group to use node-element maps for. + :param end_node_id: End node identifier to track from. Must be in only one element. + :param max_distance: Maximum distance to track to, or None for no limit. + :return: coordinates list, radius list, path node ids, start_x, end_x, mean_r + """ + element_ids = self._node_element_ids[end_node_id] + assert len(element_ids) == 1 + path_group = self._element_id_to_group(element_ids[0])[0] + path_coordinates = [] + path_radii = [] + path_node_ids = [] + path_mean_r = None + stop_node_id = end_node_id + stop_element_id = None + start_x = None + end_x = None + mean_r = None + remaining_max_distance = max_distance + last_direction = None + while (not path_coordinates) or (len(element_ids) > 2): + add_path_coordinates = None + add_path_radii = None + add_path_node_ids = None + add_path_error = None + add_element_id = None + for element_id in element_ids: + if element_id == stop_element_id: + continue + segment_group = self._element_id_to_group(element_id)[0] + if path_group and (segment_group != path_group): + continue + segment_coordinates, segment_radii, segment_node_ids, segment_stop_element_id =\ + self._track_segment(stop_node_id, element_id, remaining_max_distance) + segment_stop_node_id = segment_node_ids[-1] + if segment_stop_node_id in path_node_ids: + continue # avoid loops + if last_direction: + add_start_x, add_end_x, add_mean_r = fit_line(trial_coordinates, trial_radii)[0:3] + add_direction = sub(add_end_x, add_start_x) + if dot(normalize(last_direction), normalize(add_direction)) < 0.8: + continue # avoid sudden changes in direction + add_segment_coordinates = segment_coordinates if not path_coordinates else segment_coordinates[1:] + add_segment_radii = segment_radii if not path_radii else segment_radii[1:] + add_segment_node_ids = segment_node_ids if not path_node_ids else segment_node_ids[1:] + add_segment_mean_r = sum(add_segment_radii) / len(add_segment_radii) + trial_coordinates = path_coordinates + add_segment_coordinates + trial_radii = path_radii + add_segment_radii + trial_start_x, trial_end_x, trial_mean_r, trial_mean_projection_error =\ + fit_line(trial_coordinates, trial_radii) + radius_difference = math.fabs(add_segment_mean_r - path_mean_r) if (path_mean_r is not None) else 0.0 + trial_error = trial_mean_projection_error + radius_difference + if (add_element_id is None) or (trial_error < add_path_error): + add_path_coordinates = add_segment_coordinates + add_path_radii = add_segment_radii + add_path_node_ids = add_segment_node_ids + add_path_error = trial_error + add_node_id = segment_stop_node_id + add_element_id = segment_stop_element_id + start_x, end_x, mean_r = trial_start_x, trial_end_x, trial_mean_r + if not add_path_coordinates: + break + path_coordinates += add_path_coordinates + path_radii += add_path_radii + path_node_ids += add_path_node_ids + path_mean_r = sum(path_radii) / len(path_radii) + stop_node_id = add_node_id + stop_element_id = add_element_id + if max_distance: + remaining_max_distance = max_distance - magnitude(sub(path_coordinates[-1], path_coordinates[0])) + element_ids = self._node_element_ids[stop_node_id] + last_direction = sub(end_x, start_x) + # 2nd iteration of fit line removes outliers: + start_x, end_x, mean_r = fit_line(path_coordinates, path_radii, start_x, end_x, 0.5)[0:3] + return path_coordinates, path_radii, path_node_ids, start_x, end_x, mean_r + + def create_end_point_directions(self, max_distance): + """ + Track mean directions of network end points and create working objects for visualisation. + :param max_distance: Maximum length to track back from end point. + """ + nodetemplate = self._working_datapoints.createNodetemplate() + nodetemplate.defineField(self._working_coordinates) + nodetemplate.defineField(self._working_radius_direction) + nodetemplate.defineField(self._working_best_fit_line_orientation) + fieldcache = self._working_fieldmodule.createFieldcache() + for end_node_id in self._end_node_ids: + path_coordinates, path_radii, path_node_ids, start_x, end_x, mean_r =( + self._track_path(end_node_id, max_distance)) + # Future: want to extend length to be equivalent to path_coordinates + node = self._working_datapoints.createNode(-1, nodetemplate) + fieldcache.setNode(node) + radius_direction = set_magnitude(sub(start_x, end_x), mean_r) + self._working_coordinates.setNodeParameters(fieldcache, -1, Node.VALUE_LABEL_VALUE, 1, start_x) + self._working_radius_direction.setNodeParameters(fieldcache, -1, Node.VALUE_LABEL_VALUE, 1, + radius_direction) + direction1 = sub(end_x, start_x) + axis = [1.0, 0.0, 0.0] + if dot(normalize(direction1), axis) < 0.1: + axis = [0.0, 1.0, 0.0] + direction2 = set_magnitude(cross(axis, direction1), mean_r) + direction3 = set_magnitude(cross(direction1, direction2), mean_r) + self._working_best_fit_line_orientation.setNodeParameters(fieldcache, -1, Node.VALUE_LABEL_VALUE, 1, + direction1 + direction2 + direction3) + def get_base_region(self): """ - Get the base region for all segmentation and auxiliary data for this segment. + Get the base region for all segmentation and working data for this segment. :return: Zinc Region. """ return self._base_region @@ -97,9 +324,22 @@ def get_category_group(self, category): group = fieldmodule.findFieldByName(group_name).castGroup() return group + def get_end_point_fields(self): + """ + :return: End point coordinates, direction (out) and radius fields in working region. + """ + return self._working_coordinates, self._working_radius_direction, self._working_best_fit_line_orientation + def get_name(self): return self._name + def get_max_range(self): + """ + :return: Maximum range of raw coordinates on any axis x, y, z. + """ + raw_range = [self._raw_maximums[c] - self._raw_minimums[c] for c in range(3)] + return max(raw_range) + def get_raw_region(self): """ Get the raw region, a child of base region, into which the raw segmentation was loaded. @@ -121,6 +361,13 @@ def set_translation(self, translation): assert len(translation) == 3 self._translation = translation + def get_working_region(self): + """ + Get the working region, a child of base region, into which the non-raw visualisation objects go. + :return: Zinc Region. + """ + return self._working_region + def update_annotation_category(self, annotation, old_category=AnnotationCategory.EXCLUDE): """ Ensures special groups representing annotion categories contain via addition or removal the @@ -158,3 +405,87 @@ def reset_annotation_category_groups(self, annotations): if annotation_group: category_group = self.get_category_group(annotation.get_category()) group_add_group_local_contents(category_group, annotation_group) + + +def fit_line(path_coordinates, path_radii, x1=None, x2=None, filter_proportion=0.0): + """ + Compute best fit line to path coordinates, and mean radius of unfiltered points. + :param path_coordinates: List of coordinates along path to get best fit line to. + :param path_radii: List of radius values along path to get mean of. + :param x1: Initial start point for line. Default is first point coordinates. + :param x2: Initial end point for line. Default is last point coordinates. + :param filter_proportion: Proportion of data points to eliminate in order of + greatest projection normal to line. Default is no filtering. + :return: start_x, end_x, mean_r (of unfiltered points), mean_projection_error (of all points) + """ + assert len(path_coordinates) > 1 + if len(path_coordinates) == 2: + # avoid singular matrix + return path_coordinates[0], path_coordinates[-1], sum(path_radii) / len(path_radii), 0.0 + # project points onto line + start_coordinates = x1 if x1 else path_coordinates[0] + end_coordinates = x2 if x2 else path_coordinates[-1] + v = sub(end_coordinates, start_coordinates) + mag_v = magnitude(v) + d1 = mult(v, 1.0 / (mag_v * mag_v)) + # get 2 unit vectors normal to d1 + dt = [1.0, 0.0, 0.0] if (magnitude(cross(normalize(d1), [1.0, 0.0, 0.0])) > 0.1) else [0.0, 1.0, 0.0] + d2 = normalize(cross(dt, d1)) + d3 = normalize(cross(d1, d2)) + points_count = len(path_coordinates) + path_xi = [] # from 0.0 to 1.0 + path_projection_error = [] # magnitude of normal projection for filtering + filter_count = int(filter_proportion * points_count) + # need at least 2 points or no solution + if (points_count - filter_count) < 2: + filter_count = points_count - 2 + filter_indexes = [] + sum_projection_error = 0.0 + for d in range(points_count): + rx = sub(path_coordinates[d], start_coordinates) + dp = dot(rx, d1) + xi = min(1.0, max(dp, 0.0)) + path_xi.append(xi) + n2 = dot(rx, d2) + n3 = dot(rx, d3) + mag_normal = math.sqrt(n2 * n2 + n3 * n3) + sum_projection_error += mag_normal + path_projection_error.append(mag_normal) + for i in range(len(filter_indexes)): + if mag_normal > path_projection_error[filter_indexes[i]]: + filter_indexes.insert(i, d) + if len(filter_indexes) > filter_count: + filter_indexes.pop() + break + else: + if len(filter_indexes) < filter_count: + filter_indexes.append(d) + mean_projection_error = sum_projection_error / points_count + sum_r = 0.0 + a = [[0.0, 0.0], [0.0, 0.0]] # matrix + b = [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]] # RHS for each component + for d in range(points_count): + if d in filter_indexes: + continue + xi = path_xi[d] + phi1 = (1.0 - xi) + phi2 = xi + a[0][0] += phi1 * phi1 + a[0][1] += phi1 * phi2 + a[1][0] += phi2 * phi1 + a[1][1] += phi2 * phi2 + for c in range(3): + b[c][0] += phi1 * path_coordinates[d][c] + b[c][1] += phi2 * path_coordinates[d][c] + sum_r += path_radii[d] + mean_r = sum_r / (points_count - filter_count) + # invert matrix: + det_a = a[0][0] * a[1][1] - a[0][1] * a[1][0] + a_inv = [[a[1][1] / det_a, -a[0][1] / det_a], [-a[1][0] / det_a, a[0][0] / det_a]] + start_x = [a_inv[0][0] * rhs[0] + a_inv[0][1] * rhs[1] for rhs in b] + end_x = [a_inv[1][0] * rhs[0] + a_inv[1][1] * rhs[1] for rhs in b] + # print([a_inv[0][0] * a[0][0] + a_inv[0][1] * a[1][0], + # a_inv[0][0] * a[0][1] + a_inv[0][1] * a[1][1]], + # [a_inv[1][0] * a[0][0] + a_inv[1][1] * a[1][0], + # a_inv[1][0] * a[0][1] + a_inv[1][1] * a[1][1]]) + return start_x, end_x, mean_r, mean_projection_error diff --git a/src/segmentationstitcher/stitcher.py b/src/segmentationstitcher/stitcher.py index fac03a8..9c6659a 100644 --- a/src/segmentationstitcher/stitcher.py +++ b/src/segmentationstitcher/stitcher.py @@ -33,9 +33,11 @@ def __init__(self, segmentation_file_names: list, network_group1_keywords, netwo self._segments = [] self._connections = [] self._version = 1 # increment when new settings added to migrate older serialised settings + max_range_reciprocal_sum = 0.0 for segmentation_file_name in segmentation_file_names: name = Path(segmentation_file_name).stem segment = Segment(name, segmentation_file_name, self._root_region) + max_range_reciprocal_sum += 1.0 / segment.get_max_range() self._segments.append(segment) segment_annotations = region_get_annotations( segment.get_raw_region(), self._network_group1_keywords, self._network_group2_keywords, @@ -54,6 +56,11 @@ def __init__(self, segmentation_file_names: list, network_group1_keywords, netwo # print("Add annoation name", name, "term", term, "dim", segment_annotation.get_dimension(), # "category", segment_annotation.get_category()) self._annotations.insert(index, segment_annotation) + if self._segments: + # ask segments to track end distances using a global mean max_distance + max_distance = 0.25 * len(self._segments) / max_range_reciprocal_sum + for segment in self._segments: + segment.create_end_point_directions(max_distance) with HierarchicalChangeManager(self._root_region): for segment in self._segments: segment.reset_annotation_category_groups(self._annotations) @@ -68,10 +75,12 @@ def decode_settings(self, settings_in: dict): assert settings_in.get("annotations") and settings_in.get("segments") and settings_in.get("version"), \ "Stitcher.decode_settings: Invalid settings dictionary" # settings_version = settings_in["version"] + settings = self.encode_settings() + settings.update(settings_in) # update annotations and warn about differences processed_count = 0 - for annotation_settings in settings_in["annotations"]: + for annotation_settings in settings["annotations"]: name = annotation_settings["name"] term = annotation_settings["term"] for annotation in self._annotations: @@ -86,7 +95,7 @@ def decode_settings(self, settings_in: dict): for annotation in self._annotations: name = annotation.get_name() term = annotation.get_term() - for annotation_settings in settings_in["annotations"]: + for annotation_settings in settings["annotations"]: if (annotation_settings["name"] == name) and (annotation_settings["term"] == term): break else: @@ -95,7 +104,7 @@ def decode_settings(self, settings_in: dict): # update segment settings and warn about differences processed_count = 0 - for segment_settings in settings_in["segments"]: + for segment_settings in settings["segments"]: name = segment_settings["name"] for segment in self._segments: if segment.get_name() == name: @@ -108,7 +117,7 @@ def decode_settings(self, settings_in: dict): if processed_count != len(self._segments): for segment in self._segments: name = segment.get_name() - for segment_settings in settings_in["segments"]: + for segment_settings in settings["segments"]: if segment_settings["name"] == name: break else: @@ -120,7 +129,7 @@ def decode_settings(self, settings_in: dict): # create connections from stitcher settings' connection serialisations assert len(self._connections) == 0, "Cannot decode connections after any exist" - for connection_settings in settings_in["connections"]: + for connection_settings in settings["connections"]: connection_segments = [] for segment_name in connection_settings["segments"]: for segment in self._segments: From c69e3b4b0841f04a76720fd32b946671c8860105 Mon Sep 17 00:00:00 2001 From: Richard Christie Date: Mon, 11 Nov 2024 15:58:05 +1300 Subject: [PATCH 05/13] Add connection delete --- pyproject.toml | 2 +- src/segmentationstitcher/stitcher.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 46209b4..ba19b3a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [tool.setuptools-git-versioning] enabled = true [project] -name = "segmentation_stitcher" +name = "segmentationstitcher" dynamic = ["version"] keywords = ["Medical", "Image", "Segmentation", "Merge", "SPARC"] readme = "README.md" diff --git a/src/segmentationstitcher/stitcher.py b/src/segmentationstitcher/stitcher.py index 9c6659a..b35d9b9 100644 --- a/src/segmentationstitcher/stitcher.py +++ b/src/segmentationstitcher/stitcher.py @@ -186,6 +186,13 @@ def create_connection(self, segments): self._connections.append(connection) return connection + def delete_connection(self, connection): + """ + Delete the connection from the stitcher's list. + :param connection: Connection to delete. + """ + self._connections.remove(connection) + def get_connections(self): return self._connections From def7e1d1c74343f73311ad613169b24ede2711e5 Mon Sep 17 00:00:00 2001 From: Richard Christie Date: Fri, 22 Nov 2024 16:26:06 +1300 Subject: [PATCH 06/13] Initial orientation optimisation --- pyproject.toml | 3 +- src/segmentationstitcher/annotation.py | 3 + src/segmentationstitcher/connection.py | 422 ++++++++++++++++++++++++- src/segmentationstitcher/segment.py | 176 ++++++++--- src/segmentationstitcher/stitcher.py | 47 +-- 5 files changed, 585 insertions(+), 66 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ba19b3a..1954b1e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,8 @@ authors = [ dependencies = [ "cmlibs.maths>=0.6.2", "cmlibs.utils>=0.9", - "cmlibs.zinc>=4.1" + "cmlibs.zinc>=4.1", + "scipy" ] description = "Utility for stitching segmentations of networks and other features from multiple adjacent blocks" requires-python = ">=3.7" diff --git a/src/segmentationstitcher/annotation.py b/src/segmentationstitcher/annotation.py index a858e61..3c5918d 100644 --- a/src/segmentationstitcher/annotation.py +++ b/src/segmentationstitcher/annotation.py @@ -24,6 +24,9 @@ def get_group_name(self): """ return '.' + self.name + def is_connectable(self): + return self in (self.INDEPENDENT_NETWORK, self.NETWORK_GROUP_1, self.NETWORK_GROUP_2) + class Annotation: """ diff --git a/src/segmentationstitcher/connection.py b/src/segmentationstitcher/connection.py index 102d71c..40e1977 100644 --- a/src/segmentationstitcher/connection.py +++ b/src/segmentationstitcher/connection.py @@ -1,8 +1,17 @@ """ A connection between segments in the segmentation data. """ +from cmlibs.maths.vectorops import ( + add, cross, dot, div, euler_to_rotation_matrix, magnitude, matrix_inv, matrix_vector_mult, mult, normalize, sub) +from cmlibs.utils.zinc.field import ( + find_or_create_field_coordinates, find_or_create_field_finite_element, find_or_create_field_group) from cmlibs.utils.zinc.general import ChangeManager +from cmlibs.utils.zinc.group import group_add_group_local_contents +from cmlibs.zinc.element import Element, Elementbasis +from cmlibs.zinc.field import Field +from scipy.optimize import minimize from segmentationstitcher.annotation import AnnotationCategory +import math class Connection: @@ -11,10 +20,12 @@ class Connection: """ _separator = " - " - def __init__(self, segments, root_region): + def __init__(self, segments, root_region, annotations, max_distance): """ :param segments: List of 2 Stitcher Segment objects. :param root_region: Zinc root region to create segment region under. + :param annotations: List of all annotations from stitcher. + :param max_distance: Maximum distance directions are tracked along. Used to decide tolerance for distances. """ assert len(segments) == 2, "Only supports connections between 2 segments" self._name = self._separator.join(segment.get_name() for segment in segments) @@ -22,28 +33,43 @@ def __init__(self, segments, root_region): self._region = root_region.createChild(self._name) assert self._region.isValid(), \ "Cannot create connection region " + self._name + ". Name may already be in use?" + self._annotations = annotations + self._max_distance = max_distance # ensure category groups exist: fieldmodule = self._region.getFieldmodule() with ChangeManager(fieldmodule): + self._coordinates = find_or_create_field_coordinates(fieldmodule) + self._radius = find_or_create_field_finite_element(fieldmodule, "radius", 1, managed=True) for category in AnnotationCategory: group_name = category.get_group_name() group = fieldmodule.createFieldGroup() group.setName(group_name) group.setManaged(True) - self._linked_nodes = [] # (segment0_node_identifier, segment1_node_identifier) + self._linked_nodes = {} # dict: annotation name --> list of [segment0_node_identifier, segment1_node_identifier]] + for segment in self._segments: + segment.add_transformation_change_callback(self._segment_transformation_change) + + def detach(self): + """ + Need to call before destroying as segment callbacks maintain a handle to self. + """ + for segment in self._segments: + segment.remove_transformation_change_callback(self._segment_transformation_change) + self._region.getParent().removeChild(self._region) def decode_settings(self, settings_in: dict): """ Update segment settings from JSON dict containing serialised settings. :param settings_in: Dictionary of settings as produced by encode_settings(). - :param all_segments: List of all segments in Stitcher. """ settings_name = self._separator.join(settings_in["segments"]) assert settings_name == self._name # update current settings to gain new ones and override old ones settings = self.encode_settings() settings.update(settings_in) - self._linked_nodes = settings["linked nodes"] + linked_nodes = settings.get("linked nodes") + if isinstance(linked_nodes, dict): + self._linked_nodes = linked_nodes def encode_settings(self) -> dict: """ @@ -56,6 +82,35 @@ def encode_settings(self) -> dict: } return settings + def printLog(self): + logger = self._region.getContext().getLogger() + for index in range(logger.getNumberOfMessages()): + print(logger.getMessageTextAtIndex(index)) + + def get_annotation_group(self, annotation): + """ + Get Zinc group containing segmentations for the supplied annotation. + :param annotation: An Annotation object. + :return: Zinc FieldGroup in the connections' region, or None if not present. + """ + fieldmodule = self._region.getFieldmodule() + annotation_group = fieldmodule.findFieldByName(annotation.get_name()).castGroup() + if annotation_group.isValid(): + return annotation_group + return None + + def get_category_group(self, category): + """ + Get Zinc group in which segmentations with the supplied annotation category are maintained + for visualisation. + :param category: The AnnotationCategory to query. + :return: Zinc FieldGroup in the segment's raw region. + """ + fieldmodule = self._region.getFieldmodule() + group_name = category.get_group_name() + group = fieldmodule.findFieldByName(group_name).castGroup() + return group + def get_name(self): return self._name @@ -71,3 +126,362 @@ def get_segments(self): :return: List of segments joined by this connection. """ return self._segments + + def _segment_transformation_change(self, segment): + self.build_links() + self.update_annotation_category_groups(self._annotations) + + def add_linked_nodes(self, annotation, node_id0, node_id1): + """ + :param annotation: Annotation to use for link. + :param node_id0: Node identifier to link from segment[0]. + :param node_id1: Node identifier to link from segment[1]. + """ + annotation_name = annotation.get_name() + annotation_linked_nodes = self._linked_nodes.get(annotation_name) + if not annotation_linked_nodes: + self._linked_nodes[annotation_name] = annotation_linked_nodes = [] + annotation_linked_nodes.append([node_id0, node_id1]) + + def optimise_transformation(self): + """ + Optimise transformation of second segment to align with position and direction of nearest points between + both segments. + """ + segment_end_point_data = [] + initial_rotation = [] + initial_rotation_matrix = [] + for s, segment in enumerate(self._segments): + translation = segment.get_translation() + rotation = [math.radians(angle_degrees) for angle_degrees in segment.get_rotation()] + initial_rotation.append(rotation) + rotation_matrix = euler_to_rotation_matrix(rotation) if (rotation != [0.0, 0.0, 0.0]) else None + initial_rotation_matrix.append(rotation_matrix) + end_point_data = [] + raw_end_point_data = segment.get_end_point_data() + for node_id, data in raw_end_point_data.items(): + coordinates, direction, radius, annotation = data + transformed_coordinates = coordinates + if (annotation is not None) and annotation.get_category().is_connectable(): + if rotation_matrix: + transformed_coordinates = matrix_vector_mult(rotation_matrix, transformed_coordinates) + transformed_coordinates = add(transformed_coordinates, translation) + end_point_data.append((node_id, transformed_coordinates, coordinates, direction, radius, annotation)) + segment_end_point_data.append(end_point_data) + + mean_coordinates = [] + mean_directions = [] + for s, segment in enumerate(self._segments): + total_weight = 0.0 + distances = [] + max_distance = None + for node_id0, transformed_coordinates0, _, _, _, annotation0 in segment_end_point_data[s]: + category0 = annotation0.get_category() + distance = None + for node_id1, transformed_coordinates1, _, _, _, annotation1 in segment_end_point_data[s - 1]: + category1 = annotation1.get_category() + if (category0 != category1) or ( + (category0 == AnnotationCategory.INDEPENDENT_NETWORK) and (annotation0 != annotation1)): + continue # end points are not allowed to join + tmp_distance = magnitude(sub(transformed_coordinates0, transformed_coordinates1)) + if (distance is None) or (tmp_distance < distance): + distance = tmp_distance + if (distance is not None) and ((max_distance is None) or (distance > max_distance)): + max_distance = distance + distances.append(distance) + if max_distance is None: + print("Segmentation Stitcher. No connectable points to optimise transformation with") + return + nearby_proportion = 0.1 # proportion of max distance under which distance weighting is the same + nearby_distance = max_distance * nearby_proportion + sum_coordinates = [0.0, 0.0, 0.0] + sum_transformed_coordinates = [0.0, 0.0, 0.0] + sum_direction = [0.0, 0.0, 0.0] + sum_transformed_direction = [0.0, 0.0, 0.0] + total_weight = 0.0 + for p, data in enumerate(segment_end_point_data[s]): + distance = distances[p] + if distance is None: + continue + _, transformed_coordinates, coordinates, direction, radius, annotation = data + if distance < nearby_distance: + distance = nearby_distance + weight = radius / (distance * distance) + sum_coordinates = add(sum_coordinates, mult(coordinates, weight)) + sum_direction = add(sum_direction, mult(direction, weight)) + total_weight += weight + mean_coordinates.append(div(sum_coordinates, total_weight)) + mean_directions.append(div(sum_direction, total_weight)) + unit_mean_directions = [normalize(v) for v in mean_directions] + mean_transformed_coordinates = [] + unit_mean_transformed_directions = [] + for s, segment in enumerate(self._segments): + x = mean_coordinates[s] + d = mean_directions[s] + if initial_rotation_matrix[s]: + x = matrix_vector_mult(initial_rotation_matrix[s], x) + d = matrix_vector_mult(initial_rotation_matrix[s], d) + x = add(x, segment.get_translation()) + mean_transformed_coordinates.append(x) + unit_mean_transformed_directions.append(normalize(d)) + + # optimise transformation of second segment so mean coordinates and directions coincide + + def rotation_objective(rotation, *args): + target_direction, source_direction, target_side_direction, source_side_direction = args + rotation_matrix = euler_to_rotation_matrix(rotation) + trans_direction = matrix_vector_mult(rotation_matrix, source_direction) + trans_side_direction = matrix_vector_mult(rotation_matrix, source_side_direction) + return dot(trans_direction, target_direction) + dot(target_side_direction, trans_side_direction) + + # note the result is dependent on the initial position, but final optimisation should reduced effect + # get a side direction to minimise the unconstrained twist from the current direction + axis = [1.0, 0.0, 0.0] + if dot(unit_mean_transformed_directions[0], axis) < 0.1: + axis = [0.0, 1.0, 0.0] + target_side = normalize(cross(unit_mean_transformed_directions[0], axis)) + source_side = normalize( + cross(cross(target_side, unit_mean_transformed_directions[1]), unit_mean_transformed_directions[1])) + if initial_rotation_matrix[1]: + transformed_source_side = source_side + inverse_rotation_matrix = matrix_inv(initial_rotation_matrix[1]) + source_side = matrix_vector_mult(inverse_rotation_matrix, transformed_source_side) + initial_angles = [math.radians(angle_degrees) for angle_degrees in self._segments[1].get_rotation()] + side_weight = 0.01 # so side has only a small effect on objective + res = minimize(rotation_objective, initial_angles, + args=(unit_mean_transformed_directions[0], unit_mean_directions[1], + mult(target_side, side_weight), mult(source_side, side_weight)), + method='Nelder-Mead', tol=0.001) + if not res.success: + print("Segmentation Stitcher. Could not optimise initial rotation") + return + rotation = [math.degrees(angle_radians) for angle_radians in res.x] + rotation_matrix = euler_to_rotation_matrix(res.x) + rotated_mean_coordinates = matrix_vector_mult(rotation_matrix, mean_coordinates[1]) + translation = sub(mean_transformed_coordinates[0], rotated_mean_coordinates) + # update transformed_coordinates in second segment data + for p, data in enumerate(segment_end_point_data[1]): + coordinates = data[2] + transformed_coordinates = add(matrix_vector_mult(rotation_matrix, coordinates), translation) + segment_end_point_data[1][p] = (data[0], transformed_coordinates, data[2], data[3], data[4], data[5]) + unit_transformed_direction = matrix_vector_mult(rotation_matrix, unit_mean_directions[1]) + # translate along unit_transformed_direction so no overlap between points + total_overlap = 0.0 + for s, segment in enumerate(self._segments): + max_overlap = 0.0 + for data in segment_end_point_data[s]: + overlap = dot(sub(data[1], mean_transformed_coordinates[0]), unit_transformed_direction) + if s == 0: + overlap = -overlap + if overlap > max_overlap: + max_overlap = overlap + total_overlap += max_overlap + translation = sub(translation, mult(unit_transformed_direction, total_overlap)) + self._segments[1].set_rotation(rotation, notify=False) + self._segments[1].set_translation(translation, notify=False) + + # GRC temp + # score = self.build_links(build_link_objects=False) + # print("part 1 rotation", rotation, "translation", translation, "score", score) + + # optimise angles and translation + def links_objective(rotation_translation, *args): + rotation = list(rotation_translation[:3]) + translation = list(rotation_translation[3:]) + self._segments[1].set_rotation(rotation, notify=False) + self._segments[1].set_translation(translation, notify=False) + score = self.build_links(build_link_objects=False) + # print("rotation", rotation, "translation", translation, "score", score) + return score + + initial_parameters = rotation + translation + initial_score = links_objective(initial_parameters, ()) + TOL = initial_score * 1.0E-4 + # method='Nelder-Mead' + res = minimize(links_objective, initial_parameters, method='Powell', tol=TOL) + if not res.success: + print("Segmentation Stitcher. Could not optimise final rotation and translation") + return + rotation = list(res.x[:3]) + translation = list(res.x[3:]) + self._segments[1].set_rotation(rotation, notify=False) + # this will invoke build_links: + self._segments[1].set_translation(translation) + + def build_links(self, build_link_objects=True): + """ + Build links between nodes from connected segments. + :param build_link_objects: Set to False to defer building visualization objects. + :return: Total link score. + """ + total_score = 0.0 + remaining_radius_factor = 0.25 + self._linked_nodes = {} + # filter, transform and sort end point data from largest to smallest radius + segment_sorted_end_point_data = [] + for s, segment in enumerate(self._segments): + translation = segment.get_translation() + rotation = [math.radians(angle_degrees) for angle_degrees in segment.get_rotation()] + rotation_matrix = euler_to_rotation_matrix(rotation) if (rotation != [0.0, 0.0, 0.0]) else None + + sorted_end_point_data = [] + end_point_data = segment.get_end_point_data() + for node_id, data in end_point_data.items(): + coordinates, direction, radius, annotation = data + if (annotation is not None) and annotation.get_category().is_connectable(): + if rotation_matrix: + coordinates = matrix_vector_mult(rotation_matrix, coordinates) + direction = matrix_vector_mult(rotation_matrix, direction) + coordinates = add(coordinates, translation) + + for i, data in enumerate(sorted_end_point_data): + if radius > data[3]: + break + else: + i = len(sorted_end_point_data) + sorted_end_point_data.insert(i, (node_id, coordinates, direction, radius, annotation)) + segment_sorted_end_point_data.append(sorted_end_point_data) + sorted_end_point_data0 = segment_sorted_end_point_data[0] + sorted_end_point_data1 = segment_sorted_end_point_data[1] + + while len(sorted_end_point_data0): + end_point_data0 = sorted_end_point_data0[0] + node_id0, coordinates0, direction0, radius0, annotation0 = end_point_data0 + category0 = annotation0.get_category() + best_index1 = None + lowest_score = 0.0 + for index1, end_point_data1 in enumerate(sorted_end_point_data1): + node_id1, coordinates1, direction1, radius1, annotation1 = end_point_data1 + category1 = annotation1.get_category() + if (category0 != category1) or ( + (category0 == AnnotationCategory.INDEPENDENT_NETWORK) and (annotation0 != annotation1)): + continue # end points are not allowed to join + direction_score = math.fabs(1.0 + dot(direction0, direction1)) + if direction_score > 0.5: + continue # end points are not pointing towards each other + delta_coordinates = sub(coordinates1, coordinates0) + mag_delta_coordinates = magnitude(delta_coordinates) + tdistance = dot(direction0, delta_coordinates) + ndistance = math.sqrt(mag_delta_coordinates * mag_delta_coordinates - tdistance * tdistance) + if mag_delta_coordinates > (0.5 * self._max_distance): + continue # point is too far away + distance_score = ((tdistance * tdistance + 50.0 * ndistance * ndistance) / + (self._max_distance * self._max_distance)) + tfactor = math.exp(-100.0 * tdistance / self._max_distance) + 1.0 # arbitrary factor + penetration_distance_score = ((tfactor * tdistance * tdistance) / + (self._max_distance * self._max_distance)) + delta_radius = (radius0 - radius1) / self._max_distance # GRC temporary - use a different scale + radius_score = delta_radius * delta_radius + score = radius0 * (10.0 * direction_score + distance_score + radius_score) + if (best_index1 is None) or (score < lowest_score): + best_index1 = index1 + lowest_score = score + penetration_distance_score + if best_index1 is not None: + # if category0 != AnnotationCategory.NETWORK_GROUP_1: + total_score += lowest_score + node_id1, coordinates1, direction1, radius1, annotation1 = sorted_end_point_data1[best_index1] + self.add_linked_nodes(annotation1, node_id0, node_id1) + remaining_radius = math.sqrt(math.fabs(radius0 * radius0 - radius1 * radius1)) + if (radius0 > radius1) and (remaining_radius > remaining_radius_factor * radius0): + for i in range(1, len(sorted_end_point_data0)): + if remaining_radius > sorted_end_point_data0[i][3]: + break + # sorted_end_point_data0.insert(i, (node_id0, coordinates0, direction0, remaining_radius, annotation0)) + elif remaining_radius > (remaining_radius_factor * radius1): + for i in range(best_index1, len(sorted_end_point_data1)): + if remaining_radius > sorted_end_point_data1[i][3]: + break + # sorted_end_point_data1.insert(i, (node_id1, coordinates1, direction1, remaining_radius, annotation1)) + sorted_end_point_data1.pop(best_index1) + else: + total_score += radius0 * 20.0 # arbitrary factor + sorted_end_point_data0.pop(0) + + if build_link_objects: + self._build_link_objects() + + return total_score + + def _build_link_objects(self): + """ + Make link nodes/elements for visualisation. + """ + fieldmodule = self._region.getFieldmodule() + nodes = fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_NODES) + nodetemplate = nodes.createNodetemplate() + nodetemplate.defineField(self._coordinates) + nodetemplate.defineField(self._radius) + mesh1d = fieldmodule.findMeshByDimension(1) + elementtemplate = mesh1d.createElementtemplate() + elementtemplate.setElementShapeType(Element.SHAPE_TYPE_LINE) + linear_basis = fieldmodule.createElementbasis(1, Elementbasis.FUNCTION_TYPE_LINEAR_LAGRANGE) + eft = mesh1d.createElementfieldtemplate(linear_basis) + elementtemplate.defineField(self._coordinates, -1, eft) + elementtemplate.defineField(self._radius, -1, eft) + fieldcache = fieldmodule.createFieldcache() + + snodes, sfieldcache, scoordinates, sradius = [], [], [], [] + snode_id_to_cnode_id = [] + for s, segment in enumerate(self._segments): + sfieldmodule = segment.get_raw_region().getFieldmodule() + snodes.append(sfieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_NODES)) + sfieldcache.append(sfieldmodule.createFieldcache()) + tr_coordinates = sfieldmodule.findFieldByName("coordinates").castFiniteElement() + rotation = [math.radians(angle_degrees) for angle_degrees in segment.get_rotation()] + if rotation != [0.0, 0.0, 0.0]: + rotation_matrix = euler_to_rotation_matrix(rotation) + tr_coordinates = sfieldmodule.createFieldMatrixMultiply( + 3, sfieldmodule.createFieldConstant(rotation_matrix[0] + rotation_matrix[1] + rotation_matrix[2]), + tr_coordinates) + translation = segment.get_translation() + if translation != [0.0, 0.0, 0.0]: + tr_coordinates = tr_coordinates + sfieldmodule.createFieldConstant(translation) + scoordinates.append(tr_coordinates) + sradius.append(sfieldmodule.findFieldByName("radius").castFiniteElement()) + snode_id_to_cnode_id.append({}) # map from segment node identifier to connection node identifier + + node_identifier = 1 + element_identifier = 1 + with (ChangeManager(fieldmodule)): + mesh1d.destroyAllElements() + nodes.destroyAllNodes() + for group_name, linked_nodes_list in self._linked_nodes.items(): + group = find_or_create_field_group(fieldmodule, group_name) + nodeset_group = group.getOrCreateNodesetGroup(nodes) + mesh_group = group.getOrCreateMeshGroup(mesh1d) + for linked_nodes in linked_nodes_list: + cnode_ids = [None, None] + for s, snode_id in enumerate(linked_nodes): + cnode_ids[s] = snode_id_to_cnode_id[s].get(snode_id) + if not cnode_ids[s]: + snode = snodes[s].findNodeByIdentifier(snode_id) + sfieldcache[s].setNode(snode) + _, x = scoordinates[s].evaluateReal(sfieldcache[s], 3) + _, r = sradius[s].evaluateReal(sfieldcache[s], 1) + cnode = nodeset_group.createNode(node_identifier, nodetemplate) + fieldcache.setNode(cnode) + self._coordinates.assignReal(fieldcache, x) + self._radius.assignReal(fieldcache, r) + cnode_ids[s] = node_identifier + snode_id_to_cnode_id[s][snode_id] = cnode_ids[s] + node_identifier += 1 + element = mesh_group.createElement(element_identifier, elementtemplate) + element.setNodesByIdentifier(eft, cnode_ids) + element_identifier += 1 + + def update_annotation_category_groups(self, annotations): + """ + Rebuild all annotation category groups e.g. after loading settings. + :param annotations: List of all annotations from stitcher. + """ + fieldmodule = self._region.getFieldmodule() + with ChangeManager(fieldmodule): + # clear all category groups + for category in AnnotationCategory: + category_group = self.get_category_group(category) + category_group.clear() + for annotation in annotations: + annotation_group = self.get_annotation_group(annotation) + if annotation_group: + category_group = self.get_category_group(annotation.get_category()) + group_add_group_local_contents(category_group, annotation_group) diff --git a/src/segmentationstitcher/segment.py b/src/segmentationstitcher/segment.py index 319ddd0..f6c06c7 100644 --- a/src/segmentationstitcher/segment.py +++ b/src/segmentationstitcher/segment.py @@ -49,8 +49,7 @@ def __init__(self, name, segmentation_file_name, root_region): group.setManaged(True) self._rotation = [0.0, 0.0, 0.0] self._translation = [0.0, 0.0, 0.0] - self._group_element_node_ids = {} - self._group_node_element_ids = {} + self._transformation_change_callbacks = [] self._raw_fieldcache = self._raw_fieldmodule.createFieldcache() self._raw_coordinates = self._raw_fieldmodule.findFieldByName("coordinates").castFiniteElement() self._raw_radius = self._raw_fieldmodule.findFieldByName("radius").castFiniteElement() @@ -66,9 +65,8 @@ def __init__(self, name, segmentation_file_name, root_region): self._working_best_fit_line_orientation = find_or_create_field_finite_element( self._working_fieldmodule, "best_fit_line_orientation", 9) self._element_node_ids, self._node_element_ids = self._get_element_node_maps() - self._raw_groups = get_group_list(self._raw_fieldmodule) - self._raw_mesh_groups = [group.getMeshGroup(self._raw_mesh1d) for group in self._raw_groups] self._end_node_ids = self._get_end_node_ids() + self._end_point_data = {} # dict node_id -> (coordinates, direction, radius, annotation) def decode_settings(self, settings_in: dict): """ @@ -130,25 +128,33 @@ def _get_end_node_ids(self): end_node_ids.append(node_id) return end_node_ids - def _element_id_to_group(self, element_id): + def _element_id_to_group(self, element_id, annotations): """ - Get the first (should be only) Zinc Group containing raw element of supplied identifier. + Get the first Annotation zinc Group containing raw element of supplied identifier. :param node_id: Identifier of [end] node to query. + :param annotations: Global list of all annotations. :return: Zinc Group, MeshGroup or None, None if not found. """ element = self._raw_mesh1d.findElementByIdentifier(element_id) - for i, mesh_group in enumerate(self._raw_mesh_groups): - if mesh_group.containsElement(element): - return self._raw_groups[i], mesh_group + for annotation in annotations: + group = self._raw_fieldmodule.findFieldByName(annotation.get_name()).castGroup() + if group.isValid(): + mesh_group = group.getMeshGroup(self._raw_mesh1d) + if mesh_group.isValid() and mesh_group.containsElement(element): + return group, mesh_group return None, None - def _track_segment(self, start_node_id, start_element_id, max_distance=None): + def _track_segment(self, start_node_id, start_element_id, + max_length=None, min_element_count=None, min_aspect_ratio=None): """ Get coordinates and radii along segment from start_node_id in start_element_id, proceeding first to other local node in element, until junction, end point or max_distance is tracked. + Can finish earlier if min_element_count, min_aspect_ratio reached, but both must be reached if both in use. :param start_node_id: First node in path. :param start_element_id: Element containing start_node_id and another node to be added. - :param max_distance: Maximum distance to track to from first node coordinates, or None for no limit. + :param max_length: Maximum length to track from first node coordinates, or None for no limit. + :param min_element_count: Minimum number of elements to track, or None to not test. + :param min_aspect_ratio: Minimum ratio of length / mean radius to end tracking, or None to not test. :return: coordinates list, radius list, node id list, endElementId """ self._element_node_ids, self._node_element_ids @@ -158,24 +164,39 @@ def _track_segment(self, start_node_id, start_element_id, max_distance=None): path_radii = [] path_node_ids = [] lastNode = False + sum_r = 0.0 while True: + if node_id in path_node_ids: + print("Segmentation Stitcher. Tracking found a loop. Stopping.") + break node = self._raw_nodes.findNodeByIdentifier(node_id) self._raw_fieldcache.setNode(node) result, x = self._raw_coordinates.evaluateReal(self._raw_fieldcache, 3) if result != RESULT_OK: - continue + break path_node_ids.append(node_id) path_coordinates.append(x) result, r = self._raw_radius.evaluateReal(self._raw_fieldcache, 1) if result != RESULT_OK: r = 1.0 path_radii.append(r) + sum_r += r if lastNode: break - if (len(path_coordinates) > 1) and (max_distance is not None): - distance = magnitude(sub(x, path_coordinates[0])) - if distance > max_distance: + point_count = len(path_coordinates) + if point_count > 1: + length = magnitude(sub(x, path_coordinates[0])) + if (max_length is not None) and (length > max_length): break + if (min_element_count is not None) or (min_aspect_ratio is not None): + if (min_element_count is None) or (point_count > min_element_count): + if min_aspect_ratio is None: + break + mean_r = sum_r / point_count + if mean_r > 0.0: + aspect_ratio = length / mean_r + if aspect_ratio >= min_aspect_ratio: + break node_ids = self._element_node_ids[element_id] node_id = node_ids[1] if (node_ids[0] == node_id) else node_ids[0] element_ids = self._node_element_ids[node_id] @@ -183,20 +204,21 @@ def _track_segment(self, start_node_id, start_element_id, max_distance=None): lastNode = True continue element_id = element_ids[1] if (element_ids[0] == element_id) else element_ids[0] + return path_coordinates, path_radii, path_node_ids, element_id - def _track_path(self, end_node_id, max_distance=None): + def _track_path(self, end_node_id, annotations, max_length=None): """ Get coordinates and radii along path from end_node_id, continuing along branches if in similar direction. - :param group_name: Group to use node-element maps for. :param end_node_id: End node identifier to track from. Must be in only one element. - :param max_distance: Maximum distance to track to, or None for no limit. - :return: coordinates list, radius list, path node ids, start_x, end_x, mean_r + :param annotations: Global list of all annotations. + :param max_length: Maximum length to track along, or None for no limit. + :return: coordinates list, radius list, path node ids, path group, start_x, end_x, mean_r """ element_ids = self._node_element_ids[end_node_id] assert len(element_ids) == 1 - path_group = self._element_id_to_group(element_ids[0])[0] + path_group = self._element_id_to_group(element_ids[0], annotations)[0] path_coordinates = [] path_radii = [] path_node_ids = [] @@ -206,29 +228,38 @@ def _track_path(self, end_node_id, max_distance=None): start_x = None end_x = None mean_r = None - remaining_max_distance = max_distance last_direction = None - while (not path_coordinates) or (len(element_ids) > 2): + length = 0.0 + element_count = 0 + min_element_count = 10 + aspect_ratio = 0.0 + min_aspect_ratio = 4.0 + while (((not path_coordinates) or (len(element_ids) > 2)) and (length < max_length) and + ((element_count < min_element_count) or (aspect_ratio < min_aspect_ratio))): add_path_coordinates = None add_path_radii = None add_path_node_ids = None add_path_error = None + add_path_length = 0.0 add_element_id = None for element_id in element_ids: if element_id == stop_element_id: continue - segment_group = self._element_id_to_group(element_id)[0] + segment_group = self._element_id_to_group(element_id, annotations)[0] if path_group and (segment_group != path_group): continue - segment_coordinates, segment_radii, segment_node_ids, segment_stop_element_id =\ - self._track_segment(stop_node_id, element_id, remaining_max_distance) + segment_coordinates, segment_radii, segment_node_ids, segment_stop_element_id = self._track_segment( + stop_node_id, element_id, + max_length=max_length - length, + min_element_count=min_element_count - element_count, + min_aspect_ratio=min_aspect_ratio - aspect_ratio) segment_stop_node_id = segment_node_ids[-1] if segment_stop_node_id in path_node_ids: continue # avoid loops if last_direction: - add_start_x, add_end_x, add_mean_r = fit_line(trial_coordinates, trial_radii)[0:3] - add_direction = sub(add_end_x, add_start_x) - if dot(normalize(last_direction), normalize(add_direction)) < 0.8: + add_start_x, add_end_x, add_mean_r = fit_line(segment_coordinates, segment_radii)[0:3] + add_direction = normalize(sub(add_end_x, add_start_x)) + if dot(last_direction, add_direction) < 0.8: # arbitrary factor continue # avoid sudden changes in direction add_segment_coordinates = segment_coordinates if not path_coordinates else segment_coordinates[1:] add_segment_radii = segment_radii if not path_radii else segment_radii[1:] @@ -238,8 +269,8 @@ def _track_path(self, end_node_id, max_distance=None): trial_radii = path_radii + add_segment_radii trial_start_x, trial_end_x, trial_mean_r, trial_mean_projection_error =\ fit_line(trial_coordinates, trial_radii) - radius_difference = math.fabs(add_segment_mean_r - path_mean_r) if (path_mean_r is not None) else 0.0 - trial_error = trial_mean_projection_error + radius_difference + delta_radius = math.fabs(add_segment_mean_r - path_mean_r) if (path_mean_r is not None) else 0.0 + trial_error = trial_mean_projection_error + 4.0 * (delta_radius * delta_radius) # arbitrary factor if (add_element_id is None) or (trial_error < add_path_error): add_path_coordinates = add_segment_coordinates add_path_radii = add_segment_radii @@ -247,6 +278,8 @@ def _track_path(self, end_node_id, max_distance=None): add_path_error = trial_error add_node_id = segment_stop_node_id add_element_id = segment_stop_element_id + add_path_length = magnitude(sub(segment_coordinates[-1], segment_coordinates[0])) + add_path_mean_r = add_segment_mean_r start_x, end_x, mean_r = trial_start_x, trial_end_x, trial_mean_r if not add_path_coordinates: break @@ -256,31 +289,46 @@ def _track_path(self, end_node_id, max_distance=None): path_mean_r = sum(path_radii) / len(path_radii) stop_node_id = add_node_id stop_element_id = add_element_id - if max_distance: - remaining_max_distance = max_distance - magnitude(sub(path_coordinates[-1], path_coordinates[0])) element_ids = self._node_element_ids[stop_node_id] - last_direction = sub(end_x, start_x) + last_direction = normalize(sub(end_x, start_x)) + + element_count += len(path_coordinates) - 1 + length += add_path_length + if add_path_mean_r > 0.0: + aspect_ratio += add_path_length / add_path_mean_r # 2nd iteration of fit line removes outliers: start_x, end_x, mean_r = fit_line(path_coordinates, path_radii, start_x, end_x, 0.5)[0:3] - return path_coordinates, path_radii, path_node_ids, start_x, end_x, mean_r + return path_coordinates, path_radii, path_node_ids, path_group, start_x, end_x, mean_r - def create_end_point_directions(self, max_distance): + def create_end_point_directions(self, annotations, max_distance): """ Track mean directions of network end points and create working objects for visualisation. - :param max_distance: Maximum length to track back from end point. + :param annotations: Global list of all annotations. + :param max_distance: Maximum length to track back from end point. Stored for link tolerance. """ nodetemplate = self._working_datapoints.createNodetemplate() nodetemplate.defineField(self._working_coordinates) nodetemplate.defineField(self._working_radius_direction) nodetemplate.defineField(self._working_best_fit_line_orientation) fieldcache = self._working_fieldmodule.createFieldcache() + self._end_point_data = {} for end_node_id in self._end_node_ids: - path_coordinates, path_radii, path_node_ids, start_x, end_x, mean_r =( - self._track_path(end_node_id, max_distance)) + path_coordinates, path_radii, path_node_ids, path_group, start_x, end_x, mean_r =( + self._track_path(end_node_id, annotations, max_distance)) # Future: want to extend length to be equivalent to path_coordinates + direction = sub(start_x, end_x) + annotation = None + annotation_group_name = path_group.getName() if path_group else None + if annotation_group_name: + for tmp_annotation in annotations: + if tmp_annotation.get_name() == annotation_group_name: + annotation = tmp_annotation + break + self._end_point_data[end_node_id] = (start_x, normalize(direction), mean_r, annotation) + # set up visualization objects: node = self._working_datapoints.createNode(-1, nodetemplate) fieldcache.setNode(node) - radius_direction = set_magnitude(sub(start_x, end_x), mean_r) + radius_direction = set_magnitude(direction, mean_r) self._working_coordinates.setNodeParameters(fieldcache, -1, Node.VALUE_LABEL_VALUE, 1, start_x) self._working_radius_direction.setNodeParameters(fieldcache, -1, Node.VALUE_LABEL_VALUE, 1, radius_direction) @@ -293,6 +341,12 @@ def create_end_point_directions(self, max_distance): self._working_best_fit_line_orientation.setNodeParameters(fieldcache, -1, Node.VALUE_LABEL_VALUE, 1, direction1 + direction2 + direction3) + def get_end_point_data(self): + """ + :return: dict node_id -> (coordinates, direction, radius, annotation) + """ + return self._end_point_data + def get_base_region(self): """ Get the base region for all segmentation and working data for this segment. @@ -302,7 +356,7 @@ def get_base_region(self): def get_annotation_group(self, annotation): """ - Get Zinc group containing segmentations for the supplied annotation + Get Zinc group containing segmentations for the supplied annotation. :param annotation: An Annotation object. :return: Zinc FieldGroup in the segment's raw region, or None if not present in segment. """ @@ -347,19 +401,55 @@ def get_raw_region(self): """ return self._raw_region + def add_transformation_change_callback(self, transformation_change_callback): + """ + Set up client to be informed when segment transformation is changed. + Typically used to update connections. + :param transformation_change_callback: Callable with signature (segment) + """ + self._transformation_change_callbacks.append(transformation_change_callback) + + def remove_transformation_change_callback(self, transformation_change_callback): + """ + Remove transformation change callback set previously. + :param transformation_change_callback: Callable to remove + """ + self._transformation_change_callbacks.remove(transformation_change_callback) + + def _transformation_change(self): + """ + Inform clients of transformation change. + """ + for transformation_change_callback in self._transformation_change_callbacks: + transformation_change_callback(self) + def get_rotation(self): return self._rotation - def set_rotation(self, rotation): + def set_rotation(self, rotation, notify=True): + """ + Set segment rotation, which applies before translation. + :param rotation: Rotation as list of 3 Euler angles in degrees. + :param notify: Set to False to avoid notification to clients if setting translation afterwards. + """ assert len(rotation) == 3 self._rotation = rotation + if notify: + self._transformation_change() def get_translation(self): return self._translation - def set_translation(self, translation): + def set_translation(self, translation, notify=True): + """ + Set segment transformation, which applies after rotation. + :param translation: New translation. + :param notify: Set to False to avoid notification to clients if setting rotation afterwards. + """ assert len(translation) == 3 self._translation = translation + if notify: + self._transformation_change() def get_working_region(self): """ @@ -389,7 +479,7 @@ def update_annotation_category(self, annotation, old_category=AnnotationCategory new_category_group = self.get_category_group(new_category) group_add_group_local_contents(new_category_group, annotation_group) - def reset_annotation_category_groups(self, annotations): + def update_annotation_category_groups(self, annotations): """ Rebuild all annotation category groups e.g. after loading settings. :param annotations: List of all annotations from stitcher. diff --git a/src/segmentationstitcher/stitcher.py b/src/segmentationstitcher/stitcher.py index b35d9b9..1537360 100644 --- a/src/segmentationstitcher/stitcher.py +++ b/src/segmentationstitcher/stitcher.py @@ -56,16 +56,15 @@ def __init__(self, segmentation_file_names: list, network_group1_keywords, netwo # print("Add annoation name", name, "term", term, "dim", segment_annotation.get_dimension(), # "category", segment_annotation.get_category()) self._annotations.insert(index, segment_annotation) + self._max_distance = 0.0 if self._segments: - # ask segments to track end distances using a global mean max_distance - max_distance = 0.25 * len(self._segments) / max_range_reciprocal_sum - for segment in self._segments: - segment.create_end_point_directions(max_distance) - with HierarchicalChangeManager(self._root_region): - for segment in self._segments: - segment.reset_annotation_category_groups(self._annotations) + with HierarchicalChangeManager(self._root_region): + self._max_distance = 0.25 * len(self._segments) / max_range_reciprocal_sum + for segment in self._segments: + segment.create_end_point_directions(self._annotations, self._max_distance) + segment.update_annotation_category_groups(self._annotations) for annotation in self._annotations: - annotation.set_category_change_callback(self._annotation_change) + annotation.set_category_change_callback(self._annotation_category_change) def decode_settings(self, settings_in: dict): """ @@ -123,9 +122,6 @@ def decode_settings(self, settings_in: dict): else: print("WARNING: Segmentation Stitcher. Segment with name", name, "not found in settings; using defaults. Have input files changed?") - with HierarchicalChangeManager(self._root_region): - for segment in self._segments: - segment.reset_annotation_category_groups(self._annotations) # create connections from stitcher settings' connection serialisations assert len(self._connections) == 0, "Cannot decode connections after any exist" @@ -140,9 +136,13 @@ def decode_settings(self, settings_in: dict): print("WARNING: Segmentation Stitcher. Segment with name", segment_name, "in connection settings not found; ignoring. Have input files changed?") if len(connection_segments) >= 2: - connection = self.create_connection(connection_segments) - connection.decode_settings(connection_settings) + connection = self.create_connection(connection_segments, connection_settings) + with HierarchicalChangeManager(self._root_region): + for segment in self._segments: + segment.update_annotation_category_groups(self._annotations) + for connection in self._connections: + connection.update_annotation_category_groups(self._annotations) def encode_settings(self) -> dict: """ @@ -156,7 +156,7 @@ def encode_settings(self) -> dict: } return settings - def _annotation_change(self, annotation, old_category): + def _annotation_category_change(self, annotation, old_category): """ Callback from annotation that its category has changed. Update segment category groups. @@ -166,24 +166,34 @@ def _annotation_change(self, annotation, old_category): with HierarchicalChangeManager(self._root_region): for segment in self._segments: segment.update_annotation_category(annotation, old_category) + for connection in self._connections: + connection.build_links(self._max_distance) + connection.update_annotation_category_groups(self._annotations) def get_annotations(self): return self._annotations - def create_connection(self, segments): + def create_connection(self, segments, connection_settings={}): """ :param segments: List of 2 Stitcher Segment objects to connect. + :param connection_settings: Optional serialisation of connection to read before building links. :return: Connection object or None if invalid segments or connection between segments already exists """ if len(segments) != 2: - print("Only supports connections between 2 segments") + print("Segmentation Stitcher: Only supports connections between 2 segments") + return None + if segments[0] == segments[1]: + print("Segmentation Stitcher: Can't make a connection between a segment and itself") return None for connection in self._connections: if all(segment in connection.get_segments() for segment in segments): - print("Stitcher.create_connection: Already have a connection between segments") + print("Segmentation Stitcher: Already have a connection between segments") return None - connection = Connection(segments, self._root_region) + connection = Connection(segments, self._root_region, self._annotations, self._max_distance) + if connection_settings: + connection.decode_settings(connection_settings) self._connections.append(connection) + connection.build_links() return connection def delete_connection(self, connection): @@ -191,6 +201,7 @@ def delete_connection(self, connection): Delete the connection from the stitcher's list. :param connection: Connection to delete. """ + connection.detach() self._connections.remove(connection) def get_connections(self): From ae599671c7ae25072804bbff83d81031af552ba4 Mon Sep 17 00:00:00 2001 From: Richard Christie Date: Thu, 28 Nov 2024 14:56:38 +1300 Subject: [PATCH 07/13] Add annotation align weight --- src/segmentationstitcher/annotation.py | 10 ++++++++++ src/segmentationstitcher/connection.py | 14 ++++++++------ src/segmentationstitcher/segment.py | 10 +++++----- 3 files changed, 23 insertions(+), 11 deletions(-) diff --git a/src/segmentationstitcher/annotation.py b/src/segmentationstitcher/annotation.py index 3c5918d..220afbd 100644 --- a/src/segmentationstitcher/annotation.py +++ b/src/segmentationstitcher/annotation.py @@ -45,6 +45,7 @@ def __init__(self, name: str, term, dimension, category: AnnotationCategory): self._term = term self._dimension = dimension self._category = category + self._align_weight = 1.0 self._category_change_callback = None def decode_settings(self, settings_in: dict): @@ -62,6 +63,7 @@ def decode_settings(self, settings_in: dict): # update current settings to gain new ones and override old ones settings = self.encode_settings() settings.update(settings_in) + self._align_weight = settings["align weight"] self._category = AnnotationCategory[settings["category"]] def encode_settings(self) -> dict: @@ -70,6 +72,7 @@ def encode_settings(self) -> dict: :return: Settings in a dict ready for passing to json.dump. """ settings = { + "align weight": self._align_weight, "category": self._category.name, "dimension": self._dimension, "name": self._name, @@ -77,6 +80,13 @@ def encode_settings(self) -> dict: } return settings + def get_align_weight(self): + return self._align_weight + + def set_align_weight(self, align_weight): + if align_weight >= 0.0: + self._align_weight = align_weight + def get_category(self): return self._category diff --git a/src/segmentationstitcher/connection.py b/src/segmentationstitcher/connection.py index 40e1977..ff553ab 100644 --- a/src/segmentationstitcher/connection.py +++ b/src/segmentationstitcher/connection.py @@ -296,9 +296,9 @@ def links_objective(rotation_translation, *args): initial_parameters = rotation + translation initial_score = links_objective(initial_parameters, ()) - TOL = initial_score * 1.0E-4 + # TOL = initial_score * 1.0E-6 # method='Nelder-Mead' - res = minimize(links_objective, initial_parameters, method='Powell', tol=TOL) + res = minimize(links_objective, initial_parameters, method='Powell') # , tol=TOL) if not res.success: print("Segmentation Stitcher. Could not optimise final rotation and translation") return @@ -350,6 +350,7 @@ def build_links(self, build_link_objects=True): category0 = annotation0.get_category() best_index1 = None lowest_score = 0.0 + weight = None for index1, end_point_data1 in enumerate(sorted_end_point_data1): node_id1, coordinates1, direction1, radius1, annotation1 = end_point_data1 category1 = annotation1.get_category() @@ -357,7 +358,7 @@ def build_links(self, build_link_objects=True): (category0 == AnnotationCategory.INDEPENDENT_NETWORK) and (annotation0 != annotation1)): continue # end points are not allowed to join direction_score = math.fabs(1.0 + dot(direction0, direction1)) - if direction_score > 0.5: + if direction_score > 0.5: # arbitrary factor continue # end points are not pointing towards each other delta_coordinates = sub(coordinates1, coordinates0) mag_delta_coordinates = magnitude(delta_coordinates) @@ -365,9 +366,9 @@ def build_links(self, build_link_objects=True): ndistance = math.sqrt(mag_delta_coordinates * mag_delta_coordinates - tdistance * tdistance) if mag_delta_coordinates > (0.5 * self._max_distance): continue # point is too far away - distance_score = ((tdistance * tdistance + 50.0 * ndistance * ndistance) / + distance_score = ((tdistance * tdistance + 100.0 * ndistance * ndistance) / (self._max_distance * self._max_distance)) - tfactor = math.exp(-100.0 * tdistance / self._max_distance) + 1.0 # arbitrary factor + tfactor = math.exp(-1000.0 * tdistance / self._max_distance) + 1.0 # arbitrary factor penetration_distance_score = ((tfactor * tdistance * tdistance) / (self._max_distance * self._max_distance)) delta_radius = (radius0 - radius1) / self._max_distance # GRC temporary - use a different scale @@ -375,10 +376,11 @@ def build_links(self, build_link_objects=True): score = radius0 * (10.0 * direction_score + distance_score + radius_score) if (best_index1 is None) or (score < lowest_score): best_index1 = index1 + weight = 0.5 * (annotation0.get_align_weight() + annotation1.get_align_weight()) lowest_score = score + penetration_distance_score if best_index1 is not None: # if category0 != AnnotationCategory.NETWORK_GROUP_1: - total_score += lowest_score + total_score += weight * lowest_score node_id1, coordinates1, direction1, radius1, annotation1 = sorted_end_point_data1[best_index1] self.add_linked_nodes(annotation1, node_id0, node_id1) remaining_radius = math.sqrt(math.fabs(radius0 * radius0 - radius1 * radius1)) diff --git a/src/segmentationstitcher/segment.py b/src/segmentationstitcher/segment.py index f6c06c7..89f24e6 100644 --- a/src/segmentationstitcher/segment.py +++ b/src/segmentationstitcher/segment.py @@ -259,7 +259,7 @@ def _track_path(self, end_node_id, annotations, max_length=None): if last_direction: add_start_x, add_end_x, add_mean_r = fit_line(segment_coordinates, segment_radii)[0:3] add_direction = normalize(sub(add_end_x, add_start_x)) - if dot(last_direction, add_direction) < 0.8: # arbitrary factor + if dot(last_direction, add_direction) < 0.5: # arbitrary factor continue # avoid sudden changes in direction add_segment_coordinates = segment_coordinates if not path_coordinates else segment_coordinates[1:] add_segment_radii = segment_radii if not path_radii else segment_radii[1:] @@ -506,7 +506,7 @@ def fit_line(path_coordinates, path_radii, x1=None, x2=None, filter_proportion=0 :param x2: Initial end point for line. Default is last point coordinates. :param filter_proportion: Proportion of data points to eliminate in order of greatest projection normal to line. Default is no filtering. - :return: start_x, end_x, mean_r (of unfiltered points), mean_projection_error (of all points) + :return: start_x, end_x, harmonic mean_r (of all points), mean_projection_error (of all points) """ assert len(path_coordinates) > 1 if len(path_coordinates) == 2: @@ -551,10 +551,11 @@ def fit_line(path_coordinates, path_radii, x1=None, x2=None, filter_proportion=0 if len(filter_indexes) < filter_count: filter_indexes.append(d) mean_projection_error = sum_projection_error / points_count - sum_r = 0.0 + sum_1__r = 0.0 a = [[0.0, 0.0], [0.0, 0.0]] # matrix b = [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]] # RHS for each component for d in range(points_count): + sum_1__r += 1.0 / path_radii[d] if d in filter_indexes: continue xi = path_xi[d] @@ -567,8 +568,7 @@ def fit_line(path_coordinates, path_radii, x1=None, x2=None, filter_proportion=0 for c in range(3): b[c][0] += phi1 * path_coordinates[d][c] b[c][1] += phi2 * path_coordinates[d][c] - sum_r += path_radii[d] - mean_r = sum_r / (points_count - filter_count) + mean_r = points_count / sum_1__r # invert matrix: det_a = a[0][0] * a[1][1] - a[0][1] * a[1][0] a_inv = [[a[1][1] / det_a, -a[0][1] / det_a], [-a[1][0] / det_a, a[0][0] / det_a]] From 22eb686174f7dbc5653b0f8b580fa048ab952ce1 Mon Sep 17 00:00:00 2001 From: Richard Christie Date: Thu, 28 Nov 2024 15:11:01 +1300 Subject: [PATCH 08/13] Skip zero radius in fit_line mean --- src/segmentationstitcher/segment.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/segmentationstitcher/segment.py b/src/segmentationstitcher/segment.py index 89f24e6..b1d743f 100644 --- a/src/segmentationstitcher/segment.py +++ b/src/segmentationstitcher/segment.py @@ -554,8 +554,12 @@ def fit_line(path_coordinates, path_radii, x1=None, x2=None, filter_proportion=0 sum_1__r = 0.0 a = [[0.0, 0.0], [0.0, 0.0]] # matrix b = [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]] # RHS for each component + point_radius_count = 0 for d in range(points_count): - sum_1__r += 1.0 / path_radii[d] + r = path_radii[d] + if r > 0.0: + sum_1__r += 1.0 / r + point_radius_count += 1 if d in filter_indexes: continue xi = path_xi[d] @@ -568,7 +572,7 @@ def fit_line(path_coordinates, path_radii, x1=None, x2=None, filter_proportion=0 for c in range(3): b[c][0] += phi1 * path_coordinates[d][c] b[c][1] += phi2 * path_coordinates[d][c] - mean_r = points_count / sum_1__r + mean_r = (point_radius_count / sum_1__r) if point_radius_count else 0.0 # invert matrix: det_a = a[0][0] * a[1][1] - a[0][1] * a[1][0] a_inv = [[a[1][1] / det_a, -a[0][1] / det_a], [-a[1][0] / det_a, a[0][0] / det_a]] From 61a335c7055478a0f0c93ef2bdd6cc67d5b1279d Mon Sep 17 00:00:00 2001 From: Richard Christie Date: Tue, 10 Dec 2024 20:46:47 +1300 Subject: [PATCH 09/13] Stitch output Exclude annotations without terms by default --- src/segmentationstitcher/connection.py | 16 +- src/segmentationstitcher/stitcher.py | 263 ++++++++++++++++++++++++- tests/resources/vagus-segment1.exf | 10 +- tests/resources/vagus-segment2.exf | 8 +- tests/resources/vagus-segment3.exf | 10 +- 5 files changed, 282 insertions(+), 25 deletions(-) diff --git a/src/segmentationstitcher/connection.py b/src/segmentationstitcher/connection.py index ff553ab..94ed6bd 100644 --- a/src/segmentationstitcher/connection.py +++ b/src/segmentationstitcher/connection.py @@ -143,6 +143,12 @@ def add_linked_nodes(self, annotation, node_id0, node_id1): self._linked_nodes[annotation_name] = annotation_linked_nodes = [] annotation_linked_nodes.append([node_id0, node_id1]) + def get_linked_nodes(self): + """ + :return: Map annotation name -> list of paired nodes from segment1 and segment2 + """ + return self._linked_nodes + def optimise_transformation(self): """ Optimise transformation of second segment to align with position and direction of nearest points between @@ -206,7 +212,7 @@ def optimise_transformation(self): _, transformed_coordinates, coordinates, direction, radius, annotation = data if distance < nearby_distance: distance = nearby_distance - weight = radius / (distance * distance) + weight = annotation.get_align_weight() * radius * radius / (distance * distance) sum_coordinates = add(sum_coordinates, mult(coordinates, weight)) sum_direction = add(sum_direction, mult(direction, weight)) total_weight += weight @@ -227,11 +233,11 @@ def optimise_transformation(self): # optimise transformation of second segment so mean coordinates and directions coincide - def rotation_objective(rotation, *args): + def rotation_objective(trial_rotation, *args): target_direction, source_direction, target_side_direction, source_side_direction = args - rotation_matrix = euler_to_rotation_matrix(rotation) - trans_direction = matrix_vector_mult(rotation_matrix, source_direction) - trans_side_direction = matrix_vector_mult(rotation_matrix, source_side_direction) + trial_rotation_matrix = euler_to_rotation_matrix(trial_rotation) + trans_direction = matrix_vector_mult(trial_rotation_matrix, source_direction) + trans_side_direction = matrix_vector_mult(trial_rotation_matrix, source_side_direction) return dot(trans_direction, target_direction) + dot(target_side_direction, trans_side_direction) # note the result is dependent on the initial position, but final optimisation should reduced effect diff --git a/src/segmentationstitcher/stitcher.py b/src/segmentationstitcher/stitcher.py index 1537360..bcc9ddf 100644 --- a/src/segmentationstitcher/stitcher.py +++ b/src/segmentationstitcher/stitcher.py @@ -1,13 +1,21 @@ """ Interface for stitching segmentation data from and calculating transformations between adjacent image blocks. """ -from cmlibs.utils.zinc.general import HierarchicalChangeManager +from cmlibs.maths.vectorops import add, matrix_vector_mult, euler_to_rotation_matrix +from cmlibs.utils.zinc.field import ( + find_or_create_field_coordinates, find_or_create_field_finite_element, find_or_create_field_group, + find_or_create_field_stored_string, get_group_list) +from cmlibs.utils.zinc.general import ChangeManager, HierarchicalChangeManager from cmlibs.zinc.context import Context +from cmlibs.zinc.element import Element, Elementbasis +from cmlibs.zinc.field import Field +from cmlibs.zinc.node import Node from segmentationstitcher.connection import Connection from segmentationstitcher.segment import Segment -from segmentationstitcher.annotation import region_get_annotations +from segmentationstitcher.annotation import AnnotationCategory, region_get_annotations import copy +import math from pathlib import Path @@ -26,6 +34,7 @@ def __init__(self, segmentation_file_names: list, network_group1_keywords, netwo """ self._context = Context("Segmentation Stitcher") self._root_region = self._context.getDefaultRegion() + self._stitch_region = self._root_region.createRegion() self._annotations = [] self._network_group1_keywords = copy.deepcopy(network_group1_keywords) self._network_group2_keywords = copy.deepcopy(network_group2_keywords) @@ -35,7 +44,7 @@ def __init__(self, segmentation_file_names: list, network_group1_keywords, netwo self._version = 1 # increment when new settings added to migrate older serialised settings max_range_reciprocal_sum = 0.0 for segmentation_file_name in segmentation_file_names: - name = Path(segmentation_file_name).stem + name = Path(segmentation_file_name).name segment = Segment(name, segmentation_file_name, self._root_region) max_range_reciprocal_sum += 1.0 / segment.get_max_range() self._segments.append(segment) @@ -47,8 +56,13 @@ def __init__(self, segmentation_file_names: list, network_group1_keywords, netwo term = segment_annotation.get_term() index = 0 for annotation in self._annotations: - if (annotation.get_name() == name) and (annotation.get_term() == term): - # print("Found annotation name", name, "term", term) + if annotation.get_name() == name: + existing_term = annotation.get_term() + if term != existing_term: + print("Warning: Found existing annotation with name", name, + "but existing term", existing_term, "does not equal new term", term) + if term and (existing_term is None): + annotation.set_term(term) break # exists already if name > annotation.get_name(): index += 1 @@ -56,6 +70,11 @@ def __init__(self, segmentation_file_names: list, network_group1_keywords, netwo # print("Add annoation name", name, "term", term, "dim", segment_annotation.get_dimension(), # "category", segment_annotation.get_category()) self._annotations.insert(index, segment_annotation) + # by default put all annotations without terms into the EXCLUDE category + for annotation in self._annotations: + if not annotation.get_term(): + # print("Exclude annotation", annotation.get_name(),"with no term") + annotation.set_category(AnnotationCategory.EXCLUDE) self._max_distance = 0.0 if self._segments: with HierarchicalChangeManager(self._root_region): @@ -222,5 +241,237 @@ def get_segments(self): def get_version(self): return self._version + def _stitch(self, region): + """ + :param region: Target region to stitch segmentations into. + """ + fieldmodule = region.getFieldmodule() + with ChangeManager(fieldmodule): + coordinates = find_or_create_field_coordinates(fieldmodule) + radius = find_or_create_field_finite_element(fieldmodule, "radius", 1, managed=True) + if self._segments and self._segments[0].get_raw_region().getFieldmodule().findFieldByName("rgb").isValid(): + rgb = find_or_create_field_finite_element(fieldmodule, "rgb", 3, managed=True) + else: + rgb = None + marker_name = find_or_create_field_stored_string(fieldmodule, "marker_name", managed=True) + nodes = fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_NODES) + datapoints = fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_DATAPOINTS) + nodetemplate = nodes.createNodetemplate() + nodetemplate.defineField(coordinates) + nodetemplate.defineField(radius) + if rgb: + nodetemplate.defineField(rgb) + marker_nodetemplate = datapoints.createNodetemplate() + marker_nodetemplate.defineField(coordinates) + marker_nodetemplate.defineField(marker_name) + marker_nodetemplate.defineField(radius) + if rgb: + marker_nodetemplate.defineField(rgb) + mesh = fieldmodule.findMeshByDimension(1) + elementtemplate = mesh.createElementtemplate() + elementtemplate.setElementShapeType(Element.SHAPE_TYPE_LINE) + linear_basis = fieldmodule.createElementbasis(1, Elementbasis.FUNCTION_TYPE_LINEAR_LAGRANGE) + eft = mesh.createElementfieldtemplate(linear_basis) + elementtemplate.defineField(coordinates, -1, eft) + elementtemplate.defineField(radius, -1, eft) + if rgb: + elementtemplate.defineField(rgb, -1, eft) + fieldcache = fieldmodule.createFieldcache() + node_identifier = 1 + datapoint_identifier = 1 + element_identifier = 1 + # create annotation groups in output: + annotation_groups = {} # map from annotation name to list of Zinc groups (2nd is term group) + for annotation in self._annotations: + if annotation.get_category() != AnnotationCategory.EXCLUDE: + name = annotation.get_name() + groups = [find_or_create_field_group(fieldmodule, name)] + term = annotation.get_term() + if term: + groups.append(find_or_create_field_group(fieldmodule, term)) + annotation_groups[name] = groups + marker_group = find_or_create_field_group(fieldmodule, "marker") + marker_datapoint_group = marker_group.getOrCreateNodesetGroup(datapoints) + processed_segments = [] + segment_node_maps = [{} for segment in self._segments] # maps from segment node id to output node id + + # stitch segments in order of connections, followed by unconnected segments + for connection in self._connections: + segment_node_map_pair = [segment_node_maps[self._segments.index(segment)] + for segment in connection.get_segments()] + for segment, segment_node_map in zip(connection.get_segments(), segment_node_map_pair): + output_segment_elements = False + if segment not in processed_segments: + node_identifier, datapoint_identifier = _output_segment_nodes_and_markers( + segment, segment_node_map, annotation_groups, + fieldmodule, fieldcache, coordinates, radius, rgb, marker_name, marker_datapoint_group, + nodetemplate, marker_nodetemplate, node_identifier, datapoint_identifier) + output_segment_elements = True + processed_segments.append(segment) + if segment is connection.get_segments()[1]: + element_identifier = _output_connection_elements( + connection, segment_node_map_pair, annotation_groups, + fieldmodule, fieldcache, coordinates, + eft, elementtemplate, element_identifier) + if output_segment_elements: + element_identifier = _output_segment_elements( + segment, segment_node_map, annotation_groups, + fieldmodule, fieldcache, coordinates, + eft, elementtemplate, element_identifier) + # output any unconnected segments + for segment, segment_node_map in zip(self._segments, segment_node_maps): + if segment not in processed_segments: + node_identifier, datapoint_identifier = _output_segment_nodes_and_markers( + segment, segment_node_map, annotation_groups, + fieldmodule, fieldcache, coordinates, radius, rgb, marker_name, marker_datapoint_group, + nodetemplate, marker_nodetemplate, node_identifier, datapoint_identifier) + element_identifier = _output_segment_elements( + segment, segment_node_map, annotation_groups, + fieldmodule, fieldcache, coordinates, + eft, elementtemplate, element_identifier) + processed_segments.append(segment) + def write_output_segmentation_file(self, file_name): - pass + self._stitch(self._stitch_region) + self._stitch_region.writeFile(file_name) + + +def _output_segment_nodes_and_markers( + segment, segment_node_map, annotation_groups, + fieldmodule, fieldcache, coordinates, radius, rgb, marker_name, marker_datapoint_group, + nodetemplate, marker_nodetemplate, node_identifier, datapoint_identifier): + raw_region = segment.get_raw_region() + raw_fieldmodule = raw_region.getFieldmodule() + raw_coordinates = raw_fieldmodule.findFieldByName("coordinates").castFiniteElement() + raw_radius = raw_fieldmodule.findFieldByName("radius").castFiniteElement() + raw_rgb = raw_fieldmodule.findFieldByName("rgb").castFiniteElement() if rgb else None + raw_marker_name = raw_fieldmodule.findFieldByName("marker_name").castStoredString() + raw_nodes = raw_fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_NODES) + rotation = [math.radians(angle_degrees) for angle_degrees in segment.get_rotation()] + rotation_matrix = euler_to_rotation_matrix(rotation) + translation = segment.get_translation() + raw_groups = get_group_list(raw_fieldmodule) + raw_nodeset_groups = [] + nodeset_group_lists = [] + nodes = fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_NODES) + datapoints = fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_DATAPOINTS) + segment_group = find_or_create_field_group(fieldmodule, segment.get_name()) + segment_node_group = segment_group.getOrCreateNodesetGroup(nodes) + segment_datapoint_group = segment_group.getOrCreateNodesetGroup(datapoints) + for raw_group in raw_groups: + group_name = raw_group.getName() + groups = annotation_groups.get(group_name) + if groups: + raw_nodeset_group = raw_group.getNodesetGroup(raw_nodes) + if raw_nodeset_group.isValid() and (raw_nodeset_group.getSize() > 0): + raw_nodeset_groups.append(raw_nodeset_group) + nodeset_group_lists.append([group.getOrCreateNodesetGroup(nodes) for group in groups]) + raw_fieldcache = raw_fieldmodule.createFieldcache() + raw_nodeiterator = raw_nodes.createNodeiterator() + raw_node = raw_nodeiterator.next() + while raw_node.isValid(): + node = None + for raw_nodeset_group, nodeset_group_list in zip(raw_nodeset_groups, nodeset_group_lists): + if raw_nodeset_group.containsNode(raw_node): + if not node: + raw_node_identifier = raw_node.getIdentifier() + node = nodes.createNode(node_identifier, nodetemplate) + raw_fieldcache.setNode(raw_node) + fieldcache.setNode(node) + result, raw_x = raw_coordinates.evaluateReal(raw_fieldcache, 3) + x = add(matrix_vector_mult(rotation_matrix, raw_x), translation) + coordinates.setNodeParameters(fieldcache, -1, Node.VALUE_LABEL_VALUE, 1, x) + result, r = raw_radius.evaluateReal(raw_fieldcache, 1) + radius.setNodeParameters(fieldcache, -1, Node.VALUE_LABEL_VALUE, 1, r) + if rgb: + result, rgb_value = raw_rgb.evaluateReal(raw_fieldcache, 3) + rgb.setNodeParameters(fieldcache, -1, Node.VALUE_LABEL_VALUE, 1, rgb_value) + segment_node_map[raw_node_identifier] = node_identifier + segment_node_group.addNode(node) + node_identifier += 1 + for nodeset_group in nodeset_group_list: + nodeset_group.addNode(node) + raw_node = raw_nodeiterator.next() + raw_datapoints = raw_fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_DATAPOINTS) + raw_dataiterator = raw_datapoints.createNodeiterator() + raw_datapoint = raw_dataiterator.next() + while raw_datapoint.isValid(): + datapoint = datapoints.createNode(datapoint_identifier, marker_nodetemplate) + raw_fieldcache.setNode(raw_datapoint) + fieldcache.setNode(datapoint) + result, raw_x = raw_coordinates.evaluateReal(raw_fieldcache, 3) + x = add(matrix_vector_mult(rotation_matrix, raw_x), translation) + coordinates.setNodeParameters(fieldcache, -1, Node.VALUE_LABEL_VALUE, 1, x) + result, r = raw_radius.evaluateReal(raw_fieldcache, 1) + radius.setNodeParameters(fieldcache, -1, Node.VALUE_LABEL_VALUE, 1, r) + if rgb: + result, rgb_value = raw_rgb.evaluateReal(raw_fieldcache, 3) + rgb.setNodeParameters(fieldcache, -1, Node.VALUE_LABEL_VALUE, 1, rgb_value) + name = raw_marker_name.evaluateString(raw_fieldcache) + marker_name.assignString(fieldcache, name) + marker_datapoint_group.addNode(datapoint) + segment_datapoint_group.addNode(datapoint) + datapoint_identifier += 1 + raw_datapoint = raw_dataiterator.next() + return node_identifier, datapoint_identifier + + +def _output_segment_elements(segment, segment_node_map, annotation_groups, + fieldmodule, fieldcache, coordinates, + eft, elementtemplate, element_identifier): + raw_region = segment.get_raw_region() + raw_fieldmodule = raw_region.getFieldmodule() + raw_coordinates = raw_fieldmodule.findFieldByName("coordinates").castFiniteElement() + raw_mesh = raw_fieldmodule.findMeshByDimension(1) + segment_group = find_or_create_field_group(fieldmodule, segment.get_name()) + mesh = fieldmodule.findMeshByDimension(1) + segment_mesh_group = segment_group.getOrCreateMeshGroup(mesh) + raw_groups = get_group_list(raw_fieldmodule) + raw_mesh_groups = [] + mesh_group_lists = [] + for raw_group in raw_groups: + group_name = raw_group.getName() + groups = annotation_groups.get(group_name) + if groups: + raw_mesh_group = raw_group.getMeshGroup(raw_mesh) + if raw_mesh_group.isValid() and (raw_mesh_group.getSize() > 0): + raw_mesh_groups.append(raw_mesh_group) + mesh_group_lists.append([group.getOrCreateMeshGroup(mesh) for group in groups]) + raw_elementiterator = raw_mesh.createElementiterator() + raw_element = raw_elementiterator.next() + raw_eft = raw_element.getElementfieldtemplate(raw_coordinates, -1) + while raw_element.isValid(): + element = None + for raw_mesh_group, mesh_group_list in zip(raw_mesh_groups, mesh_group_lists): + if raw_mesh_group.containsElement(raw_element): + if not element: + element = mesh.createElement(element_identifier, elementtemplate) + element.setNodesByIdentifier( + eft, [segment_node_map[raw_element.getNode(raw_eft, ln).getIdentifier()] + for ln in [1, 2]]) + segment_mesh_group.addElement(element) + element_identifier += 1 + for mesh_group in mesh_group_list: + mesh_group.addElement(element) + raw_element = raw_elementiterator.next() + return element_identifier + + +def _output_connection_elements(connection, segment_node_maps, annotation_groups, + fieldmodule, fieldcache, coordinates, + eft, elementtemplate, element_identifier): + connection_group = find_or_create_field_group(fieldmodule, connection.get_name()) + mesh = fieldmodule.findMeshByDimension(1) + connection_mesh_group = connection_group.getOrCreateMeshGroup(mesh) + linked_nodes = connection.get_linked_nodes() + for annotation_name, annotation_linked_nodes in linked_nodes.items(): + groups = annotation_groups.get(annotation_name) + mesh_groups = [group.getOrCreateMeshGroup(mesh) for group in groups] + mesh_groups.append(connection_mesh_group) + for segment_node_identifiers in annotation_linked_nodes: + element = mesh.createElement(element_identifier, elementtemplate) + element.setNodesByIdentifier(eft, [segment_node_maps[n][segment_node_identifiers[n]] for n in range(2)]) + for mesh_group in mesh_groups: + mesh_group.addElement(element) + element_identifier += 1 + return element_identifier diff --git a/tests/resources/vagus-segment1.exf b/tests/resources/vagus-segment1.exf index e24407c..d34997a 100644 --- a/tests/resources/vagus-segment1.exf +++ b/tests/resources/vagus-segment1.exf @@ -435,27 +435,27 @@ Node: 84 Define node template: node2 Shape. Dimension=0 #Fields=3 -1) marker coordinates, coordinate, rectangular cartesian, real, #Components=3 +1) coordinates, coordinate, rectangular cartesian, real, #Components=3 x. #Values=1 (value) y. #Values=1 (value) z. #Values=1 (value) -2) marker_location, field, element_xi, #Components=1, host mesh=mesh1d, host mesh dimension=1 +2) marker_name, field, string, #Components=1 1. #Values=1 (value) -3) marker_name, field, string, #Components=1 +3) radius, field, rectangular cartesian, real, #Components=1 1. #Values=1 (value) Node template: node2 Node: 85 9.014003757666196e-01 -3.826434512807377e-03 -1.393280944888287e-02 - 2 7.935239999999998e-01 "landmark 1" + 0.01 Node: 86 2.532597965468619e+00 1.278581726279602e-01 5.305649151562682e-01 - 64 1.000000000000000e+00 orientation + 0.01 !#mesh mesh1d, dimension=1, nodeset=nodes Define element template: element1 Shape. Dimension=1, line diff --git a/tests/resources/vagus-segment2.exf b/tests/resources/vagus-segment2.exf index 80254c9..03d340a 100644 --- a/tests/resources/vagus-segment2.exf +++ b/tests/resources/vagus-segment2.exf @@ -385,21 +385,21 @@ Node: 74 Define node template: node2 Shape. Dimension=0 #Fields=3 -1) marker coordinates, coordinate, rectangular cartesian, real, #Components=3 +1) coordinates, coordinate, rectangular cartesian, real, #Components=3 x. #Values=1 (value) y. #Values=1 (value) z. #Values=1 (value) -2) marker_location, field, element_xi, #Components=1, host mesh=mesh1d, host mesh dimension=1 +2) marker_name, field, string, #Components=1 1. #Values=1 (value) -3) marker_name, field, string, #Components=1 +3) radius, field, rectangular cartesian, real, #Components=1 1. #Values=1 (value) Node template: node2 Node: 75 2.398918879869167e+00 -1.432597145502751e-01 2.241355737421549e-01 - 55 6.000000000000000e-01 orientation + 0.01 !#mesh mesh1d, dimension=1, nodeset=nodes Define element template: element1 Shape. Dimension=1, line diff --git a/tests/resources/vagus-segment3.exf b/tests/resources/vagus-segment3.exf index 025b988..55e9809 100644 --- a/tests/resources/vagus-segment3.exf +++ b/tests/resources/vagus-segment3.exf @@ -355,27 +355,27 @@ Node: 68 Define node template: node2 Shape. Dimension=0 #Fields=3 -1) marker coordinates, coordinate, rectangular cartesian, real, #Components=3 +1) coordinates, coordinate, rectangular cartesian, real, #Components=3 x. #Values=1 (value) y. #Values=1 (value) z. #Values=1 (value) -2) marker_location, field, element_xi, #Components=1, host mesh=mesh1d, host mesh dimension=1 +2) marker_name, field, string, #Components=1 1. #Values=1 (value) -3) marker_name, field, string, #Components=1 +3) radius, field, rectangular cartesian, real, #Components=1 1. #Values=1 (value) Node template: node2 Node: 69 3.000000000000000e+00 0.000000000000000e+00 7.000000000000000e-01 - 58 8.299280188212664e-01 orientation + 0.01 Node: 70 1.599724956533351e+00 3.788960603141545e-03 4.423695723249146e-03 - 4 6.144349999999998e-01 "landmark 2" + 0.01 !#mesh mesh1d, dimension=1, nodeset=nodes Define element template: element1 Shape. Dimension=1, line From 4cb1af71a77892660f921a35e129049c3711e4c0 Mon Sep 17 00:00:00 2001 From: Richard Christie Date: Thu, 12 Dec 2024 13:56:33 +1300 Subject: [PATCH 10/13] Add align and stitch test --- src/segmentationstitcher/stitcher.py | 11 +-- tests/test_vagus.py | 103 ++++++++++++++++++++++++--- 2 files changed, 99 insertions(+), 15 deletions(-) diff --git a/src/segmentationstitcher/stitcher.py b/src/segmentationstitcher/stitcher.py index bcc9ddf..d600c94 100644 --- a/src/segmentationstitcher/stitcher.py +++ b/src/segmentationstitcher/stitcher.py @@ -70,10 +70,11 @@ def __init__(self, segmentation_file_names: list, network_group1_keywords, netwo # print("Add annoation name", name, "term", term, "dim", segment_annotation.get_dimension(), # "category", segment_annotation.get_category()) self._annotations.insert(index, segment_annotation) - # by default put all annotations without terms into the EXCLUDE category + # by default put all GENERAL annotations without terms into the EXCLUDE category, except "marker" for annotation in self._annotations: - if not annotation.get_term(): - # print("Exclude annotation", annotation.get_name(),"with no term") + if ((annotation.get_category() == AnnotationCategory.GENERAL) and (not annotation.get_term()) and + (annotation.get_name() != "marker")): + # print("Exclude general annotation", annotation.get_name(), "with no term") annotation.set_category(AnnotationCategory.EXCLUDE) self._max_distance = 0.0 if self._segments: @@ -241,7 +242,7 @@ def get_segments(self): def get_version(self): return self._version - def _stitch(self, region): + def stitch(self, region): """ :param region: Target region to stitch segmentations into. """ @@ -332,7 +333,7 @@ def _stitch(self, region): processed_segments.append(segment) def write_output_segmentation_file(self, file_name): - self._stitch(self._stitch_region) + self.stitch(self._stitch_region) self._stitch_region.writeFile(file_name) diff --git a/tests/test_vagus.py b/tests/test_vagus.py index 84eefe9..422872f 100644 --- a/tests/test_vagus.py +++ b/tests/test_vagus.py @@ -1,5 +1,8 @@ +import math import os import unittest +from cmlibs.utils.zinc.finiteelement import evaluate_field_nodeset_range +from cmlibs.zinc.field import Field from segmentationstitcher.annotation import AnnotationCategory from segmentationstitcher.stitcher import Stitcher from tests.testutils import assertAlmostEqualList @@ -28,7 +31,7 @@ def test_io_vagus1(self): segments1 = stitcher1.get_segments() self.assertEqual(3, len(segments1)) segment12 = segments1[1] - self.assertEqual("vagus-segment2", segment12.get_name()) + self.assertEqual("vagus-segment2.exf", segment12.get_name()) assertAlmostEqualList(self, zero, segment12.get_translation(), delta=TOL) segment12.set_translation(new_translation) annotations1 = stitcher1.get_annotations() @@ -48,9 +51,9 @@ def test_io_vagus1(self): self.assertEqual(AnnotationCategory.NETWORK_GROUP_1, annotation15.get_category()) annotation17 = annotations1[6] self.assertEqual("unknown", annotation17.get_name()) - self.assertEqual(AnnotationCategory.GENERAL, annotation17.get_category()) + self.assertEqual(AnnotationCategory.EXCLUDE, annotation17.get_category()) - connection = stitcher1.create_connection([segments1[0], segments1[1]]) + stitcher1.create_connection([segments1[0], segments1[1]]) connections = stitcher1.get_connections() self.assertEqual(1, len(connections)) @@ -61,22 +64,26 @@ def test_io_vagus1(self): exclude13_mesh_group = exclude13_group.getMeshGroup(mesh1d) general13_group = segment13.get_category_group(AnnotationCategory.GENERAL) general13_mesh_group = general13_group.getMeshGroup(mesh1d) - self.assertFalse(exclude13_mesh_group.isValid()) - self.assertEqual(27, general13_mesh_group.getSize()) + indep13_group = segment13.get_category_group(AnnotationCategory.INDEPENDENT_NETWORK) + indep13_mesh_group = indep13_group.getMeshGroup(mesh1d) + self.assertEqual(1, exclude13_mesh_group.getSize()) + self.assertEqual(26, general13_mesh_group.getSize()) + self.assertFalse(indep13_mesh_group.isValid()) annotation17_group = segment13.get_annotation_group(annotation17) annotation17_mesh_group = annotation17_group.getMeshGroup(mesh1d) self.assertEqual(1, annotation17_mesh_group.getSize()) - annotation17.set_category(AnnotationCategory.EXCLUDE) - exclude13_mesh_group = exclude13_group.getMeshGroup(mesh1d) - self.assertEqual(1, exclude13_mesh_group.getSize()) + annotation17.set_category(AnnotationCategory.INDEPENDENT_NETWORK) + indep13_mesh_group = indep13_group.getMeshGroup(mesh1d) + self.assertEqual(0, exclude13_mesh_group.getSize()) self.assertEqual(26, general13_mesh_group.getSize()) + self.assertEqual(1, indep13_mesh_group.getSize()) settings = stitcher1.encode_settings() self.assertEqual(3, len(settings["segments"])) self.assertEqual(7, len(settings["annotations"])) self.assertEqual(1, settings["version"]) assertAlmostEqualList(self, new_translation, settings["segments"][1]["translation"], delta=TOL) - self.assertEqual(AnnotationCategory.EXCLUDE.name, settings["annotations"][6]["category"]) + self.assertEqual(AnnotationCategory.INDEPENDENT_NETWORK.name, settings["annotations"][6]["category"]) stitcher2 = Stitcher(segmentation_file_names, network_group1_keywords, network_group2_keywords) stitcher2.decode_settings(settings) @@ -85,8 +92,84 @@ def test_io_vagus1(self): assertAlmostEqualList(self, new_translation, segment22.get_translation(), delta=TOL) annotations2 = stitcher2.get_annotations() annotation27 = annotations2[6] - self.assertEqual(AnnotationCategory.EXCLUDE, annotation27.get_category()) + self.assertEqual(AnnotationCategory.INDEPENDENT_NETWORK, annotation27.get_category()) + + def test_align_stitch_vagus1(self): + """ + Test adding connections between segments, auto-aligning them and outputting stitched segmentation. + """ + resource_names = [ + "vagus-segment1.exf", + "vagus-segment2.exf", + "vagus-segment3.exf", + ] + TOL = 1.0E-5 + segmentation_file_names = [os.path.join(here, "resources", resource_name) for resource_name in resource_names] + network_group1_keywords = ["vagus", "nerve", "trunk", "branch"] + network_group2_keywords = ["fascicle"] + stitcher = Stitcher(segmentation_file_names, network_group1_keywords, network_group2_keywords) + segments = stitcher.get_segments() + + segments[1].set_rotation([0.0, -10.0, -60.0]) + segments[1].set_translation([5.0, 0.0, 0.0]) + segments[2].set_translation([10.0, 0.0, 0.5]) + + expected_fascicle_sizes = [32, 25, 25] + expected_vagus_sizes = [10, 10, 9] + for s in range(3): + fieldmodule = segments[s].get_raw_region().getFieldmodule() + fascicle = fieldmodule.findFieldByName("Fascicle").castGroup() + self.assertTrue(fascicle.isValid()) + fascicle_mesh_group = fascicle.getMeshGroup(fieldmodule.findMeshByDimension(1)) + self.assertEqual(fascicle_mesh_group.getSize(), expected_fascicle_sizes[s]) + vagus = fieldmodule.findFieldByName("left vagus X nerve trunk").castGroup() + self.assertTrue(vagus.isValid()) + vagus_mesh_group = vagus.getMeshGroup(fieldmodule.findMeshByDimension(1)) + self.assertEqual(vagus_mesh_group.getSize(), expected_vagus_sizes[s]) + + connection01 = stitcher.create_connection([segments[0], segments[1]]) + connection12 = stitcher.create_connection([segments[1], segments[2]]) + + connection01.optimise_transformation() + assertAlmostEqualList(self, [-2.894576, -5.574263, -63.93093], segments[1].get_rotation(), delta=TOL) + assertAlmostEqualList(self, [4.88866, -0.01213587, 0.01357185], segments[1].get_translation(), delta=TOL) + linked_nodes01 = connection01.get_linked_nodes() + self.assertEqual(linked_nodes01, { + "Fascicle": [[22, 28], [35, 12], [40, 23]], + "left vagus X nerve trunk": [[11, 1]]}) + + connection12.optimise_transformation() + assertAlmostEqualList(self, [-4.919549, -2.280625, -13.52467], segments[2].get_rotation(), delta=TOL) + assertAlmostEqualList(self, [9.543171, -0.3494296, 0.03930248], segments[2].get_translation(), delta=TOL) + linked_nodes12 = connection12.get_linked_nodes() + self.assertEqual(linked_nodes12, { + "Fascicle": [[22, 15], [38, 25]], + "left vagus X nerve trunk": [[11, 1]]}) + + output_region = stitcher.get_root_region().createRegion() + stitcher.stitch(output_region) + + fieldmodule = output_region.getFieldmodule() + coordinates = fieldmodule.findFieldByName("coordinates").castFiniteElement() + nodes = fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_NODES) + datapoints = fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_DATAPOINTS) + mesh = fieldmodule.findMeshByDimension(1) + minimums, maximums = evaluate_field_nodeset_range(coordinates, nodes) + assertAlmostEqualList(self, [0.04674543239403558, -1.5276719288528786, -0.5804178855490847], minimums, delta=TOL) + assertAlmostEqualList(self, [13.538987060134247, 1.11238124203403, 0.6470665850902932], maximums, delta=TOL) + fascicle = fieldmodule.findFieldByName("Fascicle").castGroup() + self.assertTrue(fascicle.isValid()) + fascicle_mesh_group = fascicle.getMeshGroup(mesh) + self.assertEqual(fascicle_mesh_group.getSize(), sum(expected_fascicle_sizes) + 5) + vagus = fieldmodule.findFieldByName("left vagus X nerve trunk").castGroup() + self.assertTrue(vagus.isValid()) + vagus_mesh_group = vagus.getMeshGroup(mesh) + self.assertEqual(vagus_mesh_group.getSize(), sum(expected_vagus_sizes) + 2) + marker = fieldmodule.findFieldByName("marker").castGroup() + self.assertTrue(marker.isValid()) + marker_datapoint_group = marker.getNodesetGroup(datapoints) + self.assertEqual(marker_datapoint_group.getSize(), 5) if __name__ == "__main__": unittest.main() From 145f5a7141ed0b896fb810f14447cda96d8b85b5 Mon Sep 17 00:00:00 2001 From: Richard Christie Date: Fri, 13 Dec 2024 10:37:36 +1300 Subject: [PATCH 11/13] Inter-segment links are only between same annotation --- src/segmentationstitcher/connection.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/segmentationstitcher/connection.py b/src/segmentationstitcher/connection.py index 94ed6bd..892c520 100644 --- a/src/segmentationstitcher/connection.py +++ b/src/segmentationstitcher/connection.py @@ -360,8 +360,8 @@ def build_links(self, build_link_objects=True): for index1, end_point_data1 in enumerate(sorted_end_point_data1): node_id1, coordinates1, direction1, radius1, annotation1 = end_point_data1 category1 = annotation1.get_category() - if (category0 != category1) or ( - (category0 == AnnotationCategory.INDEPENDENT_NETWORK) and (annotation0 != annotation1)): + # inter-segment links are only to the same annotation; links within category will be done separately + if annotation0 != annotation1: continue # end points are not allowed to join direction_score = math.fabs(1.0 + dot(direction0, direction1)) if direction_score > 0.5: # arbitrary factor From c30b3dbc002121d4804a532245ade265eabbadb0 Mon Sep 17 00:00:00 2001 From: Richard Christie Date: Fri, 13 Dec 2024 14:49:42 +1300 Subject: [PATCH 12/13] Update category groups of new connections --- src/segmentationstitcher/stitcher.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/segmentationstitcher/stitcher.py b/src/segmentationstitcher/stitcher.py index d600c94..3ba7064 100644 --- a/src/segmentationstitcher/stitcher.py +++ b/src/segmentationstitcher/stitcher.py @@ -214,6 +214,7 @@ def create_connection(self, segments, connection_settings={}): connection.decode_settings(connection_settings) self._connections.append(connection) connection.build_links() + connection.update_annotation_category_groups(self._annotations) return connection def delete_connection(self, connection): From 0cc18b03234ce28fcea8305935c7d9287598b1ea Mon Sep 17 00:00:00 2001 From: Richard Christie Date: Mon, 16 Dec 2024 16:07:49 +1300 Subject: [PATCH 13/13] Maintain working end points group --- src/segmentationstitcher/segment.py | 58 +++++++++++++++---- src/segmentationstitcher/stitcher.py | 87 ++++++++++++++-------------- 2 files changed, 91 insertions(+), 54 deletions(-) diff --git a/src/segmentationstitcher/segment.py b/src/segmentationstitcher/segment.py index b1d743f..c9c6927 100644 --- a/src/segmentationstitcher/segment.py +++ b/src/segmentationstitcher/segment.py @@ -5,7 +5,7 @@ from cmlibs.maths.vectorops import cross, dot, magnitude, matrix_mult, mult, normalize, set_magnitude, sub from cmlibs.utils.zinc.field import ( - get_group_list, find_or_create_field_coordinates, find_or_create_field_finite_element) + get_group_list, find_or_create_field_coordinates, find_or_create_field_finite_element, find_or_create_field_group) from cmlibs.utils.zinc.finiteelement import evaluate_field_nodeset_range from cmlibs.utils.zinc.group import group_add_group_local_contents, group_remove_group_local_contents from cmlibs.utils.zinc.general import ChangeManager @@ -58,12 +58,14 @@ def __init__(self, name, segmentation_file_name, root_region): self._raw_minimums, self._raw_maximums = evaluate_field_nodeset_range(self._raw_coordinates, self._raw_nodes) self._working_region = self._base_region.createChild("working") self._working_fieldmodule = self._working_region.getFieldmodule() - self._working_datapoints = self._working_fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_DATAPOINTS) - self._working_coordinates = find_or_create_field_coordinates(self._working_fieldmodule) - self._working_radius_direction = find_or_create_field_finite_element( - self._working_fieldmodule, "radius_direction", 3) - self._working_best_fit_line_orientation = find_or_create_field_finite_element( - self._working_fieldmodule, "best_fit_line_orientation", 9) + with ChangeManager(self._working_fieldmodule): + self._working_datapoints = self._working_fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_DATAPOINTS) + self._working_coordinates = find_or_create_field_coordinates(self._working_fieldmodule) + self._working_radius_direction = find_or_create_field_finite_element( + self._working_fieldmodule, "radius_direction", 3) + self._working_best_fit_line_orientation = find_or_create_field_finite_element( + self._working_fieldmodule, "best_fit_line_orientation", 9) + self._working_end_group = find_or_create_field_group(self._working_fieldmodule, "active_ends") self._element_node_ids, self._node_element_ids = self._get_element_node_maps() self._end_node_ids = self._get_end_node_ids() self._end_point_data = {} # dict node_id -> (coordinates, direction, radius, annotation) @@ -325,8 +327,8 @@ def create_end_point_directions(self, annotations, max_distance): annotation = tmp_annotation break self._end_point_data[end_node_id] = (start_x, normalize(direction), mean_r, annotation) - # set up visualization objects: - node = self._working_datapoints.createNode(-1, nodetemplate) + # set up visualization objects. End direction datapoints have same identifiers as raw end nodes + node = self._working_datapoints.createNode(end_node_id, nodetemplate) fieldcache.setNode(node) radius_direction = set_magnitude(direction, mean_r) self._working_coordinates.setNodeParameters(fieldcache, -1, Node.VALUE_LABEL_VALUE, 1, start_x) @@ -458,6 +460,13 @@ def get_working_region(self): """ return self._working_region + def get_working_end_group(self): + """ + Get group from working region containing connectable end points in segment. + :return: Zinc group containing connectable end points. + """ + return self._working_end_group + def update_annotation_category(self, annotation, old_category=AnnotationCategory.EXCLUDE): """ Ensures special groups representing annotion categories contain via addition or removal the @@ -478,14 +487,14 @@ def update_annotation_category(self, annotation, old_category=AnnotationCategory group_remove_group_local_contents(old_category_group, annotation_group) new_category_group = self.get_category_group(new_category) group_add_group_local_contents(new_category_group, annotation_group) + self._update_working_end_group() def update_annotation_category_groups(self, annotations): """ Rebuild all annotation category groups e.g. after loading settings. :param annotations: List of all annotations from stitcher. """ - fieldmodule = self._raw_region.getFieldmodule() - with ChangeManager(fieldmodule): + with ChangeManager(self._raw_fieldmodule): # clear all category groups for category in AnnotationCategory: category_group = self.get_category_group(category) @@ -495,7 +504,34 @@ def update_annotation_category_groups(self, annotations): if annotation_group: category_group = self.get_category_group(annotation.get_category()) group_add_group_local_contents(category_group, annotation_group) + self._update_working_end_group() + def _update_working_end_group(self): + """ + Ensure working end group contains all connectable end points. + """ + connectable_node_groups = [] + for category in AnnotationCategory: + if category.is_connectable(): + category_group = self.get_category_group(category) + node_group = category_group.getNodesetGroup(self._raw_nodes) + if node_group.isValid() and (node_group.getSize() > 0): + connectable_node_groups.append(node_group) + with ChangeManager(self._working_fieldmodule): + self._working_end_group.clear() + working_datapoints = \ + self._working_fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_DATAPOINTS) + working_node_group = self._working_end_group.getOrCreateNodesetGroup(working_datapoints) + working_nodeiterator = working_datapoints.createNodeiterator() + working_node = working_nodeiterator.next() + while working_node.isValid(): + node_identifier = working_node.getIdentifier() + raw_node = self._raw_nodes.findNodeByIdentifier(node_identifier) + for node_group in connectable_node_groups: + if node_group.containsNode(raw_node): + working_node_group.addNode(working_node) + break; + working_node = working_nodeiterator.next() def fit_line(path_coordinates, path_radii, x1=None, x2=None, filter_proportion=0.0): """ diff --git a/src/segmentationstitcher/stitcher.py b/src/segmentationstitcher/stitcher.py index 3ba7064..eefa454 100644 --- a/src/segmentationstitcher/stitcher.py +++ b/src/segmentationstitcher/stitcher.py @@ -41,50 +41,51 @@ def __init__(self, segmentation_file_names: list, network_group1_keywords, netwo self._term_keywords = ['fma:', 'fma_', 'ilx:', 'ilx_', 'uberon:', 'uberon_'] self._segments = [] self._connections = [] - self._version = 1 # increment when new settings added to migrate older serialised settings - max_range_reciprocal_sum = 0.0 - for segmentation_file_name in segmentation_file_names: - name = Path(segmentation_file_name).name - segment = Segment(name, segmentation_file_name, self._root_region) - max_range_reciprocal_sum += 1.0 / segment.get_max_range() - self._segments.append(segment) - segment_annotations = region_get_annotations( - segment.get_raw_region(), self._network_group1_keywords, self._network_group2_keywords, - self._term_keywords) - for segment_annotation in segment_annotations: - name = segment_annotation.get_name() - term = segment_annotation.get_term() - index = 0 - for annotation in self._annotations: - if annotation.get_name() == name: - existing_term = annotation.get_term() - if term != existing_term: - print("Warning: Found existing annotation with name", name, - "but existing term", existing_term, "does not equal new term", term) - if term and (existing_term is None): - annotation.set_term(term) - break # exists already - if name > annotation.get_name(): - index += 1 - else: - # print("Add annoation name", name, "term", term, "dim", segment_annotation.get_dimension(), - # "category", segment_annotation.get_category()) - self._annotations.insert(index, segment_annotation) - # by default put all GENERAL annotations without terms into the EXCLUDE category, except "marker" - for annotation in self._annotations: - if ((annotation.get_category() == AnnotationCategory.GENERAL) and (not annotation.get_term()) and - (annotation.get_name() != "marker")): - # print("Exclude general annotation", annotation.get_name(), "with no term") - annotation.set_category(AnnotationCategory.EXCLUDE) self._max_distance = 0.0 - if self._segments: - with HierarchicalChangeManager(self._root_region): - self._max_distance = 0.25 * len(self._segments) / max_range_reciprocal_sum - for segment in self._segments: - segment.create_end_point_directions(self._annotations, self._max_distance) - segment.update_annotation_category_groups(self._annotations) - for annotation in self._annotations: - annotation.set_category_change_callback(self._annotation_category_change) + self._version = 1 # increment when new settings added to migrate older serialised settings + with HierarchicalChangeManager(self._root_region): + max_range_reciprocal_sum = 0.0 + for segmentation_file_name in segmentation_file_names: + name = Path(segmentation_file_name).name + segment = Segment(name, segmentation_file_name, self._root_region) + max_range_reciprocal_sum += 1.0 / segment.get_max_range() + self._segments.append(segment) + segment_annotations = region_get_annotations( + segment.get_raw_region(), self._network_group1_keywords, self._network_group2_keywords, + self._term_keywords) + for segment_annotation in segment_annotations: + name = segment_annotation.get_name() + term = segment_annotation.get_term() + index = 0 + for annotation in self._annotations: + if annotation.get_name() == name: + existing_term = annotation.get_term() + if term != existing_term: + print("Warning: Found existing annotation with name", name, + "but existing term", existing_term, "does not equal new term", term) + if term and (existing_term is None): + annotation.set_term(term) + break # exists already + if name > annotation.get_name(): + index += 1 + else: + # print("Add annoation name", name, "term", term, "dim", segment_annotation.get_dimension(), + # "category", segment_annotation.get_category()) + self._annotations.insert(index, segment_annotation) + # by default put all GENERAL annotations without terms into the EXCLUDE category, except "marker" + for annotation in self._annotations: + if ((annotation.get_category() == AnnotationCategory.GENERAL) and (not annotation.get_term()) and + (annotation.get_name() != "marker")): + # print("Exclude general annotation", annotation.get_name(), "with no term") + annotation.set_category(AnnotationCategory.EXCLUDE) + if self._segments: + with HierarchicalChangeManager(self._root_region): + self._max_distance = 0.25 * len(self._segments) / max_range_reciprocal_sum + for segment in self._segments: + segment.create_end_point_directions(self._annotations, self._max_distance) + segment.update_annotation_category_groups(self._annotations) + for annotation in self._annotations: + annotation.set_category_change_callback(self._annotation_category_change) def decode_settings(self, settings_in: dict): """