Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix solutes with v-sites #76

Merged
merged 4 commits into from
Nov 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 2 additions & 58 deletions absolv/fep.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Prepare OpenMM systems for FEP calculations."""

import copy
import itertools

Expand All @@ -22,54 +23,6 @@
)


def _find_v_sites(
system: openmm.System, atom_indices: list[set[int]]
) -> list[set[int]]:
"""Finds any virtual sites in the system and ensures their indices get appended
to the atom index list.

Args:
system: The system that may contain v-sites.
atom_indices: A list of per-molecule atom indices

Returns:
A list of the per molecule **particle** indices.
"""

atom_to_molecule_idx = {
atom_idx: i for i, indices in enumerate(atom_indices) for atom_idx in indices
}

particle_to_atom_idx = {}
atom_idx = 0

for particle_idx in range(system.getNumParticles()):
if system.isVirtualSite(particle_idx):
continue

particle_to_atom_idx[particle_idx] = atom_idx
atom_idx += 1

atom_idx = 0

remapped_atom_indices: list[set[int]] = [set() for _ in range(len(atom_indices))]

for particle_idx in range(system.getNumParticles()):
if not system.isVirtualSite(particle_idx):
molecule_idx = atom_to_molecule_idx[atom_idx]
atom_idx += 1

else:
v_site = system.getVirtualSite(particle_idx)
parent_atom_idx = particle_to_atom_idx[v_site.getParticle(0)]

molecule_idx = atom_to_molecule_idx[parent_atom_idx]

remapped_atom_indices[molecule_idx].add(particle_idx)

return remapped_atom_indices


def _find_nonbonded_forces(
system: openmm.System,
) -> tuple[
Expand Down Expand Up @@ -468,7 +421,7 @@ def apply_fep(
system: The chemical system to generate the alchemical system from
alchemical_indices: The atom indices corresponding to each molecule that
should be alchemically transformable. The atom indices **must**
correspond to **all** atoms in each molecule as alchemically
correspond to **all** atoms / v-sites in each molecule as alchemically
transforming part of a molecule is not supported.
persistent_indices: The atom indices corresponding to each molecule that
should **not** be alchemically transformable.
Expand All @@ -481,15 +434,6 @@ def apply_fep(

system = copy.deepcopy(system)

# Make sure we track v-sites attached to any solutes that may be alchemically
# turned off. We do this as a post-process step as the OpenFF toolkit does not
# currently expose a clean way to access this information.
atom_indices = alchemical_indices + persistent_indices
atom_indices = _find_v_sites(system, atom_indices)

alchemical_indices = atom_indices[: len(alchemical_indices)]
persistent_indices = atom_indices[len(alchemical_indices) :]

(
nonbonded_force,
custom_nonbonded_force,
Expand Down
172 changes: 158 additions & 14 deletions absolv/runner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Run calculations defined by a config."""

import collections
import functools
import multiprocessing
import pathlib
Expand All @@ -17,6 +18,7 @@
import openff.toolkit
import openff.utilities
import openmm
import openmm.app
import openmm.unit
import pymbar
import tqdm
Expand All @@ -35,12 +37,152 @@ class PreparedSystem(typing.NamedTuple):
system: openmm.System
"""The alchemically modified OpenMM system."""

topology: openff.toolkit.Topology
"""The OpenFF topology with any box vectors set."""
topology: openmm.app.Topology
"""The OpenMM topology with any box vectors set."""
coords: openmm.unit.Quantity
"""The coordinates of the system."""


def _rebuild_topology(
orig_top: openff.toolkit.Topology,
orig_coords: openmm.unit.Quantity,
system: openmm.System,
) -> tuple[openmm.app.Topology, openmm.unit.Quantity, list[set[int]]]:
"""Rebuild the topology to also include virtual sites."""
atom_idx_to_residue_idx = {}
atom_idx = 0

