Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make manual alignment distinct from auto alignment #37

Merged
merged 18 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions .github/workflows/run_tests.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
name: Test library

on:
pull_request:

jobs:
build-and-test:
runs-on: ubuntu-22.04
name: Run tests
steps:
- name: Clone source
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: 3.11
- name: Install dependencies
run: |
sudo apt install libopengl0 libglu1-mesa -y
- name: Install library
shell: bash
run: |
pip install https://github.com/cmlibs/zinc/releases/download/v4.2.0/cmlibs.zinc-4.2.0-cp311-cp311-linux_x86_64.whl
pip install .
- name: Run tests
shell: bash
run: |
python -m unittest discover -s tests/
7 changes: 5 additions & 2 deletions src/scaffoldfitter/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,12 +283,13 @@ def load(self):
"""
self._clearFields()
self._region = self._context.createRegion()
self._region.setName("model_region")
self._fieldmodule = self._region.getFieldmodule()
self._rawDataRegion = self._region.createChild("raw_data")
self._loadModel()
self._loadData()
self._defineDataProjectionFields()
# get centre and scale of data coordinates to manage fitting tolerances and steps
# Get centre and scale of data coordinates to manage fitting tolerances and steps.
datapoints = self._fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_DATAPOINTS)
minimums, maximums = evaluate_field_nodeset_range(self._dataCoordinatesField, datapoints)
self._dataCentre = mult(add(minimums, maximums), 0.5)
Expand Down Expand Up @@ -495,11 +496,12 @@ def _loadData(self):
self._discoverDataCoordinatesField()
self._discoverMarkerGroup()

def run(self, endStep=None, modelFileNameStem=None, reorder = False):
def run(self, endStep=None, modelFileNameStem=None, reorder=False):
"""
Run either all remaining fitter steps or up to specified end step.
:param endStep: Last fitter step to run, or None to run all.
:param modelFileNameStem: File name stem for writing intermediate model files.
:param reorder: Reload if reordering.
:return: True if reloaded (so scene changed), False if not.
"""
if not endStep:
Expand Down Expand Up @@ -543,6 +545,7 @@ def _discoverDataCoordinatesField(self):
"""
self._dataCoordinatesField = None
field = None

if self._dataCoordinatesFieldName:
field = self._fieldmodule.findFieldByName(self._dataCoordinatesFieldName)
if not (field and field.isValid()):
Expand Down
178 changes: 129 additions & 49 deletions src/scaffoldfitter/fitterstepalign.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,14 @@ def __init__(self):
super(FitterStepAlign, self).__init__()
self._alignGroups = False
self._alignMarkers = False
self._alignManually = False
self._rotation = None
self._scale = None
self._scaleProportion = None
self._translation = None
self._init_fit_parameters()

