Skip to content

Commit

Permalink
fix label handling during flatten (#1208)
Browse files Browse the repository at this point in the history
* fix label handling during flatten

* Change the reset_labels method for compound.py to label the container lists with the format 'all-{name}s' for clarity

* Add Ruff to pre-commit hooks (#1207)

* update CI and precommit files

* add ruff changes

* remove gmso lines

* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci

* change error type in test

* raise the error that is created

* remove duplicate windows 3.12 test

* fix precommit errors

* fix import error

* fix CI error

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci

* Fix labeling of windows compounds

* fix references in monomers tests

---------

Co-authored-by: Chris Jones <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Dec 6, 2024
1 parent 0800b14 commit d33ba07
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 56 deletions.
41 changes: 30 additions & 11 deletions mbuild/compound.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,12 +699,13 @@ def add(

if label.endswith("[$]"):
label = label[:-3]
if label not in self.labels:
self.labels[label] = []
all_label = "all-" + label + "s"
if all_label not in self.labels:
self.labels[all_label] = []
label_pattern = label + "[{}]"

count = len(self.labels[label])
self.labels[label].append(new_child)
count = len(self.labels[all_label])
self.labels[all_label].append(new_child)
label = label_pattern.format(count)

if not replace and label in self.labels:
Expand Down Expand Up @@ -825,7 +826,21 @@ def _check_if_empty(child):
self.reset_labels()

def reset_labels(self):
"""Reset Compound labels so that substituents and ports are renumbered, indexed from port[0] to port[N], where N-1 is the number of ports."""
"""Reset Compound labels so that substituents and ports are renumbered, indexed from port[0] to port[N], where N-1 is the number of ports.
Notes
-----
Will renumber the labels in a given Compound. Duplicated labels are named in the format "{name}[$]", where the $ stands in for the 0-indexed
number in the Compound hierarchy with given "name".
i.e. self.labels.keys() = ["CH2", "CH2", "CH2"] would transform into self.labels.keys() = ["CH2[0]", "CH2[1]", "CH2[2]"]
and
i.e. self.labels.keys() = ["CH2[1]", "CH2[3]", "CH2[5]"] would transform into self.labels.keys() = ["CH2[0]", "CH2[1]", "CH2[2]"]
Additonally, if it doesn't exist, duplicated labels that are numbered as above with the "[$]" will also be put into a list index.
self.labels.keys() = ["CH2", "CH2", "CH2"] would transform into self.labels.keys() = ["CH2[0]", "CH2[1]", "CH2[2]"] as shown above, but also
have a label of self.labels["all-CH2s"], which is a list of all CH2 children in the Compound.
"""
new_labels = OrderedDict()
hoisted_children = {
key: val
Expand Down Expand Up @@ -856,16 +871,16 @@ def reset_labels(self):
if "port" in label:
label = "port[$]"
else:
label = "{0}[$]".format(child.name)

label = f"{child.name}[$]"
if label.endswith("[$]"):
label = label[:-3]
if label not in new_labels:
new_labels[label] = []
all_label = "all-" + label + "s"
if all_label not in new_labels:
new_labels[all_label] = []
label_pattern = label + "[{}]"

count = len(new_labels[label])
new_labels[label].append(child)
count = len(new_labels[all_label])
new_labels[all_label].append(child)
label = label_pattern.format(count)
new_labels[label] = child
self.labels = new_labels
Expand Down Expand Up @@ -1880,6 +1895,9 @@ def flatten(self, inplace=True):
for neighbor in nx.neighbors(bond_graph, particle):
new_bonds.append((particle, neighbor))

# Remove all labels which refer to children in the hierarchy
self.labels.clear()

# Remove all the children
if inplace:
for child in children_list:
Expand All @@ -1896,6 +1914,7 @@ def flatten(self, inplace=True):
comp = clone(self)
comp.flatten(inplace=True)
return comp
self.reset_labels()

def update_coordinates(self, filename, update_port_locations=True):
"""Update the coordinates of this Compound from a file.
Expand Down
92 changes: 48 additions & 44 deletions mbuild/tests/test_compound.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@ def test_add_by_list(self, h2o):
temp_comp.add(comp_list, label=label_list)
a = [k for k, v in temp_comp.labels.items()]
assert a == [
"water",
"all-waters",
"water[0]",
"water[1]",
"water[2]",
Expand Down Expand Up @@ -783,42 +783,14 @@ def test_remove(self, ethane):

# Test to reset labels after hydrogens
ethane6 = mb.clone(ethane)
ethane6.flatten()
hydrogens = ethane6.particles_by_name("H")
ethane6.remove(hydrogens)
ethane6.remove(hydrogens, reset_labels=True)
assert list(ethane6.labels.keys()) == [
"methyl1",
"methyl2",
"C",
"C[0]",
"H",
"C[1]",
"port",
"port[1]",
"port[3]",
"port[5]",
"port[7]",
"port[9]",
"port[11]",
]

ethane7 = mb.clone(ethane)
ethane7.flatten()
hydrogens = ethane7.particles_by_name("H")
ethane7.remove(hydrogens, reset_labels=True)

assert list(ethane7.labels.keys()) == [
"C",
"C[0]",
"C[1]",
"port",
"port[0]",
"port[1]",
"port[2]",
"port[3]",
"port[4]",
"port[5]",
]
assert ethane6.available_ports() == []
assert len(ethane6.all_ports()) == 6

def test_remove_many(self, ethane):
ethane.remove([ethane.children[0], ethane.children[1]])
Expand Down Expand Up @@ -1041,6 +1013,31 @@ def test_flatten_box_of_eth(self, ethane):
box_of_eth.flatten()
assert len(box_of_eth.children) == box_of_eth.n_particles == 8 * 2
assert box_of_eth.n_bonds == 7 * 2
assert list(box_of_eth.labels.keys()) == [
"all-Cs",
"C[0]",
"all-Hs",
"H[0]",
"H[1]",
"H[2]",
"C[1]",
"H[3]",
"H[4]",
"H[5]",
"C[2]",
"H[6]",
"H[7]",
"H[8]",
"C[3]",
"H[9]",
"H[10]",
"H[11]",
]

def test_flatten_then_fill_box(self, benzene):
benzene.flatten(inplace=True)
benzene_box = mb.packing.fill_box(compound=benzene, n_compounds=2, density=0.3)
assert next(iter(benzene_box.particles())).root.bond_graph

def test_flatten_with_port(self, ethane):
ethane.remove(ethane[2])
Expand Down Expand Up @@ -1726,7 +1723,7 @@ def test_energy_minimize_shift_com(self, octane):
"win" in sys.platform, reason="Unknown issue with Window's Open Babel "
)
def test_energy_minimize_shift_anchor(self, octane):
anchor_compound = octane.labels["chain"].labels["CH3"][0]
anchor_compound = octane.labels["chain"].labels["CH3[0]"]
pos_old = anchor_compound.pos
octane.energy_minimize(anchor=anchor_compound)
# check to see if COM of the anchor Compound
Expand All @@ -1738,9 +1735,9 @@ def test_energy_minimize_shift_anchor(self, octane):
"win" in sys.platform, reason="Unknown issue with Window's Open Babel "
)
def test_energy_minimize_fix_compounds(self, octane):
methyl_end0 = octane.labels["chain"].labels["CH3"][0]
methyl_end1 = octane.labels["chain"].labels["CH3"][1]
carbon_end = octane.labels["chain"].labels["CH3"][0].labels["C"][0]
methyl_end0 = octane.labels["chain"].labels["CH3[0]"]
methyl_end1 = octane.labels["chain"].labels["CH3[0]"]
carbon_end = octane.labels["chain"].labels["CH3[0]"].labels["C[0]"]
not_in_compound = mb.Compound(name="H")

# fix the whole molecule and make sure positions are close
Expand Down Expand Up @@ -1827,9 +1824,9 @@ def test_energy_minimize_fix_compounds(self, octane):
"win" in sys.platform, reason="Unknown issue with Window's Open Babel "
)
def test_energy_minimize_ignore_compounds(self, octane):
methyl_end0 = octane.labels["chain"].labels["CH3"][0]
methyl_end1 = octane.labels["chain"].labels["CH3"][1]
carbon_end = octane.labels["chain"].labels["CH3"][0].labels["C"][0]
methyl_end0 = octane.labels["chain"].labels["CH3[0]"]
methyl_end1 = octane.labels["chain"].labels["CH3[1]"]
carbon_end = octane.labels["chain"].labels["CH3[0]"].labels["C[0]"]
not_in_compound = mb.Compound(name="H")

# fix the whole molecule and make sure positions are close
Expand Down Expand Up @@ -1859,12 +1856,12 @@ def test_energy_minimize_ignore_compounds(self, octane):
"win" in sys.platform, reason="Unknown issue with Window's Open Babel "
)
def test_energy_minimize_distance_constraints(self, octane):
methyl_end0 = octane.labels["chain"].labels["CH3"][0]
methyl_end1 = octane.labels["chain"].labels["CH3"][1]
methyl_end0 = octane.labels["chain"].labels["CH3[0]"]
methyl_end1 = octane.labels["chain"].labels["CH3[1]"]

carbon_end0 = octane.labels["chain"].labels["CH3"][0].labels["C"][0]
carbon_end1 = octane.labels["chain"].labels["CH3"][1].labels["C"][0]
h_end0 = octane.labels["chain"].labels["CH3"][0].labels["H"][0]
carbon_end0 = octane.labels["chain"].labels["CH3[0]"].labels["C[0]"]
carbon_end1 = octane.labels["chain"].labels["CH3[1]"].labels["C[0]"]
h_end0 = octane.labels["chain"].labels["CH3[0]"].labels["H[0]"]

not_in_compound = mb.Compound(name="H")

Expand Down Expand Up @@ -2539,3 +2536,10 @@ def test_catalog_bondgraph_types(self, benzene):
catalog_bondgraph_type(compound.children[1][0], compound.bond_graph)
== "particle_graph"
)

def test_reset_labels(self):
ethane = mb.load("CC", smiles=True)
Hs = ethane.particles_by_name("H")
ethane.remove(Hs, reset_labels=True)
ports = set(f"port[{i}]" for i in range(6))
assert ports.issubset(set(ethane.labels.keys()))
2 changes: 1 addition & 1 deletion mbuild/tests/test_json_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def test_label_consistency(self):
parent.add(CH3())
compound_to_json(parent, "parent.json", include_ports=True)
parent_copy = compound_from_json("parent.json")
assert len(parent_copy["CH2"]) == len(parent["CH2"])
assert len(parent_copy["all-CH2s"]) == len(parent["all-CH2s"])
assert parent_copy.labels.keys() == parent.labels.keys()
for child, child_copy in zip(parent.successors(), parent_copy.successors()):
assert child.labels.keys() == child_copy.labels.keys()
Expand Down

0 comments on commit d33ba07

Please sign in to comment.