diff --git a/mbuild/compound.py b/mbuild/compound.py index 4e8f2ee81..02eab970b 100644 --- a/mbuild/compound.py +++ b/mbuild/compound.py @@ -1013,14 +1013,14 @@ def add( "outside of the defined simulation box" ) - def remove(self, objs_to_remove, reset_labels=True): + def remove(self, objs_to_remove, reset_labels=False): """Remove children from the Compound cleanly. Parameters ---------- objs_to_remove : mb.Compound or list of mb.Compound The Compound(s) to be removed from self - reset_labels : bool + reset_labels : bool, optional, default=False If True, the Compound labels will be reset """ # Preprocessing and validating input type @@ -1087,51 +1087,55 @@ def _check_if_empty(child): # Reorder labels if reset_labels: - new_labels = OrderedDict() - hoisted_children = { - key: val - for key, val in self.labels.items() - if ( - not isinstance(val, list) - and val.parent is not None - and id(self) != id(val.parent) - ) - } - new_labels.update(hoisted_children) - children_list = { - id(val): [key, val] - for key, val in self.labels.items() - if (not isinstance(val, list)) - } - for child in self.children: - label = ( - children_list[id(child)][0] - if "[" not in children_list[id(child)][0] - else None - ) - if label is None: - if "Port" in child.name: - label = [ - key - for key, x in self.labels.items() - if id(x) == id(child) - ][0] - if "port" in label: - label = "port[$]" - else: - label = f"{child.name}[$]" - - if label.endswith("[$]"): - label = label[:-3] - if label not in new_labels: - new_labels[label] = [] - label_pattern = label + "[{}]" - - count = len(new_labels[label]) - new_labels[label].append(child) - label = label_pattern.format(count) - new_labels[label] = child - self.labels = new_labels + self.reset_labels() + + def reset_labels(self): + """Reset Compound labels so that substituents and ports are renumbered, indexed from port[0] to port[N], where N-1 is the number of ports.""" + new_labels = OrderedDict() + hoisted_children = { + key: val + for key, val in self.labels.items() + if ( + not isinstance(val, list) + and val.parent is not None + and id(self) != id(val.parent) + ) + } + new_labels.update(hoisted_children) + children_list = { + id(val): [key, val] + for key, val in self.labels.items() + if (not isinstance(val, list)) + } + for child in self.children: + label = ( + children_list[id(child)][0] + if "[" not in children_list[id(child)][0] + else None + ) + if label is None: + if "Port" in child.name: + label = [ + key + for key, x in self.labels.items() + if id(x) == id(child) + ][0] + if "port" in label: + label = "port[$]" + else: + label = "{0}[$]".format(child.name) + + if label.endswith("[$]"): + label = label[:-3] + if label not in new_labels: + new_labels[label] = [] + label_pattern = label + "[{}]" + + count = len(new_labels[label]) + new_labels[label].append(child) + label = label_pattern.format(count) + new_labels[label] = child + self.labels = new_labels def _prune_ghost_ports(self): """Worker for remove(). Remove all ports whose anchor has been deleted.""" diff --git a/mbuild/coordinate_transform.py b/mbuild/coordinate_transform.py index c3a430fe7..533f45859 100644 --- a/mbuild/coordinate_transform.py +++ b/mbuild/coordinate_transform.py @@ -16,7 +16,9 @@ ] -def force_overlap(move_this, from_positions, to_positions, add_bond=True): +def force_overlap( + move_this, from_positions, to_positions, add_bond=True, reset_labels=False +): """Move a Compound such that a position overlaps with another. Computes an affine transformation that maps the from_positions to the @@ -33,6 +35,8 @@ def force_overlap(move_this, from_positions, to_positions, add_bond=True): add_bond : bool, optional, default=True If `from_positions` and `to_positions` are `Ports`, create a bond between the two anchor atoms. + reset_labels : bool, optional, default=False + If True, the Compound labels will be reset, renumbered using the Compound.reset_labels methods """ from mbuild.port import Port @@ -67,8 +71,12 @@ def force_overlap(move_this, from_positions, to_positions, add_bond=True): to_positions.anchor.parent.add_bond( (from_positions.anchor, to_positions.anchor) ) - from_positions.anchor.parent.remove(from_positions) - to_positions.anchor.parent.remove(to_positions) + from_positions.anchor.parent.remove( + from_positions, reset_labels=reset_labels + ) + to_positions.anchor.parent.remove( + to_positions, reset_labels=reset_labels + ) class CoordinateTransform(object): diff --git a/mbuild/tests/test_compound.py b/mbuild/tests/test_compound.py index e94f74ebe..6d162c605 100644 --- a/mbuild/tests/test_compound.py +++ b/mbuild/tests/test_compound.py @@ -915,9 +915,10 @@ def test_remove(self, ethane): ethane6 = mb.clone(ethane) ethane6.flatten() hydrogens = ethane6.particles_by_name("H") - ethane6.remove(hydrogens, reset_labels=False) - + ethane6.remove(hydrogens) assert list(ethane6.labels.keys()) == [ + "methyl1", + "methyl2", "C", "C[0]", "H", @@ -934,7 +935,7 @@ def test_remove(self, ethane): ethane7 = mb.clone(ethane) ethane7.flatten() hydrogens = ethane7.particles_by_name("H") - ethane7.remove(hydrogens) + ethane7.remove(hydrogens, reset_labels=True) assert list(ethane7.labels.keys()) == [ "C",