Skip to content

Commit

Permalink
Merge branch 'main' into pre-commit-ci-update-config
Browse files Browse the repository at this point in the history
  • Loading branch information
daico007 authored Nov 28, 2023
2 parents c26a8b9 + b20472c commit 4eb4e52
Show file tree
Hide file tree
Showing 8 changed files with 132 additions and 7 deletions.
16 changes: 16 additions & 0 deletions gmso/core/atom.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,15 @@ class Atom(Site):
atom_type_: Optional[AtomType] = Field(
None, description="AtomType associated with the atom"
)
restraint_: Optional[dict] = Field(
default=None,
description="""
Restraint for this atom, must be a dict with the following keys:
'kx', 'ky', 'kz' (unit of energy/(mol * length**2),
Refer to https://manual.gromacs.org/current/reference-manual/topologies/topology-file-formats.html
for more information.
""",
)

@property
def charge(self) -> Union[u.unyt_quantity, None]:
Expand Down Expand Up @@ -93,6 +102,11 @@ def atom_type(self) -> Union[AtomType, None]:
"""Return the atom_type associated with the atom."""
return self.__dict__.get("atom_type_", None)

@property
def restraint(self):
"""Return the restraint of this atom."""
return self.__dict__.get("restraint_")

def clone(self):
"""Clone this atom."""
return Atom(
Expand Down Expand Up @@ -164,13 +178,15 @@ class Config:
"mass_": "mass",
"element_": "element",
"atom_type_": "atom_type",
"restraint_": "restraint",
}

alias_to_fields = {
"charge": "charge_",
"mass": "mass_",
"element": "element_",
"atom_type": "atom_type_",
"restraint": "restraint_",
}

validate_assignment = True
3 changes: 2 additions & 1 deletion gmso/formats/gro.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ def _validate_positions(pos_array):
"in order to ensure all coordinates are non-negative."
)
min_xyz = np.min(pos_array, axis=0)
min_xyz0 = np.where(min_xyz < 0, min_xyz, 0) * min_xyz.units
unit = min_xyz.units
min_xyz0 = np.where(min_xyz < 0 * unit, min_xyz, 0 * unit)

pos_array -= min_xyz0

Expand Down
36 changes: 33 additions & 3 deletions gmso/formats/top.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def write_top(top, filename, top_vars=None):

# Section headers
headers = {
"position_restraints": "\n[ position_restraints ]\n; ai\tfunct\tkx\tky\tkz\tfunct\tb0\t\tkb\n",
"bonds": "\n[ bonds ]\n; ai\taj\tfunct\tb0\t\tkb\n",
"bond_restraints": "\n[ bonds ] ;Harmonic potential restraint\n"
"; ai\taj\tfunct\tb0\t\tkb\n",
Expand Down Expand Up @@ -165,12 +166,20 @@ def write_top(top, filename, top_vars=None):
site.atom_type.name,
str(site.molecule.number if site.molecule else 1),
tag,
site.atom_type.tags["element"],
site.atom_type.tags.get("element", site.element.symbol),
"1", # TODO: care about charge groups
site.charge.in_units(u.elementary_charge).value,
site.atom_type.mass.in_units(u.amu).value,
)
)
if unique_molecules[tag]["position_restraints"]:
out_file.write(headers["position_restraints"])
for site in unique_molecules[tag]["position_restraints"]:
out_file.write(
_write_restraint(
top, site, "position_restraints", shifted_idx_map
)
)

