diff --git a/src/scaffoldmaker/utils/eft_utils.py b/src/scaffoldmaker/utils/eft_utils.py index 218ecd7a..96f79daf 100644 --- a/src/scaffoldmaker/utils/eft_utils.py +++ b/src/scaffoldmaker/utils/eft_utils.py @@ -5,6 +5,7 @@ from cmlibs.zinc.element import Elementbasis, Elementfieldtemplate from cmlibs.zinc.node import Node from cmlibs.zinc.result import RESULT_OK +from scaffoldmaker.utils.interpolation import interpolateHermiteLagrangeDerivative, interpolateLagrangeHermiteDerivative import copy import math @@ -530,6 +531,12 @@ def _determinePermutations(self): permutations.append((self._directions[kk], self._directions[ii], self._directions[jj])) return permutations + def getComplexity(self): + """ + :return: Integer complexity equal to number of axis permutations in node layout. + """ + return len(self._permutations) + def getDirections(self): return self._directions @@ -720,31 +727,63 @@ def determineCubicHermiteSerendipityEft(mesh, nodeParameters, nodeLayouts): derivativeLabels = [Node.VALUE_LABEL_D_DS1, Node.VALUE_LABEL_D_DS2, Node.VALUE_LABEL_D_DS3] derivativesPerNode = 3 if d3Defined else 2 functionsPerNode = 1 + derivativesPerNode + # order local nodes from default then simplest to most complex node layout + nodeOrder = [] for n in range(nodesCount): + if not nodeLayouts[n]: + nodeOrder.append(n) + while len(nodeOrder) < nodesCount: + lowestComplexity = 0 + next_n = None + for n in range(nodesCount): + if n in nodeOrder: + continue + complexity = nodeLayouts[n].getComplexity() + if (next_n is None) or (complexity < lowestComplexity): + lowestComplexity = complexity + next_n = n + nodeOrder.append(next_n) + for n in nodeOrder: ln = n + 1 nodeLayout = nodeLayouts[n] - if not nodeLayout: - # default node layout is regular, so nothing to do if no nodeLayout - continue nodeDerivatives = [ nodeParameters[n][1], nodeParameters[n][2], nodeParameters[n][3] if d3Defined else None] - derivativeWeightsList = nodeLayout.getDerivativeWeightsList(deltas[n], nodeDerivatives, n) + derivativeWeightsList =\ + nodeLayout.getDerivativeWeightsList(deltas[n], nodeDerivatives, n) if nodeLayout else None for ed in range(derivativesPerNode): - derivativeWeights = derivativeWeightsList[ed] - functionNumber = n * functionsPerNode + ed + 2 - termsCount = sum(1 for wt in derivativeWeights if wt != 0.0) - eft.setFunctionNumberOfTerms(functionNumber, termsCount) - term = 0 - for i in range(derivativesPerNode): - weight = derivativeWeights[i] - if weight: - term += 1 - eft.setTermNodeParameter(functionNumber, term, ln, derivativeLabels[i], 1) - if weight < 0.0: - if not scalefactors: - setEftScaleFactorIds(eft, [1], []) - scalefactors = [-1.0] - eft.setTermScaling(functionNumber, term, [1]) + if nodeLayout: + derivativeWeights = derivativeWeightsList[ed] + functionNumber = n * functionsPerNode + ed + 2 + termsCount = sum(1 for wt in derivativeWeights if wt != 0.0) + eft.setFunctionNumberOfTerms(functionNumber, termsCount) + term = 0 + elementDerivative = [0.0, 0.0, 0.0] + for i in range(derivativesPerNode): + weight = derivativeWeights[i] + if weight: + term += 1 + eft.setTermNodeParameter(functionNumber, term, ln, derivativeLabels[i], 1) + if weight < 0.0: + if not scalefactors: + setEftScaleFactorIds(eft, [1], []) + scalefactors = [-1.0] + eft.setTermScaling(functionNumber, term, [1]) + for c in range(3): + elementDerivative[c] += weight * nodeDerivatives[i][c] + else: + elementDerivative = nodeDerivatives[ed] + # update delta to equal the exact derivative + deltas[n][ed] = elementDerivative + # get other node index interpolated with this element derivative + on = n ^ (1 << ed) + if nodeOrder.index(on) > nodeOrder.index(n): + # update other node delta to work in with this derivative + otherElementDerivative = ( + interpolateHermiteLagrangeDerivative(nodeParameters[n][0], elementDerivative, nodeParameters[on][0], 1.0) + if (on > n) else + interpolateLagrangeHermiteDerivative(nodeParameters[on][0], nodeParameters[n][0], elementDerivative, 0.0)) + deltas[on][ed] = otherElementDerivative + return eft, scalefactors