Skip to content

Commit

Permalink
add site posision restraints and positions restraints writer for grom…
Browse files Browse the repository at this point in the history
…acs top
  • Loading branch information
daico007 committed Oct 7, 2023
1 parent 615fd29 commit f1d7e13
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 42 deletions.
16 changes: 16 additions & 0 deletions gmso/core/atom.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,15 @@ class Atom(Site):
atom_type_: Optional[AtomType] = Field(
None, description="AtomType associated with the atom"
)
restraint_: Optional[dict] = Field(
default=None,
description="""
Restraint for this atom, must be a dict with the following keys:
'kx', 'ky', 'kz' (unit of energy/(mol * length**2),
Refer to https://manual.gromacs.org/current/reference-manual/topologies/topology-file-formats.html
for more information.
""",
)

@property
def charge(self) -> Union[u.unyt_quantity, None]:
Expand Down Expand Up @@ -93,6 +102,11 @@ def atom_type(self) -> Union[AtomType, None]:
"""Return the atom_type associated with the atom."""
return self.__dict__.get("atom_type_", None)

@property
def restraint(self):
"""Return the restraint of this atom."""
return self.__dict__.get("restraint_")

def clone(self):
"""Clone this atom."""
return Atom(
Expand Down Expand Up @@ -164,13 +178,15 @@ class Config:
"mass_": "mass",
"element_": "element",
"atom_type_": "atom_type",
"restraint_": "restraint",
}

alias_to_fields = {
"charge": "charge_",
"mass": "mass_",
"element": "element_",
"atom_type": "atom_type_",
"restraint": "restraint_",
}

validate_assignment = True
93 changes: 51 additions & 42 deletions gmso/formats/top.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def write_top(top, filename, top_vars=None):

# Section headers
headers = {
"position_restraints": "\n[ position_restraints ]\n; ai\tfunct\tkx\tky\t\kz\funct\tb0\t\tkb\n",
"bonds": "\n[ bonds ]\n; ai\taj\tfunct\tb0\t\tkb\n",
"bond_restraints": "\n[ bonds ] ;Harmonic potential restraint\n"
"; ai\taj\tfunct\tb0\t\tkb\n",
Expand Down Expand Up @@ -146,12 +147,12 @@ def write_top(top, filename, top_vars=None):
"[ atoms ]\n"
"; nr\ttype\tresnr\tresidue\t\tatom\tcgnr\tcharge\tmass\n"
)
# Each unique molecule need to be reindexed (restarting from 0)
# Each unique molecule need to be reindexed (restarting from 1)
# The shifted_idx_map is needed to make sure all the atom index used in
# latter connection sections are acurate
shifted_idx_map = dict()
for idx, site in enumerate(unique_molecules[tag]["sites"]):
shifted_idx_map[top.get_index(site)] = idx
shifted_idx_map[top.get_index(site)] = idx + 1
out_file.write(
"{0:8s}"
"{1:12s}"
Expand All @@ -161,7 +162,7 @@ def write_top(top, filename, top_vars=None):
"{5:4s}"
"{6:12.5f}"
"{7:12.5f}\n".format(
str(idx + 1),
str(idx),
site.atom_type.name,
str(site.molecule.number if site.molecule else 1),
tag,
Expand All @@ -171,6 +172,10 @@ def write_top(top, filename, top_vars=None):
site.atom_type.mass.in_units(u.amu).value,
)
)
if unique_molecules[tag]["position_restraints"]:
out_file.write(headers["position_restraints"])
for site in unique_molecules[tag]["position_restraints"]:
out_file.write(_write_restraint(top, site, shifted_idx_map))