for residue_idx, molecule in enumerate(orig_top.molecules):
for _ in molecule.atoms:
atom_idx_to_residue_idx[atom_idx] = residue_idx
atom_idx += 1

particle_idx_to_atom_idx = {}
atom_idx = 0

for particle_idx in range(system.getNumParticles()):
if system.isVirtualSite(particle_idx):
continue

particle_idx_to_atom_idx[particle_idx] = atom_idx
atom_idx += 1

atoms_off = [*orig_top.atoms]
particles = []

for particle_idx in range(system.getNumParticles()):
if system.isVirtualSite(particle_idx):
v_site = system.getVirtualSite(particle_idx)

parent_idxs = {
particle_idx_to_atom_idx[v_site.getParticle(i)]
for i in range(v_site.getNumParticles())
}
parent_residue = atom_idx_to_residue_idx[next(iter(parent_idxs))]

particles.append((-1, parent_residue))
continue

atom_idx = particle_idx_to_atom_idx[particle_idx]
residue_idx = atom_idx_to_residue_idx[atom_idx]

particles.append((atoms_off[atom_idx].atomic_number, residue_idx))

topology = openmm.app.Topology()

if orig_top.box_vectors is not None:
topology.setPeriodicBoxVectors(orig_top.box_vectors.to_openmm())

chain = topology.addChain()

atom_counts_per_residue = collections.defaultdict(
lambda: collections.defaultdict(int)
)
atoms = []

last_residue_idx = -1
residue = None

residue_to_particle_idx = collections.defaultdict(list)

for particle_idx, (atomic_num, residue_idx) in enumerate(particles):
if residue_idx != last_residue_idx:
last_residue_idx = residue_idx
residue = topology.addResidue("UNK", chain)

element = (
None if atomic_num < 0 else openmm.app.Element.getByAtomicNumber(atomic_num)
)
symbol = "X" if element is None else element.symbol

atom_counts_per_residue[residue_idx][atomic_num] += 1
atom = topology.addAtom(
f"{symbol}{atom_counts_per_residue[residue_idx][atomic_num]}".ljust(3, "x"),
element,
residue,
)
atoms.append(atom)

residue_to_particle_idx[residue_idx].append(particle_idx)

_rename_residues(topology)

atom_idx_to_particle_idx = {j: i for i, j in particle_idx_to_atom_idx.items()}

for bond in orig_top.bonds:
if atoms[atom_idx_to_particle_idx[bond.atom1_index]].residue.name == "HOH":
continue

topology.addBond(
atoms[atom_idx_to_particle_idx[bond.atom1_index]],
atoms[atom_idx_to_particle_idx[bond.atom2_index]],
)

coords_full = []

for particle_idx in range(system.getNumParticles()):
if particle_idx in particle_idx_to_atom_idx:
coords_i = orig_coords[particle_idx_to_atom_idx[particle_idx]]
coords_full.append(coords_i.value_in_unit(openmm.unit.angstrom))
else:
coords_full.append(numpy.zeros((1, 3)))

coords_full = numpy.vstack(coords_full) * openmm.unit.angstrom

if len(orig_coords) != len(coords_full):
context = openmm.Context(system, openmm.VerletIntegrator(1.0))
context.setPositions(coords_full)
context.computeVirtualSites()

coords_full = context.getState(getPositions=True).getPositions(asNumpy=True)

residues = [
set(residue_to_particle_idx[residue_idx])
for residue_idx in range(len(residue_to_particle_idx))
]

return topology, coords_full, residues


def _rename_residues(topology: openmm.app.Topology):
"""Attempts to assign standard residue names to known residues"""

for residue in topology.residues():
symbols = sorted(
(
atom.element.symbol
for atom in residue.atoms()
if atom.element is not None
)
)

if symbols == ["H", "H", "O"]:
residue.name = "HOH"

for i, atom in enumerate(residue.atoms()):
atom.name = "OW" if atom.element.symbol == "O" else f"HW{i}"


