From 37fdf5729d98f8b9e495c83fff625b7130924c9e Mon Sep 17 00:00:00 2001 From: CalCraven Date: Mon, 11 Sep 2023 13:11:35 +0100 Subject: [PATCH] Fixes to general sorting algorithms for hoomd and lammps --- gmso/core/views.py | 32 +---- gmso/external/convert_hoomd.py | 65 ++++++----- gmso/formats/lammpsdata.py | 11 +- gmso/tests/files/alkanes.xml | 122 +++++++++++++++++++ gmso/tests/files/alkanes_wildcards.xml | 112 ++++++++++++++++++ gmso/tests/test_hoomd.py | 155 +++++++++++++++++++++--- gmso/utils/sorting.py | 156 +++++++++++++++++++------ 7 files changed, 538 insertions(+), 115 deletions(-) create mode 100644 gmso/tests/files/alkanes.xml create mode 100644 gmso/tests/files/alkanes_wildcards.xml diff --git a/gmso/core/views.py b/gmso/core/views.py index 0ed3be1e4..2f4970148 100644 --- a/gmso/core/views.py +++ b/gmso/core/views.py @@ -11,6 +11,7 @@ from gmso.core.dihedral_type import DihedralType from gmso.core.improper import Improper from gmso.core.improper_type import ImproperType +from gmso.utils.sorting import sort_by_types __all__ = ["TopologyPotentialView", "PotentialFilters"] @@ -37,35 +38,6 @@ def get_name_or_class(potential): return potential.member_types or potential.member_classes -def get_sorted_names(potential): - """Get identifier for a topology potential based on name or membertype/class.""" - if isinstance(potential, AtomType): - return potential.name - elif isinstance(potential, BondType): - return tuple(sorted(potential.member_types)) - elif isinstance(potential, AngleType): - if potential.member_types[0] > potential.member_types[2]: - return tuple(reversed(potential.member_types)) - else: - return potential.member_types - elif isinstance(potential, DihedralType): - if potential.member_types[1] > potential.member_types[2] or ( - potential.member_types[1] == potential.member_types[2] - and potential.member_types[0] > potential.member_types[3] - ): - return tuple(reversed(potential.member_types)) - else: - return potential.member_types - elif isinstance(potential, ImproperType): - return ( - potential.member_types[0], - *potential.member_types[1:], - ) # could sort using `sorted` - return ValueError( - f"Potential {potential} not one of {potential_attribute_map.values()}" - ) - - def get_parameters(potential): """Return hashable version of parameters for a potential.""" return ( @@ -105,7 +77,7 @@ def all(): potential_identifiers = { PotentialFilters.UNIQUE_NAME_CLASS: get_name_or_class, - PotentialFilters.UNIQUE_SORTED_NAMES: get_sorted_names, + PotentialFilters.UNIQUE_SORTED_NAMES: sort_by_types, PotentialFilters.UNIQUE_EXPRESSION: lambda p: str(p.expression), PotentialFilters.UNIQUE_PARAMETERS: get_parameters, PotentialFilters.UNIQUE_ID: lambda p: id(p), diff --git a/gmso/external/convert_hoomd.py b/gmso/external/convert_hoomd.py index 637b00d2e..6c7825314 100644 --- a/gmso/external/convert_hoomd.py +++ b/gmso/external/convert_hoomd.py @@ -20,11 +20,7 @@ ) from gmso.utils.geometry import coord_shift from gmso.utils.io import has_gsd, has_hoomd -from gmso.utils.sorting import ( - natural_sort, - sort_connection_members, - sort_member_types, -) +from gmso.utils.sorting import sort_by_classes, sort_connection_members if has_gsd: import gsd.hoomd @@ -384,9 +380,9 @@ def _parse_bond_information(snapshot, top): for bond in top.bonds: if all([site.atom_type for site in bond.connection_members]): - connection_members = sort_connection_members(bond, "atom_type") + connection_members = sort_connection_members(bond, "atomclass") bond_type = "-".join( - [site.atom_type.name for site in connection_members] + [site.atom_type.atomclass for site in connection_members] ) else: connection_members = sort_connection_members(bond, "name") @@ -402,8 +398,8 @@ def _parse_bond_information(snapshot, top): if isinstance(snapshot, hoomd.Snapshot): snapshot.bonds.types = unique_bond_types - snapshot.bonds.typeid[0:] = bond_typeids - snapshot.bonds.group[0:] = bond_groups + snapshot.bonds.typeid[:] = bond_typeids + snapshot.bonds.group[:] = bond_groups elif isinstance(snapshot, gsd.hoomd.Frame): snapshot.bonds.types = unique_bond_types snapshot.bonds.typeid = bond_typeids @@ -431,9 +427,9 @@ def _parse_angle_information(snapshot, top): for angle in top.angles: if all([site.atom_type for site in angle.connection_members]): - connection_members = sort_connection_members(angle, "atom_type") + connection_members = sort_connection_members(angle, "atomclass") angle_type = "-".join( - [site.atom_type.name for site in connection_members] + [site.atom_type.atomclass for site in connection_members] ) else: connection_members = sort_connection_members(angle, "name") @@ -449,8 +445,8 @@ def _parse_angle_information(snapshot, top): if isinstance(snapshot, hoomd.Snapshot): snapshot.angles.types = unique_angle_types - snapshot.angles.typeid[0:] = angle_typeids - snapshot.angles.group[0:] = np.reshape(angle_groups, (-1, 3)) + snapshot.angles.typeid[:] = angle_typeids + snapshot.angles.group[:] = np.reshape(angle_groups, (-1, 3)) elif isinstance(snapshot, gsd.hoomd.Frame): snapshot.angles.types = unique_angle_types snapshot.angles.typeid = angle_typeids @@ -477,9 +473,9 @@ def _parse_dihedral_information(snapshot, top): for dihedral in top.dihedrals: if all([site.atom_type for site in dihedral.connection_members]): - connection_members = sort_connection_members(dihedral, "atom_type") + connection_members = sort_connection_members(dihedral, "atomclass") dihedral_type = "-".join( - [site.atom_type.name for site in connection_members] + [site.atom_type.atomclass for site in connection_members] ) else: connection_members = sort_connection_members(dihedral, "name") @@ -495,8 +491,8 @@ def _parse_dihedral_information(snapshot, top): if isinstance(snapshot, hoomd.Snapshot): snapshot.dihedrals.types = unique_dihedral_types - snapshot.dihedrals.typeid[0:] = dihedral_typeids - snapshot.dihedrals.group[0:] = np.reshape(dihedral_groups, (-1, 4)) + snapshot.dihedrals.typeid[:] = dihedral_typeids + snapshot.dihedrals.group[:] = np.reshape(dihedral_groups, (-1, 4)) elif isinstance(snapshot, gsd.hoomd.Frame): snapshot.dihedrals.types = unique_dihedral_types snapshot.dihedrals.typeid = dihedral_typeids @@ -525,9 +521,9 @@ def _parse_improper_information(snapshot, top): for improper in top.impropers: if all([site.atom_type for site in improper.connection_members]): - connection_members = sort_connection_members(improper, "atom_type") + connection_members = sort_connection_members(improper, "atomclass") improper_type = "-".join( - [site.atom_type.name for site in connection_members] + [site.atom_type.atomclass for site in connection_members] ) else: connection_members = sort_connection_members(improper, "name") @@ -994,8 +990,8 @@ def _parse_harmonic_bond( ): for btype in btypes: # TODO: Unit conversion - member_types = sort_member_types(btype) - container.params["-".join(member_types)] = { + member_classes = sort_by_classes(btype) + container.params["-".join(member_classes)] = { "k": btype.parameters["k"], "r0": btype.parameters["r_eq"], } @@ -1064,8 +1060,8 @@ def _parse_harmonic_angle( agtypes, ): for agtype in agtypes: - member_types = sort_member_types(agtype) - container.params["-".join(member_types)] = { + member_classes = sort_by_classes(agtype) + container.params["-".join(member_classes)] = { "k": agtype.parameters["k"], "t0": agtype.parameters["theta_eq"], } @@ -1094,6 +1090,16 @@ def _parse_dihedral_forces( unique_dtypes = top.dihedral_types( filter_by=PotentialFilters.UNIQUE_NAME_CLASS ) + unique_dihedrals = {} + for dihedral in top.dihedrals: + unique_members = tuple( + [site.atom_type.atomclass for site in dihedral.connection_members] + ) + unique_dihedrals[unique_members] = dihedral + + unique_dtypes = [ + dihedral.dihedral_type for dihedral in unique_dihedrals.values() + ] groups = dict() for dtype in unique_dtypes: group = potential_types[dtype] @@ -1157,8 +1163,8 @@ def _parse_periodic_dihedral( dtypes, ): for dtype in dtypes: - member_types = sort_member_types(dtype) - container.params["-".join(member_types)] = { + member_classes = sort_by_classes(dtype) + container.params["-".join(member_classes)] = { "k": dtype.parameters["k"], "d": 1, "n": dtype.parameters["n"], @@ -1174,7 +1180,8 @@ def _parse_opls_dihedral( for dtype in dtypes: # TODO: The range of ks is mismatched (GMSO go from k0 to k5) # May need to do a check that k0 == k5 == 0 or raise a warning - container.params["-".join(dtype.member_types)] = { + member_classes = sort_by_classes(dtype) + container.params["-".join(member_classes)] = { "k1": dtype.parameters["k1"], "k2": dtype.parameters["k2"], "k3": dtype.parameters["k3"], @@ -1192,10 +1199,10 @@ def _parse_rb_dihedral( ) for dtype in dtypes: opls = convert_ryckaert_to_opls(dtype) - member_types = sort_member_types(dtype) + member_classes = sort_by_classes(dtype) # TODO: The range of ks is mismatched (GMSO go from k0 to k5) # May need to do a check that k0 == k5 == 0 or raise a warning - container.params["-".join(member_types)] = { + container.params["-".join(member_classes)] = { "k1": opls.parameters["k1"], "k2": opls.parameters["k2"], "k3": opls.parameters["k3"], @@ -1267,7 +1274,7 @@ def _parse_harmonic_improper( itypes, ): for itype in itypes: - member_types = sort_member_types(itype) + member_types = sort_by_classes(itype) container.params["-".join(member_types)] = { "k": itype.parameters["k"], "chi0": itype.parameters["phi_eq"], # diff nomenclature? diff --git a/gmso/formats/lammpsdata.py b/gmso/formats/lammpsdata.py index 69ce71a64..198584c33 100644 --- a/gmso/formats/lammpsdata.py +++ b/gmso/formats/lammpsdata.py @@ -26,7 +26,7 @@ from gmso.core.element import element_by_mass from gmso.core.improper import Improper from gmso.core.topology import Topology -from gmso.core.views import PotentialFilters, get_sorted_names +from gmso.core.views import PotentialFilters pfilter = PotentialFilters.UNIQUE_SORTED_NAMES from gmso.exceptions import NotYetImplementedWarning @@ -37,6 +37,7 @@ convert_opls_to_ryckaert, convert_ryckaert_to_opls, ) +from gmso.utils.sorting import sort_by_types from gmso.utils.units import LAMMPS_UnitSystems, write_out_parameter_and_units @@ -875,7 +876,7 @@ def _write_dihedraltypes(out_file, top, base_unyts, cfactorsDict): out_file.write("#\t" + "\t".join(param_labels) + "\n") indexList = list(top.dihedral_types(filter_by=pfilter)) index_membersList = [ - (dihedral_type, get_sorted_names(dihedral_type)) + (dihedral_type, sort_by_types(dihedral_type)) for dihedral_type in indexList ] index_membersList.sort(key=lambda x: ([x[1][i] for i in [1, 2, 0, 3]])) @@ -915,7 +916,7 @@ def _write_impropertypes(out_file, top, base_unyts, cfactorsDict): out_file.write("#\t" + "\t".join(param_labels) + "\n") indexList = list(top.improper_types(filter_by=pfilter)) index_membersList = [ - (improper_type, get_sorted_names(improper_type)) + (improper_type, sort_by_types(improper_type)) for improper_type in indexList ] index_membersList.sort(key=lambda x: ([x[1][i] for i in [0, 1, 2, 3]])) @@ -1005,14 +1006,14 @@ def _write_conn_data(out_file, top, connIter, connStr): out_file.write(f"\n{connStr.capitalize()}\n\n") indexList = list( map( - get_sorted_names, + sort_by_types, getattr(top, connStr[:-1] + "_types")(filter_by=pfilter), ) ) indexList.sort(key=sorting_funcDict[connStr]) for i, conn in enumerate(getattr(top, connStr)): - typeStr = f"{i+1:<6d}\t{indexList.index(get_sorted_names(conn.connection_type))+1:<6d}\t" + typeStr = f"{i+1:<6d}\t{indexList.index(sort_by_types(conn.connection_type))+1:<6d}\t" indexStr = "\t".join( map( lambda x: str(top.sites.index(x) + 1).ljust(6), diff --git a/gmso/tests/files/alkanes.xml b/gmso/tests/files/alkanes.xml new file mode 100644 index 000000000..7d621c475 --- /dev/null +++ b/gmso/tests/files/alkanes.xml @@ -0,0 +1,122 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/gmso/tests/files/alkanes_wildcards.xml b/gmso/tests/files/alkanes_wildcards.xml new file mode 100644 index 000000000..7cd7f3f2b --- /dev/null +++ b/gmso/tests/files/alkanes_wildcards.xml @@ -0,0 +1,112 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/gmso/tests/test_hoomd.py b/gmso/tests/test_hoomd.py index 76651034a..0b57859da 100644 --- a/gmso/tests/test_hoomd.py +++ b/gmso/tests/test_hoomd.py @@ -5,12 +5,14 @@ import unyt as u from mbuild.formats.hoomd_forcefield import create_hoomd_forcefield +from gmso import ForceField from gmso.external import from_mbuild from gmso.external.convert_hoomd import to_hoomd_forcefield, to_hoomd_snapshot from gmso.parameterization import apply from gmso.tests.base_test import BaseTest from gmso.tests.utils import get_path from gmso.utils.io import has_hoomd, has_mbuild, import_ +from gmso.utils.sorting import sort_connection_strings if has_hoomd: hoomd = import_("hoomd") @@ -25,7 +27,7 @@ class TestGsd(BaseTest): def test_mbuild_comparison(self): compound = mb.load("CCC", smiles=True) - com_box = mb.packing.fill_box(compound, box=[5, 5, 5], n_compounds=20) + com_box = mb.packing.fill_box(compound, box=[5, 5, 5], n_compounds=2) base_units = { "mass": u.g / u.mol, "length": u.nm, @@ -86,21 +88,37 @@ def test_mbuild_comparison(self): mb_forcefield, key=lambda cls: str(cls.__class__) ) for mb_force, gmso_force in zip(sorted_mbuild_ff, sorted_gmso_ff): - if not isinstance(mb_force, hoomd.md.long_range.pppm.Coulomb): - keys = mb_force.params.param_dict.keys() - for key in keys: - mb_params = mb_force.params.param_dict[key] - gmso_params = gmso_force.params.param_dict[key] - variables = mb_params.keys() - for var in variables: - assert np.isclose(mb_params[var], gmso_params[var]) + if ( # TODO: why are these skipped? + isinstance(mb_force, hoomd.md.long_range.pppm.Coulomb) + or isinstance(mb_force, hoomd.md.pair.pair.LJ) + or isinstance(mb_force, hoomd.md.special_pair.LJ) + or isinstance(mb_force, hoomd.md.pair.pair.Ewald) + or isinstance(mb_force, hoomd.md.special_pair.Coulomb) + ): + continue + keys = mb_force.params.param_dict.keys() + gmso_keys = gmso_force.params.param_dict.keys() + print("\n\n", keys, gmso_keys, gmso_force, "\n\n") + for key in keys: + gmso_key = key.replace("opls_135", "CT") + gmso_key = gmso_key.replace("opls_136", "CT") + gmso_key = gmso_key.replace("opls_140", "HC") + gmso_key = "-".join( + sort_connection_strings(gmso_key.split("-")) + ) + mb_params = mb_force.params.param_dict[key] + gmso_params = gmso_force.params.param_dict[gmso_key] + variables = mb_params.keys() + for var in variables: + print(key, gmso_key, var, mb_params[var], gmso_params[var]) + assert np.isclose(mb_params[var], gmso_params[var]) @pytest.mark.skipif( int(hoomd_version[0]) < 4, reason="Unsupported features in HOOMD 3" ) def test_hoomd4_simulation(self): compound = mb.load("CCC", smiles=True) - com_box = mb.packing.fill_box(compound, box=[5, 5, 5], n_compounds=200) + com_box = mb.packing.fill_box(compound, box=[5, 5, 5], n_compounds=2) base_units = { "mass": u.g / u.mol, "length": u.nm, @@ -159,7 +177,7 @@ def test_hoomd4_simulation(self): ) def test_hoomd4_simulation_auto_scaled(self): compound = mb.load("CCC", smiles=True) - com_box = mb.packing.fill_box(compound, box=[5, 5, 5], n_compounds=200) + com_box = mb.packing.fill_box(compound, box=[5, 5, 5], n_compounds=2) base_units = { "mass": u.g / u.mol, "length": u.nm, @@ -221,7 +239,7 @@ def test_hoomd4_simulation_auto_scaled(self): ) def test_hoomd3_simulation(self): compound = mb.load("CCC", smiles=True) - com_box = mb.packing.fill_box(compound, box=[5, 5, 5], n_compounds=200) + com_box = mb.packing.fill_box(compound, box=[5, 5, 5], n_compounds=2) base_units = { "mass": u.g / u.mol, "length": u.nm, @@ -277,7 +295,7 @@ def test_hoomd3_simulation(self): ) def test_hoomd3_simulation_auto_scaled(self): compound = mb.load("CCC", smiles=True) - com_box = mb.packing.fill_box(compound, box=[5, 5, 5], n_compounds=200) + com_box = mb.packing.fill_box(compound, box=[5, 5, 5], n_compounds=2) base_units = { "mass": u.g / u.mol, "length": u.nm, @@ -333,7 +351,7 @@ def test_hoomd3_simulation_auto_scaled(self): def test_diff_base_units(self): compound = mb.load("CC", smiles=True) - com_box = mb.packing.fill_box(compound, box=[5, 5, 5], n_compounds=100) + com_box = mb.packing.fill_box(compound, box=[5, 5, 5], n_compounds=2) base_units = { "mass": u.amu, "length": u.nm, @@ -357,7 +375,7 @@ def test_diff_base_units(self): def test_default_units(self): compound = mb.load("CC", smiles=True) - com_box = mb.packing.fill_box(compound, box=[5, 5, 5], n_compounds=100) + com_box = mb.packing.fill_box(compound, box=[5, 5, 5], n_compounds=2) base_units = { "mass": u.amu, "length": u.nm, @@ -413,7 +431,7 @@ def test_ff_zero_parameter(self): def test_zero_charges(self): compound = mb.load("CC", smiles=True) - com_box = mb.packing.fill_box(compound, box=[5, 5, 5], n_compounds=20) + com_box = mb.packing.fill_box(compound, box=[5, 5, 5], n_compounds=2) base_units = { "mass": u.amu, "length": u.nm, @@ -436,3 +454,108 @@ def test_zero_charges(self): assert not isinstance(force, hoomd.md.pair.pair.Ewald) assert not isinstance(force, hoomd.md.long_range.pppm.Coulomb) assert not isinstance(force, hoomd.md.special_pair.Coulomb) + + def test_forces_connections_match(self): + compound = mb.load("CC", smiles=True) + com_box = mb.packing.fill_box(compound, box=[5, 5, 5], n_compounds=2) + base_units = { + "mass": u.amu, + "length": u.nm, + "energy": u.kJ / u.mol, + } + top = com_box.to_gmso() + top.identify_connections() + ethaneFF = ForceField(get_path("alkanes.xml")) + ethaneFF.atom_types["opls_01"] = ethaneFF.atom_types.pop("opls_140") + ethaneFF.atom_types["opls_01"].name = "opls_01" + ethaneFF.atom_types["opls_01"].atomclass = "opls_01" + ethaneFF.atom_types["opls_1004"] = ethaneFF.atom_types.pop("opls_135") + ethaneFF.atom_types["opls_1004"].name = "opls_1004" + ethaneFF.atom_types["opls_1004"].atomclass = "opls_1004" + xDict = {"bond": {}, "angle": {}, "dihedral": {}} + for dictKey in xDict: + for connection in getattr(ethaneFF, dictKey + "_types"): + newname = connection + for atomclass, atomname in { + "HC": "opls_01", + "CT": "opls_1004", + }.items(): + newname = newname.replace(atomclass, atomname) + xDict[dictKey][connection] = newname + + for dictKey in xDict: + for oldname, newname in xDict[dictKey].items(): + getattr(ethaneFF, dictKey + "_types")[newname] = getattr( + ethaneFF, dictKey + "_types" + ).pop(oldname) + getattr(ethaneFF, dictKey + "_types")[newname].name = newname + getattr(ethaneFF, dictKey + "_types")[ + newname + ].member_types = tuple(newname.split("~")) + + # ethaneFF.bond_types["opls_01~opls_1004"] = ethaneFF.bond_types.pop("CT~HC") + # ethaneFF.bond_types["opls_01~opls_01"] = ethaneFF.bond_types.pop("CT~CT") + # ethaneFF.angle_types["opls_01~opls_01~opls_1004"] = ethaneFF.angle_types.pop("CT~CT~HC") + # ethaneFF.angle_types["opls_1004~opls_01~opls_1004"] = ethaneFF.angle_types.pop("HC~CT~HC") + # ethaneFF.dihedral_types["opls_1004~opls_01~opls_01~opls_1004"] = ethaneFF.dihedral_types.pop("HC~CT~CT~HC") + # should sort these opls_01, opls_1004 + top = apply(top, ethaneFF, remove_untyped=True) + + snapshot, snapshot_base_units = to_hoomd_snapshot( + top, base_units=base_units + ) + + forces, forces_base_units = to_hoomd_forcefield( + top=top, r_cut=1.4, base_units=base_units + ) + + assert forces_base_units == snapshot_base_units + for conntype in snapshot.bonds.types: + assert conntype in list(forces["bonds"][0].params.keys()) + + def test_forces_connections_match2(self): + compound = mb.load("CC", smiles=True) + com_box = mb.packing.fill_box(compound, box=[5, 5, 5], n_compounds=1) + base_units = { + "mass": u.amu, + "length": u.nm, + "energy": u.kJ / u.mol, + } + top = com_box.to_gmso() + top.identify_connections() + ethaneFF = ForceField(get_path("alkanes.xml")) + + top = apply(top, ethaneFF, remove_untyped=True) + + snapshot, snapshot_base_units = to_hoomd_snapshot( + top, base_units=base_units + ) + assert "CT-HC" in snapshot.bonds.types + + forces, forces_base_units = to_hoomd_forcefield( + top=top, r_cut=1.4, base_units=base_units + ) + assert "CT-HC" in forces["bonds"][0].params.keys() + + def test_forces_wildcards(self): + compound = mb.load("CCCC", smiles=True) + com_box = mb.packing.fill_box(compound, box=[5, 5, 5], n_compounds=1) + base_units = { + "mass": u.amu, + "length": u.nm, + "energy": u.kJ / u.mol, + } + top = com_box.to_gmso() + top.identify_connections() + ethaneFF = ForceField(get_path("alkanes_wildcards.xml")) + top = apply(top, ethaneFF, remove_untyped=True) + + snapshot, _ = to_hoomd_snapshot(top, base_units=base_units) + assert "CT-HC" in snapshot.bonds.types + + forces, _ = to_hoomd_forcefield( + top=top, r_cut=1.4, base_units=base_units + ) + assert "CT-CT-CT-HC" in list(forces["dihedrals"][0].params) + for conntype in snapshot.dihedrals.types: + assert conntype in list(forces["dihedrals"][0].params) diff --git a/gmso/utils/sorting.py b/gmso/utils/sorting.py index 602885b61..da2ad34d4 100644 --- a/gmso/utils/sorting.py +++ b/gmso/utils/sorting.py @@ -2,6 +2,24 @@ import re import gmso +from gmso.core.angle import Angle +from gmso.core.angle_type import AngleType +from gmso.core.atom import Atom +from gmso.core.atom_type import AtomType +from gmso.core.bond import Bond +from gmso.core.bond_type import BondType +from gmso.core.dihedral import Dihedral +from gmso.core.dihedral_type import DihedralType +from gmso.core.improper import Improper +from gmso.core.improper_type import ImproperType + +potential_attribute_map = { + Atom: "atom_type", + Bond: "bond_type", + Angle: "angle_type", + Dihedral: "dihedral_type", + Improper: "improper_type", +} def _atoi(text): @@ -14,42 +32,14 @@ def natural_sort(text): return [_atoi(a) for a in re.split(r"(\d+)", text)] -def sort_member_types(connection_type): - """Sort connection_members of connection_type.""" - if isinstance(connection_type, gmso.BondType): - type1, type2 = connection_type.member_types - type1, type2 = sorted([type1, type2], key=natural_sort) - return [type1, type2] - elif isinstance(connection_type, gmso.AngleType): - type1, type2, type3 = connection_type.member_types - type1, type3 = sorted([type1, type3], key=natural_sort) - return [type1, type2, type3] - elif isinstance(connection_type, gmso.DihedralType): - type1, type2, type3, type4 = connection_type.member_types - if [type2, type3] == sorted([type2, type3], key=natural_sort): - return [type1, type2, type3, type4] - else: - return [type4, type3, type2, type1] - elif isinstance(connection_type, gmso.ImproperType): - type1, type2, type3, type4 = connection_type.member_types - type2, type3, type4 = sorted([type2, type3, type4], key=natural_sort) - return [type1, type2, type3, type4] - else: - raise TypeError("Provided connection_type not supported.") - - def sort_connection_members(connection, sort_by="name"): """Sort connection_members of connection.""" if sort_by == "name": - - def sorting_key(site): - return site.name - + sorting_key = lambda site: natural_sort(site.name) elif sort_by == "atom_type": - - def sorting_key(site): - return site.atom_type.name - + sorting_key = lambda site: natural_sort(site.atom_type.name) + elif sort_by == "atomclass": + sorting_key = lambda site: natural_sort(site.atom_type.atomclass) else: raise ValueError("Unsupported sort_by value provided.") @@ -63,13 +53,109 @@ def sorting_key(site): return [site1, site2, site3] elif isinstance(connection, gmso.Dihedral): site1, site2, site3, site4 = connection.connection_members - if [site2, site3] == sorted([site2, site3], key=sorting_key): - return [site1, site2, site3, site4] - else: + if sorting_key(site2) > sorting_key(site3) or ( + sorting_key(site2) == sorting_key(site3) + and sorting_key(site1) > sorting_key(site4) + ): return [site4, site3, site2, site1] + else: + return [site1, site2, site3, site4] elif isinstance(connection, gmso.Improper): site1, site2, site3, site4 = connection.connection_members site2, site3, site4 = sorted([site2, site3, site4], key=sorting_key) return [site1, site2, site3, site4] else: raise TypeError("Provided connection not supported.") + + +def sort_by_classes(potential): + """Get list of classes for a topology potential based on memberclass.""" + if isinstance(potential, AtomType): + return potential.atom_type.atomclass + elif isinstance(potential, BondType): + return tuple(sorted(potential.member_classes)) + elif isinstance(potential, AngleType): + if potential.member_classes[0] > potential.member_classes[2]: + return tuple(reversed(potential.member_classes)) + else: + return potential.member_classes + elif isinstance(potential, DihedralType): + if potential.member_classes[1] > potential.member_classes[2] or ( + potential.member_classes[1] == potential.member_classes[2] + and potential.member_classes[0] > potential.member_classes[3] + ): + return tuple(reversed(potential.member_classes)) + else: + return potential.member_classes + elif isinstance(potential, ImproperType): + return ( + potential.member_classes[0], + *potential.member_classes[1:], + ) # could sort using `sorted` + return ValueError( + f"Potential {potential} not one of {potential_attribute_map.values()}" + ) + + +def sort_by_types(potential): + """Get list of types for a topology potential based on membertype.""" + if isinstance(potential, AtomType): + return potential.name + elif isinstance(potential, BondType): + return tuple(sorted(potential.member_types)) + elif isinstance(potential, AngleType): + if potential.member_types[0] > potential.member_types[2]: + return tuple(reversed(potential.member_types)) + else: + return potential.member_types + elif isinstance(potential, DihedralType): + if potential.member_types[1] > potential.member_types[2] or ( + potential.member_types[1] == potential.member_types[2] + and potential.member_types[0] > potential.member_types[3] + ): + return tuple(reversed(potential.member_types)) + else: + return potential.member_types + elif isinstance(potential, ImproperType): + return ( + potential.member_types[0], + *potential.member_types[1:], + ) # could sort using `sorted` + return ValueError( + f"Potential {potential} not one of {potential_attribute_map.values()}" + ) + + +def sort_connection_strings(namesList, improperBool=False): + """Sort list of strings for a connection to get proper ordering of the connection. + + Parameters + ---------- + namesList : list + List of strings connected to a compound to sort. + improperBool : bool, option, default=False + whether or not a four member list refers to an improper + """ + if len(namesList) == 2: # assume bonds + return tuple(sorted(namesList)) + elif len(namesList) == 3: + if namesList[0] > namesList[2]: + return tuple(reversed(namesList)) + else: + return tuple(namesList) + elif len(namesList) == 4 and improperBool: + return tuple( + namesList[0], + sorted(*namesList[1:]), + ) + elif len(namesList) == 4 and not improperBool: + if namesList[1] > namesList[2] or ( + namesList[1] == namesList[2] and namesList[0] > namesList[3] + ): + return tuple(reversed(namesList)) + else: + return tuple(namesList) + else: + return ValueError( + f"Cannot sort {namesList}. It is not a length of 2,3, or 4 members." + )