Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Nov 28, 2023
1 parent 53de3dd commit 82f123e
Showing 1 changed file with 41 additions and 18 deletions.
59 changes: 41 additions & 18 deletions gmso/formats/lammpsdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,15 +705,25 @@ def _write_header(out_file, top, atom_style):
if top.n_dihedrals > 0 and atom_style in ["full", "molecular"]:
unique_dtypes = top.dihedral_types(filter_by=pfilter)
from itertools import chain

nkeys = len(next(iter(unique_dtypes)).parameters.keys())
nparams = len(list(_flatten((map(lambda x: [val.tolist() for val in x.parameters.values()], unique_dtypes)))))
#nparams = len(list(chain([dt.parameters.values() for dt in unique_dtypes])))
ntypes = int(nparams / nkeys)
out_file.write(
"{:d} dihedral types\n".format(
ntypes
nparams = len(
list(
_flatten(
(
map(
lambda x: [
val.tolist() for val in x.parameters.values()
],
unique_dtypes,
)
)
)
)
)
# nparams = len(list(chain([dt.parameters.values() for dt in unique_dtypes])))
ntypes = int(nparams / nkeys)
out_file.write("{:d} dihedral types\n".format(ntypes))
if top.n_impropers > 0 and atom_style in ["full", "molecular"]:
out_file.write(
"{:d} improper types\n".format(
Expand Down Expand Up @@ -933,9 +943,13 @@ def _write_dihedraltypes(out_file, top, base_unyts, parser, cfactorsDict):
"""Write out dihedrals to LAMMPS file."""
test_dihedraltype = top.dihedrals[0].dihedral_type
out_file.write(f"\nDihedral Coeffs #{test_dihedraltype.name}\n")
param_labels0 = parser(test_dihedraltype) # tuple (paramsList, params_namesList)

if isinstance(param_labels0[0][0], list): # check for parsing out multiple instances from the dihedral
param_labels0 = parser(
test_dihedraltype
) # tuple (paramsList, params_namesList)

if isinstance(
param_labels0[0][0], list
): # check for parsing out multiple instances from the dihedral
param_labels = [
write_out_parameter_and_units(
name, convert_kelvin_to_energy_units(param, "kJ"), base_unyts
Expand All @@ -959,8 +973,10 @@ def _write_dihedraltypes(out_file, top, base_unyts, parser, cfactorsDict):
# handle variable lengths for parameters
base_msg = "{}\t" # handles index
end_msg = "# {}\t{}\t{}\t{}\n"

if parser.__name__ == "parse_opls_style_dihedral": # one opls parameter per dihedral type

if (
parser.__name__ == "parse_opls_style_dihedral"
): # one opls parameter per dihedral type
for idx, (dihedral_type, members) in enumerate(index_membersList):
param_labels = parser(dihedral_type)
variable_msg = "{:8}\t" * len(param_labels[1])
Expand All @@ -975,18 +991,22 @@ def _write_dihedraltypes(out_file, top, base_unyts, parser, cfactorsDict):
n_decimals=6,
name=parameterStr,
)
for parameter, parameterStr in zip(*parser(dihedral_type))
for parameter, parameterStr in zip(
*parser(dihedral_type)
)
],
*members,
)
)
elif parser.__name__ == "parse_charmm_style_dihedral":
ndecimalsDict = {"k":6, "n":0, "phi_eq":0, "weights":1}
ndecimalsDict = {"k": 6, "n": 0, "phi_eq": 0, "weights": 1}
for idx, (dihedral_type, members) in enumerate(index_membersList):
parameter_termList, parameterStrList = parser(dihedral_type)
variable_msg = "{:8}\t" * len(parameterStrList)
full_msg = base_msg + variable_msg + end_msg
for parameter_terms in parameter_termList: # list of params on each line
for (
parameter_terms
) in parameter_termList: # list of params on each line
out_file.write(
full_msg.format(
idx + 1,
Expand All @@ -997,7 +1017,9 @@ def _write_dihedraltypes(out_file, top, base_unyts, parser, cfactorsDict):
n_decimals=ndecimalsDict[parameterStr],
name=parameterStr,
)
for parameter, parameterStr in zip(parameter_terms, parameterStrList)
for parameter, parameterStr in zip(
parameter_terms, parameterStrList
)
],
*members,
)
Expand All @@ -1022,8 +1044,8 @@ def parse_charmm_style_dihedral(dihedral_type, weightsArray=None):
if not weightsArray: # used for amber forcefield weights are 0
weightsArray = np.zeros(kArray.size) * u.dimensionless
allParamsList = []
for a,b,c,d in zip(kArray, nArray, phi_eqArray, weightsArray):
allParamsList.append([a,b,c,d])
for a, b, c, d in zip(kArray, nArray, phi_eqArray, weightsArray):
allParamsList.append([a, b, c, d])
return allParamsList, ["k", "n", "phi_eq", "weights"]


Expand Down Expand Up @@ -1190,9 +1212,10 @@ def _default_lj_val(top, source):
f"Provided {source} for default LJ cannot be found in the topology."
)


def _flatten(iterable):
try:
for item in iterable:
yield from _flatten(item)
except TypeError:
yield iterable
yield iterable

0 comments on commit 82f123e

Please sign in to comment.