Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes to general sorting algorithms for hoomd and lammps #762

Closed
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 2 additions & 30 deletions gmso/core/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from gmso.core.dihedral_type import DihedralType
from gmso.core.improper import Improper
from gmso.core.improper_type import ImproperType
from gmso.utils.sorting import sort_by_types

__all__ = ["TopologyPotentialView", "PotentialFilters"]

Expand All @@ -37,35 +38,6 @@ def get_name_or_class(potential):
return potential.member_types or potential.member_classes


def get_sorted_names(potential):
"""Get identifier for a topology potential based on name or membertype/class."""
if isinstance(potential, AtomType):
return potential.name
elif isinstance(potential, BondType):
return tuple(sorted(potential.member_types))
elif isinstance(potential, AngleType):
if potential.member_types[0] > potential.member_types[2]:
return tuple(reversed(potential.member_types))
else:
return potential.member_types
elif isinstance(potential, DihedralType):
if potential.member_types[1] > potential.member_types[2] or (
potential.member_types[1] == potential.member_types[2]
and potential.member_types[0] > potential.member_types[3]
):
return tuple(reversed(potential.member_types))
else:
return potential.member_types
elif isinstance(potential, ImproperType):
return (
potential.member_types[0],
*potential.member_types[1:],
) # could sort using `sorted`
return ValueError(
f"Potential {potential} not one of {potential_attribute_map.values()}"
)


