-
Notifications
You must be signed in to change notification settings - Fork 33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixes to general sorting algorithms for hoomd and lammps #762
Changes from 3 commits
37fdf57
adebaaa
6fa1518
06e9ddd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,11 +20,7 @@ | |
) | ||
from gmso.utils.geometry import coord_shift | ||
from gmso.utils.io import has_gsd, has_hoomd | ||
from gmso.utils.sorting import ( | ||
natural_sort, | ||
sort_connection_members, | ||
sort_member_types, | ||
) | ||
from gmso.utils.sorting import sort_by_classes, sort_connection_members | ||
|
||
if has_gsd: | ||
import gsd.hoomd | ||
|
@@ -384,9 +380,9 @@ | |
|
||
for bond in top.bonds: | ||
if all([site.atom_type for site in bond.connection_members]): | ||
connection_members = sort_connection_members(bond, "atom_type") | ||
connection_members = sort_connection_members(bond, "atomclass") | ||
bond_type = "-".join( | ||
[site.atom_type.name for site in connection_members] | ||
[site.atom_type.atomclass for site in connection_members] | ||
) | ||
else: | ||
connection_members = sort_connection_members(bond, "name") | ||
|
@@ -402,8 +398,8 @@ | |
|
||
if isinstance(snapshot, hoomd.Snapshot): | ||
snapshot.bonds.types = unique_bond_types | ||
snapshot.bonds.typeid[0:] = bond_typeids | ||
snapshot.bonds.group[0:] = bond_groups | ||
snapshot.bonds.typeid[:] = bond_typeids | ||
snapshot.bonds.group[:] = bond_groups | ||
elif isinstance(snapshot, gsd.hoomd.Frame): | ||
snapshot.bonds.types = unique_bond_types | ||
snapshot.bonds.typeid = bond_typeids | ||
|
@@ -431,9 +427,9 @@ | |
|
||
for angle in top.angles: | ||
if all([site.atom_type for site in angle.connection_members]): | ||
connection_members = sort_connection_members(angle, "atom_type") | ||
connection_members = sort_connection_members(angle, "atomclass") | ||
angle_type = "-".join( | ||
[site.atom_type.name for site in connection_members] | ||
[site.atom_type.atomclass for site in connection_members] | ||
) | ||
else: | ||
connection_members = sort_connection_members(angle, "name") | ||
|
@@ -449,8 +445,8 @@ | |
|
||
if isinstance(snapshot, hoomd.Snapshot): | ||
snapshot.angles.types = unique_angle_types | ||
snapshot.angles.typeid[0:] = angle_typeids | ||
snapshot.angles.group[0:] = np.reshape(angle_groups, (-1, 3)) | ||
snapshot.angles.typeid[:] = angle_typeids | ||
snapshot.angles.group[:] = np.reshape(angle_groups, (-1, 3)) | ||
elif isinstance(snapshot, gsd.hoomd.Frame): | ||
snapshot.angles.types = unique_angle_types | ||
snapshot.angles.typeid = angle_typeids | ||
|
@@ -477,9 +473,9 @@ | |
|
||
for dihedral in top.dihedrals: | ||
if all([site.atom_type for site in dihedral.connection_members]): | ||
connection_members = sort_connection_members(dihedral, "atom_type") | ||
connection_members = sort_connection_members(dihedral, "atomclass") | ||
dihedral_type = "-".join( | ||
[site.atom_type.name for site in connection_members] | ||
[site.atom_type.atomclass for site in connection_members] | ||
) | ||
else: | ||
connection_members = sort_connection_members(dihedral, "name") | ||
|
@@ -495,8 +491,8 @@ | |
|
||
if isinstance(snapshot, hoomd.Snapshot): | ||
snapshot.dihedrals.types = unique_dihedral_types | ||
snapshot.dihedrals.typeid[0:] = dihedral_typeids | ||
snapshot.dihedrals.group[0:] = np.reshape(dihedral_groups, (-1, 4)) | ||
snapshot.dihedrals.typeid[:] = dihedral_typeids | ||
snapshot.dihedrals.group[:] = np.reshape(dihedral_groups, (-1, 4)) | ||
elif isinstance(snapshot, gsd.hoomd.Frame): | ||
snapshot.dihedrals.types = unique_dihedral_types | ||
snapshot.dihedrals.typeid = dihedral_typeids | ||
|
@@ -525,9 +521,9 @@ | |
|
||
for improper in top.impropers: | ||
if all([site.atom_type for site in improper.connection_members]): | ||
connection_members = sort_connection_members(improper, "atom_type") | ||
connection_members = sort_connection_members(improper, "atomclass") | ||
improper_type = "-".join( | ||
[site.atom_type.name for site in connection_members] | ||
[site.atom_type.atomclass for site in connection_members] | ||
) | ||
else: | ||
connection_members = sort_connection_members(improper, "name") | ||
|
@@ -994,8 +990,8 @@ | |
): | ||
for btype in btypes: | ||
# TODO: Unit conversion | ||
member_types = sort_member_types(btype) | ||
container.params["-".join(member_types)] = { | ||
member_classes = sort_by_classes(btype) | ||
container.params["-".join(member_classes)] = { | ||
"k": btype.parameters["k"], | ||
"r0": btype.parameters["r_eq"], | ||
} | ||
|
@@ -1064,8 +1060,8 @@ | |
agtypes, | ||
): | ||
for agtype in agtypes: | ||
member_types = sort_member_types(agtype) | ||
container.params["-".join(member_types)] = { | ||
member_classes = sort_by_classes(agtype) | ||
container.params["-".join(member_classes)] = { | ||
"k": agtype.parameters["k"], | ||
"t0": agtype.parameters["theta_eq"], | ||
} | ||
|
@@ -1094,23 +1090,33 @@ | |
unique_dtypes = top.dihedral_types( | ||
filter_by=PotentialFilters.UNIQUE_NAME_CLASS | ||
) | ||
unique_dihedrals = {} | ||
for dihedral in top.dihedrals: | ||
unique_members = tuple( | ||
[site.atom_type.atomclass for site in dihedral.connection_members] | ||
) | ||
unique_dihedrals[unique_members] = dihedral | ||
unique_dtypes = [ | ||
dihedral.dihedral_type for dihedral in unique_dihedrals.values() | ||
] | ||
groups = dict() | ||
for dtype in unique_dtypes: | ||
group = potential_types[dtype] | ||
for dihedral in unique_dihedrals.values(): | ||
group = potential_types[dihedral.dihedral_type] | ||
if group not in groups: | ||
groups[group] = [dtype] | ||
groups[group] = [dihedral] | ||
else: | ||
groups[group].append(dtype) | ||
groups[group].append(dihedral) | ||
|
||
expected_unitsDict = {} | ||
for group in groups: | ||
expected_units_dim = potential_refs[group][ | ||
expected_unitsDict[group] = potential_refs[group][ | ||
"expected_parameters_dimensions" | ||
] | ||
groups[group] = _convert_params_units( | ||
groups[group], | ||
expected_units_dim, | ||
base_units, | ||
) | ||
# groups[group] = _convert_connection_params_units( | ||
# groups[group], | ||
# expected_units_dim, | ||
# base_units, | ||
# ) | ||
dtype_group_map = { | ||
"OPLSTorsionPotential": { | ||
"container": hoomd.md.dihedral.OPLS, | ||
|
@@ -1146,19 +1152,24 @@ | |
dihedral_forces.append( | ||
dtype_group_map[group]["parser"]( | ||
container=dtype_group_map[group]["container"](), | ||
dtypes=groups[group], | ||
dihedrals=groups[group], | ||
expected_units_dim=expected_unitsDict[group], | ||
base_units=base_units, | ||
) | ||
) | ||
return dihedral_forces | ||
|
||
|
||
def _parse_periodic_dihedral( | ||
container, | ||
dtypes, | ||
container, dihedrals, expected_units_dim, base_units | ||
): | ||
for dtype in dtypes: | ||
member_types = sort_member_types(dtype) | ||
container.params["-".join(member_types)] = { | ||
for dihedral in dihedrals: | ||
dtype = dihedral.dihedral_type | ||
dtype = _convert_single_param_units( | ||
dtype, expected_units_dim, base_units | ||
) | ||
member_sites = sort_connection_members(dihedral, "atomclass") | ||
container.params["-".join(member_sites)] = { | ||
"k": dtype.parameters["k"], | ||
"d": 1, | ||
"n": dtype.parameters["n"], | ||
|
@@ -1167,14 +1178,17 @@ | |
return container | ||
|
||
|
||
def _parse_opls_dihedral( | ||
container, | ||
dtypes, | ||
): | ||
for dtype in dtypes: | ||
def _parse_opls_dihedral(container, dihedrals, expected_units_dim, base_units): | ||
for dihedral in dihedrals: | ||
dtype = dihedral.dihedral_type | ||
dtype = _convert_single_param_units( | ||
dtype, expected_units_dim, base_units | ||
) | ||
member_sites = sort_connection_members(dihedral, "atomclass") | ||
# TODO: The range of ks is mismatched (GMSO go from k0 to k5) | ||
# May need to do a check that k0 == k5 == 0 or raise a warning | ||
container.params["-".join(dtype.member_types)] = { | ||
member_classes = sort_by_classes(dtype) | ||
Check notice Code scanning / CodeQL Unused local variable Note library
Variable member_classes is not used.
|
||
container.params["-".join(member_sites)] = { | ||
"k1": dtype.parameters["k1"], | ||
"k2": dtype.parameters["k2"], | ||
"k3": dtype.parameters["k3"], | ||
|
@@ -1183,19 +1197,22 @@ | |
return container | ||
|
||
|
||
def _parse_rb_dihedral( | ||
container, | ||
dtypes, | ||
): | ||
def _parse_rb_dihedral(container, dihedrals, expected_units_dim, base_units): | ||
warnings.warn( | ||
"RyckaertBellemansTorsionPotential will be converted to OPLSTorsionPotential." | ||
) | ||
for dtype in dtypes: | ||
for dihedral in dihedrals: | ||
dtype = dihedral.dihedral_type | ||
dtype = _convert_single_param_units( | ||
dtype, expected_units_dim, base_units | ||
) | ||
opls = convert_ryckaert_to_opls(dtype) | ||
member_types = sort_member_types(dtype) | ||
member_sites = sort_connection_members(dihedral, "atomclass") | ||
# TODO: The range of ks is mismatched (GMSO go from k0 to k5) | ||
# May need to do a check that k0 == k5 == 0 or raise a warning | ||
container.params["-".join(member_types)] = { | ||
container.params[ | ||
"-".join([site.atom_type.atomclass for site in member_sites]) | ||
] = { | ||
"k1": opls.parameters["k1"], | ||
"k2": opls.parameters["k2"], | ||
"k3": opls.parameters["k3"], | ||
|
@@ -1267,7 +1284,7 @@ | |
itypes, | ||
): | ||
for itype in itypes: | ||
member_types = sort_member_types(itype) | ||
member_types = sort_by_classes(itype) | ||
container.params["-".join(member_types)] = { | ||
"k": itype.parameters["k"], | ||
"chi0": itype.parameters["phi_eq"], # diff nomenclature? | ||
|
@@ -1413,3 +1430,26 @@ | |
potential.parameters = converted_params | ||
converted_potentials.append(potential) | ||
return converted_potentials | ||
|
||
|
||
def _convert_single_param_units( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should this belong to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree. It should get moved into using the new UnitsClasses, but maybe we can do that in a later PR. |
||
potential, | ||
expected_units_dim, | ||
base_units, | ||
): | ||
"""Convert parameters' units in the potential to that specified in the base_units.""" | ||
converted_params = dict() | ||
for parameter in potential.parameters: | ||
unit_dim = expected_units_dim[parameter] | ||
ind_units = re.sub("[^a-zA-Z]+", " ", unit_dim).split() | ||
for unit in ind_units: | ||
unit_dim = unit_dim.replace( | ||
unit, | ||
f"({str(base_units[unit].value)} * {str(base_units[unit].units)})", | ||
) | ||
|
||
converted_params[parameter] = potential.parameters[parameter].to( | ||
unit_dim | ||
) | ||
potential.parameters = converted_params | ||
return potential |
Check notice
Code scanning / CodeQL
Unused local variable Note library