From b20472c4c1b5219a1f1713c73bc351e5856d61d4 Mon Sep 17 00:00:00 2001 From: Co Quach <43968221+daico007@users.noreply.github.com> Date: Tue, 28 Nov 2023 13:15:30 -0600 Subject: [PATCH] Add position restraints for Site (#770) * add site posision restraints and positions restraints writer for gromacs top * revert changes to the shifted_idx_map, fix other bugs * fix typo * fix recent failing test (unyt related) * add spce water fixture * add test for site level * fix code order, add spce xml * fix typo * add test for spce water with restraints * remove duplicate spce.xml * used site.element if atom_type.tags["element"] does not exist * fix typo * add mass information to spce xml * fix sigma value for HW in spce.xml --- gmso/core/atom.py | 16 +++++++ gmso/formats/gro.py | 3 +- gmso/formats/top.py | 36 ++++++++++++++-- gmso/tests/base_test.py | 18 ++++++++ gmso/tests/files/spce_restraint.top | 43 +++++++++++++++++++ gmso/tests/test_atom.py | 6 +++ gmso/tests/test_top.py | 11 +++++ .../files/gmso_xmls/test_ffstyles/spce.xml | 6 +-- 8 files changed, 132 insertions(+), 7 deletions(-) create mode 100644 gmso/tests/files/spce_restraint.top diff --git a/gmso/core/atom.py b/gmso/core/atom.py index 3e09c931e..bd1a1c551 100644 --- a/gmso/core/atom.py +++ b/gmso/core/atom.py @@ -58,6 +58,15 @@ class Atom(Site): atom_type_: Optional[AtomType] = Field( None, description="AtomType associated with the atom" ) + restraint_: Optional[dict] = Field( + default=None, + description=""" + Restraint for this atom, must be a dict with the following keys: + 'kx', 'ky', 'kz' (unit of energy/(mol * length**2), + Refer to https://manual.gromacs.org/current/reference-manual/topologies/topology-file-formats.html + for more information. + """, + ) @property def charge(self) -> Union[u.unyt_quantity, None]: @@ -93,6 +102,11 @@ def atom_type(self) -> Union[AtomType, None]: """Return the atom_type associated with the atom.""" return self.__dict__.get("atom_type_", None) + @property + def restraint(self): + """Return the restraint of this atom.""" + return self.__dict__.get("restraint_") + def clone(self): """Clone this atom.""" return Atom( @@ -164,6 +178,7 @@ class Config: "mass_": "mass", "element_": "element", "atom_type_": "atom_type", + "restraint_": "restraint", } alias_to_fields = { @@ -171,6 +186,7 @@ class Config: "mass": "mass_", "element": "element_", "atom_type": "atom_type_", + "restraint": "restraint_", } validate_assignment = True diff --git a/gmso/formats/gro.py b/gmso/formats/gro.py index f9315648f..76d57ca2f 100644 --- a/gmso/formats/gro.py +++ b/gmso/formats/gro.py @@ -159,7 +159,8 @@ def _validate_positions(pos_array): "in order to ensure all coordinates are non-negative." ) min_xyz = np.min(pos_array, axis=0) - min_xyz0 = np.where(min_xyz < 0, min_xyz, 0) * min_xyz.units + unit = min_xyz.units + min_xyz0 = np.where(min_xyz < 0 * unit, min_xyz, 0 * unit) pos_array -= min_xyz0 diff --git a/gmso/formats/top.py b/gmso/formats/top.py index 24b83b4e6..57f16142e 100644 --- a/gmso/formats/top.py +++ b/gmso/formats/top.py @@ -115,6 +115,7 @@ def write_top(top, filename, top_vars=None): # Section headers headers = { + "position_restraints": "\n[ position_restraints ]\n; ai\tfunct\tkx\tky\tkz\tfunct\tb0\t\tkb\n", "bonds": "\n[ bonds ]\n; ai\taj\tfunct\tb0\t\tkb\n", "bond_restraints": "\n[ bonds ] ;Harmonic potential restraint\n" "; ai\taj\tfunct\tb0\t\tkb\n", @@ -165,12 +166,20 @@ def write_top(top, filename, top_vars=None): site.atom_type.name, str(site.molecule.number if site.molecule else 1), tag, - site.atom_type.tags["element"], + site.atom_type.tags.get("element", site.element.symbol), "1", # TODO: care about charge groups site.charge.in_units(u.elementary_charge).value, site.atom_type.mass.in_units(u.amu).value, ) ) + if unique_molecules[tag]["position_restraints"]: + out_file.write(headers["position_restraints"]) + for site in unique_molecules[tag]["position_restraints"]: + out_file.write( + _write_restraint( + top, site, "position_restraints", shifted_idx_map + ) + ) for conn_group in [ "pairs", @@ -323,6 +332,9 @@ def _get_unique_molecules(top): unique_molecules[top.name] = dict() unique_molecules[top.name]["subtags"] = [top.name] unique_molecules[top.name]["sites"] = list(top.sites) + unique_molecules[top.name]["position_restraints"] = list( + site for site in top.sites if site.restraint + ) unique_molecules[top.name]["pairs"] = generate_pairs_lists( top, refer_from_scaling_factor=True )["pairs14"] @@ -346,6 +358,11 @@ def _get_unique_molecules(top): unique_molecules[tag]["sites"] = list( top.iter_sites(key="molecule", value=molecule) ) + unique_molecules[tag]["position_restraints"] = list( + site + for site in top.sites + if (site.restraint and site.molecule == molecule) + ) unique_molecules[tag]["pairs"] = generate_pairs_lists( top, molecule )["pairs14"] @@ -530,15 +547,28 @@ def _periodic_torsion_writer(top, dihedral, shifted_idx_map): return lines -def _write_restraint(top, connection, type, shifted_idx_map): +def _write_restraint(top, site_or_conn, type, shifted_idx_map): """Worker function to write various connection restraint information.""" worker_functions = { + "position_restraints": _position_restraints_writer, "bond_restraints": _bond_restraint_writer, "angle_restraints": _angle_restraint_writer, "dihedral_restraints": _dihedral_restraint_writer, } - return worker_functions[type](top, connection, shifted_idx_map) + return worker_functions[type](top, site_or_conn, shifted_idx_map) + + +def _position_restraints_writer(top, site, shifted_idx_map): + """Write site position restraint information.""" + line = "{0:8s}{1:4s}{2:15.5f}{3:15.5f}{4:15.5f}\n".format( + str(shifted_idx_map[top.get_index(site)] + 1), + "1", + site.restraint["kx"].in_units(u.Unit("kJ/(mol * nm**2)")).value, + site.restraint["ky"].in_units(u.Unit("kJ/(mol * nm**2)")).value, + site.restraint["kz"].in_units(u.Unit("kJ/(mol * nm**2)")).value, + ) + return line def _bond_restraint_writer(top, bond, shifted_idx_map): diff --git a/gmso/tests/base_test.py b/gmso/tests/base_test.py index a91d410e4..383bcfd76 100644 --- a/gmso/tests/base_test.py +++ b/gmso/tests/base_test.py @@ -165,6 +165,24 @@ def _typed_topology(n_sites=100): return _typed_topology + @pytest.fixture + def spce_water(self): + spce_comp = mb.lib.molecules.water.WaterSPC() + spce_ff = ForceField(get_fn("gmso_xmls/test_ffstyles/spce.xml")) + spce_top = spce_comp.to_gmso() + spce_top.identify_connections() + + spce_top = apply(spce_top, spce_ff, remove_untyped=True) + + for site in spce_top.sites: + site.restraint = { + "kx": 1000 * u.Unit("kJ/(mol*nm**2)"), + "ky": 1000 * u.Unit("kJ/(mol*nm**2)"), + "kz": 1000 * u.Unit("kJ/(mol*nm**2)"), + } + + return spce_top + @pytest.fixture def water_system(self): water = Topology(name="water") diff --git a/gmso/tests/files/spce_restraint.top b/gmso/tests/files/spce_restraint.top new file mode 100644 index 000000000..7223aa986 --- /dev/null +++ b/gmso/tests/files/spce_restraint.top @@ -0,0 +1,43 @@ +; File Topology written by GMSO at 2023-11-13 16:47:04.916225 + +[ defaults ] +; nbfunc comb-rule gen-pairs fudgeLJ fudgeQQ +1 3 yes 0.5 0.5 + +[ atomtypes ] +; name at.num mass charge ptype sigma epsilon +opls_116 8 15.99940 -0.84760 A 0.31656 0.65019 +opls_117 1 1.00800 0.42380 A 0.00000 0.00000 + +[ moleculetype ] +; name nrexcl +WaterSPC 3 + +[ atoms ] +; nr type resnr residue atom cgnr charge mass +1 opls_116 1 WaterSPC O 1 -0.84760 15.99940 +2 opls_117 1 WaterSPC H 1 0.42380 1.00800 +3 opls_117 1 WaterSPC H 1 0.42380 1.00800 + +[ position_restraints ] +; ai funct kx ky kz funct b0 kb +1 1 1000.00000 1000.00000 1000.00000 +2 1 1000.00000 1000.00000 1000.00000 +3 1 1000.00000 1000.00000 1000.00000 + +[ bonds ] +; ai aj funct b0 kb +1 2 1 0.10000 345000.00000 +1 3 1 0.10000 345000.00000 + +[ angles ] +; ai aj ak funct phi_0 k0 +2 1 3 1 109.47000 383.00000 + +[ system ] +; name +Topology + +[ molecules ] +; molecule nmols +WaterSPC 1 diff --git a/gmso/tests/test_atom.py b/gmso/tests/test_atom.py index c32faa928..db58241f2 100644 --- a/gmso/tests/test_atom.py +++ b/gmso/tests/test_atom.py @@ -74,6 +74,12 @@ def test_bad_pos_input(self, position): with pytest.raises((u.exceptions.InvalidUnitOperation, ValueError)): Atom(name="atom", position=u.nm * position) + def test_pos_restraint(self, spce_water): + for site in spce_water.sites: + for ax in ["kx", "ky", "kz"]: + assert site.restraint[ax] == 1000 * u.Unit("kJ/(mol*nm**2)") + site.restraint[ax] = 500 * u.Unit("kJ/(mol*nm**2)") + def test_equivalence(self): ref = Atom(name="atom", position=u.nm * np.zeros(3)) same_atom = Atom(name="atom", position=u.nm * np.zeros(3)) diff --git a/gmso/tests/test_top.py b/gmso/tests/test_top.py index 433b758a1..4edb8a356 100644 --- a/gmso/tests/test_top.py +++ b/gmso/tests/test_top.py @@ -20,6 +20,17 @@ def test_write_top(self, typed_ar_system): top = typed_ar_system top.save("ar.top") + def test_top(self, spce_water): + spce_water.save("spce_restraint.top", overwrite=True) + with open("spce_restraint.top", "r") as f: + contents = f.readlines() + with open(get_path("spce_restraint.top"), "r") as f: + ref_contents = f.readlines() + + assert len(contents) == len(ref_contents) + for line, ref_line in zip(contents[1:], ref_contents[1:]): + assert line == ref_line + @pytest.mark.parametrize( "top", [ diff --git a/gmso/utils/files/gmso_xmls/test_ffstyles/spce.xml b/gmso/utils/files/gmso_xmls/test_ffstyles/spce.xml index cde76944a..2d891728e 100644 --- a/gmso/utils/files/gmso_xmls/test_ffstyles/spce.xml +++ b/gmso/utils/files/gmso_xmls/test_ffstyles/spce.xml @@ -6,16 +6,16 @@ - + - + - +