From 90aae2fd5a740c0e14f51cb0b6c188af23bcedff Mon Sep 17 00:00:00 2001 From: jac16 Date: Wed, 7 Aug 2024 07:58:32 -0400 Subject: [PATCH] Add reset_labels kwarg to force_overlap, set default to False, extract label reset to a new method --- mbuild/compound.py | 94 ++++++++++++++++++---------------- mbuild/coordinate_transform.py | 14 +++-- 2 files changed, 60 insertions(+), 48 deletions(-) diff --git a/mbuild/compound.py b/mbuild/compound.py index e30d15d1f..915d284b1 100644 --- a/mbuild/compound.py +++ b/mbuild/compound.py @@ -1089,51 +1089,55 @@ def _check_if_empty(child): # Reorder labels if reset_labels: - new_labels = OrderedDict() - hoisted_children = { - key: val - for key, val in self.labels.items() - if ( - not isinstance(val, list) - and val.parent is not None - and id(self) != id(val.parent) - ) - } - new_labels.update(hoisted_children) - children_list = { - id(val): [key, val] - for key, val in self.labels.items() - if (not isinstance(val, list)) - } - for child in self.children: - label = ( - children_list[id(child)][0] - if "[" not in children_list[id(child)][0] - else None - ) - if label is None: - if "Port" in child.name: - label = [ - key - for key, x in self.labels.items() - if id(x) == id(child) - ][0] - if "port" in label: - label = "port[$]" - else: - label = "{0}[$]".format(child.name) - - if label.endswith("[$]"): - label = label[:-3] - if label not in new_labels: - new_labels[label] = [] - label_pattern = label + "[{}]" - - count = len(new_labels[label]) - new_labels[label].append(child) - label = label_pattern.format(count) - new_labels[label] = child - self.labels = new_labels + self.reset_labels() + + def reset_labels(self): + """Reset Compound labels so that substituents and ports are renumbered.""" + new_labels = OrderedDict() + hoisted_children = { + key: val + for key, val in self.labels.items() + if ( + not isinstance(val, list) + and val.parent is not None + and id(self) != id(val.parent) + ) + } + new_labels.update(hoisted_children) + children_list = { + id(val): [key, val] + for key, val in self.labels.items() + if (not isinstance(val, list)) + } + for child in self.children: + label = ( + children_list[id(child)][0] + if "[" not in children_list[id(child)][0] + else None + ) + if label is None: + if "Port" in child.name: + label = [ + key + for key, x in self.labels.items() + if id(x) == id(child) + ][0] + if "port" in label: + label = "port[$]" + else: + label = "{0}[$]".format(child.name) + + if label.endswith("[$]"): + label = label[:-3] + if label not in new_labels: + new_labels[label] = [] + label_pattern = label + "[{}]" + + count = len(new_labels[label]) + new_labels[label].append(child) + label = label_pattern.format(count) + new_labels[label] = child + self.labels = new_labels def _prune_ghost_ports(self): """Worker for remove(). Remove all ports whose anchor has been deleted.""" diff --git a/mbuild/coordinate_transform.py b/mbuild/coordinate_transform.py index 1edf2d609..839348c9f 100644 --- a/mbuild/coordinate_transform.py +++ b/mbuild/coordinate_transform.py @@ -16,7 +16,9 @@ ] -def force_overlap(move_this, from_positions, to_positions, add_bond=True): +def force_overlap( + move_this, from_positions, to_positions, add_bond=True, reset_labels=False +): """Move a Compound such that a position overlaps with another. Computes an affine transformation that maps the from_positions to the @@ -33,6 +35,8 @@ def force_overlap(move_this, from_positions, to_positions, add_bond=True): add_bond : bool, optional, default=True If `from_positions` and `to_positions` are `Ports`, create a bond between the two anchor atoms. + reset_labels : bool + If True, the Compound labels will be reset, renumbered """ from mbuild.port import Port @@ -67,8 +71,12 @@ def force_overlap(move_this, from_positions, to_positions, add_bond=True): to_positions.anchor.parent.add_bond( (from_positions.anchor, to_positions.anchor) ) - from_positions.anchor.parent.remove(from_positions) - to_positions.anchor.parent.remove(to_positions) + from_positions.anchor.parent.remove( + from_positions, reset_labels=reset_labels + ) + to_positions.anchor.parent.remove( + to_positions, reset_labels=reset_labels + ) class CoordinateTransform(object):