def _init_fit_parameters(self):
self._rotation = [0.0, 0.0, 0.0]
self._scale = 1.0
self._scaleProportion = 1.0
Expand All @@ -78,6 +86,7 @@ def decodeSettingsJSONDict(self, dctIn: dict):
dct.update(dctIn)
self._alignGroups = dct["alignGroups"]
self._alignMarkers = dct["alignMarkers"]
self._alignManually = dct["alignManually"]
self._rotation = dct["rotation"]
self._scale = dct["scale"]
scaleProportion = dct.get("scaleProportion")
Expand All @@ -94,6 +103,7 @@ def encodeSettingsJSONDict(self) -> dict:
dct.update({
"alignGroups": self._alignGroups,
"alignMarkers": self._alignMarkers,
"alignManually": self._alignManually,
"rotation": self._rotation,
"scale": self._scale,
"scaleProportion": self._scaleProportion,
Expand Down Expand Up @@ -132,6 +142,97 @@ def setAlignMarkers(self, alignMarkers):
return True
return False

def isAlignManually(self):
return self._alignManually

def setAlignManually(self, alignManually):
if alignManually != self._alignManually:
self._alignManually = alignManually
return True
return False

def _alignable_group_count(self):
count = 0
fieldmodule = self._fitter.getFieldmodule()
groups = get_group_list(fieldmodule)
with ChangeManager(fieldmodule):
for group in groups:
dataGroup = self._fitter.getGroupDataProjectionNodesetGroup(group)
if not dataGroup:
continue
meshGroup = self._fitter.getGroupDataProjectionMeshGroup(group)
if not meshGroup:
continue
count += 1

return count

def _match_markers(self):
writeDiagnostics = self.getDiagnosticLevel() > 0
matches = {}

markerGroup = self._fitter.getMarkerGroup()
if markerGroup is None:
if writeDiagnostics:
print("Align: No marker group to align with.")
return matches

markerNodeGroup, markerLocation, markerCoordinates, markerName = self._fitter.getMarkerModelFields()
if markerNodeGroup is None or markerCoordinates is None or markerName is None:
if writeDiagnostics:
print("Align: No marker group, coordinates or name fields.")

return matches

markerDataGroup, markerDataCoordinates, markerDataName = self._fitter.getMarkerDataFields()
if markerDataGroup is None or markerDataCoordinates is None or markerDataName is None:
if writeDiagnostics:
print("Align: No marker data group, coordinates or name fields.")

return matches

modelMarkers = getNodeNameCentres(markerNodeGroup, markerCoordinates, markerName)
dataMarkers = getNodeNameCentres(markerDataGroup, markerDataCoordinates, markerDataName)

# match model and data markers, warn of unmatched markers
for modelName in modelMarkers:
# name match allows case and whitespace differences
matchName = modelName.strip().casefold()
for dataName in dataMarkers:
if dataName.strip().casefold() == matchName:
entry_name = f"{modelName}_marker"
matches[entry_name] = (modelMarkers[modelName], dataMarkers[dataName])
if writeDiagnostics:
print("Align: Model marker '" + modelName + "' found in data" +
(" as '" + dataName + "'" if (dataName != modelName) else ""))
dataMarkers.pop(dataName)
break
else:
if writeDiagnostics:
print("Align: Model marker '" + modelName + "' not found in data")
if writeDiagnostics:
for dataName in dataMarkers:
print("Align: Data marker '" + dataName + "' not found in model")

return matches

def matchingMarkerCount(self):
return len(self._match_markers())

def matchingGroupCount(self):
return self._alignable_group_count()

def canAutoAlign(self):
total = self.matchingGroupCount() + self.matchingMarkerCount()
return total > 2

def canAlignGroups(self):
return self._alignable_group_count() > 2

def canAlignMarkers(self):
matches = self._match_markers()
return len(matches) > 2

def getRotation(self):
return self._rotation

Expand Down Expand Up @@ -197,23 +298,32 @@ def run(self, modelFileNameStem=None):
"""
modelCoordinates = self._fitter.getModelCoordinatesField()
assert modelCoordinates, "Align: Missing model coordinates"
if self._alignGroups or self._alignMarkers:
if not self._alignManually and (self._alignGroups or self._alignMarkers):
self._doAutoAlign()
elif not self._alignManually and not (self._alignGroups or self._alignMarkers):
# Nothing is set, so make the fit do nothing by setting the fit parameters to
# their identity values.
self._init_fit_parameters()

self._applyAlignment(modelCoordinates)

self._fitter.calculateDataProjections(self)
if modelFileNameStem:
self._fitter.writeModel(modelFileNameStem + "_align.exf")
self.setHasRun(True)

def _applyAlignment(self, model_coordinates):
fieldmodule = self._fitter.getFieldmodule()
with ChangeManager(fieldmodule):
# rotate, scale and translate model
modelCoordinatesTransformed = createFieldsTransformations(
modelCoordinates, self._rotation, self._scale, self._translation)[0]
fieldassignment = modelCoordinates.createFieldassignment(modelCoordinatesTransformed)
model_coordinates_transformed = createFieldsTransformations(
model_coordinates, self._rotation, self._scale, self._translation)[0]
fieldassignment = model_coordinates.createFieldassignment(model_coordinates_transformed)
result = fieldassignment.assign()
assert result in [RESULT_OK, RESULT_WARNING_PART_DONE], "Align: Failed to transform model"
self._fitter.updateModelReferenceCoordinates()
del fieldassignment
del modelCoordinatesTransformed
self._fitter.calculateDataProjections(self)
if modelFileNameStem:
self._fitter.writeModel(modelFileNameStem + "_align.exf")
self.setHasRun(True)
del model_coordinates_transformed

def _doAutoAlign(self):
"""
Expand All @@ -235,7 +345,7 @@ def _doAutoAlign(self):
meshGroup = self._fitter.getGroupDataProjectionMeshGroup(group)
if not meshGroup:
continue
groupName = group.getName()
groupName = f"{group.getName()}_group"
# use centre of bounding box as middle of data; previous use of mean was affected by uneven density
minDataCoordinates, maxDataCoordinates = evaluate_field_nodeset_range(dataCoordinates, dataGroup)
middleDataCoordinates = mult(add(minDataCoordinates, maxDataCoordinates), 0.5)
Expand All @@ -246,46 +356,16 @@ def _doAutoAlign(self):
del one

if self._alignMarkers:
markerGroup = self._fitter.getMarkerGroup()
assert markerGroup, "Align: No marker group to align with"

markerNodeGroup, markerLocation, markerCoordinates, markerName = self._fitter.getMarkerModelFields()
assert markerNodeGroup and markerCoordinates and markerName, \
"Align: No marker group, coordinates or name fields"
modelMarkers = getNodeNameCentres(markerNodeGroup, markerCoordinates, markerName)

markerDataGroup, markerDataCoordinates, markerDataName = self._fitter.getMarkerDataFields()
assert markerDataGroup and markerDataCoordinates and markerDataName, \
"Align: No marker data group, coordinates or name fields"
dataMarkers = getNodeNameCentres(markerDataGroup, markerDataCoordinates, markerDataName)

# match model and data markers, warn of unmatched markers
writeDiagnostics = self.getDiagnosticLevel() > 0
for modelName in modelMarkers:
# name match allows case and whitespace differences
matchName = modelName.strip().casefold()
for dataName in dataMarkers:
if dataName.strip().casefold() == matchName:
pointMap[modelName] = (modelMarkers[modelName], dataMarkers[dataName])
if writeDiagnostics:
print("Align: Model marker '" + modelName + "' found in data" +
(" as '" + dataName + "'" if (dataName != modelName) else ""))
dataMarkers.pop(dataName)
break
else:
if writeDiagnostics:
print("Align: Model marker '" + modelName + "' not found in data")
if writeDiagnostics:
for dataName in dataMarkers:
print("Align: Data marker '" + dataName + "' not found in model")
matches = self._match_markers()
pointMap.update(matches)

self._optimiseAlignment(pointMap)

def getTransformationMatrix(self):
'''
"""
:return: 4x4 row-major transformation matrix with first index down rows, second across columns,
suitable for multiplication p' = Mp where p = [ x, y, z, h ].
'''
"""
# apply transformation in order: scale then rotation then translation
if not all((v == 0.0) for v in self._rotation):
rotationMatrix = euler_to_rotation_matrix(self._rotation)
Expand Down Expand Up @@ -313,13 +393,13 @@ def _optimiseAlignment(self, pointMap):
region = self._fitter.getContext().createRegion()
fieldmodule = region.getFieldmodule()

# get centre of mass CM and span of model coordinates and data
# Get centre of mass CM and span of model coordinates and data.
modelsum = [0.0, 0.0, 0.0]
datasum = [0.0, 0.0, 0.0]
modelMin = copy.deepcopy(list(pointMap.values())[0][0])
modelMax = copy.deepcopy(list(pointMap.values())[0][0])
dataMin = copy.deepcopy(list(pointMap.values())[0][1])
dataMax = copy.deepcopy(list(pointMap.values())[0][1])
modelMin = [math.inf] * 3 # copy.deepcopy(list(pointMap.values())[0][0])
modelMax = [-math.inf] * 3 # copy.deepcopy(list(pointMap.values())[0][0])
dataMin = [math.inf] * 3 # copy.deepcopy(list(pointMap.values())[0][1])
dataMax = [-math.inf] * 3 # copy.deepcopy(list(pointMap.values())[0][1])
for name, positions in pointMap.items():
modelx = positions[0]
datax = positions[1]
Expand Down
Loading