Skip to content

Commit

Permalink
Merge pull request #37 from hsorby/main
Browse files Browse the repository at this point in the history
Make manual alignment distinct from auto alignment
  • Loading branch information
hsorby authored Dec 12, 2024
2 parents 86f6b05 + 0e7073b commit 1451211
Show file tree
Hide file tree
Showing 5 changed files with 578 additions and 52 deletions.
28 changes: 28 additions & 0 deletions .github/workflows/run_tests.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
name: Test library

on:
pull_request:

jobs:
build-and-test:
runs-on: ubuntu-22.04
name: Run tests
steps:
- name: Clone source
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: 3.11
- name: Install dependencies
run: |
sudo apt install libopengl0 libglu1-mesa -y
- name: Install library
shell: bash
run: |
pip install https://github.com/cmlibs/zinc/releases/download/v4.2.0/cmlibs.zinc-4.2.0-cp311-cp311-linux_x86_64.whl
pip install .
- name: Run tests
shell: bash
run: |
python -m unittest discover -s tests/
7 changes: 5 additions & 2 deletions src/scaffoldfitter/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,12 +283,13 @@ def load(self):
"""
self._clearFields()
self._region = self._context.createRegion()
self._region.setName("model_region")
self._fieldmodule = self._region.getFieldmodule()
self._rawDataRegion = self._region.createChild("raw_data")
self._loadModel()
self._loadData()
self._defineDataProjectionFields()
# get centre and scale of data coordinates to manage fitting tolerances and steps
# Get centre and scale of data coordinates to manage fitting tolerances and steps.
datapoints = self._fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_DATAPOINTS)
minimums, maximums = evaluate_field_nodeset_range(self._dataCoordinatesField, datapoints)
self._dataCentre = mult(add(minimums, maximums), 0.5)
Expand Down Expand Up @@ -495,11 +496,12 @@ def _loadData(self):
self._discoverDataCoordinatesField()
self._discoverMarkerGroup()

def run(self, endStep=None, modelFileNameStem=None, reorder = False):
def run(self, endStep=None, modelFileNameStem=None, reorder=False):
"""
Run either all remaining fitter steps or up to specified end step.
:param endStep: Last fitter step to run, or None to run all.
:param modelFileNameStem: File name stem for writing intermediate model files.
:param reorder: Reload if reordering.
:return: True if reloaded (so scene changed), False if not.
"""
if not endStep:
Expand Down Expand Up @@ -543,6 +545,7 @@ def _discoverDataCoordinatesField(self):
"""
self._dataCoordinatesField = None
field = None

if self._dataCoordinatesFieldName:
field = self._fieldmodule.findFieldByName(self._dataCoordinatesFieldName)
if not (field and field.isValid()):
Expand Down
178 changes: 129 additions & 49 deletions src/scaffoldfitter/fitterstepalign.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,14 @@ def __init__(self):
super(FitterStepAlign, self).__init__()
self._alignGroups = False
self._alignMarkers = False
self._alignManually = False
self._rotation = None
self._scale = None
self._scaleProportion = None
self._translation = None
self._init_fit_parameters()