def _setup_solvent(
solvent_idx: typing.Literal["solvent-a", "solvent-b"],
components: list[tuple[str, int]],
Expand All @@ -67,19 +209,21 @@ def _setup_solvent(

is_vacuum = n_solvent_molecules == 0

topology, coords = absolv.setup.setup_system(components)
topology.box_vectors = None if is_vacuum else topology.box_vectors
topology_off, coords = absolv.setup.setup_system(components)
topology_off.box_vectors = None if is_vacuum else topology_off.box_vectors

if isinstance(force_field, openff.toolkit.ForceField):
original_system = force_field.create_openmm_system(topology_off)
else:
original_system: openmm.System = force_field(topology_off, coords, solvent_idx)

atom_indices = absolv.utils.topology.topology_to_atom_indices(topology)
topology, coords, atom_indices = _rebuild_topology(
topology_off, coords, original_system
)

alchemical_indices = atom_indices[:n_solute_molecules]
persistent_indices = atom_indices[n_solute_molecules:]

if isinstance(force_field, openff.toolkit.ForceField):
original_system = force_field.create_openmm_system(topology)
else:
original_system: openmm.System = force_field(topology, coords, solvent_idx)

alchemical_system = absolv.fep.apply_fep(
original_system,
alchemical_indices,
Expand Down Expand Up @@ -196,7 +340,7 @@ def _run_eq_phase(
"""
platform = (
femto.md.constants.OpenMMPlatform.REFERENCE
if prepared_system.topology.box_vectors is None
if prepared_system.topology.getPeriodicBoxVectors() is None
else platform
)

Expand Down Expand Up @@ -312,7 +456,7 @@ def _run_phase_end_states(
):
platform = (
femto.md.constants.OpenMMPlatform.REFERENCE
if prepared_system.topology.box_vectors is None
if prepared_system.topology.getPeriodicBoxVectors() is None
else platform
)

Expand Down Expand Up @@ -363,11 +507,11 @@ def _run_switching(
):
platform = (
femto.md.constants.OpenMMPlatform.REFERENCE
if prepared_system.topology.box_vectors is None
if prepared_system.topology.getPeriodicBoxVectors() is None
else platform
)

mdtraj_topology = mdtraj.Topology.from_openmm(prepared_system.topology.to_openmm())
mdtraj_topology = mdtraj.Topology.from_openmm(prepared_system.topology)

trajectory_0 = mdtraj.load_dcd(str(output_dir / "state-0.dcd"), mdtraj_topology)
trajectory_1 = mdtraj.load_dcd(str(output_dir / "state-1.dcd"), mdtraj_topology)
Expand Down
22 changes: 0 additions & 22 deletions absolv/tests/test_fep.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,33 +13,11 @@
_add_electrostatics_lambda,
_add_lj_vdw_lambda,
_find_nonbonded_forces,
_find_v_sites,
apply_fep,
)
from absolv.tests import is_close


def test_find_v_sites():
"""Ensure that v-sites are correctly detected from an OMM system and assigned
to the right parent molecule."""

# Construct a mock system of V A A A V A A where (0, 5, 6), (3,), (4, 1, 2)
# are the core molecules.
system = openmm.System()

for _ in range(7):
system.addParticle(1.0)

system.setVirtualSite(0, openmm.TwoParticleAverageSite(5, 6, 0.5, 0.5))
system.setVirtualSite(4, openmm.TwoParticleAverageSite(1, 2, 0.5, 0.5))

atom_indices = [{0, 1}, {2}, {3, 4}]

particle_indices = _find_v_sites(system, atom_indices)

assert particle_indices == [{1, 2, 4}, {3}, {0, 5, 6}]


def test_find_nonbonded_forces_lj_only(aq_nacl_lj_system):
(
nonbonded_force,
Expand Down
Loading
Loading