for conn_group in [
"pairs",
Expand Down Expand Up @@ -323,6 +332,9 @@ def _get_unique_molecules(top):
unique_molecules[top.name] = dict()
unique_molecules[top.name]["subtags"] = [top.name]
unique_molecules[top.name]["sites"] = list(top.sites)
unique_molecules[top.name]["position_restraints"] = list(
site for site in top.sites if site.restraint
)
unique_molecules[top.name]["pairs"] = generate_pairs_lists(
top, refer_from_scaling_factor=True
)["pairs14"]
Expand All @@ -346,6 +358,11 @@ def _get_unique_molecules(top):
unique_molecules[tag]["sites"] = list(
top.iter_sites(key="molecule", value=molecule)
)
unique_molecules[tag]["position_restraints"] = list(
site
for site in top.sites
if (site.restraint and site.molecule == molecule)
)
unique_molecules[tag]["pairs"] = generate_pairs_lists(
top, molecule
)["pairs14"]
Expand Down Expand Up @@ -530,15 +547,28 @@ def _periodic_torsion_writer(top, dihedral, shifted_idx_map):
return lines


def _write_restraint(top, connection, type, shifted_idx_map):
def _write_restraint(top, site_or_conn, type, shifted_idx_map):
"""Worker function to write various connection restraint information."""
worker_functions = {
"position_restraints": _position_restraints_writer,
"bond_restraints": _bond_restraint_writer,
"angle_restraints": _angle_restraint_writer,
"dihedral_restraints": _dihedral_restraint_writer,
}

return worker_functions[type](top, connection, shifted_idx_map)
return worker_functions[type](top, site_or_conn, shifted_idx_map)


def _position_restraints_writer(top, site, shifted_idx_map):
"""Write site position restraint information."""
line = "{0:8s}{1:4s}{2:15.5f}{3:15.5f}{4:15.5f}\n".format(
str(shifted_idx_map[top.get_index(site)] + 1),
"1",
site.restraint["kx"].in_units(u.Unit("kJ/(mol * nm**2)")).value,
site.restraint["ky"].in_units(u.Unit("kJ/(mol * nm**2)")).value,
site.restraint["kz"].in_units(u.Unit("kJ/(mol * nm**2)")).value,
)
return line


def _bond_restraint_writer(top, bond, shifted_idx_map):
Expand Down
18 changes: 18 additions & 0 deletions gmso/tests/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,24 @@ def _typed_topology(n_sites=100):

return _typed_topology

@pytest.fixture
def spce_water(self):
spce_comp = mb.lib.molecules.water.WaterSPC()
spce_ff = ForceField(get_fn("gmso_xmls/test_ffstyles/spce.xml"))
spce_top = spce_comp.to_gmso()
spce_top.identify_connections()

spce_top = apply(spce_top, spce_ff, remove_untyped=True)

for site in spce_top.sites:
site.restraint = {
"kx": 1000 * u.Unit("kJ/(mol*nm**2)"),
"ky": 1000 * u.Unit("kJ/(mol*nm**2)"),
"kz": 1000 * u.Unit("kJ/(mol*nm**2)"),
}

return spce_top

@pytest.fixture
def water_system(self):
water = Topology(name="water")
Expand Down
43 changes: 43 additions & 0 deletions gmso/tests/files/spce_restraint.top
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
; File Topology written by GMSO at 2023-11-13 16:47:04.916225

[ defaults ]
; nbfunc comb-rule gen-pairs fudgeLJ fudgeQQ
1 3 yes 0.5 0.5

[ atomtypes ]
; name at.num mass charge ptype sigma epsilon
opls_116 8 15.99940 -0.84760 A 0.31656 0.65019
opls_117 1 1.00800 0.42380 A 0.00000 0.00000

[ moleculetype ]
; name nrexcl
WaterSPC 3

[ atoms ]
; nr type resnr residue atom cgnr charge mass
1 opls_116 1 WaterSPC O 1 -0.84760 15.99940
2 opls_117 1 WaterSPC H 1 0.42380 1.00800
3 opls_117 1 WaterSPC H 1 0.42380 1.00800

[ position_restraints ]
; ai funct kx ky kz funct b0 kb
1 1 1000.00000 1000.00000 1000.00000
2 1 1000.00000 1000.00000 1000.00000
3 1 1000.00000 1000.00000 1000.00000

[ bonds ]
; ai aj funct b0 kb
1 2 1 0.10000 345000.00000
1 3 1 0.10000 345000.00000

[ angles ]
; ai aj ak funct phi_0 k0
2 1 3 1 109.47000 383.00000

[ system ]
; name
Topology

[ molecules ]
; molecule nmols
WaterSPC 1
6 changes: 6 additions & 0 deletions gmso/tests/test_atom.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@ def test_bad_pos_input(self, position):
with pytest.raises((u.exceptions.InvalidUnitOperation, ValueError)):
Atom(name="atom", position=u.nm * position)

def test_pos_restraint(self, spce_water):
for site in spce_water.sites:
for ax in ["kx", "ky", "kz"]:
assert site.restraint[ax] == 1000 * u.Unit("kJ/(mol*nm**2)")
site.restraint[ax] = 500 * u.Unit("kJ/(mol*nm**2)")

def test_equivalence(self):
ref = Atom(name="atom", position=u.nm * np.zeros(3))
same_atom = Atom(name="atom", position=u.nm * np.zeros(3))
Expand Down
11 changes: 11 additions & 0 deletions gmso/tests/test_top.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,17 @@ def test_write_top(self, typed_ar_system):
top = typed_ar_system
top.save("ar.top")

def test_top(self, spce_water):
spce_water.save("spce_restraint.top", overwrite=True)
with open("spce_restraint.top", "r") as f:
contents = f.readlines()
with open(get_path("spce_restraint.top"), "r") as f:
ref_contents = f.readlines()

assert len(contents) == len(ref_contents)
for line, ref_line in zip(contents[1:], ref_contents[1:]):
assert line == ref_line

@pytest.mark.parametrize(
"top",
[
Expand Down
6 changes: 3 additions & 3 deletions gmso/utils/files/gmso_xmls/test_ffstyles/spce.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@
<AtomTypes expression="4*epsilon*(-sigma**6/r**6 + sigma**12/r**12)">
<ParametersUnitDef parameter="epsilon" unit="kJ/mol"/>
<ParametersUnitDef parameter="sigma" unit="nm"/>
<AtomType name="opls_116" mass="0.0" charge="-1.35800490379008e-19" atomclass="OW" definition="O">
<AtomType name="opls_116" element="O" mass="15.99940" charge="-1.35800490379008e-19" atomclass="OW" definition="O">
<Parameters>
<Parameter name="epsilon" value="0.650194"/>
<Parameter name="sigma" value="0.316557"/>
</Parameters>
</AtomType>
<AtomType name="opls_117" mass="0.0" charge="6.7900245189504e-20" atomclass="HW" definition="H">
<AtomType name="opls_117" element="H" mass="1.00800" charge="6.7900245189504e-20" atomclass="HW" definition="H">
<Parameters>
<Parameter name="epsilon" value="0.0"/>
<Parameter name="sigma" value="0.1"/>
<Parameter name="sigma" value="0.0"/>
</Parameters>
</AtomType>
</AtomTypes>
Expand Down

0 comments on commit 4eb4e52

Please sign in to comment.