def _init_fit_parameters(self):
self._rotation = [0.0, 0.0, 0.0]
self._scale = 1.0
self._scaleProportion = 1.0
Expand All @@ -78,6 +86,7 @@ def decodeSettingsJSONDict(self, dctIn: dict):
dct.update(dctIn)
self._alignGroups = dct["alignGroups"]
self._alignMarkers = dct["alignMarkers"]
self._alignManually = dct["alignManually"]
self._rotation = dct["rotation"]
self._scale = dct["scale"]
scaleProportion = dct.get("scaleProportion")
Expand All @@ -94,6 +103,7 @@ def encodeSettingsJSONDict(self) -> dict:
dct.update({
"alignGroups": self._alignGroups,
"alignMarkers": self._alignMarkers,
"alignManually": self._alignManually,
"rotation": self._rotation,
"scale": self._scale,
"scaleProportion": self._scaleProportion,
Expand Down Expand Up @@ -132,6 +142,97 @@ def setAlignMarkers(self, alignMarkers):
return True
return False

def isAlignManually(self):
return self._alignManually

def setAlignManually(self, alignManually):
if alignManually != self._alignManually:
self._alignManually = alignManually
return True
return False

def _alignable_group_count(self):
count = 0
fieldmodule = self._fitter.getFieldmodule()
groups = get_group_list(fieldmodule)
with ChangeManager(fieldmodule):
for group in groups:
dataGroup = self._fitter.getGroupDataProjectionNodesetGroup(group)
if not dataGroup:
continue
meshGroup = self._fitter.getGroupDataProjectionMeshGroup(group)
if not meshGroup:
continue
count += 1

return count

def _match_markers(self):
writeDiagnostics = self.getDiagnosticLevel() > 0
matches = {}

markerGroup = self._fitter.getMarkerGroup()
if markerGroup is None:
if writeDiagnostics:
print("Align: No marker group to align with.")
return matches

markerNodeGroup, markerLocation, markerCoordinates, markerName = self._fitter.getMarkerModelFields()
if markerNodeGroup is None or markerCoordinates is None or markerName is None:
if writeDiagnostics:
print("Align: No marker group, coordinates or name fields.")

return matches

markerDataGroup, markerDataCoordinates, markerDataName = self._fitter.getMarkerDataFields()
if markerDataGroup is None or markerDataCoordinates is None or markerDataName is None:
if writeDiagnostics:
print("Align: No marker data group, coordinates or name fields.")

return matches

modelMarkers = getNodeNameCentres(markerNodeGroup, markerCoordinates, markerName)
dataMarkers = getNodeNameCentres(markerDataGroup, markerDataCoordinates, markerDataName)

# match model and data markers, warn of unmatched markers
for modelName in modelMarkers:
# name match allows case and whitespace differences
matchName = modelName.strip().casefold()
for dataName in dataMarkers:
if dataName.strip().casefold() == matchName:
entry_name = f"{modelName}_marker"
matches[entry_name] = (modelMarkers[modelName], dataMarkers[dataName])
if writeDiagnostics:
print("Align: Model marker '" + modelName + "' found in data" +
(" as '" + dataName + "'" if (dataName != modelName) else ""))
dataMarkers.pop(dataName)
break
else:
if writeDiagnostics:
print("Align: Model marker '" + modelName + "' not found in data")
if writeDiagnostics:
for dataName in dataMarkers:
print("Align: Data marker '" + dataName + "' not found in model")

return matches

def matchingMarkerCount(self):
return len(self._match_markers())

def matchingGroupCount(self):
return self._alignable_group_count()

def canAutoAlign(self):
total = self.matchingGroupCount() + self.matchingMarkerCount()
return total > 2

def canAlignGroups(self):
return self._alignable_group_count() > 2

def canAlignMarkers(self):
matches = self._match_markers()
return len(matches) > 2

def getRotation(self):
return self._rotation

Expand Down Expand Up @@ -197,23 +298,32 @@ def run(self, modelFileNameStem=None):
"""
modelCoordinates = self._fitter.getModelCoordinatesField()
assert modelCoordinates, "Align: Missing model coordinates"
if self._alignGroups or self._alignMarkers:
if not self._alignManually and (self._alignGroups or self._alignMarkers):
self._doAutoAlign()
elif not self._alignManually and not (self._alignGroups or self._alignMarkers):
# Nothing is set, so make the fit do nothing by setting the fit parameters to
# their identity values.
self._init_fit_parameters()

self._applyAlignment(modelCoordinates)

self._fitter.calculateDataProjections(self)
if modelFileNameStem:
self._fitter.writeModel(modelFileNameStem + "_align.exf")
self.setHasRun(True)

def _applyAlignment(self, model_coordinates):
fieldmodule = self._fitter.getFieldmodule()
with ChangeManager(fieldmodule):
# rotate, scale and translate model
modelCoordinatesTransformed = createFieldsTransformations(
modelCoordinates, self._rotation, self._scale, self._translation)[0]
fieldassignment = modelCoordinates.createFieldassignment(modelCoordinatesTransformed)
model_coordinates_transformed = createFieldsTransformations(
model_coordinates, self._rotation, self._scale, self._translation)[0]
fieldassignment = model_coordinates.createFieldassignment(model_coordinates_transformed)
result = fieldassignment.assign()
assert result in [RESULT_OK, RESULT_WARNING_PART_DONE], "Align: Failed to transform model"
self._fitter.updateModelReferenceCoordinates()
del fieldassignment
del modelCoordinatesTransformed
self._fitter.calculateDataProjections(self)
if modelFileNameStem:
self._fitter.writeModel(modelFileNameStem + "_align.exf")
self.setHasRun(True)
del model_coordinates_transformed

def _doAutoAlign(self):
"""
Expand All @@ -235,7 +345,7 @@ def _doAutoAlign(self):
meshGroup = self._fitter.getGroupDataProjectionMeshGroup(group)
if not meshGroup:
continue
groupName = group.getName()
groupName = f"{group.getName()}_group"
# use centre of bounding box as middle of data; previous use of mean was affected by uneven density
minDataCoordinates, maxDataCoordinates = evaluate_field_nodeset_range(dataCoordinates, dataGroup)
middleDataCoordinates = mult(add(minDataCoordinates, maxDataCoordinates), 0.5)
Expand All @@ -246,46 +356,16 @@ def _doAutoAlign(self):
del one

if self._alignMarkers:
markerGroup = self._fitter.getMarkerGroup()
assert markerGroup, "Align: No marker group to align with"

markerNodeGroup, markerLocation, markerCoordinates, markerName = self._fitter.getMarkerModelFields()
assert markerNodeGroup and markerCoordinates and markerName, \
"Align: No marker group, coordinates or name fields"
modelMarkers = getNodeNameCentres(markerNodeGroup, markerCoordinates, markerName)

markerDataGroup, markerDataCoordinates, markerDataName = self._fitter.getMarkerDataFields()
assert markerDataGroup and markerDataCoordinates and markerDataName, \
"Align: No marker data group, coordinates or name fields"
dataMarkers = getNodeNameCentres(markerDataGroup, markerDataCoordinates, markerDataName)

# match model and data markers, warn of unmatched markers
writeDiagnostics = self.getDiagnosticLevel() > 0
for modelName in modelMarkers:
# name match allows case and whitespace differences
matchName = modelName.strip().casefold()
for dataName in dataMarkers:
if dataName.strip().casefold() == matchName:
pointMap[modelName] = (modelMarkers[modelName], dataMarkers[dataName])
if writeDiagnostics:
print("Align: Model marker '" + modelName + "' found in data" +
(" as '" + dataName + "'" if (dataName != modelName) else ""))
dataMarkers.pop(dataName)
break
else:
if writeDiagnostics:
print("Align: Model marker '" + modelName + "' not found in data")
if writeDiagnostics:
for dataName in dataMarkers:
print("Align: Data marker '" + dataName + "' not found in model")
matches = self._match_markers()
pointMap.update(matches)

self._optimiseAlignment(pointMap)

def getTransformationMatrix(self):
'''
"""
:return: 4x4 row-major transformation matrix with first index down rows, second across columns,
suitable for multiplication p' = Mp where p = [ x, y, z, h ].
'''
"""
# apply transformation in order: scale then rotation then translation
if not all((v == 0.0) for v in self._rotation):
rotationMatrix = euler_to_rotation_matrix(self._rotation)
Expand Down Expand Up @@ -313,13 +393,13 @@ def _optimiseAlignment(self, pointMap):
region = self._fitter.getContext().createRegion()
fieldmodule = region.getFieldmodule()

# get centre of mass CM and span of model coordinates and data
# Get centre of mass CM and span of model coordinates and data.
modelsum = [0.0, 0.0, 0.0]
datasum = [0.0, 0.0, 0.0]
modelMin = copy.deepcopy(list(pointMap.values())[0][0])
modelMax = copy.deepcopy(list(pointMap.values())[0][0])
dataMin = copy.deepcopy(list(pointMap.values())[0][1])
dataMax = copy.deepcopy(list(pointMap.values())[0][1])
modelMin = [math.inf] * 3 # copy.deepcopy(list(pointMap.values())[0][0])
modelMax = [-math.inf] * 3 # copy.deepcopy(list(pointMap.values())[0][0])
dataMin = [math.inf] * 3 # copy.deepcopy(list(pointMap.values())[0][1])
dataMax = [-math.inf] * 3 # copy.deepcopy(list(pointMap.values())[0][1])
for name, positions in pointMap.items():
modelx = positions[0]
datax = positions[1]
Expand Down
Loading

0 comments on commit 1451211

Please sign in to comment.