From f1d7e1364fe0b2fe226eac412d6da42627bc59fe Mon Sep 17 00:00:00 2001 From: Co Quach Date: Sat, 7 Oct 2023 14:24:57 -0500 Subject: [PATCH] add site posision restraints and positions restraints writer for gromacs top --- gmso/core/atom.py | 16 ++++++++ gmso/formats/top.py | 93 +++++++++++++++++++++++++-------------------- 2 files changed, 67 insertions(+), 42 deletions(-) diff --git a/gmso/core/atom.py b/gmso/core/atom.py index 3e09c931e..bd1a1c551 100644 --- a/gmso/core/atom.py +++ b/gmso/core/atom.py @@ -58,6 +58,15 @@ class Atom(Site): atom_type_: Optional[AtomType] = Field( None, description="AtomType associated with the atom" ) + restraint_: Optional[dict] = Field( + default=None, + description=""" + Restraint for this atom, must be a dict with the following keys: + 'kx', 'ky', 'kz' (unit of energy/(mol * length**2), + Refer to https://manual.gromacs.org/current/reference-manual/topologies/topology-file-formats.html + for more information. + """, + ) @property def charge(self) -> Union[u.unyt_quantity, None]: @@ -93,6 +102,11 @@ def atom_type(self) -> Union[AtomType, None]: """Return the atom_type associated with the atom.""" return self.__dict__.get("atom_type_", None) + @property + def restraint(self): + """Return the restraint of this atom.""" + return self.__dict__.get("restraint_") + def clone(self): """Clone this atom.""" return Atom( @@ -164,6 +178,7 @@ class Config: "mass_": "mass", "element_": "element", "atom_type_": "atom_type", + "restraint_": "restraint", } alias_to_fields = { @@ -171,6 +186,7 @@ class Config: "mass": "mass_", "element": "element_", "atom_type": "atom_type_", + "restraint": "restraint_", } validate_assignment = True diff --git a/gmso/formats/top.py b/gmso/formats/top.py index 24b83b4e6..0cb28f9c2 100644 --- a/gmso/formats/top.py +++ b/gmso/formats/top.py @@ -115,6 +115,7 @@ def write_top(top, filename, top_vars=None): # Section headers headers = { + "position_restraints": "\n[ position_restraints ]\n; ai\tfunct\tkx\tky\t\kz\funct\tb0\t\tkb\n", "bonds": "\n[ bonds ]\n; ai\taj\tfunct\tb0\t\tkb\n", "bond_restraints": "\n[ bonds ] ;Harmonic potential restraint\n" "; ai\taj\tfunct\tb0\t\tkb\n", @@ -146,12 +147,12 @@ def write_top(top, filename, top_vars=None): "[ atoms ]\n" "; nr\ttype\tresnr\tresidue\t\tatom\tcgnr\tcharge\tmass\n" ) - # Each unique molecule need to be reindexed (restarting from 0) + # Each unique molecule need to be reindexed (restarting from 1) # The shifted_idx_map is needed to make sure all the atom index used in # latter connection sections are acurate shifted_idx_map = dict() for idx, site in enumerate(unique_molecules[tag]["sites"]): - shifted_idx_map[top.get_index(site)] = idx + shifted_idx_map[top.get_index(site)] = idx + 1 out_file.write( "{0:8s}" "{1:12s}" @@ -161,7 +162,7 @@ def write_top(top, filename, top_vars=None): "{5:4s}" "{6:12.5f}" "{7:12.5f}\n".format( - str(idx + 1), + str(idx), site.atom_type.name, str(site.molecule.number if site.molecule else 1), tag, @@ -171,6 +172,10 @@ def write_top(top, filename, top_vars=None): site.atom_type.mass.in_units(u.amu).value, ) ) + if unique_molecules[tag]["position_restraints"]: + out_file.write(headers["position_restraints"]) + for site in unique_molecules[tag]["position_restraints"]: + out_file.write(_write_restraint(top, site, shifted_idx_map)) for conn_group in [ "pairs", @@ -323,6 +328,9 @@ def _get_unique_molecules(top): unique_molecules[top.name] = dict() unique_molecules[top.name]["subtags"] = [top.name] unique_molecules[top.name]["sites"] = list(top.sites) + unique_molecules[top.name]["position_restraints"] = list( + site for site in top.sites if site.restraint + ) unique_molecules[top.name]["pairs"] = generate_pairs_lists( top, refer_from_scaling_factor=True )["pairs14"] @@ -396,8 +404,8 @@ def _lookup_element_symbol(atom_type): def _write_pairs(top, pair, shifted_idx_map): """Workder function to write out pairs information.""" pair_idx = [ - shifted_idx_map[top.get_index(pair[0])] + 1, - shifted_idx_map[top.get_index(pair[1])] + 1, + shifted_idx_map[top.get_index(pair[0])], + shifted_idx_map[top.get_index(pair[1])], ] line = "{0:8s}{1:8s}{2:4s}\n".format( @@ -423,8 +431,8 @@ def _write_connection(top, connection, potential_name, shifted_idx_map): def _harmonic_bond_potential_writer(top, bond, shifted_idx_map): """Write harmonic bond information.""" line = "{0:8s}{1:8s}{2:4s}{3:15.5f}{4:15.5f}\n".format( - str(shifted_idx_map[top.get_index(bond.connection_members[0])] + 1), - str(shifted_idx_map[top.get_index(bond.connection_members[1])] + 1), + str(shifted_idx_map[top.get_index(bond.connection_members[0])]), + str(shifted_idx_map[top.get_index(bond.connection_members[1])]), "1", bond.connection_type.parameters["r_eq"].in_units(u.nm).value, bond.connection_type.parameters["k"] @@ -437,9 +445,9 @@ def _harmonic_bond_potential_writer(top, bond, shifted_idx_map): def _harmonic_angle_potential_writer(top, angle, shifted_idx_map): """Write harmonic angle information.""" line = "{0:8s}{1:8s}{2:8s}{3:4s}{4:15.5f}{5:15.5f}\n".format( - str(shifted_idx_map[top.get_index(angle.connection_members[0])] + 1), - str(shifted_idx_map[top.get_index(angle.connection_members[1])] + 1), - str(shifted_idx_map[top.get_index(angle.connection_members[2])] + 1), + str(shifted_idx_map[top.get_index(angle.connection_members[0])]), + str(shifted_idx_map[top.get_index(angle.connection_members[1])]), + str(shifted_idx_map[top.get_index(angle.connection_members[2])]), "1", angle.connection_type.parameters["theta_eq"].in_units(u.degree).value, angle.connection_type.parameters["k"] @@ -452,10 +460,10 @@ def _harmonic_angle_potential_writer(top, angle, shifted_idx_map): def _ryckaert_bellemans_torsion_writer(top, dihedral, shifted_idx_map): """Write Ryckaert-Bellemans Torsion information.""" line = "{0:8s}{1:8s}{2:8s}{3:8s}{4:4s}{5:15.5f}{6:15.5f}{7:15.5f}{8:15.5f}{9:15.5f}{10:15.5f}\n".format( - str(shifted_idx_map[top.get_index(dihedral.connection_members[0])] + 1), - str(shifted_idx_map[top.get_index(dihedral.connection_members[1])] + 1), - str(shifted_idx_map[top.get_index(dihedral.connection_members[2])] + 1), - str(shifted_idx_map[top.get_index(dihedral.connection_members[3])] + 1), + str(shifted_idx_map[top.get_index(dihedral.connection_members[0])]), + str(shifted_idx_map[top.get_index(dihedral.connection_members[1])]), + str(shifted_idx_map[top.get_index(dihedral.connection_members[2])]), + str(shifted_idx_map[top.get_index(dihedral.connection_members[3])]), "3", dihedral.connection_type.parameters["c0"] .in_units(u.Unit("kJ/mol")) @@ -501,22 +509,10 @@ def _periodic_torsion_writer(top, dihedral, shifted_idx_map): lines = list() for i in range(layers): line = "{0:8s}{1:8s}{2:8s}{3:8s}{4:4s}{5:15.5f}{6:15.5f}{7:4}\n".format( - str( - shifted_idx_map[top.get_index(dihedral.connection_members[0])] - + 1 - ), - str( - shifted_idx_map[top.get_index(dihedral.connection_members[1])] - + 1 - ), - str( - shifted_idx_map[top.get_index(dihedral.connection_members[2])] - + 1 - ), - str( - shifted_idx_map[top.get_index(dihedral.connection_members[3])] - + 1 - ), + str(shifted_idx_map[top.get_index(dihedral.connection_members[0])]), + str(shifted_idx_map[top.get_index(dihedral.connection_members[1])]), + str(shifted_idx_map[top.get_index(dihedral.connection_members[2])]), + str(shifted_idx_map[top.get_index(dihedral.connection_members[3])]), funct, dihedral.connection_type.parameters["phi_eq"][i] .in_units(u.degree) @@ -530,22 +526,35 @@ def _periodic_torsion_writer(top, dihedral, shifted_idx_map): return lines -def _write_restraint(top, connection, type, shifted_idx_map): +def _write_restraint(top, site_or_conn, type, shifted_idx_map): """Worker function to write various connection restraint information.""" worker_functions = { + "position_restraints": _position_restraints_writer, "bond_restraints": _bond_restraint_writer, "angle_restraints": _angle_restraint_writer, "dihedral_restraints": _dihedral_restraint_writer, } - return worker_functions[type](top, connection, shifted_idx_map) + return worker_functions[type](top, site_or_conn, shifted_idx_map) + + +def _position_restraints_writer(top, site, shifted_idx_map): + """Write site position restraint information.""" + line = "{0:8s}{1:4s}{2:15.5f}{3:15.5f}{4:15.5f}\n".format( + str(shifted_idx_map[top.get_index(site)]), + "1", + site.restraint["kx"].in_units(u.Unit("kJ/(mol * nm**2)")).value, + site.restraint["ky"].in_units(u.Unit("kJ/(mol * nm**2)")).value, + site.restraint["kz"].in_units(u.Unit("kJ/(mol * nm**2)")).value, + ) + return line def _bond_restraint_writer(top, bond, shifted_idx_map): """Write bond restraint information.""" line = "{0:8s}{1:8s}{2:4s}{3:15.5f}{4:15.5f}\n".format( - str(shifted_idx_map[top.get_index(bond.connection_members[1])] + 1), - str(shifted_idx_map[top.get_index(bond.connection_members[0])] + 1), + str(shifted_idx_map[top.get_index(bond.connection_members[1])]), + str(shifted_idx_map[top.get_index(bond.connection_members[0])]), "6", bond.restraint["r_eq"].in_units(u.nm).value, bond.restraint["k"].in_units(u.Unit("kJ/(mol * nm**2)")).value, @@ -556,10 +565,10 @@ def _bond_restraint_writer(top, bond, shifted_idx_map): def _angle_restraint_writer(top, angle, shifted_idx_map): """Write angle restraint information.""" line = "{0:8s}{1:8s}{2:8s}{3:8s}{4:4s}{5:15.5f}{6:15.5f}{7:4}\n".format( - str(shifted_idx_map[top.get_index(angle.connection_members[1])] + 1), - str(shifted_idx_map[top.get_index(angle.connection_members[0])] + 1), - str(shifted_idx_map[top.get_index(angle.connection_members[1])] + 1), - str(shifted_idx_map[top.get_index(angle.connection_members[2])] + 1), + str(shifted_idx_map[top.get_index(angle.connection_members[1])]), + str(shifted_idx_map[top.get_index(angle.connection_members[0])]), + str(shifted_idx_map[top.get_index(angle.connection_members[1])]), + str(shifted_idx_map[top.get_index(angle.connection_members[2])]), "1", angle.restraint["theta_eq"].in_units(u.degree).value, angle.restraint["k"].in_units(u.Unit("kJ/mol")).value, @@ -571,10 +580,10 @@ def _angle_restraint_writer(top, angle, shifted_idx_map): def _dihedral_restraint_writer(top, dihedral, shifted_idx_map): """Write dihedral restraint information.""" line = "{0:8s}{1:8s}{2:8s}{3:8s}{4:4s}{5:15.5f}{6:15.5f}{7:15.5f}\n".format( - str(shifted_idx_map[top.get_index(dihedral.connection_members[0])] + 1), - str(shifted_idx_map[top.get_index(dihedral.connection_members[1])] + 1), - str(shifted_idx_map[top.get_index(dihedral.connection_members[2])] + 1), - str(shifted_idx_map[top.get_index(dihedral.connection_members[3])] + 1), + str(shifted_idx_map[top.get_index(dihedral.connection_members[0])]), + str(shifted_idx_map[top.get_index(dihedral.connection_members[1])]), + str(shifted_idx_map[top.get_index(dihedral.connection_members[2])]), + str(shifted_idx_map[top.get_index(dihedral.connection_members[3])]), "1", dihedral.restraint["phi_eq"].in_units(u.degree).value, dihedral.restraint["delta_phi"].in_units(u.degree).value,