for conn_group in [
"pairs",
Expand Down Expand Up @@ -323,6 +328,9 @@ def _get_unique_molecules(top):
unique_molecules[top.name] = dict()
unique_molecules[top.name]["subtags"] = [top.name]
unique_molecules[top.name]["sites"] = list(top.sites)
unique_molecules[top.name]["position_restraints"] = list(
site for site in top.sites if site.restraint
)
unique_molecules[top.name]["pairs"] = generate_pairs_lists(
top, refer_from_scaling_factor=True
)["pairs14"]
Expand Down Expand Up @@ -396,8 +404,8 @@ def _lookup_element_symbol(atom_type):
def _write_pairs(top, pair, shifted_idx_map):
"""Workder function to write out pairs information."""
pair_idx = [
shifted_idx_map[top.get_index(pair[0])] + 1,
shifted_idx_map[top.get_index(pair[1])] + 1,
shifted_idx_map[top.get_index(pair[0])],
shifted_idx_map[top.get_index(pair[1])],
]

line = "{0:8s}{1:8s}{2:4s}\n".format(
Expand All @@ -423,8 +431,8 @@ def _write_connection(top, connection, potential_name, shifted_idx_map):
def _harmonic_bond_potential_writer(top, bond, shifted_idx_map):
"""Write harmonic bond information."""
line = "{0:8s}{1:8s}{2:4s}{3:15.5f}{4:15.5f}\n".format(
str(shifted_idx_map[top.get_index(bond.connection_members[0])] + 1),
str(shifted_idx_map[top.get_index(bond.connection_members[1])] + 1),
str(shifted_idx_map[top.get_index(bond.connection_members[0])]),
str(shifted_idx_map[top.get_index(bond.connection_members[1])]),
"1",
bond.connection_type.parameters["r_eq"].in_units(u.nm).value,
bond.connection_type.parameters["k"]
Expand All @@ -437,9 +445,9 @@ def _harmonic_bond_potential_writer(top, bond, shifted_idx_map):
def _harmonic_angle_potential_writer(top, angle, shifted_idx_map):
"""Write harmonic angle information."""
line = "{0:8s}{1:8s}{2:8s}{3:4s}{4:15.5f}{5:15.5f}\n".format(
str(shifted_idx_map[top.get_index(angle.connection_members[0])] + 1),
str(shifted_idx_map[top.get_index(angle.connection_members[1])] + 1),
str(shifted_idx_map[top.get_index(angle.connection_members[2])] + 1),
str(shifted_idx_map[top.get_index(angle.connection_members[0])]),
str(shifted_idx_map[top.get_index(angle.connection_members[1])]),
str(shifted_idx_map[top.get_index(angle.connection_members[2])]),
"1",
angle.connection_type.parameters["theta_eq"].in_units(u.degree).value,
angle.connection_type.parameters["k"]
Expand All @@ -452,10 +460,10 @@ def _harmonic_angle_potential_writer(top, angle, shifted_idx_map):
def _ryckaert_bellemans_torsion_writer(top, dihedral, shifted_idx_map):
"""Write Ryckaert-Bellemans Torsion information."""
line = "{0:8s}{1:8s}{2:8s}{3:8s}{4:4s}{5:15.5f}{6:15.5f}{7:15.5f}{8:15.5f}{9:15.5f}{10:15.5f}\n".format(
str(shifted_idx_map[top.get_index(dihedral.connection_members[0])] + 1),
str(shifted_idx_map[top.get_index(dihedral.connection_members[1])] + 1),
str(shifted_idx_map[top.get_index(dihedral.connection_members[2])] + 1),
str(shifted_idx_map[top.get_index(dihedral.connection_members[3])] + 1),
str(shifted_idx_map[top.get_index(dihedral.connection_members[0])]),
str(shifted_idx_map[top.get_index(dihedral.connection_members[1])]),
str(shifted_idx_map[top.get_index(dihedral.connection_members[2])]),
str(shifted_idx_map[top.get_index(dihedral.connection_members[3])]),
"3",
dihedral.connection_type.parameters["c0"]
.in_units(u.Unit("kJ/mol"))
Expand Down Expand Up @@ -501,22 +509,10 @@ def _periodic_torsion_writer(top, dihedral, shifted_idx_map):
lines = list()
for i in range(layers):
line = "{0:8s}{1:8s}{2:8s}{3:8s}{4:4s}{5:15.5f}{6:15.5f}{7:4}\n".format(
str(
shifted_idx_map[top.get_index(dihedral.connection_members[0])]
+ 1
),
str(
shifted_idx_map[top.get_index(dihedral.connection_members[1])]
+ 1
),
str(
shifted_idx_map[top.get_index(dihedral.connection_members[2])]
+ 1
),
str(
shifted_idx_map[top.get_index(dihedral.connection_members[3])]
+ 1
),
str(shifted_idx_map[top.get_index(dihedral.connection_members[0])]),
str(shifted_idx_map[top.get_index(dihedral.connection_members[1])]),
str(shifted_idx_map[top.get_index(dihedral.connection_members[2])]),
str(shifted_idx_map[top.get_index(dihedral.connection_members[3])]),
funct,
dihedral.connection_type.parameters["phi_eq"][i]
.in_units(u.degree)
Expand All @@ -530,22 +526,35 @@ def _periodic_torsion_writer(top, dihedral, shifted_idx_map):
return lines


def _write_restraint(top, connection, type, shifted_idx_map):
def _write_restraint(top, site_or_conn, type, shifted_idx_map):
"""Worker function to write various connection restraint information."""
worker_functions = {
"position_restraints": _position_restraints_writer,
"bond_restraints": _bond_restraint_writer,
"angle_restraints": _angle_restraint_writer,
"dihedral_restraints": _dihedral_restraint_writer,
}

return worker_functions[type](top, connection, shifted_idx_map)
return worker_functions[type](top, site_or_conn, shifted_idx_map)


def _position_restraints_writer(top, site, shifted_idx_map):
"""Write site position restraint information."""
line = "{0:8s}{1:4s}{2:15.5f}{3:15.5f}{4:15.5f}\n".format(
str(shifted_idx_map[top.get_index(site)]),
"1",
site.restraint["kx"].in_units(u.Unit("kJ/(mol * nm**2)")).value,
site.restraint["ky"].in_units(u.Unit("kJ/(mol * nm**2)")).value,
site.restraint["kz"].in_units(u.Unit("kJ/(mol * nm**2)")).value,
)
return line


def _bond_restraint_writer(top, bond, shifted_idx_map):
"""Write bond restraint information."""
line = "{0:8s}{1:8s}{2:4s}{3:15.5f}{4:15.5f}\n".format(
str(shifted_idx_map[top.get_index(bond.connection_members[1])] + 1),
str(shifted_idx_map[top.get_index(bond.connection_members[0])] + 1),
str(shifted_idx_map[top.get_index(bond.connection_members[1])]),
str(shifted_idx_map[top.get_index(bond.connection_members[0])]),
"6",
bond.restraint["r_eq"].in_units(u.nm).value,
bond.restraint["k"].in_units(u.Unit("kJ/(mol * nm**2)")).value,
Expand All @@ -556,10 +565,10 @@ def _bond_restraint_writer(top, bond, shifted_idx_map):
def _angle_restraint_writer(top, angle, shifted_idx_map):
"""Write angle restraint information."""
line = "{0:8s}{1:8s}{2:8s}{3:8s}{4:4s}{5:15.5f}{6:15.5f}{7:4}\n".format(
str(shifted_idx_map[top.get_index(angle.connection_members[1])] + 1),
str(shifted_idx_map[top.get_index(angle.connection_members[0])] + 1),
str(shifted_idx_map[top.get_index(angle.connection_members[1])] + 1),
str(shifted_idx_map[top.get_index(angle.connection_members[2])] + 1),
str(shifted_idx_map[top.get_index(angle.connection_members[1])]),
str(shifted_idx_map[top.get_index(angle.connection_members[0])]),
str(shifted_idx_map[top.get_index(angle.connection_members[1])]),
str(shifted_idx_map[top.get_index(angle.connection_members[2])]),
"1",
angle.restraint["theta_eq"].in_units(u.degree).value,
angle.restraint["k"].in_units(u.Unit("kJ/mol")).value,
Expand All @@ -571,10 +580,10 @@ def _angle_restraint_writer(top, angle, shifted_idx_map):
def _dihedral_restraint_writer(top, dihedral, shifted_idx_map):
"""Write dihedral restraint information."""
line = "{0:8s}{1:8s}{2:8s}{3:8s}{4:4s}{5:15.5f}{6:15.5f}{7:15.5f}\n".format(
str(shifted_idx_map[top.get_index(dihedral.connection_members[0])] + 1),
str(shifted_idx_map[top.get_index(dihedral.connection_members[1])] + 1),
str(shifted_idx_map[top.get_index(dihedral.connection_members[2])] + 1),
str(shifted_idx_map[top.get_index(dihedral.connection_members[3])] + 1),
str(shifted_idx_map[top.get_index(dihedral.connection_members[0])]),
str(shifted_idx_map[top.get_index(dihedral.connection_members[1])]),
str(shifted_idx_map[top.get_index(dihedral.connection_members[2])]),
str(shifted_idx_map[top.get_index(dihedral.connection_members[3])]),
"1",
dihedral.restraint["phi_eq"].in_units(u.degree).value,
dihedral.restraint["delta_phi"].in_units(u.degree).value,
Expand Down

0 comments on commit f1d7e13

Please sign in to comment.