Skip to content

Commit

Permalink
Add reset_labels kwarg to force_overlap, set default to False, extrac…
Browse files Browse the repository at this point in the history
…t label reset to a new method
  • Loading branch information
jaclark5 committed Aug 7, 2024
1 parent cd53189 commit 90aae2f
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 48 deletions.
94 changes: 49 additions & 45 deletions mbuild/compound.py
Original file line number Diff line number Diff line change
Expand Up @@ -1089,51 +1089,55 @@ def _check_if_empty(child):

# Reorder labels
if reset_labels:
new_labels = OrderedDict()
hoisted_children = {
key: val
for key, val in self.labels.items()
if (
not isinstance(val, list)
and val.parent is not None
and id(self) != id(val.parent)
)
}
new_labels.update(hoisted_children)
children_list = {
id(val): [key, val]
for key, val in self.labels.items()
if (not isinstance(val, list))
}
for child in self.children:
label = (
children_list[id(child)][0]
if "[" not in children_list[id(child)][0]
else None
)
if label is None:
if "Port" in child.name:
label = [
key
for key, x in self.labels.items()
if id(x) == id(child)
][0]
if "port" in label:
label = "port[$]"
else:
label = "{0}[$]".format(child.name)

if label.endswith("[$]"):
label = label[:-3]
if label not in new_labels:
new_labels[label] = []
label_pattern = label + "[{}]"

count = len(new_labels[label])
new_labels[label].append(child)
label = label_pattern.format(count)
new_labels[label] = child
self.labels = new_labels
self.reset_labels()

def reset_labels(self):
"""Reset Compound labels so that substituents and ports are renumbered."""
new_labels = OrderedDict()
hoisted_children = {
key: val
for key, val in self.labels.items()
if (
not isinstance(val, list)
and val.parent is not None
and id(self) != id(val.parent)
)
}
new_labels.update(hoisted_children)
children_list = {
id(val): [key, val]
for key, val in self.labels.items()
if (not isinstance(val, list))
}
for child in self.children:
label = (
children_list[id(child)][0]
if "[" not in children_list[id(child)][0]
else None
)
if label is None:
if "Port" in child.name:
label = [
key
for key, x in self.labels.items()
if id(x) == id(child)
][0]
if "port" in label:
label = "port[$]"
else:
label = "{0}[$]".format(child.name)

if label.endswith("[$]"):
label = label[:-3]
if label not in new_labels:
new_labels[label] = []
label_pattern = label + "[{}]"

count = len(new_labels[label])
new_labels[label].append(child)
label = label_pattern.format(count)
new_labels[label] = child
self.labels = new_labels

def _prune_ghost_ports(self):
"""Worker for remove(). Remove all ports whose anchor has been deleted."""
Expand Down
14 changes: 11 additions & 3 deletions mbuild/coordinate_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
]


def force_overlap(move_this, from_positions, to_positions, add_bond=True):
def force_overlap(
move_this, from_positions, to_positions, add_bond=True, reset_labels=False
):
"""Move a Compound such that a position overlaps with another.
Computes an affine transformation that maps the from_positions to the
Expand All @@ -33,6 +35,8 @@ def force_overlap(move_this, from_positions, to_positions, add_bond=True):
add_bond : bool, optional, default=True
If `from_positions` and `to_positions` are `Ports`, create a bond
between the two anchor atoms.
reset_labels : bool
If True, the Compound labels will be reset, renumbered
"""
from mbuild.port import Port

Expand Down Expand Up @@ -67,8 +71,12 @@ def force_overlap(move_this, from_positions, to_positions, add_bond=True):
to_positions.anchor.parent.add_bond(
(from_positions.anchor, to_positions.anchor)
)
from_positions.anchor.parent.remove(from_positions)
to_positions.anchor.parent.remove(to_positions)
from_positions.anchor.parent.remove(
from_positions, reset_labels=reset_labels
)
to_positions.anchor.parent.remove(
to_positions, reset_labels=reset_labels
)


class CoordinateTransform(object):
Expand Down

0 comments on commit 90aae2f

Please sign in to comment.