def get_parameters(potential):
"""Return hashable version of parameters for a potential."""
return (
Expand Down Expand Up @@ -105,7 +77,7 @@ def all():

potential_identifiers = {
PotentialFilters.UNIQUE_NAME_CLASS: get_name_or_class,
PotentialFilters.UNIQUE_SORTED_NAMES: get_sorted_names,
PotentialFilters.UNIQUE_SORTED_NAMES: sort_by_types,
PotentialFilters.UNIQUE_EXPRESSION: lambda p: str(p.expression),
PotentialFilters.UNIQUE_PARAMETERS: get_parameters,
PotentialFilters.UNIQUE_ID: lambda p: id(p),
Expand Down
146 changes: 93 additions & 53 deletions gmso/external/convert_hoomd.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,7 @@
)
from gmso.utils.geometry import coord_shift
from gmso.utils.io import has_gsd, has_hoomd
from gmso.utils.sorting import (
natural_sort,
sort_connection_members,
sort_member_types,
)
from gmso.utils.sorting import sort_by_classes, sort_connection_members

if has_gsd:
import gsd.hoomd
Expand Down Expand Up @@ -384,9 +380,9 @@

for bond in top.bonds:
if all([site.atom_type for site in bond.connection_members]):
connection_members = sort_connection_members(bond, "atom_type")
connection_members = sort_connection_members(bond, "atomclass")
bond_type = "-".join(
[site.atom_type.name for site in connection_members]
[site.atom_type.atomclass for site in connection_members]
)
else:
connection_members = sort_connection_members(bond, "name")
Expand All @@ -402,8 +398,8 @@

if isinstance(snapshot, hoomd.Snapshot):
snapshot.bonds.types = unique_bond_types
snapshot.bonds.typeid[0:] = bond_typeids
snapshot.bonds.group[0:] = bond_groups
snapshot.bonds.typeid[:] = bond_typeids
snapshot.bonds.group[:] = bond_groups
elif isinstance(snapshot, gsd.hoomd.Frame):
snapshot.bonds.types = unique_bond_types
snapshot.bonds.typeid = bond_typeids
Expand Down Expand Up @@ -431,9 +427,9 @@

for angle in top.angles:
if all([site.atom_type for site in angle.connection_members]):
connection_members = sort_connection_members(angle, "atom_type")
connection_members = sort_connection_members(angle, "atomclass")
angle_type = "-".join(
[site.atom_type.name for site in connection_members]
[site.atom_type.atomclass for site in connection_members]
)
else:
connection_members = sort_connection_members(angle, "name")
Expand All @@ -449,8 +445,8 @@

if isinstance(snapshot, hoomd.Snapshot):
snapshot.angles.types = unique_angle_types
snapshot.angles.typeid[0:] = angle_typeids
snapshot.angles.group[0:] = np.reshape(angle_groups, (-1, 3))
snapshot.angles.typeid[:] = angle_typeids
snapshot.angles.group[:] = np.reshape(angle_groups, (-1, 3))
elif isinstance(snapshot, gsd.hoomd.Frame):
snapshot.angles.types = unique_angle_types
snapshot.angles.typeid = angle_typeids
Expand All @@ -477,9 +473,9 @@

for dihedral in top.dihedrals:
if all([site.atom_type for site in dihedral.connection_members]):
connection_members = sort_connection_members(dihedral, "atom_type")
connection_members = sort_connection_members(dihedral, "atomclass")
dihedral_type = "-".join(
[site.atom_type.name for site in connection_members]
[site.atom_type.atomclass for site in connection_members]
)
else:
connection_members = sort_connection_members(dihedral, "name")
Expand All @@ -495,8 +491,8 @@

if isinstance(snapshot, hoomd.Snapshot):
snapshot.dihedrals.types = unique_dihedral_types
snapshot.dihedrals.typeid[0:] = dihedral_typeids
snapshot.dihedrals.group[0:] = np.reshape(dihedral_groups, (-1, 4))
snapshot.dihedrals.typeid[:] = dihedral_typeids
snapshot.dihedrals.group[:] = np.reshape(dihedral_groups, (-1, 4))
elif isinstance(snapshot, gsd.hoomd.Frame):
snapshot.dihedrals.types = unique_dihedral_types
snapshot.dihedrals.typeid = dihedral_typeids
Expand Down Expand Up @@ -525,9 +521,9 @@

for improper in top.impropers:
if all([site.atom_type for site in improper.connection_members]):
connection_members = sort_connection_members(improper, "atom_type")
connection_members = sort_connection_members(improper, "atomclass")

Check warning on line 524 in gmso/external/convert_hoomd.py

View check run for this annotation

Codecov / codecov/patch

gmso/external/convert_hoomd.py#L524

Added line #L524 was not covered by tests
improper_type = "-".join(
[site.atom_type.name for site in connection_members]
[site.atom_type.atomclass for site in connection_members]
)
else:
connection_members = sort_connection_members(improper, "name")
Expand Down Expand Up @@ -994,8 +990,8 @@
):
for btype in btypes:
# TODO: Unit conversion
member_types = sort_member_types(btype)
container.params["-".join(member_types)] = {
member_classes = sort_by_classes(btype)
container.params["-".join(member_classes)] = {
"k": btype.parameters["k"],
"r0": btype.parameters["r_eq"],
}
Expand Down Expand Up @@ -1064,8 +1060,8 @@
agtypes,
):
for agtype in agtypes:
member_types = sort_member_types(agtype)
container.params["-".join(member_types)] = {
member_classes = sort_by_classes(agtype)
container.params["-".join(member_classes)] = {
"k": agtype.parameters["k"],
"t0": agtype.parameters["theta_eq"],
}
Expand Down Expand Up @@ -1094,23 +1090,33 @@
unique_dtypes = top.dihedral_types(
filter_by=PotentialFilters.UNIQUE_NAME_CLASS
)
unique_dihedrals = {}
for dihedral in top.dihedrals:
unique_members = tuple(
[site.atom_type.atomclass for site in dihedral.connection_members]
)
unique_dihedrals[unique_members] = dihedral
unique_dtypes = [

Check notice

Code scanning / CodeQL

Unused local variable Note library

Variable unique_dtypes is not used.
dihedral.dihedral_type for dihedral in unique_dihedrals.values()
]
groups = dict()
for dtype in unique_dtypes:
group = potential_types[dtype]
for dihedral in unique_dihedrals.values():
group = potential_types[dihedral.dihedral_type]
if group not in groups:
groups[group] = [dtype]
groups[group] = [dihedral]
else:
groups[group].append(dtype)
groups[group].append(dihedral)

expected_unitsDict = {}
for group in groups:
expected_units_dim = potential_refs[group][
expected_unitsDict[group] = potential_refs[group][
"expected_parameters_dimensions"
]
groups[group] = _convert_params_units(
groups[group],
expected_units_dim,
base_units,
)
# groups[group] = _convert_connection_params_units(
# groups[group],
# expected_units_dim,
# base_units,
# )
dtype_group_map = {
"OPLSTorsionPotential": {
"container": hoomd.md.dihedral.OPLS,
Expand Down Expand Up @@ -1146,19 +1152,24 @@
dihedral_forces.append(
dtype_group_map[group]["parser"](
container=dtype_group_map[group]["container"](),
dtypes=groups[group],
dihedrals=groups[group],
expected_units_dim=expected_unitsDict[group],
base_units=base_units,
)
)
return dihedral_forces


def _parse_periodic_dihedral(
container,
dtypes,
container, dihedrals, expected_units_dim, base_units
):
for dtype in dtypes:
member_types = sort_member_types(dtype)
container.params["-".join(member_types)] = {
for dihedral in dihedrals:
dtype = dihedral.dihedral_type
dtype = _convert_single_param_units(

Check warning on line 1168 in gmso/external/convert_hoomd.py

View check run for this annotation

Codecov / codecov/patch

gmso/external/convert_hoomd.py#L1166-L1168

Added lines #L1166 - L1168 were not covered by tests
dtype, expected_units_dim, base_units
)
member_sites = sort_connection_members(dihedral, "atomclass")
container.params["-".join(member_sites)] = {

Check warning on line 1172 in gmso/external/convert_hoomd.py

View check run for this annotation

Codecov / codecov/patch

gmso/external/convert_hoomd.py#L1171-L1172

Added lines #L1171 - L1172 were not covered by tests
"k": dtype.parameters["k"],
"d": 1,
"n": dtype.parameters["n"],
Expand All @@ -1167,14 +1178,17 @@
return container


def _parse_opls_dihedral(
container,
dtypes,
):
for dtype in dtypes:
def _parse_opls_dihedral(container, dihedrals, expected_units_dim, base_units):
for dihedral in dihedrals:
dtype = dihedral.dihedral_type
dtype = _convert_single_param_units(

Check warning on line 1184 in gmso/external/convert_hoomd.py

View check run for this annotation

Codecov / codecov/patch

gmso/external/convert_hoomd.py#L1182-L1184

Added lines #L1182 - L1184 were not covered by tests
dtype, expected_units_dim, base_units
)
member_sites = sort_connection_members(dihedral, "atomclass")

Check warning on line 1187 in gmso/external/convert_hoomd.py

View check run for this annotation

Codecov / codecov/patch

gmso/external/convert_hoomd.py#L1187

Added line #L1187 was not covered by tests
# TODO: The range of ks is mismatched (GMSO go from k0 to k5)
# May need to do a check that k0 == k5 == 0 or raise a warning
container.params["-".join(dtype.member_types)] = {
member_classes = sort_by_classes(dtype)

Check notice

Code scanning / CodeQL

Unused local variable Note library

Variable member_classes is not used.
container.params["-".join(member_sites)] = {

Check warning on line 1191 in gmso/external/convert_hoomd.py

View check run for this annotation

Codecov / codecov/patch

gmso/external/convert_hoomd.py#L1190-L1191

Added lines #L1190 - L1191 were not covered by tests
"k1": dtype.parameters["k1"],
"k2": dtype.parameters["k2"],
"k3": dtype.parameters["k3"],
Expand All @@ -1183,19 +1197,22 @@
return container


def _parse_rb_dihedral(
container,
dtypes,
):
def _parse_rb_dihedral(container, dihedrals, expected_units_dim, base_units):
warnings.warn(
"RyckaertBellemansTorsionPotential will be converted to OPLSTorsionPotential."
)
for dtype in dtypes:
for dihedral in dihedrals:
dtype = dihedral.dihedral_type
dtype = _convert_single_param_units(
dtype, expected_units_dim, base_units
)
opls = convert_ryckaert_to_opls(dtype)
member_types = sort_member_types(dtype)
member_sites = sort_connection_members(dihedral, "atomclass")
# TODO: The range of ks is mismatched (GMSO go from k0 to k5)
# May need to do a check that k0 == k5 == 0 or raise a warning
container.params["-".join(member_types)] = {
container.params[
"-".join([site.atom_type.atomclass for site in member_sites])
] = {
"k1": opls.parameters["k1"],
"k2": opls.parameters["k2"],
"k3": opls.parameters["k3"],
Expand Down Expand Up @@ -1267,7 +1284,7 @@
itypes,
):
for itype in itypes:
member_types = sort_member_types(itype)
member_types = sort_by_classes(itype)

Check warning on line 1287 in gmso/external/convert_hoomd.py

View check run for this annotation

Codecov / codecov/patch

gmso/external/convert_hoomd.py#L1287

Added line #L1287 was not covered by tests
container.params["-".join(member_types)] = {
"k": itype.parameters["k"],
"chi0": itype.parameters["phi_eq"], # diff nomenclature?
Expand Down Expand Up @@ -1413,3 +1430,26 @@
potential.parameters = converted_params
converted_potentials.append(potential)
return converted_potentials


def _convert_single_param_units(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this belong to utils/units.py?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. It should get moved into using the new UnitsClasses, but maybe we can do that in a later PR.

potential,
expected_units_dim,
base_units,
):
"""Convert parameters' units in the potential to that specified in the base_units."""
converted_params = dict()
for parameter in potential.parameters:
unit_dim = expected_units_dim[parameter]
ind_units = re.sub("[^a-zA-Z]+", " ", unit_dim).split()
for unit in ind_units:
unit_dim = unit_dim.replace(
unit,
f"({str(base_units[unit].value)} * {str(base_units[unit].units)})",
)

converted_params[parameter] = potential.parameters[parameter].to(
unit_dim
)
potential.parameters = converted_params
return potential
11 changes: 6 additions & 5 deletions gmso/formats/lammpsdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from gmso.core.element import element_by_mass
from gmso.core.improper import Improper
from gmso.core.topology import Topology
from gmso.core.views import PotentialFilters, get_sorted_names
from gmso.core.views import PotentialFilters

pfilter = PotentialFilters.UNIQUE_SORTED_NAMES
from gmso.exceptions import NotYetImplementedWarning
Expand All @@ -37,6 +37,7 @@
convert_opls_to_ryckaert,
convert_ryckaert_to_opls,
)
from gmso.utils.sorting import sort_by_types
from gmso.utils.units import LAMMPS_UnitSystems, write_out_parameter_and_units


Expand Down Expand Up @@ -875,7 +876,7 @@ def _write_dihedraltypes(out_file, top, base_unyts, cfactorsDict):
out_file.write("#\t" + "\t".join(param_labels) + "\n")
indexList = list(top.dihedral_types(filter_by=pfilter))
index_membersList = [
(dihedral_type, get_sorted_names(dihedral_type))
(dihedral_type, sort_by_types(dihedral_type))
for dihedral_type in indexList
]
index_membersList.sort(key=lambda x: ([x[1][i] for i in [1, 2, 0, 3]]))
Expand Down Expand Up @@ -915,7 +916,7 @@ def _write_impropertypes(out_file, top, base_unyts, cfactorsDict):
out_file.write("#\t" + "\t".join(param_labels) + "\n")
indexList = list(top.improper_types(filter_by=pfilter))
index_membersList = [
(improper_type, get_sorted_names(improper_type))
(improper_type, sort_by_types(improper_type))
for improper_type in indexList
]
index_membersList.sort(key=lambda x: ([x[1][i] for i in [0, 1, 2, 3]]))
Expand Down Expand Up @@ -1005,14 +1006,14 @@ def _write_conn_data(out_file, top, connIter, connStr):
out_file.write(f"\n{connStr.capitalize()}\n\n")
indexList = list(
map(
get_sorted_names,
sort_by_types,
getattr(top, connStr[:-1] + "_types")(filter_by=pfilter),
)
)
indexList.sort(key=sorting_funcDict[connStr])

for i, conn in enumerate(getattr(top, connStr)):
typeStr = f"{i+1:<6d}\t{indexList.index(get_sorted_names(conn.connection_type))+1:<6d}\t"
typeStr = f"{i+1:<6d}\t{indexList.index(sort_by_types(conn.connection_type))+1:<6d}\t"
indexStr = "\t".join(
map(
lambda x: str(top.sites.index(x) + 1).ljust(6),
Expand Down
Loading
Loading