From 3e848a6e1f7bd7484c337c545df04f1d0d98690a Mon Sep 17 00:00:00 2001 From: IAlibay Date: Wed, 4 Dec 2024 13:24:29 +0000 Subject: [PATCH 01/29] first pass at abstract classes --- .../protocols/openmm_utils/omm_restraints.py | 64 +++++++++++++++++++ 1 file changed, 64 insertions(+) create mode 100644 openfe/protocols/openmm_utils/omm_restraints.py diff --git a/openfe/protocols/openmm_utils/omm_restraints.py b/openfe/protocols/openmm_utils/omm_restraints.py new file mode 100644 index 000000000..2a62d30ae --- /dev/null +++ b/openfe/protocols/openmm_utils/omm_restraints.py @@ -0,0 +1,64 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +Classes for applying restraints to OpenMM Systems. + +Acknowledgements +---------------- +Many of the classes here are at least in part inspired, if not taken from +`Yank `_ and +`OpenMMTools `_. + +TODO +---- +* Add relevant duecredit entries. +""" +import abc +from typing import Optional, Union + +from openmmtools.states import GlobalParameterState + + +class RestraintParameterState(GlobalParameterState): + """ + Composable state to control `lambda_restraints` OpenMM Force parameters. + + See :class:`openmmtools.states.GlobalParameterState` for more details. + + Parameters + ---------- + parameters_name_suffix : Optional[str] + If specified, the state will control a modified version of the parameter + ``lambda_restraints_{parameters_name_suffix}` instead of just ``lambda_restraints``. + lambda_restraints : Optional[float] + The strength of the restraint. If defined, must be between 0 and 1. + + Acknowledgement + --------------- + Partially reproduced from Yank. + """ + + lambda_restraints = GlobalParameterState.GlobalParameter('lambda_restraints', standard_value=1.0) + + @lambda_restraints.validator + def lambda_restraints(self, instance, new_value): + if new_value is not None and not (0.0 <= new_value <= 1.0): + errmsg = ("lambda_restraints must be between 0.0 and 1.0, " + f"got {new_value}") + raise ValueError(errmsg) + # Not crashing out on None to match upstream behaviour + return new_value + + +class BaseHostGuestRestraints(abc.ABC): + """ + An abstract base class for defining objects that apply a restraint between + two entities (referred to as a Host and a Guest). + + + TODO + ---- + Add some examples here. + """ + def __init__(self, host_atoms: list[int], guest_atoms: list[int], restraint_settings: SettingBaseModel, restraint_geometry: BaseRestraintGeometry): + From ef050e5ab47fc228d98d6fba00476b331c46ab3d Mon Sep 17 00:00:00 2001 From: IAlibay Date: Thu, 5 Dec 2024 15:31:32 +0000 Subject: [PATCH 02/29] A start at restraints and forces --- openfe/protocols/openmm_utils/omm_forces.py | 134 +++++++++++++ .../protocols/openmm_utils/omm_restraints.py | 187 +++++++++++++++++- 2 files changed, 315 insertions(+), 6 deletions(-) create mode 100644 openfe/protocols/openmm_utils/omm_forces.py diff --git a/openfe/protocols/openmm_utils/omm_forces.py b/openfe/protocols/openmm_utils/omm_forces.py new file mode 100644 index 000000000..e9246b694 --- /dev/null +++ b/openfe/protocols/openmm_utils/omm_forces.py @@ -0,0 +1,134 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +Custom OpenMM Forces + +TODO +---- +* Add relevant duecredit entries. +""" +import numpy as np +import openmm + + +def get_boresch_energy_function( + control_parameter: str, + K_r: float, r_aA0: float, + K_thetaA: float, theta_A0: float, + K_thetaB: float, theta_B0: float, + K_phiA: float, phi_A0: float, + K_phiB: float, phi_B0: float, + K_phiC: float, phi_C0: float +) -> str: + energy_function = ( + f"{control_parameter} * E; " + "E = (K_r/2)*(distance(p3,p4) - r_aA0)^2 " + "+ (K_thetaA/2)*(angle(p2,p3,p4)-theta_A0)^2 + (K_thetaB/2)*(angle(p3,p4,p5)-theta_B0)^2 " + "+ (K_phiA/2)*dphi_A^2 + (K_phiB/2)*dphi_B^2 + (K_phiC/2)*dphi_C^2; " + "dphi_A = dA - floor(dA/(2*pi)+0.5)*(2*pi); dA = dihedral(p1,p2,p3,p4) - phi_A0; " + "dphi_B = dB - floor(dB/(2*pi)+0.5)*(2*pi); dB = dihedral(p2,p3,p4,p5) - phi_B0; " + "dphi_C = dC - floor(dC/(2*pi)+0.5)*(2*pi); dC = dihedral(p3,p4,p5,p6) - phi_C0; " + f"pi = {np.pi}; " + f"K_r = {K_r}; " + f"r_aA0 = {r_aA0}; " + f"K_thetaA = {K_thetaA}; " + f"theta_A0 = {theta_A0}; " + f"K_thetaB = {K_thetaB}; " + f"theta_B0 = {theta_B0}; " + f"K_phiA = {K_phiA}; " + f"phi_A0 = {phi_A0}; " + f"K_phiB = {K_phiB}; " + f"phi_B0 = {phi_B0}; " + f"K_phiC = {K_phiC}; " + f"phi_C0 = {phi_C0}; " + ) + return energy_function + + +def get_periodic_boresch_energy_function( + control_parameter: str, + K_r: float, r_aA0: float, + K_thetaA: float, theta_A0: float, + K_thetaB: float, theta_B0: float, + K_phiA: float, phi_A0: float, + K_phiB: float, phi_B0: float, + K_phiC: float, phi_C0: float +) -> str: + energy_function = ( + f"{control_parameter} * E; " + "E = (K_r/2)*(distance(p3,p4) - r_aA0)^2 " + "+ (K_thetaA/2)*(angle(p2,p3,p4)-theta_A0)^2 + (K_thetaB/2)*(angle(p3,p4,p5)-theta_B0)^2 " + "+ (K_phiA/2)*uphi_A + (K_phiB/2)*uphi_B + (K_phiC/2)*uphi_C; " + "uphi_A = (1-cos(dA)); dA = dihedral(p1,p2,p3,p4) - phi_A0; " + "uphi_B = (1-cos(dB)); dB = dihedral(p2,p3,p4,p5) - phi_B0; " + "uphi_C = (1-cos(dC)); dC = dihedral(p3,p4,p5,p6) - phi_C0; " + f"pi = {np.pi}; " + f"K_r = {K_r}; " + f"r_aA0 = {r_aA0}; " + f"K_thetaA = {K_thetaA}; " + f"theta_A0 = {theta_A0}; " + f"K_thetaB = {K_thetaB}; " + f"theta_B0 = {theta_B0}; " + f"K_phiA = {K_phiA}; " + f"phi_A0 = {phi_A0}; " + f"K_phiB = {K_phiB}; " + f"phi_B0 = {phi_B0}; " + f"K_phiC = {K_phiC}; " + f"phi_C0 = {phi_C0}; " + ) + return energy_function + + +def get_custom_compound_bond_force( + n_particles: int = 6, energy_function: str = BORESCH_ENERGY_FUNCTION +): + """ + Return an OpenMM CustomCompoundForce + + TODO + ---- + Change this to a direct subclass like openmmtools.force. + + Acknowledgements + ---------------- + Boresch-like energy functions are reproduced from `Yank `_ + """ + return openmm.CustomCompoundBondForce(n_particles, energy_function) + + +def add_force_in_separate_group( + system: openmm.System, + force: openmm.Force, +): + """ + Add force to a System in a separate force group. + + Parameters + ---------- + system : openmm.System + System to add the Force to. + force : openmm.Force + The Force to add to the System. + + Raises + ------ + ValueError + If all 32 force groups are occupied. + + + TODO + ---- + Unlike the original Yank implementation, we assume that + all 32 force groups will not be filled. Should this be an issue + we can consider just separating it from NonbondedForce. + + Acknowledgements + ---------------- + Mostly reproduced from `Yank `_. + """ + available_force_groups = set(range(32)) + for force in system.getForces(): + available_force_groups.discard(force.getForceGroup()) + + force.setForceGroup(min(available_force_groups)) + system.addForce(force) diff --git a/openfe/protocols/openmm_utils/omm_restraints.py b/openfe/protocols/openmm_utils/omm_restraints.py index 2a62d30ae..f678428c8 100644 --- a/openfe/protocols/openmm_utils/omm_restraints.py +++ b/openfe/protocols/openmm_utils/omm_restraints.py @@ -14,9 +14,28 @@ * Add relevant duecredit entries. """ import abc -from typing import Optional, Union +from typing import Optional, Union, Callable -from openmmtools.states import GlobalParameterState +import openmm +from openmmtools.forces import ( + HarmonicRestraintForce, + HarmonicRestraintBondForce, + FlatBottomRestraintForce, + FlatBottomRestraintBondForce, +) +from openmmtools.states import GlobalParameterState, ThermodynamicState + +from gufe.settings.models import SettingsBaseModel +from openfe.protocols.openmm_utils.omm_forces import ( + get_custom_compound_bond_force, + add_force_in_separate_group, + get_boresch_energy_function, + get_periodic_boresch_energy_function, +) + + +class BaseRestraintGeometry: + pass class RestraintParameterState(GlobalParameterState): @@ -38,13 +57,16 @@ class RestraintParameterState(GlobalParameterState): Partially reproduced from Yank. """ - lambda_restraints = GlobalParameterState.GlobalParameter('lambda_restraints', standard_value=1.0) + lambda_restraints = GlobalParameterState.GlobalParameter( + "lambda_restraints", standard_value=1.0 + ) @lambda_restraints.validator def lambda_restraints(self, instance, new_value): if new_value is not None and not (0.0 <= new_value <= 1.0): - errmsg = ("lambda_restraints must be between 0.0 and 1.0, " - f"got {new_value}") + errmsg = ( + "lambda_restraints must be between 0.0 and 1.0, " f"got {new_value}" + ) raise ValueError(errmsg) # Not crashing out on None to match upstream behaviour return new_value @@ -60,5 +82,158 @@ class BaseHostGuestRestraints(abc.ABC): ---- Add some examples here. """ - def __init__(self, host_atoms: list[int], guest_atoms: list[int], restraint_settings: SettingBaseModel, restraint_geometry: BaseRestraintGeometry): + def __init__( + self, + host_atoms: list[int], + guest_atoms: list[int], + restraint_settings: SettingsBaseModel, + restraint_geometry: BaseRestraintGeometry, + controlling_parameter_name: str = "lambda_restraints", + ): + self.host_atoms = host_atoms + self.guest_atoms = guest_atoms + self.settings = restraint_settings + self.geometry = restraint_geometry + self._verify_input() + + @abc.abstractmethod + def _verify_inputs(self): + pass + + @abc.abstractmethod + def add_force(self, thermodynamic_state: ThermodynamicState): + pass + + @abc.abstractmethod + def get_standard_state_correction(self, thermodynamic_state: ThermodynamicState): + pass + + @abc.abstractmethod + def _get_force(self): + pass + + +class SingleBondMixin: + def _verify_input(self): + if len(self.host_atoms) != 1 or len(self.guest_atoms) != 1: + errmsg = ( + "host_atoms and guest_atoms must only include a single index " + f"each, got {len(host_atoms)} and " + f"{len(guest_atoms)} respectively." + ) + raise ValueError(errmsg) + super()._verify_inputs() + + +class BaseRadialllySymmetricRestraintForce(BaseHostGuestRestraints): + def _verify_inputs(self) -> None: + if not isinstance(self.settings, BaseDistanceRestraintSettings): + errmsg = f"Incorrect settings type {self.settings} passed through" + raise ValueError(errmsg) + if not isinstance(self.geometry, DistanceRestraintGeometry): + errmsg = f"Incorrect geometry type {self.geometry} passed through" + raise ValueError(errmsg) + + def add_force(self, thermodynamic_state: ThermodynamicState) -> None: + force = self._get_force() + force.setUsesPeriodicBoundaryConditions(thermodynamic_state.is_periodic) + # Note .system is a call to get_system() so it's returning a copy + system = thermodynamic_state.system + add_force_in_separate_group(system, force) + thermodynamic_state.system = system + + def get_standard_state_correction( + self, thermodynamic_state: ThermodynamicState + ) -> float: + force = self._get_force() + return force.compute_standard_state_correction( + thermodynamic_state, volume="system" + ) + + def _get_force(self): + raise NotImplementedError("only implemented in child classes") + + +class HarmonicBondRestraint(BaseRadialllySymmetricRestraintForce, SingleBondMixin): + def _get_force(self) -> openmm.Force: + return HarmonicRestraintBondForce( + spring_constant=self.settings.spring_constant, + restrained_atom_index1=self.host_atoms[0], + restrained_atom_index2=self.guest_atoms[0], + controlling_parameter_name=self.controlling_parameter_name, + ) + + +class FlatBottomBondRestraint(BaseRadialllySymmetricRestraintForce, SingleBondMixin): + def _get_force(self) -> openmm.Force: + return FlatBottomRestraintBondForce( + spring_constant=self.settings.spring_constant, + well_radius=self.settings.well_radius, + restrained_atom_index1=self.host_atoms[0], + restrained_atom_index2=self.guest_atoms[0], + controlling_parameter_name=self.controlling_parameter_name, + ) + + +class CentroidHarmonicRestraint(BaseRadialllySymmetricRestraintForce): + def _get_force(self) -> openmm.Force: + return HarmonicRestraintForce( + spring_constant=self.settings.spring_constant, + restrained_atom_index1=self.host_atoms, + restrained_atom_index2=self.guest_atoms, + controlling_parameter_name=self.controlling_parameter_name, + ) + + +class CentroidFlatBottomRestraint(BaseRadialllySymmetricRestraintForce): + def _get_force(self): + return FlatBottomRestraintBondForce( + spring_constant=self.settings.spring_constant, + well_radius=self.settings.well_radius, + restrained_atom_index1=self.host_atoms, + restrained_atom_index2=self.guest_atoms, + controlling_parameter_name=self.controlling_parameter_name, + ) + + +class BoreschRestraint(BaseHostGuestRestraints): + _EFUNC_METHOD: Callable = get_boresch_energy_function + def _verify_inputs(self) -> None: + if not isinstance(self.settings, BoreschRestraintSettings): + errmsg = f"Incorrect settings type {self.settings} passed through" + raise ValueError(errmsg) + if not isinstance(self.geometry, BoreschRestraintGeometry): + errmsg = f"Incorrect geometry type {self.geometry} passed through" + raise ValueError(errmsg) + + def add_force(self, thermodynamic_state: ThermodynamicState) -> None: + force = self._get_force() + force.addGlobalParameter(self.controlling_parameter_name, 1.0) + force.addBond(self.host_atoms + self.guest_atoms, []) + force.setUsesPeriodicBoundaryConditions(thermodynamic_state.is_periodic) + # Note .system is a call to get_system() so it's returning a copy + system = thermodynamic_state.system + add_force_in_separate_group(system, force) + thermodynamic_state.system = system + + def _get_force(self) -> openmm.Force: + efunc = _EFUNC_METHOD( + self.controlling_parameter_name, + self.settings.K_r, + self.geometry.r_aA0, + self.settings.K_thetaA, + self.geometry.theta_A0, + self.settings.K_thetaB, + self.geometry.theta_B0, + self.settings.K_phiA, + self.geometry.phi_A0, + self.settings.K_phiB, + self.geometry.phi_B0, + self.settings.K_phiC, + self.geometry.phi_C0, + ) + + return get_custom_compound_bond_force( + n_particles=6, energy_function=efunc + ) From 420f3e56d6321a2df49f06a865fefc628e771d87 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Fri, 6 Dec 2024 16:24:15 +0000 Subject: [PATCH 03/29] Add boresch restraint class --- openfe/protocols/openmm_utils/omm_forces.py | 42 +------- .../protocols/openmm_utils/omm_restraints.py | 95 ++++++++++++++----- 2 files changed, 76 insertions(+), 61 deletions(-) diff --git a/openfe/protocols/openmm_utils/omm_forces.py b/openfe/protocols/openmm_utils/omm_forces.py index e9246b694..3ad9d0aa6 100644 --- a/openfe/protocols/openmm_utils/omm_forces.py +++ b/openfe/protocols/openmm_utils/omm_forces.py @@ -13,46 +13,22 @@ def get_boresch_energy_function( control_parameter: str, - K_r: float, r_aA0: float, - K_thetaA: float, theta_A0: float, - K_thetaB: float, theta_B0: float, - K_phiA: float, phi_A0: float, - K_phiB: float, phi_B0: float, - K_phiC: float, phi_C0: float ) -> str: energy_function = ( f"{control_parameter} * E; " "E = (K_r/2)*(distance(p3,p4) - r_aA0)^2 " "+ (K_thetaA/2)*(angle(p2,p3,p4)-theta_A0)^2 + (K_thetaB/2)*(angle(p3,p4,p5)-theta_B0)^2 " "+ (K_phiA/2)*dphi_A^2 + (K_phiB/2)*dphi_B^2 + (K_phiC/2)*dphi_C^2; " - "dphi_A = dA - floor(dA/(2*pi)+0.5)*(2*pi); dA = dihedral(p1,p2,p3,p4) - phi_A0; " - "dphi_B = dB - floor(dB/(2*pi)+0.5)*(2*pi); dB = dihedral(p2,p3,p4,p5) - phi_B0; " - "dphi_C = dC - floor(dC/(2*pi)+0.5)*(2*pi); dC = dihedral(p3,p4,p5,p6) - phi_C0; " + "dphi_A = dA - floor(dA/(2.0*pi)+0.5)*(2.0*pi); dA = dihedral(p1,p2,p3,p4) - phi_A0; " + "dphi_B = dB - floor(dB/(2.0*pi)+0.5)*(2.0*pi); dB = dihedral(p2,p3,p4,p5) - phi_B0; " + "dphi_C = dC - floor(dC/(2.0*pi)+0.5)*(2.0*pi); dC = dihedral(p3,p4,p5,p6) - phi_C0; " f"pi = {np.pi}; " - f"K_r = {K_r}; " - f"r_aA0 = {r_aA0}; " - f"K_thetaA = {K_thetaA}; " - f"theta_A0 = {theta_A0}; " - f"K_thetaB = {K_thetaB}; " - f"theta_B0 = {theta_B0}; " - f"K_phiA = {K_phiA}; " - f"phi_A0 = {phi_A0}; " - f"K_phiB = {K_phiB}; " - f"phi_B0 = {phi_B0}; " - f"K_phiC = {K_phiC}; " - f"phi_C0 = {phi_C0}; " ) return energy_function def get_periodic_boresch_energy_function( control_parameter: str, - K_r: float, r_aA0: float, - K_thetaA: float, theta_A0: float, - K_thetaB: float, theta_B0: float, - K_phiA: float, phi_A0: float, - K_phiB: float, phi_B0: float, - K_phiC: float, phi_C0: float ) -> str: energy_function = ( f"{control_parameter} * E; " @@ -63,18 +39,6 @@ def get_periodic_boresch_energy_function( "uphi_B = (1-cos(dB)); dB = dihedral(p2,p3,p4,p5) - phi_B0; " "uphi_C = (1-cos(dC)); dC = dihedral(p3,p4,p5,p6) - phi_C0; " f"pi = {np.pi}; " - f"K_r = {K_r}; " - f"r_aA0 = {r_aA0}; " - f"K_thetaA = {K_thetaA}; " - f"theta_A0 = {theta_A0}; " - f"K_thetaB = {K_thetaB}; " - f"theta_B0 = {theta_B0}; " - f"K_phiA = {K_phiA}; " - f"phi_A0 = {phi_A0}; " - f"K_phiB = {K_phiB}; " - f"phi_B0 = {phi_B0}; " - f"K_phiC = {K_phiC}; " - f"phi_C0 = {phi_C0}; " ) return energy_function diff --git a/openfe/protocols/openmm_utils/omm_restraints.py b/openfe/protocols/openmm_utils/omm_restraints.py index f678428c8..d03b1f195 100644 --- a/openfe/protocols/openmm_utils/omm_restraints.py +++ b/openfe/protocols/openmm_utils/omm_restraints.py @@ -16,7 +16,9 @@ import abc from typing import Optional, Union, Callable +import numpy as np import openmm +from openmm import unit as omm_unit from openmmtools.forces import ( HarmonicRestraintForce, HarmonicRestraintBondForce, @@ -24,6 +26,7 @@ FlatBottomRestraintBondForce, ) from openmmtools.states import GlobalParameterState, ThermodynamicState +from openff.units.openmm import to_openmm from gufe.settings.models import SettingsBaseModel from openfe.protocols.openmm_utils.omm_forces import ( @@ -143,7 +146,7 @@ def add_force(self, thermodynamic_state: ThermodynamicState) -> None: add_force_in_separate_group(system, force) thermodynamic_state.system = system - def get_standard_state_correction( + def get_standard_state_correction( self, thermodynamic_state: ThermodynamicState ) -> float: force = self._get_force() @@ -157,8 +160,9 @@ def _get_force(self): class HarmonicBondRestraint(BaseRadialllySymmetricRestraintForce, SingleBondMixin): def _get_force(self) -> openmm.Force: + spring_constant = to_openmm(self.settings.sprint_constant).value_in_unit_system(omm_unit.md_unit_system) return HarmonicRestraintBondForce( - spring_constant=self.settings.spring_constant, + spring_constant=spring_constant, restrained_atom_index1=self.host_atoms[0], restrained_atom_index2=self.guest_atoms[0], controlling_parameter_name=self.controlling_parameter_name, @@ -167,9 +171,11 @@ def _get_force(self) -> openmm.Force: class FlatBottomBondRestraint(BaseRadialllySymmetricRestraintForce, SingleBondMixin): def _get_force(self) -> openmm.Force: + spring_constant = to_openmm(self.settings.sprint_constant).value_in_unit_system(omm_unit.md_unit_system) + well_radius = to_openmm(self.settings.well_radius).value_in_unit_system(omm_unit.md_unit_system) return FlatBottomRestraintBondForce( - spring_constant=self.settings.spring_constant, - well_radius=self.settings.well_radius, + spring_constant=spring_constant, + well_radius=well_radius, restrained_atom_index1=self.host_atoms[0], restrained_atom_index2=self.guest_atoms[0], controlling_parameter_name=self.controlling_parameter_name, @@ -178,8 +184,9 @@ def _get_force(self) -> openmm.Force: class CentroidHarmonicRestraint(BaseRadialllySymmetricRestraintForce): def _get_force(self) -> openmm.Force: + spring_constant = to_openmm(self.settings.sprint_constant).value_in_unit_system(omm_unit.md_unit_system) return HarmonicRestraintForce( - spring_constant=self.settings.spring_constant, + spring_constant=spring_constant, restrained_atom_index1=self.host_atoms, restrained_atom_index2=self.guest_atoms, controlling_parameter_name=self.controlling_parameter_name, @@ -188,9 +195,11 @@ def _get_force(self) -> openmm.Force: class CentroidFlatBottomRestraint(BaseRadialllySymmetricRestraintForce): def _get_force(self): + spring_constant = to_openmm(self.settings.sprint_constant).value_in_unit_system(omm_unit.md_unit_system) + well_radius = to_openmm(self.settings.well_radius).value_in_unit_system(omm_unit.md_unit_system) return FlatBottomRestraintBondForce( - spring_constant=self.settings.spring_constant, - well_radius=self.settings.well_radius, + spring_constant=spring_constant, + well_radius=well_radius, restrained_atom_index1=self.host_atoms, restrained_atom_index2=self.guest_atoms, controlling_parameter_name=self.controlling_parameter_name, @@ -209,8 +218,6 @@ def _verify_inputs(self) -> None: def add_force(self, thermodynamic_state: ThermodynamicState) -> None: force = self._get_force() - force.addGlobalParameter(self.controlling_parameter_name, 1.0) - force.addBond(self.host_atoms + self.guest_atoms, []) force.setUsesPeriodicBoundaryConditions(thermodynamic_state.is_periodic) # Note .system is a call to get_system() so it's returning a copy system = thermodynamic_state.system @@ -220,20 +227,64 @@ def add_force(self, thermodynamic_state: ThermodynamicState) -> None: def _get_force(self) -> openmm.Force: efunc = _EFUNC_METHOD( self.controlling_parameter_name, - self.settings.K_r, - self.geometry.r_aA0, - self.settings.K_thetaA, - self.geometry.theta_A0, - self.settings.K_thetaB, - self.geometry.theta_B0, - self.settings.K_phiA, - self.geometry.phi_A0, - self.settings.K_phiB, - self.geometry.phi_B0, - self.settings.K_phiC, - self.geometry.phi_C0, ) - return get_custom_compound_bond_force( + force = get_custom_compound_bond_force( n_particles=6, energy_function=efunc ) + + param_values = [] + + parameter_dict = { + 'K_r': self.settings.K_r, + 'r_aA0': self.geometry.r_aA0, + 'K_thetaA': self.settings.K_thetaA, + 'theta_A0': self.geometry.theta_A0, + 'K_thetaB': self.settings.K_thetaB, + 'theta_B0': self.geometry.theta_B0, + 'K_phiA': self.settings.K_phiA, + 'phi_A0': self.geometry.phi_A0, + 'K_phiB': self.settings.K_phiB, + 'phi_B0': self.geometry.phi_B0, + 'K_phiC': self.settings.K_phiC, + 'phi_C0': self.geometry.phi_C0, + } + for key, val in parameter_dict.items(): + param_values.append(to_openmm(val).value_in_unit_system(omm_unit.md_unit_system)) + force.addPerBondParameter(key) + + force.addGlobalParameter(self.controlling_parameter_name, 1.0) + force.addBond(self.host_atoms + self.guest_atoms, param_values) + return force + + def get_standard_state_correction( + self, thermodynamic_state: ThermodynamicState + ) -> float: + + StandardV = 1660.53928 * unit.angstroms**3 + kt = from_openmm(thermodynamic_state.kT) + + # distances + r_aA0 = self.geometry.r_aA0.to('nm') + sin_thetaA0 = np.sin(self.geometry.theta_A0.to('radians')) + sin_thetaB0 = np.sin(self.geometry.theta_B0.to('radians')) + + # restraint energies + K_r = self.settings.K_r.to('kilojoule_per_mole') + K_thetaA = self.settings.K_thetaA.to('kilojoule_per_mole') + k_thetaB = self.settings.K_thetaB.to('kilojoule_per_mole') + K_phiA = self.settings.K_phiA.to('kilojoule_per_mole') + K_phiB = self.settings.K_phiB.to('kilojoule_per_mole') + K_phiC = self.settings.K_phiC.to('kilojoule_per_mole') + + numerator1 = 8.0 * (np.pi**2) * StandardV + denum1 = (r_aA0**2) * sin_thetaA0 * sin_thetaB0 + numerator2 = np.sqrt(K_r * K_thetaA * K_thetaB * K_phiA * K_phiB * K_phiC) + denum2 = (2.0 * np.pi * kt)**3 + + dG = -kt * np.log((numerator1/denum1) * (numerator2/denum2)) + + return dG + + +# TODO - implement periodic torsion Boresch restraint From 2e2b52e38a44c636fb431c0544c12503de59e0cf Mon Sep 17 00:00:00 2001 From: IAlibay Date: Fri, 6 Dec 2024 16:35:00 +0000 Subject: [PATCH 04/29] Fix units --- openfe/protocols/openmm_utils/omm_restraints.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/openfe/protocols/openmm_utils/omm_restraints.py b/openfe/protocols/openmm_utils/omm_restraints.py index d03b1f195..104b07a4a 100644 --- a/openfe/protocols/openmm_utils/omm_restraints.py +++ b/openfe/protocols/openmm_utils/omm_restraints.py @@ -261,7 +261,7 @@ def get_standard_state_correction( self, thermodynamic_state: ThermodynamicState ) -> float: - StandardV = 1660.53928 * unit.angstroms**3 + StandardV = 1.66053928 * unit.nanometer**3 kt = from_openmm(thermodynamic_state.kT) # distances @@ -270,12 +270,12 @@ def get_standard_state_correction( sin_thetaB0 = np.sin(self.geometry.theta_B0.to('radians')) # restraint energies - K_r = self.settings.K_r.to('kilojoule_per_mole') - K_thetaA = self.settings.K_thetaA.to('kilojoule_per_mole') - k_thetaB = self.settings.K_thetaB.to('kilojoule_per_mole') - K_phiA = self.settings.K_phiA.to('kilojoule_per_mole') - K_phiB = self.settings.K_phiB.to('kilojoule_per_mole') - K_phiC = self.settings.K_phiC.to('kilojoule_per_mole') + K_r = self.settings.K_r.to('kilojoule_per_mole / nm ** 2') + K_thetaA = self.settings.K_thetaA.to('kilojoule_per_mole / radians ** 2') + k_thetaB = self.settings.K_thetaB.to('kilojoule_per_mole / radians ** 2') + K_phiA = self.settings.K_phiA.to('kilojoule_per_mole / radians ** 2') + K_phiB = self.settings.K_phiB.to('kilojoule_per_mole / radians ** 2') + K_phiC = self.settings.K_phiC.to('kilojoule_per_mole / radians ** 2') numerator1 = 8.0 * (np.pi**2) * StandardV denum1 = (r_aA0**2) * sin_thetaA0 * sin_thetaB0 From 76c5fcfe845250b527ab815cda89e204daedc206 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Mon, 9 Dec 2024 12:25:46 +0000 Subject: [PATCH 05/29] Fix correction return in kj/mole --- openfe/protocols/openmm_utils/omm_restraints.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/openfe/protocols/openmm_utils/omm_restraints.py b/openfe/protocols/openmm_utils/omm_restraints.py index 104b07a4a..915d51d3c 100644 --- a/openfe/protocols/openmm_utils/omm_restraints.py +++ b/openfe/protocols/openmm_utils/omm_restraints.py @@ -12,6 +12,7 @@ TODO ---- * Add relevant duecredit entries. +* Add Periodic Torsion Boresch class """ import abc from typing import Optional, Union, Callable @@ -26,7 +27,8 @@ FlatBottomRestraintBondForce, ) from openmmtools.states import GlobalParameterState, ThermodynamicState -from openff.units.openmm import to_openmm +from openff.units.openmm import to_openmm, from_openmm +from openff.units import unit from gufe.settings.models import SettingsBaseModel from openfe.protocols.openmm_utils.omm_forces import ( @@ -148,11 +150,13 @@ def add_force(self, thermodynamic_state: ThermodynamicState) -> None: def get_standard_state_correction( self, thermodynamic_state: ThermodynamicState - ) -> float: + ) -> unit.Quantity: force = self._get_force() - return force.compute_standard_state_correction( + corr = force.compute_standard_state_correction( thermodynamic_state, volume="system" ) + dg = corr * thermodynamic_state.kT + return from_openmm(dg).to('kilojoule_per_mole') def _get_force(self): raise NotImplementedError("only implemented in child classes") @@ -194,7 +198,7 @@ def _get_force(self) -> openmm.Force: class CentroidFlatBottomRestraint(BaseRadialllySymmetricRestraintForce): - def _get_force(self): + def _get_force(self) -> openmm.Force: spring_constant = to_openmm(self.settings.sprint_constant).value_in_unit_system(omm_unit.md_unit_system) well_radius = to_openmm(self.settings.well_radius).value_in_unit_system(omm_unit.md_unit_system) return FlatBottomRestraintBondForce( @@ -259,7 +263,7 @@ def _get_force(self) -> openmm.Force: def get_standard_state_correction( self, thermodynamic_state: ThermodynamicState - ) -> float: + ) -> unit.Quantity: StandardV = 1.66053928 * unit.nanometer**3 kt = from_openmm(thermodynamic_state.kT) @@ -285,6 +289,3 @@ def get_standard_state_correction( dG = -kt * np.log((numerator1/denum1) * (numerator2/denum2)) return dG - - -# TODO - implement periodic torsion Boresch restraint From e9cd918c60a9d3053f2459acb61b37cead9f2579 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Mon, 9 Dec 2024 12:57:04 +0000 Subject: [PATCH 06/29] Add more restraint API bits --- .../openmm_utils/restraints/__init__.py | 0 .../openmm_utils/restraints/geometry.py | 56 +++++++++++++++++++ .../{ => restraints}/omm_forces.py | 0 .../{ => restraints}/omm_restraints.py | 34 +++++------ 4 files changed, 71 insertions(+), 19 deletions(-) create mode 100644 openfe/protocols/openmm_utils/restraints/__init__.py create mode 100644 openfe/protocols/openmm_utils/restraints/geometry.py rename openfe/protocols/openmm_utils/{ => restraints}/omm_forces.py (100%) rename openfe/protocols/openmm_utils/{ => restraints}/omm_restraints.py (90%) diff --git a/openfe/protocols/openmm_utils/restraints/__init__.py b/openfe/protocols/openmm_utils/restraints/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/openfe/protocols/openmm_utils/restraints/geometry.py b/openfe/protocols/openmm_utils/restraints/geometry.py new file mode 100644 index 000000000..3d1f37a10 --- /dev/null +++ b/openfe/protocols/openmm_utils/restraints/geometry.py @@ -0,0 +1,56 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +Restraint Geometry classes + +TODO +---- +* Add relevant duecredit entries. +""" +import abc +from pydantic.v1 import BaseModel, validator + + +class BaseRestraintGeometry(BaseModel, abc.ABC): + class Config: + arbitrary_types_allowed = True + + +class HostGuestRestraintGeometry(BaseRestraintGeometry): + """ + An ordered list of guest atoms to restrain. + + Note + ---- + The order matters! It will be used to define the underlying + force. + """ + guest_atoms: list[int] + """ + An ordered list of host atoms to restrain. + + Note + ---- + The order matters! It will be used to define the underlying + force. + """ + host_atoms: list[int] + + @validator("guest_atoms", "host_atoms") + def positive_idxs(cls, v): + if any([i < 0 for i in v]): + errmsg = "negative indices passed" + raise ValueError(errmsg) + return v + + +class BondDistanceRestraintGeoemtry(HostGuestRestraintGeometry): + @validator("host_atoms", "guest_atoms") + def single_atoms(cls, v): + if len(v) != 1: + errmsg = ( + "Host and guest atom lists must only include a single atom, " + f"got {len(v)} atoms." + ) + raise ValueError(errmsg) + return v diff --git a/openfe/protocols/openmm_utils/omm_forces.py b/openfe/protocols/openmm_utils/restraints/omm_forces.py similarity index 100% rename from openfe/protocols/openmm_utils/omm_forces.py rename to openfe/protocols/openmm_utils/restraints/omm_forces.py diff --git a/openfe/protocols/openmm_utils/omm_restraints.py b/openfe/protocols/openmm_utils/restraints/omm_restraints.py similarity index 90% rename from openfe/protocols/openmm_utils/omm_restraints.py rename to openfe/protocols/openmm_utils/restraints/omm_restraints.py index 915d51d3c..0bfb6eb80 100644 --- a/openfe/protocols/openmm_utils/omm_restraints.py +++ b/openfe/protocols/openmm_utils/restraints/omm_restraints.py @@ -90,14 +90,10 @@ class BaseHostGuestRestraints(abc.ABC): def __init__( self, - host_atoms: list[int], - guest_atoms: list[int], restraint_settings: SettingsBaseModel, restraint_geometry: BaseRestraintGeometry, controlling_parameter_name: str = "lambda_restraints", ): - self.host_atoms = host_atoms - self.guest_atoms = guest_atoms self.settings = restraint_settings self.geometry = restraint_geometry self._verify_input() @@ -121,7 +117,7 @@ def _get_force(self): class SingleBondMixin: def _verify_input(self): - if len(self.host_atoms) != 1 or len(self.guest_atoms) != 1: + if len(self.geometry.host_atoms) != 1 or len(self.geometry.guest_atoms) != 1: errmsg = ( "host_atoms and guest_atoms must only include a single index " f"each, got {len(host_atoms)} and " @@ -148,7 +144,7 @@ def add_force(self, thermodynamic_state: ThermodynamicState) -> None: add_force_in_separate_group(system, force) thermodynamic_state.system = system - def get_standard_state_correction( + def get_standard_state_correction( self, thermodynamic_state: ThermodynamicState ) -> unit.Quantity: force = self._get_force() @@ -164,48 +160,48 @@ def _get_force(self): class HarmonicBondRestraint(BaseRadialllySymmetricRestraintForce, SingleBondMixin): def _get_force(self) -> openmm.Force: - spring_constant = to_openmm(self.settings.sprint_constant).value_in_unit_system(omm_unit.md_unit_system) + spring_constant = to_openmm(self.settings.spring_constant).value_in_unit_system(omm_unit.md_unit_system) return HarmonicRestraintBondForce( spring_constant=spring_constant, - restrained_atom_index1=self.host_atoms[0], - restrained_atom_index2=self.guest_atoms[0], + restrained_atom_index1=self.geometry.host_atoms[0], + restrained_atom_index2=self.geometry.guest_atoms[0], controlling_parameter_name=self.controlling_parameter_name, ) class FlatBottomBondRestraint(BaseRadialllySymmetricRestraintForce, SingleBondMixin): def _get_force(self) -> openmm.Force: - spring_constant = to_openmm(self.settings.sprint_constant).value_in_unit_system(omm_unit.md_unit_system) + spring_constant = to_openmm(self.settings.spring_constant).value_in_unit_system(omm_unit.md_unit_system) well_radius = to_openmm(self.settings.well_radius).value_in_unit_system(omm_unit.md_unit_system) return FlatBottomRestraintBondForce( spring_constant=spring_constant, well_radius=well_radius, - restrained_atom_index1=self.host_atoms[0], - restrained_atom_index2=self.guest_atoms[0], + restrained_atom_index1=self.geometry.host_atoms[0], + restrained_atom_index2=self.geometry.guest_atoms[0], controlling_parameter_name=self.controlling_parameter_name, ) class CentroidHarmonicRestraint(BaseRadialllySymmetricRestraintForce): def _get_force(self) -> openmm.Force: - spring_constant = to_openmm(self.settings.sprint_constant).value_in_unit_system(omm_unit.md_unit_system) + spring_constant = to_openmm(self.settings.spring_constant).value_in_unit_system(omm_unit.md_unit_system) return HarmonicRestraintForce( spring_constant=spring_constant, - restrained_atom_index1=self.host_atoms, - restrained_atom_index2=self.guest_atoms, + restrained_atom_index1=self.geometry.host_atoms, + restrained_atom_index2=self.geometry.guest_atoms, controlling_parameter_name=self.controlling_parameter_name, ) class CentroidFlatBottomRestraint(BaseRadialllySymmetricRestraintForce): def _get_force(self) -> openmm.Force: - spring_constant = to_openmm(self.settings.sprint_constant).value_in_unit_system(omm_unit.md_unit_system) + spring_constant = to_openmm(self.settings.spring_constant).value_in_unit_system(omm_unit.md_unit_system) well_radius = to_openmm(self.settings.well_radius).value_in_unit_system(omm_unit.md_unit_system) return FlatBottomRestraintBondForce( spring_constant=spring_constant, well_radius=well_radius, - restrained_atom_index1=self.host_atoms, - restrained_atom_index2=self.guest_atoms, + restrained_atom_index1=self.geometry.host_atoms, + restrained_atom_index2=self.geometry.guest_atoms, controlling_parameter_name=self.controlling_parameter_name, ) @@ -258,7 +254,7 @@ def _get_force(self) -> openmm.Force: force.addPerBondParameter(key) force.addGlobalParameter(self.controlling_parameter_name, 1.0) - force.addBond(self.host_atoms + self.guest_atoms, param_values) + force.addBond(self.geometry.host_atoms + self.geometry.guest_atoms, param_values) return force def get_standard_state_correction( From f1bbd8a440fd254f2b1eb0dfa17f1326e605a11e Mon Sep 17 00:00:00 2001 From: IAlibay Date: Mon, 9 Dec 2024 15:50:07 +0000 Subject: [PATCH 07/29] move some things around --- openfe/protocols/openmm_utils/restraints/omm_restraints.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/openfe/protocols/openmm_utils/restraints/omm_restraints.py b/openfe/protocols/openmm_utils/restraints/omm_restraints.py index 0bfb6eb80..599230ccc 100644 --- a/openfe/protocols/openmm_utils/restraints/omm_restraints.py +++ b/openfe/protocols/openmm_utils/restraints/omm_restraints.py @@ -5,7 +5,7 @@ Acknowledgements ---------------- -Many of the classes here are at least in part inspired, if not taken from +Many of the classes here are at least in part inspired from `Yank `_ and `OpenMMTools `_. @@ -39,10 +39,6 @@ ) -class BaseRestraintGeometry: - pass - - class RestraintParameterState(GlobalParameterState): """ Composable state to control `lambda_restraints` OpenMM Force parameters. From ac452e9df96f2c8ae23536be702e6939b1e98769 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Wed, 11 Dec 2024 11:52:04 +0000 Subject: [PATCH 08/29] Some changes --- .../openmm_utils/restraints/geometry.py | 54 +++++++++++++++---- .../openmm_utils/restraints/omm_restraints.py | 4 +- 2 files changed, 46 insertions(+), 12 deletions(-) diff --git a/openfe/protocols/openmm_utils/restraints/geometry.py b/openfe/protocols/openmm_utils/restraints/geometry.py index 3d1f37a10..14e1cd289 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry.py +++ b/openfe/protocols/openmm_utils/restraints/geometry.py @@ -10,6 +10,10 @@ import abc from pydantic.v1 import BaseModel, validator +from openff.units import unit +import MDAnalysis as mda +from MDAnalysis.lib.distances import calc_bonds + class BaseRestraintGeometry(BaseModel, abc.ABC): class Config: @@ -25,6 +29,7 @@ class HostGuestRestraintGeometry(BaseRestraintGeometry): The order matters! It will be used to define the underlying force. """ + guest_atoms: list[int] """ An ordered list of host atoms to restrain. @@ -44,13 +49,42 @@ def positive_idxs(cls, v): return v -class BondDistanceRestraintGeoemtry(HostGuestRestraintGeometry): - @validator("host_atoms", "guest_atoms") - def single_atoms(cls, v): - if len(v) != 1: - errmsg = ( - "Host and guest atom lists must only include a single atom, " - f"got {len(v)} atoms." - ) - raise ValueError(errmsg) - return v +class CentroidDistanceMixin: + def get_distance(self, topology, coordinates) -> unit.Quantity: + u = mda.Universe(topology, coordinates) + ag1 = u.atoms[self.host_atoms] + ag2 = u.atoms[self.guest_atoms] + bond = calc_bonds( + ag1.center_of_mass(), ag2.center_of_mass(), u.atoms.dimensions + ) + # convert to float so we avoid having a np.float64 + return float(bond) * unit.angstrom + + +def _check_single_atoms(value): + if len(value) != 1: + errmsg = ( + "Host and guest atom lists must only include a single atom, " + f"got {len(value)} atoms." + ) + raise ValueError(errmsg) + return value + + +class BondDistanceMixin: + def get_distance(self, topology, coordinates) -> unit.Quantity: + u = mda.Universe(topology, coordinates) + at1 = u.atoms[self.host_atoms[0]] + at2 = u.atoms[self.guest_atoms[0]] + bond = calc_bonds(at1.position, at2.position, u.atoms.dimensions) + # convert to float so we avoid having a np.float64 value + return float(bond) * unit.angstrom + + +class CentroidDistanceRestraintGeometry(HostGuestRestraintGeometry, CentroidDistanceMixin): + pass + + +class BondDistanceRestraintGeoemtry(HostGuestRestraintGeometry, BondDistanceMixin): + _check_host_atoms: classmethod = validator("host_atoms", allow_reuse=True)(_check_single_atoms) + _check_guest_atoms: classmethod = validator("guest_atoms", allow_reuse=True)(_check_single_atoms) diff --git a/openfe/protocols/openmm_utils/restraints/omm_restraints.py b/openfe/protocols/openmm_utils/restraints/omm_restraints.py index 599230ccc..ab5f4e821 100644 --- a/openfe/protocols/openmm_utils/restraints/omm_restraints.py +++ b/openfe/protocols/openmm_utils/restraints/omm_restraints.py @@ -168,7 +168,7 @@ def _get_force(self) -> openmm.Force: class FlatBottomBondRestraint(BaseRadialllySymmetricRestraintForce, SingleBondMixin): def _get_force(self) -> openmm.Force: spring_constant = to_openmm(self.settings.spring_constant).value_in_unit_system(omm_unit.md_unit_system) - well_radius = to_openmm(self.settings.well_radius).value_in_unit_system(omm_unit.md_unit_system) + well_radius = to_openmm(self.geometry.well_radius).value_in_unit_system(omm_unit.md_unit_system) return FlatBottomRestraintBondForce( spring_constant=spring_constant, well_radius=well_radius, @@ -192,7 +192,7 @@ def _get_force(self) -> openmm.Force: class CentroidFlatBottomRestraint(BaseRadialllySymmetricRestraintForce): def _get_force(self) -> openmm.Force: spring_constant = to_openmm(self.settings.spring_constant).value_in_unit_system(omm_unit.md_unit_system) - well_radius = to_openmm(self.settings.well_radius).value_in_unit_system(omm_unit.md_unit_system) + well_radius = to_openmm(self.geometry.well_radius).value_in_unit_system(omm_unit.md_unit_system) return FlatBottomRestraintBondForce( spring_constant=spring_constant, well_radius=well_radius, From 5d0b6837a2a5fdd129118c4283934f4bda787118 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Wed, 11 Dec 2024 13:18:29 +0000 Subject: [PATCH 09/29] Towards ABFE protocol --- openfe/protocols/openmm_afe/base.py | 149 ++- .../openmm_afe/equil_binding_afe_method.py | 901 ++++++++++++++++++ 2 files changed, 998 insertions(+), 52 deletions(-) create mode 100644 openfe/protocols/openmm_afe/equil_binding_afe_method.py diff --git a/openfe/protocols/openmm_afe/base.py b/openfe/protocols/openmm_afe/base.py index 633ec884a..7ec2acea2 100644 --- a/openfe/protocols/openmm_afe/base.py +++ b/openfe/protocols/openmm_afe/base.py @@ -31,6 +31,7 @@ from openmmtools import multistate from openmmtools.states import (SamplerState, ThermodynamicState, + GlobalParameterState, create_thermodynamic_state_protocol,) from openmmtools.alchemy import (AlchemicalRegion, AbsoluteAlchemicalFactory, AlchemicalState,) @@ -469,45 +470,70 @@ def _get_modeller( def _get_omm_objects( self, - system_modeller: app.Modeller, - system_generator: SystemGenerator, - smc_components: list[OFFMolecule], - ) -> tuple[app.Topology, openmm.unit.Quantity, openmm.System]: + settings: dict[str, SettingsBaseModel], + protein_component: Optional[ProteinComponent], + solvent_component: Optional[SolventComponent], + smc_components: dict[SmallMoleculeComponent, OFFMolecule], + ) -> tuple[ + app.Topology, + openmm.System, + openmm.unit.Quantity, + dict[str, npt.NDArray], + ]: """ Get the OpenMM Topology, Positions and System of the parameterised system. Parameters ---------- - system_modeller : app.Modeller - OpenMM Modeller object representing the system to be - parametrized. - system_generator : SystemGenerator - SystemGenerator object to create a System with. - smc_components : list[openff.toolkit.Molecule] - A list of openff Molecules to add to the system. + settings : dict[str, SettingsBaseModel] + Protocol settings + protein_component : Optional[ProteinComponent] + Protein component for the system. + solvent_component : Optional[SolventComponent] + Solvent component for the system. + smc_components : dict[str, OFFMolecule] + SmallMoleculeComponents defining ligands to be added to the system Returns ------- topology : app.Topology - Topology object describing the parameterized system + OpenMM Topology object describing the parameterized system. system : openmm.System - An OpenMM System of the alchemical system. - positionns : openmm.unit.Quantity + An non-alchemical OpenMM System of the simulated system. + positions : openmm.unit.Quantity Positions of the system. + comp_resids : dict[str, npt.NDArray] + A dictionary of residues for each component in the System. """ - topology = system_modeller.getTopology() + if self.verbose: + self.logger.info("Parameterizing system") + + system_generator = self._get_system_generator( + settings, solvent_component + ) + + modeller, comp_resids = self._get_modeller( + protein_component, + solvent_component, + smc_components, + system_generator, + settings['charge_settings'], + settings['solvation_settings'] + ) + + topology = modeller.getTopology() # roundtrip positions to remove vec3 issues - positions = to_openmm(from_openmm(system_modeller.getPositions())) + positions = to_openmm(from_openmm(modeller.getPositions())) # Block out oechem backend to avoid any issues with # smiles roundtripping between rdkit and oechem with without_oechem_backend(): system = system_generator.create_system( - system_modeller.topology, + modeller.topology, molecules=smc_components, ) - return topology, system, positions + return topology, system, positions, comp_resids def _get_lambda_schedule( self, settings: dict[str, SettingsBaseModel] @@ -533,13 +559,16 @@ def _get_lambda_schedule( lambda_elec = settings['lambda_settings'].lambda_elec lambda_vdw = settings['lambda_settings'].lambda_vdw + lambda_rest = settings['lambda_settings'].lambda_restraints # Reverse lambda schedule since in AbsoluteAlchemicalFactory 1 # means fully interacting, not stateB - lambda_elec = [1-x for x in lambda_elec] - lambda_vdw = [1-x for x in lambda_vdw] - lambdas['lambda_electrostatics'] = lambda_elec - lambdas['lambda_sterics'] = lambda_vdw + for name, schedule in [ + ('lambda_electrostatics', lambda_elec), + ('lambda_sterics', lambda_vdw), + ('lambda_restraints', lambda_rest), + ]: + lambdas[name] = [1-x for x in schedule] return lambdas @@ -547,7 +576,7 @@ def _add_restraints(self, system, topology, settings): """ Placeholder method to add restraints if necessary """ - return + return None, None def _get_alchemical_system( self, @@ -607,6 +636,7 @@ def _get_states( settings: dict[str, SettingsBaseModel], lambdas: dict[str, npt.NDArray], solvent_comp: Optional[SolventComponent], + restraint_state: Optional[GlobalParameterState], ) -> tuple[list[SamplerState], list[ThermodynamicState]]: """ Get a list of sampler and thermodynmic states from an @@ -624,6 +654,8 @@ def _get_states( A dictionary of lambda scales. solvent_comp : Optional[SolventComponent] The solvent component of the system, if there is one. + restraint_state : Optional[GlobalParameterState] + The restraint parameter control state, if there is one. Returns ------- @@ -641,9 +673,14 @@ def _get_states( if solvent_comp is not None: constants['pressure'] = ensure_quantity(pressure, 'openmm') + if restraint_state is not None: + composable_states = [alchemical_state, restraint_state] + else: + composable_states = [alchemical_state,] + cmp_states = create_thermodynamic_state_protocol( alchemical_system, protocol=lambdas, - constants=constants, composable_states=[alchemical_state], + constants=constants, composable_states=composable_states, ) sampler_state = SamplerState(positions=positions) @@ -873,6 +910,7 @@ def _run_simulation( sampler: multistate.MultiStateSampler, reporter: multistate.MultiStateReporter, settings: dict[str, SettingsBaseModel], + standard_state_corr: Optional[unit.Quantity] dry: bool ): """ @@ -886,6 +924,8 @@ def _run_simulation( The reporter associated with the sampler. settings : dict[str, SettingsBaseModel] The dictionary of settings for the protocol. + standard_state_corr : Optional[unit.Quantity] + The standard state correction, if available. dry : bool Whether or not to dry run the simulation @@ -944,7 +984,12 @@ def _run_simulation( analyzer.plot(filepath=self.shared_basepath, filename_prefix="") analyzer.close() - return analyzer.unit_results_dict + return_dict = analyzer.unit_results_dict + + if standard_state_corr is not None: + return_dict['standard_state_correction'] = standard_state_corr + + return return_dict else: # close reporter when you're done, prevent file handle clashes @@ -991,44 +1036,40 @@ def run(self, dry=False, verbose=True, # 2. Get settings settings = self._handle_settings() - # 3. Get system generator - system_generator = self._get_system_generator(settings, solv_comp) - - # 4. Get modeller - system_modeller, comp_resids = self._get_modeller( - prot_comp, solv_comp, smc_comps, system_generator, - settings['charge_settings'], - settings['solvation_settings'], + # 3. Get OpenMM topology, positions, and system + omm_topology, omm_system, position, comp_resids = self._get_omm_objects( + settings, prot_comps, solv_comps, smc_comps, ) - # 5. Get OpenMM topology, positions and system - omm_topology, omm_system, positions = self._get_omm_objects( - system_modeller, system_generator, list(smc_comps.values()) - ) - - # 6. Pre-equilbrate System (Test + Avoid NaNs + get stable system) + # 4. Pre-equilbrate System (Test + Avoid NaNs + get stable system) positions = self._pre_equilibrate( omm_system, omm_topology, positions, settings, dry ) - # 7. Get lambdas + # 5. Get lambdas lambdas = self._get_lambda_schedule(settings) - # 8. Add restraints - self._add_restraints(omm_system, omm_topology, settings) + # 6. Add restraints + restraint_parameter_state, standard_state_corr = self._add_restraints( + omm_system, omm_topology, settings + ) - # 9. Get alchemical system + # 7. Get alchemical system alchem_factory, alchem_system, alchem_indices = self._get_alchemical_system( omm_topology, omm_system, comp_resids, alchem_comps ) - # 10. Get compound and sampler states + # 7. Get compound and sampler states sampler_states, cmp_states = self._get_states( - alchem_system, positions, settings, - lambdas, solv_comp + alchem_system, + positions, + settings, + lambdas, + solv_comp, + restraint_parameter_state, ) - # 11. Create the multistate reporter & create PDB + # 9. Create the multistate reporter & create PDB reporter = self._get_reporter( omm_topology, positions, settings['simulation_settings'], @@ -1037,19 +1078,19 @@ def run(self, dry=False, verbose=True, # Wrap in try/finally to avoid memory leak issues try: - # 12. Get context caches + # 10. Get context caches energy_ctx_cache, sampler_ctx_cache = self._get_ctx_caches( settings['forcefield_settings'], settings['engine_settings'] ) - # 13. Get integrator + # 11. Get integrator integrator = self._get_integrator( settings['integrator_settings'], settings['simulation_settings'], ) - # 14. Get sampler + # 12. Get sampler sampler = self._get_sampler( integrator, reporter, settings['simulation_settings'], settings['thermo_settings'], @@ -1057,9 +1098,13 @@ def run(self, dry=False, verbose=True, energy_ctx_cache, sampler_ctx_cache ) - # 15. Run simulation + # 13. Run simulation unit_result_dict = self._run_simulation( - sampler, reporter, settings, dry + sampler, + reporter, + settings, + standard_state_corr, + dry ) finally: diff --git a/openfe/protocols/openmm_afe/equil_binding_afe_method.py b/openfe/protocols/openmm_afe/equil_binding_afe_method.py new file mode 100644 index 000000000..7943ca4c3 --- /dev/null +++ b/openfe/protocols/openmm_afe/equil_binding_afe_method.py @@ -0,0 +1,901 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +"""OpenMM Equilibrium Solvation AFE Protocol --- :mod:`openfe.protocols.openmm_afe.equil_solvation_afe_method` +=============================================================================================================== + +This module implements the necessary methodology tooling to run calculate an +absolute solvation free energy using OpenMM tools and one of the following +alchemical sampling methods: + +* Hamiltonian Replica Exchange +* Self-adjusted mixture sampling +* Independent window sampling + +Current limitations +------------------- +* Disapearing molecules are only allowed in state A. Support for + appearing molecules will be added in due course. +* Only small molecules are allowed to act as alchemical molecules. + Alchemically changing protein or solvent components would induce + perturbations which are too large to be handled by this Protocol. + + +Acknowledgements +---------------- +* Originally based on hydration.py in + `espaloma_charge `_ + +""" +from __future__ import annotations + +import pathlib +import logging +import warnings +from collections import defaultdict +import gufe +from gufe.components import Component +import itertools +import numpy as np +import numpy.typing as npt +from openff.units import unit +from openmmtools import multistate +from typing import Optional, Union +from typing import Any, Iterable +import uuid + +from gufe import ( + settings, + ChemicalSystem, SmallMoleculeComponent, + ProteinComponent, SolventComponent +) +from openfe.protocols.openmm_afe.equil_afe_settings import ( + AbsoluteSolvationSettings, + OpenMMSolvationSettings, AlchemicalSettings, LambdaSettings, + MDSimulationSettings, MDOutputSettings, + MultiStateSimulationSettings, OpenMMEngineSettings, + IntegratorSettings, MultiStateOutputSettings, + OpenFFPartialChargeSettings, + SettingsBaseModel, +) +from ..openmm_utils import system_validation, settings_validation +from .base import BaseAbsoluteUnit +from openfe.utils import log_system_probe +from openfe.due import due, Doi + + +due.cite(Doi("10.5281/zenodo.596504"), + description="Yank", + path="openfe.protocols.openmm_afe.equil_binding_afe_method", + cite_module=True) + +due.cite(Doi("10.5281/zenodo.596622"), + description="OpenMMTools", + path="openfe.protocols.openmm_afe.equil_binding_afe_method", + cite_module=True) + +due.cite(Doi("10.1371/journal.pcbi.1005659"), + description="OpenMM", + path="openfe.protocols.openmm_afe.equil_binding_afe_method", + cite_module=True) + + +logger = logging.getLogger(__name__) + + +class AbsoluteBindingProtocolResult(gufe.ProtocolResult): + """Dict-like container for the output of a AbsoluteBindingProtocol + """ + def __init__(self, **data): + super().__init__(**data) + # TODO: Detect when we have extensions and stitch these together? + if any(len(pur_list) > 2 for pur_list + in itertools.chain(self.data['solvent'].values(), self.data['vacuum'].values())): + raise NotImplementedError("Can't stitch together results yet") + + def get_individual_estimates(self) -> dict[str, list[tuple[unit.Quantity, unit.Quantity]]]: + """ + Get the individual estimate of the free energies. + + Returns + ------- + dGs : dict[str, list[tuple[unit.Quantity, unit.Quantity]]] + A dictionary, keyed `solvent`, `complex`, and 'standard_state' + representing each portion of the thermodynamic cycle, + with lists of tuples containing the individual free energy + estimates and, for 'solvent' and 'complex', the associated MBAR + uncertainties for each repeat of that simulation type. + + TODO + ---- + * Work out poperly what to do with the standard state correction. + """ + complex_dGs = [] + correction_dGs = [] + solv_dGs = [] + + for pus in self.data['complex'].values(): + complex_dGs.append(( + pus[0].outputs['unit_estimate'], + pus[0].outputs['unit_estimate_error'] + )) + correction_dGs.append(( + pus[0].outputs['standard_state_correction'] + )) + + for pus in self.data['solvent'].values(): + solv_dGs.append(( + pus[0].outputs['unit_estimate'], + pus[0].outputs['unit_estimate_error'] + )) + + return {'solvent': solv_dGs, 'complex': complex_dGs, 'standard_state': correction_dGs} + + def get_estimate(self): + """Get the binding free energy estimate for this calculation. + + Returns + ------- + dG : unit.Quantity + The binding free energy. This is a Quantity defined with units. + """ + def _get_average(estimates): + # Get the unit value of the first value in the estimates + u = estimates[0][0].u + # Loop through estimates and get the free energy values + # in the unit of the first estimate + dGs = [i[0].to(u).m for i in estimates] + + return np.average(dGs) * u + + individual_estimates = self.get_individual_estimates() + complex_dG = _get_average(individual_estimates['complex']) + solv_dG = _get_average(individual_estimates['solvent']) + standard_state_dG = _get_average( + individual_estimates['standard_state'] + ) + + return - complex_dG + solv_dG + standard_state_dG + + def get_uncertainty(self): + """Get the binding free energy error for this calculation. + + Returns + ------- + err : unit.Quantity + The standard deviation between estimates of the binding free + energy. This is a Quantity defined with units. + """ + def _get_stdev(estimates): + # Get the unit value of the first value in the estimates + u = estimates[0][0].u + # Loop through estimates and get the free energy values + # in the unit of the first estimate + dGs = [i[0].to(u).m for i in estimates] + + return np.std(dGs) * u + + individual_estimates = self.get_individual_estimates() + complex_err = _get_stdev(individual_estimates['complex']) + solv_err = _get_stdev(individual_estimates['solvent']) + standard_state_err = _get_stdev(individual_estimates['standard_state']) + + # return the combined error + return np.sqrt(complex_err**2 + solv_err**2 + standard_state_err**2) + + def get_forward_and_reverse_energy_analysis(self) -> dict[str, list[Optional[dict[str, Union[npt.NDArray, unit.Quantity]]]]]: + """ + Get the reverse and forward analysis of the free energies. + + Returns + ------- + forward_reverse : dict[str, list[Optional[dict[str, Union[npt.NDArray, unit.Quantity]]]]] + A dictionary, keyed `solvent` and `complex` for each leg of the + thermodynamic cycle which each contain a list of dictionaries + containing the forward and reverse analysis of each repeat + of that simulation type. + + The forward and reverse analysis dictionaries contain: + - `fractions`: npt.NDArray + The fractions of data used for the estimates + - `forward_DGs`, `reverse_DGs`: unit.Quantity + The forward and reverse estimates for each fraction of data + - `forward_dDGs`, `reverse_dDGs`: unit.Quantity + The forward and reverse estimate uncertainty for each + fraction of data. + + If one of the cycle leg list entries is ``None``, this indicates + that the analysis could not be carried out for that repeat. This + is most likely caused by MBAR convergence issues when attempting to + calculate free energies from too few samples. + + Raises + ------ + UserWarning + * If any of the forward and reverse dictionaries are ``None`` in a + given thermodynamic cycle leg. + """ + + forward_reverse: dict[str, list[Optional[dict[str, Union[npt.NDArray, unit.Quantity]]]]] = {} + + for key in ['solvent', 'complex']: + forward_reverse[key] = [ + pus[0].outputs['forward_and_reverse_energies'] + for pus in self.data[key].values() + ] + + if None in forward_reverse[key]: + wmsg = ( + "One or more ``None`` entries were found in the forward " + f"and reverse dictionaries of the repeats of the {key} " + "calculations. This is likely caused by an MBAR convergence " + "failure caused by too few independent samples when " + "calculating the free energies of the 10% timeseries slice." + ) + warnings.warn(wmsg) + + return forward_reverse + + def get_overlap_matrices(self) -> dict[str, list[dict[str, npt.NDArray]]]: + """ + Get a the MBAR overlap estimates for all legs of the simulation. + + Returns + ------- + overlap_stats : dict[str, list[dict[str, npt.NDArray]]] + A dictionary with keys `solvent` and `complex` for each + leg of the thermodynamic cycle, which each containing a + list of dictionaries with the MBAR overlap estimates of + each repeat of that simulation type. + + The underlying MBAR dictionaries contain the following keys: + * ``scalar``: One minus the largest nontrivial eigenvalue + * ``eigenvalues``: The sorted (descending) eigenvalues of the + overlap matrix + * ``matrix``: Estimated overlap matrix of observing a sample from + state i in state j + """ + # Loop through and get the repeats and get the matrices + overlap_stats: dict[str, list[dict[str, npt.NDArray]]] = {} + + for key in ['solvent', 'complex']: + overlap_stats[key] = [ + pus[0].outputs['unit_mbar_overlap'] + for pus in self.data[key].values() + ] + + return overlap_stats + + def get_replica_transition_statistics(self) -> dict[str, list[dict[str, npt.NDArray]]]: + """ + Get the replica exchange transition statistics for all + legs of the simulation. + + Note + ---- + This is currently only available in cases where a replica exchange + simulation was run. + + Returns + ------- + repex_stats : dict[str, list[dict[str, npt.NDArray]]] + A dictionary with keys `solvent` and `complex` for each + leg of the thermodynamic cycle, which each containing + a list of dictionaries containing the replica transition + statistics for each repeat of that simulation type. + + The replica transition statistics dictionaries contain the following: + * ``eigenvalues``: The sorted (descending) eigenvalues of the + lambda state transition matrix + * ``matrix``: The transition matrix estimate of a replica switching + from state i to state j. + """ + repex_stats: dict[str, list[dict[str, npt.NDArray]]] = {} + try: + for key in ['solvent', 'complex']: + repex_stats[key] = [ + pus[0].outputs['replica_exchange_statistics'] + for pus in self.data[key].values() + ] + except KeyError: + errmsg = ("Replica exchange statistics were not found, " + "did you run a repex calculation?") + raise ValueError(errmsg) + + return repex_stats + + def get_replica_states(self) -> dict[str, list[npt.NDArray]]: + """ + Get the timeseries of replica states for all simulation legs. + + Returns + ------- + replica_states : dict[str, list[npt.NDArray]] + Dictionary keyed `solvent` and `complex` for each leg of + the thermodynamic cycle, with lists of replica states + timeseries for each repeat of that simulation type. + """ + replica_states: dict[str, list[npt.NDArray]] = { + 'solvent': [], 'complex': [] + } + + def is_file(filename: str): + p = pathlib.Path(filename) + + if not p.exists(): + errmsg = f"File could not be found {p}" + raise ValueError(errmsg) + + return p + + def get_replica_state(nc, chk): + nc = is_file(nc) + dir_path = nc.parents[0] + chk = is_file(dir_path / chk).name + + reporter = multistate.MultiStateReporter( + storage=nc, checkpoint_storage=chk, open_mode='r' + ) + + retval = np.asarray(reporter.read_replica_thermodynamic_states()) + reporter.close() + + return retval + + for key in ['solvent', 'complex']: + for pus in self.data[key].values(): + states = get_replica_state( + pus[0].outputs['nc'], + pus[0].outputs['last_checkpoint'], + ) + replica_states[key].append(states) + + return replica_states + + def equilibration_iterations(self) -> dict[str, list[float]]: + """ + Get the number of equilibration iterations for each simulation. + + Returns + ------- + equilibration_lengths : dict[str, list[float]] + Dictionary keyed `solvent` and `complex` for each leg + of the thermodynamic cycle, with lists containing the + number of equilibration iterations for each repeat + of that simulation type. + """ + equilibration_lengths: dict[str, list[float]] = {} + + for key in ['solvent', 'complex']: + equilibration_lengths[key] = [ + pus[0].outputs['equilibration_iterations'] + for pus in self.data[key].values() + ] + + return equilibration_lengths + + def production_iterations(self) -> dict[str, list[float]]: + """ + Get the number of production iterations for each simulation. + Returns the number of uncorrelated production samples for each + repeat of the calculation. + + Returns + ------- + production_lengths : dict[str, list[float]] + Dictionary keyed `solvent` and `complex` for each leg of the + thermodynamic cycle, with lists with the number + of production iterations for each repeat of that simulation + type. + """ + production_lengths: dict[str, list[float]] = {} + + for key in ['solvent', 'complex']: + production_lengths[key] = [ + pus[0].outputs['production_iterations'] + for pus in self.data[key].values() + ] + + return production_lengths + + +class AbsoluteBindingProtocol(gufe.Protocol): + """ + Absolute binding free energy calculations using OpenMM and OpenMMTools. + + See Also + -------- + :mod:`openfe.protocols` + :class:`openfe.protocols.openmm_afe.AbsoluteBindingSettings` + :class:`openfe.protocols.openmm_afe.AbsoluteBindingProtocolResult` + :class:`openfe.protocols.openmm_afe.AbsoluteBindingSolventUnit` + :class:`openfe.protocols.openmm_afe.AbsoluteBindingComplexUnit` + """ + result_cls = AbsoluteBindingProtocolResult + _settings: AbsoluteBindingSettings + + @classmethod + def _default_settings(cls): + """A dictionary of initial settings for this creating this Protocol + + These settings are intended as a suitable starting point for creating + an instance of this protocol. It is recommended, however that care is + taken to inspect and customize these before performing a Protocol. + + Returns + ------- + Settings + a set of default settings + """ + return AbsoluteBindingSettings( + protocol_repeats=3, + forcefield_settings=settings.OpenMMSystemGeneratorFFSettings(), + thermo_settings=settings.ThermoSettings( + temperature=298.15 * unit.kelvin, + pressure=1 * unit.bar, + ), + alchemical_settings=AlchemicalSettings(), + solvent_lambda_settings=LambdaSettings( + lambda_elec=[ + 0.0, 0.25, 0.5, 0.75, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + lambda_vdw=[ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.12, 0.24, + 0.36, 0.48, 0.6, 0.7, 0.77, 0.85, 1.0], + lambda_restraints=[ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ), + complex_lambda_settings=LambdaSettings( + lambda_elec=[ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.1, + 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.00, 1.0, 1.00, 1.0, 1.00, 1.0, 1.00, 1.0], + lambda_vdw=[ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1, + 0.2, 0.3, 0.4, 0.5, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1.0], + lambda_restraints=[ + 0.0, 0.2, 0.4, 0.6, 0.8, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.00, 1.0, 1.00, 1.0, 1.00, 1.0, 1.00, 1.0], + ), + partial_charge_settings=OpenFFPartialChargeSettings(), + solvation_settings=OpenMMSolvationSettings(), + engine_settings=OpenMMEngineSettings(), + integrator_settings=IntegratorSettings(), + solvent_equil_simulation_settings=MDSimulationSettings( + equilibration_length_nvt=0.1 * unit.nanosecond, + equilibration_length=0.2 * unit.nanosecond, + production_length=0.5 * unit.nanosecond, + ), + solvent_equil_output_settings=MDOutputSettings( + equil_nvt_structure='equil_nvt_structure.pdb', + equil_npt_structure='equil_npt_structure.pdb', + production_trajectory_filename='production_equil.xtc', + log_output='equil_simulation.log', + ), + solvent_simulation_settings=MultiStateSimulationSettings( + n_replicas=14, + equilibration_length=1.0 * unit.nanosecond, + production_length=10.0 * unit.nanosecond, + ), + solvent_output_settings=MultiStateOutputSettings( + output_filename='solvent.nc', + checkpoint_storage_filename='solvent_checkpoint.nc', + ), + complex_equil_simulation_settings=MDSimulationSettings( + equilibration_length_nvt=0.25 * unit.nanosecond, + equilibration_length=0.5 * unit.nanosecond, + production_length=5.0 * unit.nanosecond, + ), + complex_equil_output_settings=MDOutputSettings( + equil_nvt_structure='equil_nvt_structure.pdb', + equil_npt_structure='equil_npt_structure.pdb', + production_trajectory_filename='production_equil.xtc', + log_output='equil_simulation.log', + ), + vacuum_simulation_settings=MultiStateSimulationSettings( + n_replicas=28, + equilibration_length=1 * unit.nanosecond, + production_length=10.0 * unit.nanosecond, + ), + vacuum_output_settings=MultiStateOutputSettings( + output_filename='complex.nc', + checkpoint_storage_filename='complex_checkpoint.nc' + ), + ) + + @staticmethod + def _validate_endstates( + stateA: ChemicalSystem, stateB: ChemicalSystem, + ) -> None: + """ + A binding transformation is defined (in terms of gufe components) + as starting from one or more ligands with one protein and solvent, + that then ends up in a state with one less ligand. + + Parameters + ---------- + stateA : ChemicalSystem + The chemical system of end state A + stateB : ChemicalSystem + The chemical system of end state B + + Raises + ------ + ValueError + If stateA does not contain a ProteinComponent + If stateA does not contain a SolventComponent + If stateA has more than one unique Component + If the stateA unique Component is not a SmallMoleculeComponent + If stateB contains any unique Components + """ + if not any( + isinstance(comp, ProteinComponent) for comp in stateA.values() + ): + errmsg = "No ProteinComponent found" + raise ValueError(errmsg) + + if not any( + isinstance(comp, SolventComponent) for comp in stateA.values() + ): + errmsg = "No SolventComponent found" + raise ValueError(errmsg) + + # Needs gufe 1.3 + diff = stateA.component_diff(stateB) + if len(diff[0]) > 1: + errmsg = ("More than unique components found in stateA, " + "only one alchemical species is supported") + raise ValueError(errmsg) + + if not isinstance(diff[0][0], SmallMoleculeComponent): + errmsg = ("Only dissapearing smalll molecule components " + "are supported by this protocol. " + f"Found a {type(diff[0][0])}") + raise ValueError(errmsg) + + if len(diff[1]) > 0: + errmsg = ("Unique components are found in stateB, " + "this should not happen") + raise ValueError(errmsg) + + @staticmethod + def _validate_lambda_schedule( + lambda_settings: LambdaSettings, + simulation_settings: MultiStateSimulationSettings, + ) -> None: + """ + Checks that the lambda schedule is set up correctly. + + Parameters + ---------- + lambda_settings : LambdaSettings + the lambda schedule Settings + simulation_settings : MultiStateSimulationSettings + the settings for either the vacuum or solvent phase + + Raises + ------ + ValueError + If the number of lambda windows differs for electrostatics, sterics, + and restraints. + If the number of replicas does not match the number of lambda windows. + If there are states with naked charges. + """ + + lambda_elec = lambda_settings.lambda_elec + lambda_vdw = lambda_settings.lambda_vdw + lambda_restraints = lambda_settings.lambda_restraints + n_replicas = simulation_settings.n_replicas + + # Ensure that all lambda components have equal amount of windows + lambda_components = [lambda_vdw, lambda_elec, lambda_restraints] + it = iter(lambda_components) + the_len = len(next(it)) + if not all(len(l) == the_len for l in it): + errmsg = ( + "Components elec, vdw, and restraints must have equal amount" + f" of lambda windows. Got {len(lambda_elec)} elec lambda" + f" windows, {len(lambda_vdw)} vdw lambda windows, and" + f"{len(lambda_restraints)} restraints lambda windows.") + raise ValueError(errmsg) + + # Ensure that number of overall lambda windows matches number of lambda + # windows for individual components + if n_replicas != len(lambda_vdw): + errmsg = (f"Number of replicas {n_replicas} does not equal the" + f" number of lambda windows {len(lambda_vdw)}") + raise ValueError(errmsg) + + # Check if there are no lambda windows with naked charges + for inx, lam in enumerate(lambda_elec): + if lam < 1 and lambda_vdw[inx] == 1: + errmsg = ( + "There are states along this lambda schedule " + "where there are atoms with charges but no LJ " + f"interactions: lambda {inx}: " + f"elec {lam} vdW {lambda_vdw[inx]}") + raise ValueError(errmsg) + + def _create( + self, + stateA: ChemicalSystem, + stateB: ChemicalSystem, + mapping: Optional[Union[gufe.ComponentMapping, list[gufe.ComponentMapping]]] = None, + extends: Optional[gufe.ProtocolDAGResult] = None, + ) -> list[gufe.ProtocolUnit]: + # TODO: extensions + if extends: # pragma: no-cover + raise NotImplementedError("Can't extend simulations yet") + + # Validate components and get alchemical components + self._validate_endstates(stateA, stateB) + alchem_comps = system_validation.get_alchemical_components( + stateA, stateB, + ) + + # Validate the lambda schedule + self._validate_lambda_schedule(self.settings.solvent_lambda_settings, + self.settings.solvent_simulation_settings) + self._validate_lambda_schedule(self.settings.complex_lambda_settings, + self.settings.complex_simulation_settings) + + # Check nonbond & solvent compatibility + nonbonded_method = self.settings.forcefield_settings.nonbonded_method + # Use the more complete system validation solvent checks + system_validation.validate_solvent(stateA, nonbonded_method) + + # Validate solvation settings + settings_validation.validate_openmm_solvation_settings( + self.settings.solvation_settings + ) + + # Get the name of the alchemical species + alchname = alchem_comps['stateA'][0].name + + # Create list units for vacuum and solvent transforms + + solvent_units = [ + AbsoluteBindingSolventUnit( + protocol=self, + stateA=stateA, + stateB=stateB, + alchemical_components=alchem_comps, + generation=0, repeat_id=int(uuid.uuid4()), + name=(f"Absolute Binding, {alchname} solvent leg: " + f"repeat {i} generation 0"), + ) + for i in range(self.settings.protocol_repeats) + ] + + complex_units = [ + AbsoluteBindingComplexUnit( + protocol=self, + stateA=stateA, + stateB=stateB, + alchemical_components=alchem_comps, + generation=0, repeat_id=int(uuid.uuid4()), + name=(f"Absolute Binding, {alchname} complex leg: " + f"repeat {i} generation 0"), + ) + for i in range(self.settings.protocol_repeats) + ] + + return solvent_units + complex_units + + def _gather( + self, protocol_dag_results: Iterable[gufe.ProtocolDAGResult] + ) -> dict[str, dict[str, Any]]: + # result units will have a repeat_id and generation + # first group according to repeat_id + unsorted_solvent_repeats = defaultdict(list) + unsorted_complex_repeats = defaultdict(list) + for d in protocol_dag_results: + pu: gufe.ProtocolUnitResult + for pu in d.protocol_unit_results: + if not pu.ok(): + continue + if pu.outputs['simtype'] == 'solvent': + unsorted_solvent_repeats[pu.outputs['repeat_id']].append(pu) + else: + unsorted_complex_repeats[pu.outputs['repeat_id']].append(pu) + + repeats: dict[str, dict[str, list[gufe.ProtocolUnitResult]]] = { + 'solvent': {}, 'complex': {}, + } + for k, v in unsorted_solvent_repeats.items(): + repeats['solvent'][str(k)] = sorted(v, key=lambda x: x.outputs['generation']) + + for k, v in unsorted_complex_repeats.items(): + repeats['complex'][str(k)] = sorted(v, key=lambda x: x.outputs['generation']) + return repeats + + +class AbsoluteSolvationVacuumUnit(BaseAbsoluteUnit): + """ + Protocol Unit for the vacuum phase of an absolute solvation free energy + """ + def _get_components(self): + """ + Get the relevant components for a vacuum transformation. + + Returns + ------- + alchem_comps : dict[str, list[Component]] + A list of alchemical components + solv_comp : None + For the gas phase transformation, None will always be returned + for the solvent component of the chemical system. + prot_comp : Optional[ProteinComponent] + The protein component of the system, if it exists. + small_mols : dict[Component, OpenFF Molecule] + The openff Molecules to add to the system. This + is equivalent to the alchemical components in stateA (since + we only allow for disappearing ligands). + """ + stateA = self._inputs['stateA'] + alchem_comps = self._inputs['alchemical_components'] + + off_comps = {m: m.to_openff() + for m in alchem_comps['stateA']} + + _, prot_comp, _ = system_validation.get_components(stateA) + + # Notes: + # 1. Our input state will contain a solvent, we ``None`` that out + # since this is the gas phase unit. + # 2. Our small molecules will always just be the alchemical components + # (of stateA since we enforce only one disappearing ligand) + return alchem_comps, None, prot_comp, off_comps + + def _handle_settings(self) -> dict[str, SettingsBaseModel]: + """ + Extract the relevant settings for a vacuum transformation. + + Returns + ------- + settings : dict[str, SettingsBaseModel] + A dictionary with the following entries: + * forcefield_settings : OpenMMSystemGeneratorFFSettings + * thermo_settings : ThermoSettings + * charge_settings : OpenFFPartialChargeSettings + * solvation_settings : OpenMMSolvationSettings + * alchemical_settings : AlchemicalSettings + * lambda_settings : LambdaSettings + * engine_settings : OpenMMEngineSettings + * integrator_settings : IntegratorSettings + * equil_simulation_settings : MDSimulationSettings + * equil_output_settings : MDOutputSettings + * simulation_settings : SimulationSettings + * output_settings: MultiStateOutputSettings + """ + prot_settings = self._inputs['protocol'].settings + + settings = {} + settings['forcefield_settings'] = prot_settings.vacuum_forcefield_settings + settings['thermo_settings'] = prot_settings.thermo_settings + settings['charge_settings'] = prot_settings.partial_charge_settings + settings['solvation_settings'] = prot_settings.solvation_settings + settings['alchemical_settings'] = prot_settings.alchemical_settings + settings['lambda_settings'] = prot_settings.lambda_settings + settings['engine_settings'] = prot_settings.vacuum_engine_settings + settings['integrator_settings'] = prot_settings.integrator_settings + settings['equil_simulation_settings'] = prot_settings.vacuum_equil_simulation_settings + settings['equil_output_settings'] = prot_settings.vacuum_equil_output_settings + settings['simulation_settings'] = prot_settings.vacuum_simulation_settings + settings['output_settings'] = prot_settings.vacuum_output_settings + + settings_validation.validate_timestep( + settings['forcefield_settings'].hydrogen_mass, + settings['integrator_settings'].timestep + ) + + return settings + + def _execute( + self, ctx: gufe.Context, **kwargs, + ) -> dict[str, Any]: + log_system_probe(logging.INFO, paths=[ctx.scratch]) + + outputs = self.run(scratch_basepath=ctx.scratch, + shared_basepath=ctx.shared) + + return { + 'repeat_id': self._inputs['repeat_id'], + 'generation': self._inputs['generation'], + 'simtype': 'vacuum', + **outputs + } + + +class AbsoluteSolvationSolventUnit(BaseAbsoluteUnit): + """ + Protocol Unit for the solvent phase of an absolute solvation free energy + """ + def _get_components(self): + """ + Get the relevant components for a solvent transformation. + + Returns + ------- + alchem_comps : dict[str, Component] + A list of alchemical components + solv_comp : SolventComponent + The SolventComponent of the system + prot_comp : Optional[ProteinComponent] + The protein component of the system, if it exists. + small_mols : dict[SmallMoleculeComponent: OFFMolecule] + SmallMoleculeComponents to add to the system. + """ + stateA = self._inputs['stateA'] + alchem_comps = self._inputs['alchemical_components'] + + solv_comp, prot_comp, small_mols = system_validation.get_components(stateA) + off_comps = {m: m.to_openff() for m in small_mols} + + # We don't need to check that solv_comp is not None, otherwise + # an error will have been raised when calling `validate_solvent` + # in the Protocol's `_create`. + # Similarly we don't need to check prot_comp since that's also + # disallowed on create + return alchem_comps, solv_comp, prot_comp, off_comps + + def _handle_settings(self) -> dict[str, SettingsBaseModel]: + """ + Extract the relevant settings for a vacuum transformation. + + Returns + ------- + settings : dict[str, SettingsBaseModel] + A dictionary with the following entries: + * forcefield_settings : OpenMMSystemGeneratorFFSettings + * thermo_settings : ThermoSettings + * charge_settings : OpenFFPartialChargeSettings + * solvation_settings : OpenMMSolvationSettings + * alchemical_settings : AlchemicalSettings + * lambda_settings : LambdaSettings + * engine_settings : OpenMMEngineSettings + * integrator_settings : IntegratorSettings + * equil_simulation_settings : MDSimulationSettings + * equil_output_settings : MDOutputSettings + * simulation_settings : MultiStateSimulationSettings + * output_settings: MultiStateOutputSettings + """ + prot_settings = self._inputs['protocol'].settings + + settings = {} + settings['forcefield_settings'] = prot_settings.solvent_forcefield_settings + settings['thermo_settings'] = prot_settings.thermo_settings + settings['charge_settings'] = prot_settings.partial_charge_settings + settings['solvation_settings'] = prot_settings.solvation_settings + settings['alchemical_settings'] = prot_settings.alchemical_settings + settings['lambda_settings'] = prot_settings.lambda_settings + settings['engine_settings'] = prot_settings.solvent_engine_settings + settings['integrator_settings'] = prot_settings.integrator_settings + settings['equil_simulation_settings'] = prot_settings.solvent_equil_simulation_settings + settings['equil_output_settings'] = prot_settings.solvent_equil_output_settings + settings['simulation_settings'] = prot_settings.solvent_simulation_settings + settings['output_settings'] = prot_settings.solvent_output_settings + + settings_validation.validate_timestep( + settings['forcefield_settings'].hydrogen_mass, + settings['integrator_settings'].timestep + ) + + return settings + + def _execute( + self, ctx: gufe.Context, **kwargs, + ) -> dict[str, Any]: + log_system_probe(logging.INFO, paths=[ctx.scratch]) + + outputs = self.run(scratch_basepath=ctx.scratch, + shared_basepath=ctx.shared) + + return { + 'repeat_id': self._inputs['repeat_id'], + 'generation': self._inputs['generation'], + 'simtype': 'solvent', + **outputs + } From 22fbc371137c342d4ea51c020ddb92940478574e Mon Sep 17 00:00:00 2001 From: IAlibay Date: Wed, 11 Dec 2024 14:00:14 +0000 Subject: [PATCH 10/29] Add base code for settings --- .../openmm_afe/equil_afe_settings.py | 107 ++++++++++++++++++ .../openmm_afe/equil_binding_afe_method.py | 14 ++- openfe/protocols/openmm_utils/omm_settings.py | 95 ++++++++++++++++ 3 files changed, 211 insertions(+), 5 deletions(-) diff --git a/openfe/protocols/openmm_afe/equil_afe_settings.py b/openfe/protocols/openmm_afe/equil_afe_settings.py index 1e45e007f..860a67a61 100644 --- a/openfe/protocols/openmm_afe/equil_afe_settings.py +++ b/openfe/protocols/openmm_afe/equil_afe_settings.py @@ -30,6 +30,10 @@ MultiStateOutputSettings, MDSimulationSettings, MDOutputSettings, + BaseRestraintSettings, + HarmonicRestraintSettings, + FlatBottomRestraintSettings, + BoreschRestraintSettings, ) import numpy as np @@ -217,3 +221,106 @@ def must_be_positive(cls, v): including the partial charge assignment method, and the number of conformers used to generate the partial charges. """ + + +class AbsoluteBindingSettings(SettingsBaseModel): + """ + Configuration object for ``AbsoluteBindingPProtocol`` + + See Also + -------- + openfe.protocols.openmm_afe.AbsoluteBindingProtocol + """ + protocol_repeats: int + """ + The number of completely independent repeats of the entire sampling + process. The mean of the repeats defines the final estimate of FE + difference, while the variance between repeats is used as the uncertainty. + """ + + @validator('protocol_repeats') + def must_be_positive(cls, v): + if v <= 0: + errmsg = f"protocol_repeats must be a positive value, got {v}." + raise ValueError(errmsg) + return v + + forcefield_settings: OpenMMSystemGeneratorFFSettings + """Parameters to set up the force field with OpenMM Force Fields""" + thermo_settings: ThermoSettings + """Settings for thermodynamic parameters""" + + solvation_settings: OpenMMSolvationSettings + """Settings for solvating the system.""" + + # Alchemical settings + alchemical_settings: AlchemicalSettings + """ + Alchemical protocol settings. + """ + lambda_settings: LambdaSettings + """ + Settings for controlling the lambda schedule for the different components + (vdw, elec, restraints). + """ + + # MD Engine things + engine_settings: OpenMMEngineSettings + """ + Settings specific to the OpenMM engine, such as the compute platform. + """ + + # Sampling State defining things + integrator_settings: IntegratorSettings + """ + Settings for controlling the integrator, such as the timestep and + barostat settings. + """ + + # Simulation run settings + complex_equil_simulation_settings: MDSimulationSettings + """ + Pre-alchemical complex simulation control settings. + """ + complex_simulation_settings: MultiStateSimulationSettings + """ + Simulation control settings, including simulation lengths + for the complex transformation. + """ + solvent_equil_simulation_settings: MDSimulationSettings + """ + Pre-alchemical solvent simulation control settings. + """ + solvent_simulation_settings: MultiStateSimulationSettings + """ + Simulation control settings, including simulation lengths + for the solvent transformation. + """ + complex_equil_output_settings: MDOutputSettings + """ + Simulation output settings for the complex non-alchemical equilibration. + """ + complex_output_settings: MultiStateOutputSettings + """ + Simulation output settings for the complex transformation. + """ + solvent_equil_output_settings: MDOutputSettings + """ + Simulation output settings for the solvent non-alchemical equilibration. + """ + solvent_output_settings: MultiStateOutputSettings + """ + Simulation output settings for the solvent transformation. + """ + partial_charge_settings: OpenFFPartialChargeSettings + """ + Settings for controlling how to assign partial charges, + including the partial charge assignment method, and the + number of conformers used to generate the partial charges. + """ + restraint_settings: BaseRestraintSettings + """ + Settings controlling how restraints are added to the system in the + complex simulation. + """ + diff --git a/openfe/protocols/openmm_afe/equil_binding_afe_method.py b/openfe/protocols/openmm_afe/equil_binding_afe_method.py index 7943ca4c3..ca0bab06e 100644 --- a/openfe/protocols/openmm_afe/equil_binding_afe_method.py +++ b/openfe/protocols/openmm_afe/equil_binding_afe_method.py @@ -56,6 +56,9 @@ IntegratorSettings, MultiStateOutputSettings, OpenFFPartialChargeSettings, SettingsBaseModel, + HarmonicRestraintSettings, + FlatBottomRestraintSettings, + BoreschRestraintSettings, ) from ..openmm_utils import system_validation, settings_validation from .base import BaseAbsoluteUnit @@ -463,6 +466,7 @@ def _default_settings(cls): solvation_settings=OpenMMSolvationSettings(), engine_settings=OpenMMEngineSettings(), integrator_settings=IntegratorSettings(), + restraint_settings=BoreschRestraintSettings(), solvent_equil_simulation_settings=MDSimulationSettings( equilibration_length_nvt=0.1 * unit.nanosecond, equilibration_length=0.2 * unit.nanosecond, @@ -494,12 +498,12 @@ def _default_settings(cls): production_trajectory_filename='production_equil.xtc', log_output='equil_simulation.log', ), - vacuum_simulation_settings=MultiStateSimulationSettings( + complex_simulation_settings=MultiStateSimulationSettings( n_replicas=28, equilibration_length=1 * unit.nanosecond, production_length=10.0 * unit.nanosecond, ), - vacuum_output_settings=MultiStateOutputSettings( + complex_output_settings=MultiStateOutputSettings( output_filename='complex.nc', checkpoint_storage_filename='complex_checkpoint.nc' ), @@ -573,7 +577,7 @@ def _validate_lambda_schedule( lambda_settings : LambdaSettings the lambda schedule Settings simulation_settings : MultiStateSimulationSettings - the settings for either the vacuum or solvent phase + the settings for either the complex or solvent phase Raises ------ @@ -654,7 +658,7 @@ def _create( # Get the name of the alchemical species alchname = alchem_comps['stateA'][0].name - # Create list units for vacuum and solvent transforms + # Create list units for complex and solvent transforms solvent_units = [ AbsoluteBindingSolventUnit( @@ -712,7 +716,7 @@ def _gather( return repeats -class AbsoluteSolvationVacuumUnit(BaseAbsoluteUnit): +class AbsoluteBindingComplexUnit(BaseAbsoluteUnit): """ Protocol Unit for the vacuum phase of an absolute solvation free energy """ diff --git a/openfe/protocols/openmm_utils/omm_settings.py b/openfe/protocols/openmm_utils/omm_settings.py index 63cb5789c..d1e03ffb1 100644 --- a/openfe/protocols/openmm_utils/omm_settings.py +++ b/openfe/protocols/openmm_utils/omm_settings.py @@ -660,3 +660,98 @@ class Config: Filename for writing the log of the MD simulation, including timesteps, energies, density, etc. """ + +class BaseRestraintSettings(SettingsBaseModel): + """ + Settings contolling how to add restraints to a system. + """ + class Config: + arbitrary_types_allowed = True + + +class BaseDistanceRestraintSettings(BaseRestraintSettings): + """ + Base settings for a harmonic or flatbottom distance between two groups of + atoms. + """ + spring_constant: FloatQuantity['kilojoule_per_mole / nanometer**2'] + """ + The spring constant K between the two atom groups. + """ + atom_group1: Union[list[int], str] + """ + A definition for the first atom group to restrain. + Can either be a list of atom indices or an mdanalysis atom selection query. + """ + atom_group2: Union[list[int], str] + """ + A definition for the second atom group to restrain. + Can either be a list of atom indices or an mdanalysis atom selection query. + """ + +class HarmonicRestraintSettings(BaseDistanceRestraintSettings): + """ + Settings for a harmonic restraint between two groups of atoms. + """ + pass + + +class FlatBottomRestraintSettings(BaseDistanceRestraintSettings): + """ + Settings for a flat bottom restraint between two groups of atoms. + """ + well_radius: FloatQuantity['nanometer']] + """ + The well radius for the flat bottom restraint. + + TODO + ---- + * Implement an option to automatically pick the well radius. + """ + +class BoreschRestraintSettings(SettingsBaseModel): + """ + Settings for a Boresch-style restraint. + """ + host_atoms : Optional[list[int]] + """ + A list 3 host atom indices. + + TODO: How do you relate this back to your input? + """ + guest_atoms : Optional[list[int]] + """ + A list of 3 guest atom indices. + + TODO: How do you relate this back to your input? + """ + K_r: FloatQuantity['kilocalorie_per_mole / nm ** 2'] = 2000.0 * unit.kilocalorie_per_mole / unit.nm **2 + """ + The spring constant for the distance restraint between + host_atom[2] and guest_atom[0]. + """ + K_thetaA: FloatQuantity['kilocalorie_per_mole / radians ** 2'] = 20 * unit.kilocalorie_per_mole / unit.radians**2 + """ + The spring constant for + angle(host_atoms[1], host_atoms[2], guest_atoms[2]) + """ + K_thetaB: FloatQuantity['kilocalorie_per_mole / radians ** 2'] = 20 * unit.kilocalorie_per_mole / unit.radians**2 + """ + The spring constant for + angle(host_atoms[2], guest_atoms[0], guest_atoms[1]) + """ + K_phiA: FloatQuantity['kilocalorie_per_mole / radians ** 2'] = 20 * unit.kilocalorie_per_mole / unit.radians**2 + """ + The spring constant for + dihedral(host_atoms[0], host_atoms[1], host_atoms[2], guest_atoms[0]) + """ + K_phiB: FloatQuantity['kilocalorie_per_mole / radians ** 2'] = 20 * unit.kilocalorie_per_mole / unit.radians**2 + """ + The spring constant for + dihedral(host_atoms[1], host_atoms[2], guest_atoms[0], guest_atoms[1]) + """ + K_phiC: FloatQuantity['kilocalorie_per_mole / radians ** 2'] = 20 * unit.kilocalorie_per_mole / unit.radians**2 + """ + The spring constant for + dihedral(host_atoms[2], guest_atoms[0], guest_atoms[1], guest_aotms[2]) + """ From e1ef18c9ee271170c26676595e39bbd8be0cb212 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Wed, 11 Dec 2024 14:09:16 +0000 Subject: [PATCH 11/29] Fix up the units a bit --- .../openmm_afe/equil_binding_afe_method.py | 69 +++++++++---------- .../openmm_afe/equil_solvation_afe_method.py | 2 +- 2 files changed, 33 insertions(+), 38 deletions(-) diff --git a/openfe/protocols/openmm_afe/equil_binding_afe_method.py b/openfe/protocols/openmm_afe/equil_binding_afe_method.py index ca0bab06e..ff976815b 100644 --- a/openfe/protocols/openmm_afe/equil_binding_afe_method.py +++ b/openfe/protocols/openmm_afe/equil_binding_afe_method.py @@ -718,44 +718,38 @@ def _gather( class AbsoluteBindingComplexUnit(BaseAbsoluteUnit): """ - Protocol Unit for the vacuum phase of an absolute solvation free energy + Protocol Unit for the complex phase of an absolute binding free energy """ def _get_components(self): """ - Get the relevant components for a vacuum transformation. + Get the relevant components for a complex transformation. Returns ------- - alchem_comps : dict[str, list[Component]] + alchem_comps : dict[str, Component] A list of alchemical components - solv_comp : None - For the gas phase transformation, None will always be returned - for the solvent component of the chemical system. + solv_comp : SolventComponent + The SolventComponent of the system prot_comp : Optional[ProteinComponent] The protein component of the system, if it exists. - small_mols : dict[Component, OpenFF Molecule] - The openff Molecules to add to the system. This - is equivalent to the alchemical components in stateA (since - we only allow for disappearing ligands). + small_mols : dict[SmallMoleculeComponent: OFFMolecule] + SmallMoleculeComponents to add to the system. """ stateA = self._inputs['stateA'] alchem_comps = self._inputs['alchemical_components'] - off_comps = {m: m.to_openff() - for m in alchem_comps['stateA']} - - _, prot_comp, _ = system_validation.get_components(stateA) + solv_comp, prot_comp, small_mols = system_validation.get_components(stateA) + off_comps = {m: m.to_openff() for m in small_mols} - # Notes: - # 1. Our input state will contain a solvent, we ``None`` that out - # since this is the gas phase unit. - # 2. Our small molecules will always just be the alchemical components - # (of stateA since we enforce only one disappearing ligand) - return alchem_comps, None, prot_comp, off_comps + # We don't need to check that solv_comp is not None, otherwise + # an error will have been raised when calling `validate_solvent` + # in the Protocol's `_create`. + # Similarly we don't need to check prot_comp + return alchem_comps, solv_comp, prot_comp, off_comps def _handle_settings(self) -> dict[str, SettingsBaseModel]: """ - Extract the relevant settings for a vacuum transformation. + Extract the relevant settings for a complex transformation. Returns ------- @@ -773,22 +767,24 @@ def _handle_settings(self) -> dict[str, SettingsBaseModel]: * equil_output_settings : MDOutputSettings * simulation_settings : SimulationSettings * output_settings: MultiStateOutputSettings + * restraint_settings: BaseRestraintSettings """ prot_settings = self._inputs['protocol'].settings settings = {} - settings['forcefield_settings'] = prot_settings.vacuum_forcefield_settings + settings['forcefield_settings'] = prot_settings.forcefield_settings settings['thermo_settings'] = prot_settings.thermo_settings settings['charge_settings'] = prot_settings.partial_charge_settings settings['solvation_settings'] = prot_settings.solvation_settings settings['alchemical_settings'] = prot_settings.alchemical_settings - settings['lambda_settings'] = prot_settings.lambda_settings - settings['engine_settings'] = prot_settings.vacuum_engine_settings + settings['lambda_settings'] = prot_settings.complex_lambda_settings + settings['engine_settings'] = prot_settings.engine_settings settings['integrator_settings'] = prot_settings.integrator_settings - settings['equil_simulation_settings'] = prot_settings.vacuum_equil_simulation_settings - settings['equil_output_settings'] = prot_settings.vacuum_equil_output_settings - settings['simulation_settings'] = prot_settings.vacuum_simulation_settings - settings['output_settings'] = prot_settings.vacuum_output_settings + settings['equil_simulation_settings'] = prot_settings.complex_equil_simulation_settings + settings['equil_output_settings'] = prot_settings.complex_equil_output_settings + settings['simulation_settings'] = prot_settings.complex_simulation_settings + settings['output_settings'] = prot_settings.complex_output_settings + settings['restraint_settings'] = prot_settings.restraint_settings settings_validation.validate_timestep( settings['forcefield_settings'].hydrogen_mass, @@ -808,14 +804,14 @@ def _execute( return { 'repeat_id': self._inputs['repeat_id'], 'generation': self._inputs['generation'], - 'simtype': 'vacuum', + 'simtype': 'complex', **outputs } -class AbsoluteSolvationSolventUnit(BaseAbsoluteUnit): +class AbsoluteBindingSolventUnit(BaseAbsoluteUnit): """ - Protocol Unit for the solvent phase of an absolute solvation free energy + Protocol Unit for the solvent phase of an absolute binding free energy """ def _get_components(self): """ @@ -841,13 +837,12 @@ def _get_components(self): # We don't need to check that solv_comp is not None, otherwise # an error will have been raised when calling `validate_solvent` # in the Protocol's `_create`. - # Similarly we don't need to check prot_comp since that's also - # disallowed on create + # Similarly we don't need to check prot_comp return alchem_comps, solv_comp, prot_comp, off_comps def _handle_settings(self) -> dict[str, SettingsBaseModel]: """ - Extract the relevant settings for a vacuum transformation. + Extract the relevant settings for a solvent transformation. Returns ------- @@ -869,13 +864,13 @@ def _handle_settings(self) -> dict[str, SettingsBaseModel]: prot_settings = self._inputs['protocol'].settings settings = {} - settings['forcefield_settings'] = prot_settings.solvent_forcefield_settings + settings['forcefield_settings'] = prot_settings.forcefield_settings settings['thermo_settings'] = prot_settings.thermo_settings settings['charge_settings'] = prot_settings.partial_charge_settings settings['solvation_settings'] = prot_settings.solvation_settings settings['alchemical_settings'] = prot_settings.alchemical_settings - settings['lambda_settings'] = prot_settings.lambda_settings - settings['engine_settings'] = prot_settings.solvent_engine_settings + settings['lambda_settings'] = prot_settings.solvent_lambda_settings + settings['engine_settings'] = prot_settings.engine_settings settings['integrator_settings'] = prot_settings.integrator_settings settings['equil_simulation_settings'] = prot_settings.solvent_equil_simulation_settings settings['equil_output_settings'] = prot_settings.solvent_equil_output_settings diff --git a/openfe/protocols/openmm_afe/equil_solvation_afe_method.py b/openfe/protocols/openmm_afe/equil_solvation_afe_method.py index b77df9dfb..12d16dd30 100644 --- a/openfe/protocols/openmm_afe/equil_solvation_afe_method.py +++ b/openfe/protocols/openmm_afe/equil_solvation_afe_method.py @@ -889,7 +889,7 @@ def _get_components(self): def _handle_settings(self) -> dict[str, SettingsBaseModel]: """ - Extract the relevant settings for a vacuum transformation. + Extract the relevant settings for a solvent transformation. Returns ------- From 158ce40b55dd761cf2b008a191d74909a3a3bd6c Mon Sep 17 00:00:00 2001 From: IAlibay Date: Wed, 11 Dec 2024 14:45:05 +0000 Subject: [PATCH 12/29] Deal with the restraint addition --- openfe/protocols/openmm_afe/base.py | 4 +- .../openmm_afe/equil_binding_afe_method.py | 68 +++++++++++++++++++ 2 files changed, 70 insertions(+), 2 deletions(-) diff --git a/openfe/protocols/openmm_afe/base.py b/openfe/protocols/openmm_afe/base.py index 7ec2acea2..6c3a47ac7 100644 --- a/openfe/protocols/openmm_afe/base.py +++ b/openfe/protocols/openmm_afe/base.py @@ -576,7 +576,7 @@ def _add_restraints(self, system, topology, settings): """ Placeholder method to add restraints if necessary """ - return None, None + return None, None, system def _get_alchemical_system( self, @@ -1050,7 +1050,7 @@ def run(self, dry=False, verbose=True, lambdas = self._get_lambda_schedule(settings) # 6. Add restraints - restraint_parameter_state, standard_state_corr = self._add_restraints( + restraint_parameter_state, standard_state_corr, omm_system = self._add_restraints( omm_system, omm_topology, settings ) diff --git a/openfe/protocols/openmm_afe/equil_binding_afe_method.py b/openfe/protocols/openmm_afe/equil_binding_afe_method.py index ff976815b..3c1e63acb 100644 --- a/openfe/protocols/openmm_afe/equil_binding_afe_method.py +++ b/openfe/protocols/openmm_afe/equil_binding_afe_method.py @@ -39,6 +39,7 @@ import numpy.typing as npt from openff.units import unit from openmmtools import multistate +from openmmtools.state import ThermodynamicState, GlobalParameterState from typing import Optional, Union from typing import Any, Iterable import uuid @@ -793,6 +794,73 @@ def _handle_settings(self) -> dict[str, SettingsBaseModel]: return settings + def _add_restraints( + self, + system: openmm.System, + topology: openmm.app.Topology, + settings: dict[str, SettingsBaseModel] + ) -> [GlobalParameterState, unit.Quantity, openmm.System]: + """ + Find and add restraints to the OpenMM System. + + Parameters + ---------- + system : openmm.System + The System to add the restraint to. + topology : openmm.app.Topology + An OpenMM Topology that defines the System. + settings : dict[str, SettingsBaseModel] + A dictionary of settings that defines how to find and set + the restraint. + + Returns + ------- + restraint_parameter_state : RestraintParameterState + A RestraintParameterState object that defines the control + parameter for the restraint. + correction : unit.Quantity + The standard state correction for the restraint. + system : openmm.System + A copy of the System with the restraint added. + """ + from openfe.protocols.openmm_utils import ( + omm_restraints, geometry, search + ) + + if isinstance(settings['restraints_settings'], BoreschRestraintSettings): + geom = search.get_boresch_restraint( + topology, + self.shared_basepath / settings['equil_output_settings'].production_trajectory_filename + ) + + restraint = omm_restraints.BoreschRestraint( + settings['restraints_settings'], + geom, + controlling_parameter_name='lambda_restraints' + ) + else: + # TODO turn this into a direction for different restraint types supported? + raise NotImplementedError() + + # We need a temporary thermodynamic state to add the restraint + # & get the correction + thermodynamic_state = ThermodynamicState( + system, + temperature=to_openmm(settings['thermo_settings'].temperature), + pressure=to_openmm(settings['thermo_settings'].pressure), + ) + + # Add the force to the thermodynamic state + restraint.add_force(thermodynamic_state) + # Get the standard state correction as a unit.Quantity + correction = restraint.get_standard_state_correction(thermodynamic_state) + + # Get the GlobalParameterState for the restraint + retraint_parameter_state = omm_restraints.RestraintParameterState( + lambda_restraints=1.0 + ) + return restraint_parameter_state, correction, thermodynamic_state.system + def _execute( self, ctx: gufe.Context, **kwargs, ) -> dict[str, Any]: From a19c86cae5100520cef62a87de7a6dcd44cb5050 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Thu, 12 Dec 2024 12:03:52 +0000 Subject: [PATCH 13/29] refactor restraints --- .../openmm_utils/restraints/geometry.py | 90 ----- .../restraints/geometry/__init__.py | 0 .../openmm_utils/restraints/geometry/base.py | 50 +++ .../restraints/geometry/boresch.py | 66 ++++ .../restraints/geometry/flatbottom.py | 90 +++++ .../restraints/geometry/harmonic.py | 94 +++++ .../openmm_utils/restraints/geometry/utils.py | 360 ++++++++++++++++++ .../restraints/openmm/__init__.py | 0 .../restraints/{ => openmm}/omm_forces.py | 0 .../restraints/{ => openmm}/omm_restraints.py | 0 .../openmm_utils/restraints/search.py | 360 ++++++++++++++++++ 11 files changed, 1020 insertions(+), 90 deletions(-) delete mode 100644 openfe/protocols/openmm_utils/restraints/geometry.py create mode 100644 openfe/protocols/openmm_utils/restraints/geometry/__init__.py create mode 100644 openfe/protocols/openmm_utils/restraints/geometry/base.py create mode 100644 openfe/protocols/openmm_utils/restraints/geometry/boresch.py create mode 100644 openfe/protocols/openmm_utils/restraints/geometry/flatbottom.py create mode 100644 openfe/protocols/openmm_utils/restraints/geometry/harmonic.py create mode 100644 openfe/protocols/openmm_utils/restraints/geometry/utils.py create mode 100644 openfe/protocols/openmm_utils/restraints/openmm/__init__.py rename openfe/protocols/openmm_utils/restraints/{ => openmm}/omm_forces.py (100%) rename openfe/protocols/openmm_utils/restraints/{ => openmm}/omm_restraints.py (100%) create mode 100644 openfe/protocols/openmm_utils/restraints/search.py diff --git a/openfe/protocols/openmm_utils/restraints/geometry.py b/openfe/protocols/openmm_utils/restraints/geometry.py deleted file mode 100644 index 14e1cd289..000000000 --- a/openfe/protocols/openmm_utils/restraints/geometry.py +++ /dev/null @@ -1,90 +0,0 @@ -# This code is part of OpenFE and is licensed under the MIT license. -# For details, see https://github.com/OpenFreeEnergy/openfe -""" -Restraint Geometry classes - -TODO ----- -* Add relevant duecredit entries. -""" -import abc -from pydantic.v1 import BaseModel, validator - -from openff.units import unit -import MDAnalysis as mda -from MDAnalysis.lib.distances import calc_bonds - - -class BaseRestraintGeometry(BaseModel, abc.ABC): - class Config: - arbitrary_types_allowed = True - - -class HostGuestRestraintGeometry(BaseRestraintGeometry): - """ - An ordered list of guest atoms to restrain. - - Note - ---- - The order matters! It will be used to define the underlying - force. - """ - - guest_atoms: list[int] - """ - An ordered list of host atoms to restrain. - - Note - ---- - The order matters! It will be used to define the underlying - force. - """ - host_atoms: list[int] - - @validator("guest_atoms", "host_atoms") - def positive_idxs(cls, v): - if any([i < 0 for i in v]): - errmsg = "negative indices passed" - raise ValueError(errmsg) - return v - - -class CentroidDistanceMixin: - def get_distance(self, topology, coordinates) -> unit.Quantity: - u = mda.Universe(topology, coordinates) - ag1 = u.atoms[self.host_atoms] - ag2 = u.atoms[self.guest_atoms] - bond = calc_bonds( - ag1.center_of_mass(), ag2.center_of_mass(), u.atoms.dimensions - ) - # convert to float so we avoid having a np.float64 - return float(bond) * unit.angstrom - - -def _check_single_atoms(value): - if len(value) != 1: - errmsg = ( - "Host and guest atom lists must only include a single atom, " - f"got {len(value)} atoms." - ) - raise ValueError(errmsg) - return value - - -class BondDistanceMixin: - def get_distance(self, topology, coordinates) -> unit.Quantity: - u = mda.Universe(topology, coordinates) - at1 = u.atoms[self.host_atoms[0]] - at2 = u.atoms[self.guest_atoms[0]] - bond = calc_bonds(at1.position, at2.position, u.atoms.dimensions) - # convert to float so we avoid having a np.float64 value - return float(bond) * unit.angstrom - - -class CentroidDistanceRestraintGeometry(HostGuestRestraintGeometry, CentroidDistanceMixin): - pass - - -class BondDistanceRestraintGeoemtry(HostGuestRestraintGeometry, BondDistanceMixin): - _check_host_atoms: classmethod = validator("host_atoms", allow_reuse=True)(_check_single_atoms) - _check_guest_atoms: classmethod = validator("guest_atoms", allow_reuse=True)(_check_single_atoms) diff --git a/openfe/protocols/openmm_utils/restraints/geometry/__init__.py b/openfe/protocols/openmm_utils/restraints/geometry/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/openfe/protocols/openmm_utils/restraints/geometry/base.py b/openfe/protocols/openmm_utils/restraints/geometry/base.py new file mode 100644 index 000000000..21a714cde --- /dev/null +++ b/openfe/protocols/openmm_utils/restraints/geometry/base.py @@ -0,0 +1,50 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +Restraint Geometry classes + +TODO +---- +* Add relevant duecredit entries. +""" +import abc +from pydantic.v1 import BaseModel, validator + +from openff.units import unit +import MDAnalysis as mda +from MDAnalysis.lib.distances import calc_bonds, calc_angles + + +class BaseRestraintGeometry(BaseModel, abc.ABC): + class Config: + arbitrary_types_allowed = True + + +class HostGuestRestraintGeometry(BaseRestraintGeometry): + """ + An ordered list of guest atoms to restrain. + + Note + ---- + The order matters! It will be used to define the underlying + force. + """ + + guest_atoms: list[int] + """ + An ordered list of host atoms to restrain. + + Note + ---- + The order matters! It will be used to define the underlying + force. + """ + host_atoms: list[int] + + @validator("guest_atoms", "host_atoms") + def positive_idxs(cls, v): + if any([i < 0 for i in v]): + errmsg = "negative indices passed" + raise ValueError(errmsg) + return v + diff --git a/openfe/protocols/openmm_utils/restraints/geometry/boresch.py b/openfe/protocols/openmm_utils/restraints/geometry/boresch.py new file mode 100644 index 000000000..822382b9c --- /dev/null +++ b/openfe/protocols/openmm_utils/restraints/geometry/boresch.py @@ -0,0 +1,66 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +Restraint Geometry classes + +TODO +---- +* Add relevant duecredit entries. +""" +import abc +from pydantic.v1 import BaseModel, validator + +from openff.units import unit +import MDAnalysis as mda +from MDAnalysis.lib.distances import calc_bonds, calc_angles + +from .base import HostGuestRestraintGeometry + + +class BoreschRestraintGeometry(HostGuestRestraintGeometry): + """ + A class that defines the restraint geometry for a Boresch restraint. + + The restraint is defined by the following: + + H0 G2 + - - + - - + H1 - - H2 -- G0 - - G1 + + Where HX represents the X index of ``host_atoms`` and GX + the X index of ``guest_atoms``. + """ + def get_bond_distance(self, topology, coordinates) -> unit.Quantity: + u = mda.Universe(topology, coordinates) + at1 = u.atoms[host_atoms[2]] + at2 = u.atoms[guest_atoms[0]] + bond = calc_bonds(at1.position, at2.position, u.atoms.dimensions) + # convert to float so we avoid having a np.float64 + return float(bond) * unit.angstrom + + def get_angles(self, topology, coordinates) -> unit.Quantity: + u = mda.Universe(topology, coordinates) + at1 = u.atoms[host_atoms[1]] + at2 = u.atoms[host_atoms[2]] + at3 = u.atoms[guest_atoms[0]] + at4 = u.atoms[guest_atoms[1]] + + angleA = calc_angles(at1.position, at2.position, at3.position, u.atoms.dimensions) + angleB = calc_angles(at2.position, at3.position, at4.position, u.atoms.dimensions) + return angleA, angleB + + def get_dihedrals(self, topology, coordinates) -> unit.Quantity: + u = mda.Universe(topology, coordinates) + at1 = u.atoms[host_atoms[0]] + at2 = u.atoms[host_atoms[1]] + at3 = u.atoms[host_atoms[2]] + at4 = u.atoms[guest_atoms[0]] + at5 = u.atoms[guest_atoms[1]] + at6 = u.atoms[guest_atoms[2]] + + dihA = calc_dihedrals(at1.position, at2.position, at3.position, at4.position, u.atoms.dimensions) + dihB = calc_dihedrals(at2.position, at3.position, at4.position, at5.position, u.atoms.dimensions) + dihC = calc_dihedrals(at3.position, at4.position, at5.position, at6.position, u.atoms.dimensions) + + return dihA, dihB, dihC diff --git a/openfe/protocols/openmm_utils/restraints/geometry/flatbottom.py b/openfe/protocols/openmm_utils/restraints/geometry/flatbottom.py new file mode 100644 index 000000000..c7e987736 --- /dev/null +++ b/openfe/protocols/openmm_utils/restraints/geometry/flatbottom.py @@ -0,0 +1,90 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +Restraint Geometry classes + +TODO +---- +* Add relevant duecredit entries. +""" +import abc +from pydantic.v1 import BaseModel, validator + +import numpy as np +from openff.units import unit +import MDAnalysis as mda +from MDAnalysis.analysis.base import AnalysisBase +from MDAnalysis.lib.distances import calc_bonds, calc_angles + +from .harmonic import ( + DistanceRestraintGeometry, + _get_selection, +) + + +class FlatBottomDistanceGeometry(DistanceRestraintGeometry): + """ + A geometry class for a flat bottom distance restraint between two groups + of atoms. + """ + + well_radius: FloatQuantity["nanometer"] + + +class COMDistanceAnalysis(AnalysisBase): + """ + Get a timeseries of COM distances between two AtomGroups + + Parameters + ---------- + group1 : MDAnalysis.AtomGroup + Atoms defining the first centroid. + group2 : MDANalysis.AtomGroup + Atoms defining the second centroid. + """ + + _analysis_algorithm_is_parallelizable = False + + def __init__(self, host_atoms, guest_atoms, search_distance, **kwargs): + super().__init__(host_atoms.universe.trajectory, **kwargs) + + self.ag1 = group1 + self.ag2 = group2 + + def _prepare(self): + self.results.distances = np.zeros(self.n_frames) + + def _single_frame(self): + com_dist = calc_bonds( + self.ag1.center_of_mass(), + self.ag2.center_of_mass(), + box=self.ag1.universe.dimensions, + ) + self.results.distances[self._frame_index] = com_dist + + def _conclude(self): + pass + + +def get_flatbottom_distance_restraint( + topology: Union[str, openmm.app.Topology], + trajectory: pathlib.Path, + topology_format: Optional[str] = None, + host_atoms: Optional[list[int]] = None, + guest_atoms: Optional[list[int]] = None, + host_selection: Optional[str] = None, + guest_selection: Optional[str] = None, + padding: unit.Quantity = 0.5 * unit.nanometer, +) -> FlatBottomDistanceGeometry: + u = mda.Universe(topology, trajectory, topology_format=topology_format) + + guest_ag = _get_selection(u, guest_atoms, guest_selection) + host_ag = _get_selection(u, host_atoms, host_selection) + + com_dists = COMDistanceAnalysis(guest_ag, host_ag) + com_dists.run() + + well_radius = com_dists.results.distances.max() * unit.angstrom + padding + return FlatBottomDistanceGeometry( + guest_atoms=guest_atoms, host_atoms=host_atoms, well_radius=well_radius + ) diff --git a/openfe/protocols/openmm_utils/restraints/geometry/harmonic.py b/openfe/protocols/openmm_utils/restraints/geometry/harmonic.py new file mode 100644 index 000000000..36e7a61a7 --- /dev/null +++ b/openfe/protocols/openmm_utils/restraints/geometry/harmonic.py @@ -0,0 +1,94 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +Restraint Geometry classes + +TODO +---- +* Add relevant duecredit entries. +""" +import abc +from pydantic.v1 import BaseModel, validator + +from openff.units import unit +import MDAnalysis as mda +from MDAnalysis.lib.distances import calc_bonds, calc_angles +from rdkit import Chem + +from .base import HostGuestRestraintGeometry +from .utils import _get_central_atom_idx + + +class DistanceRestraintGeometry(HostGuestRestraintGeometry): + """ + A geometry class for a distance restraint between two groups of atoms. + """ + + def get_distance(self, topology, coordinates) -> unit.Quantity: + u = mda.Universe(topology, coordinates) + ag1 = u.atoms[self.host_atoms] + ag2 = u.atoms[self.guest_atoms] + bond = calc_bonds( + ag1.center_of_mass(), ag2.center_of_mass(), box=u.atoms.dimensions + ) + # convert to float so we avoid having a np.float64 + return float(bond) * unit.angstrom + + +def _get_selection(universe, atom_list, selection): + if atom_list is None: + if selection is None: + raise ValueError( + "one of either the atom lists or selections must be defined" + ) + + ag = universe.select_atoms(selection) + else: + ag = universe.atoms[atom_list] + + return ag + + +def get_distance_restraint( + topology: Union[str, openmm.app.Topology], + trajectory: pathlib.Path, + topology_format: Optional[str] = None, + host_atoms: Optional[list[int]] = None, + guest_atoms: Optional[list[int]] = None, + host_selection: Optional[str] = None, + guest_selection: Optional[str] = None, +) -> DistanceRestraintGeometry: + u = mda.Universe(topology, trajectory, topology_format=topology_format) + + guest_ag = _get_selection(u, guest_atoms, guest_selection) + host_ag = _get_selection(u, host_atoms, host_selection) + + return DistanceRestraintGeometry(guest_atoms=guest_atoms, host_atoms=host_atoms) + + +def get_molecule_centers_restraint( + topology: Union[str, openmm.app.Topology], + trajectory: pathlib.Path, + molA_rdmol: Chem.Mol, + molB_rdmol: Chem.Mol, + molA_idxs: list[int], + molB_idxs: list[int], + topology_format: Optional[str] = None, +): + # We assume that the mol idxs are ordered + centerA = molA_idxs[_get_central_atom_idx(molA_rdmol)] + centerB = molB_idxs[_get_central_atom_idx(molB_rdmol)] + + u = mda.Universe(topology, trajectory, topology_format=topology_format) + guest_ag = _get_selection( + u, + [centerA], + None, + ) + guest_ag = _get_selection( + u, + [centerB], + None, + ) + + return DistsanceRestraintGeometry(guest_atoms=guest_atoms, host_atoms=host_atoms) diff --git a/openfe/protocols/openmm_utils/restraints/geometry/utils.py b/openfe/protocols/openmm_utils/restraints/geometry/utils.py new file mode 100644 index 000000000..6b3d94eb7 --- /dev/null +++ b/openfe/protocols/openmm_utils/restraints/geometry/utils.py @@ -0,0 +1,360 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +Search methods for generating Geometry objects + +TODO +---- +* Add relevant duecredit entries. +""" +import abc +from pydantic.v1 import BaseModel, validator + +from openff.toolkit import Molecule as OFFMol +from openff.units import unit +import networkx as nx +import MDAnalysis as mda +from MDAnalysis.analysis.base import AnalysisBase +from MDAnalysis.lib.distances import calc_bonds, calc_angles + + +def _get_aromatic_atom_idxs(rdmol) -> list[int]: + """ + Helper method to get aromatic atoms idxs + in a RDKit Molecule + + Parameters + ---------- + rdmol : ??? + RDKit Molecule + + Returns + ------- + list[int] + A list of the aromatic atom idxs + """ + idxs = [ + at.GetIdx() for at in rdmol.GetAtoms() + if at.GetIsAromatic() + ] + return idxs + + +def _get_heavy_atom_idxs(rdmol) -> list[int]: + """ + Get idxs of heavy atoms in an RDKit Molecule + + Parameters + ---------- + rmdol : ??? + + Returns + ------- + list[int] + A list of heavy atom idxs + """ + idxs = [ + at.GetIdx() for at in rdmol.GetAtoms() + if at.GetAtomicNum() > 1 + ] + return idxs + + +def _get_central_atom_idx(rdmol) -> int: + offmol = OFFMol(rdmol, allow_undefined_stereo=True) + # We take the zero-th entry if there are multiple center + # atoms (e.g. equal likelihood centers) + center = nx.center(offmol.to_networkx())[0] + return center + + +def _sort_by_distance_from_target(rdmol, target_idx: int, atom_idxs: list[int]) -> list[int]: + """ + Sort a list of atoms by their distance from a target atom. + + Parameters + ---------- + target_idx : int + The idx of the target atom. + atom_idxs : list[int] + The idx values of the atoms to sort. + rdmol : ??? + RDKit Molecule the atoms belong to + + Returns + ------- + list[int] + The input atom idxs sorted by their distance from the target atom. + """ + distances = [] + + conformer = rdmol.GetConformer() + # Get the target atom position + target_pos = conformer.GetAtomPosition(target_idx) + + for idx in atom_idxs: + pos = conformer.GetAtomPosition(idx) + distances.append(((target_pos - pos).Length(), idx)) + + return [i[1] for i in sorted(distances)] + + +def _get_bonded_angles_from_pool(rdmol, atom_idx, atom_pool): + angles = [] + + # Get the base atom and its neighbors + at1 = rdmol.GetAtomWithIdx(atom_idx) + at1_neighbors = [at.GetIdx() for at in at1.GetNeighbors()] + + # We loop at2 and at3 through the sorted atom_pool in order to get + # a list of angles in the branch that are sorted by how close the atoms + # are from the central atom + for at2 in atom_pool: + if at2 in at1_neighbors: + at2_neighbors = [ + at.GetIdx() + for at in rdmol.GetAtomWithIdx(at2).GetNeighbors() + ] + for at3 in atom_pool: + if at3 != atom_idx and at3 in at2_neighbors: + angles.append((atom_idx, at2, at3)) + return angles + + +def get_ligand_anchor_atoms(rdmol) -> list[tuple[int, int, int]]: + """ + Get a list of ligand anchor atoms (e.g. l1, l2, and l3 of an orientational restraint). + + Parameters + ---------- + rdmol : ??? + Molecule object for the ligand to apply a restraint to. + + Returns + ------- + angles : list[tuple[int, int, int]] + A list of ligand atom triples denoting the possible l1, l2, and l3 + restraint atoms. Ordered by likelihood of restraint-ability. + """ + # Find the central atom + center = _get_central_atom_idx(rdmol) + + # Get a pool of potential anchor atoms looking for aromatic atoms + anchor_pool = _get_aromatic_atoms(rdmol) + + # If there are not enough aromatic atoms, then default to heavy atoms + if len(anchor_pool) < 3: + anchor_pool = _get_heavy_atoms(rdmol) + + # Raise an error if we have less than 3 anchors + if len(anchor_pool) < 3: + errmsg = f"Too few potential ligand anchor atoms, {len(anchor_pool)}" + raise ValueError(errmsg) + + # Sort the pool of anchor atoms by their distance from the central atom + sorted_anchor_pool = _sort_by_distance_from_target(rdmol, center, anchor_pool) + + # Get a list of ligand anchor angle atoms + angles = [] + for atom in sorted_anchor_pool: + angles.extend( + _get_bonded_angles_from_pool(rdmol, atom, sorted_anchor_pool) + ) + + +def get_host_anchors(positions, topology, exclude_resids: list[int], lig_anchor_idx: int, selection: str): + """ + Get a list of host anchor atomss sorted by their distance from a ligand anchor atom. + + Parameters + ---------- + positions : openmm.unit.Quantity + Positions of the input system + topology : openmm.app.Topology + OpenMM Topology for input system + exclude_resids : list[int] + List of residue numbers to exclude from host selection + lig_anchor_idx : int + The index of the l1 ligand anchor. + selection : str + Selection string for the host atoms. + """ + # Create an mdtraj trajectory to manipulate + # First fetch the box vectors and pass them as lengths and angles + vectors = from_openmm(topology.getPeriodicBoxVectors()) + a, b, c, alpha, beta, gamma = mdt.utils.box_vectors_to_lengths_and_angles(vectors[0].m, vectors[1].m, vectors[2].m) + + traj = mdt.Trajectory( + positions[np.newaxis, ...], + mdt.Topology.from_openmm(topology) + ) + + # Get all the potential protein atoms matching the selection + host_sel = traj.topology.select(selection) + + # Get residues to exclude from the selection + exclude_sel = np.array([ + at.index for at in + chain(*[traj.topology.residue(i).atoms for i in exclude_resids]) + ]) + + # Remove exclusion + anchors = host_sel[np.isin(host_sel, exclude_sel, invert=True)] + + # Compute distanecs from ligand l1 anchor atom + pairs = np.vstack((anchors, np.array([lig_anchor_idx for _ in range(len(anchors))]))).T + + distances = mdt.compute_distances(traj, pairs, periodic=True) + + return np.array([pairs[i][0] for i in np.argsort(distances[0])]) + + +def is_collinear(positions, atoms, threshold=0.9): + """ + Check whether any sequential vectors in a sequence of atoms are collinear. + + Parameters + ---------- + positions : openmm.unit.Quantity + System positions. + atoms : list[int] + The indices of the atoms to test. + threshold : float + Atoms are not collinear if their sequential vector separation dot + products are less than ``threshold``. Default 0.9. + + Returns + ------- + result : bool + Returns True if any sequential pair of vectors is collinear; False otherwise. + + Notes + ----- + Originally from Yank, with modifications from Separated Topologies + """ + results = False + for i in range(len(atoms) - 2): + v1 = positions[atoms[i + 1], :] - positions[atoms[i], :] + v2 = positions[atoms[i + 2], :] - positions[atoms[i + 1], :] + normalized_inner_product = np.dot(v1, v2) / np.sqrt(np.dot(v1, v1) * np.dot(v2, v2)) + result = result or (np.abs(normalized_inner_product) > threshold) + return result + + +def check_angle(angle, force_constant=83.68): + """ + Check whether the chosen angle is less than 10 kT from 0 or 180 + + Parameters + ---------- + angle : float + The angle to check in degrees. + force_constant : float + Force constant of the angle. + + Note + ---- + We assume the temperature to be 298.15 Kelvin. + """ + # TODO: convert this to unit.Quantity so we don't end up with + # conversion errors + RT = 8.31445985 * 0.001 * 298.15 + # check if angle is <10kT from 0 or 180 + check1 = 0.5 * force_constant * np.power((angle - 0.0) / 180.0 * np.pi, 2) + check2 = 0.5 * force_constant * np.power((angle - 180.0) / 180.0 * np.pi, 2) + ang_check_1 = check1 / RT + ang_check_2 = check2 / RT + if ang_check_1 < 10.0 or ang_check_2 < 10.0: + return False + return True + + + + +class FindHostAtoms(AnalysisBase): + """ + Class filter host atoms based on their distance + from a set of guest atoms. + + Parameters + ---------- + host_atoms : MDAnalysis.AtomGroup + Initial selection of host atoms to filter from. + guest_atoms : MDANalysis.AtomGroup + Selection of guest atoms to search around. + search_distance: unit.Quantity + Distance to filter atoms within. + """ + _analysis_algorithm_is_parallelizable = False + + def __init__(self, host_atoms, guest_atoms, search_distance, **kwargs): + super().__init__(host_atoms.universe.trajectory, **kwargs) + + self.host_ag = host_atoms + self.guest_ag = guest_atoms + self.cutoff = search_distance.to('angstrom').m + + def _prepare(self): + self.results.host_idxs = set() + + def _single_frame(self): + pairs = capped_distance( + reference=self.host_ag.positions, + configuration=self.guest_ag.positions, + max_cutoff=self.cutoff, + min_cutoff=None + box=self.guest_ag.universe.dimensions, + return_distances=False) + + host_idxs = [self.guest_ag.atoms[p].index for p in pairs[:, 1]] + self.results.host_idxs.update(set(host_idxs)) + + def _conclude(self): + pass + + +def find_host_atoms(topology, trajectory, host_selection, guest_selection, cutoff) -> mda.AtomGroup: + """ + Get an AtomGroup of the host atoms based on their distances from the guest atoms. + """ + u = mda.Universe(topology, trajectory) + + def _get_selection(selection): + """ + If it's a str, call select_atoms, if not a list of atom idxs + """ + if isinstance(selection, str): + ag = u.select_atoms(host_selection) + else: + ag = u.atoms[host_ag] + return ag + + host_ag = _get_selection(host_selection) + guest_ag = _get_selection(guest_selection) + + finder = FindHostAtoms(host_ag, guest_ag, cutoff) + finder.run() + + return u.atoms[list(finder.results.host_idxs)] + +def get_molecule_center_idx(atomgroup): + offmol = Molecule(atomgroup.convert_to("RDKIT"), allow_undefined_stereo=True) + # Check if the molecule is whole, otherwise throw an error. + nx = offmol.to_networkx() + + +def get_distance_restraint(topology, trajectory, host_atoms, guest_atoms, host_selection, guest_selection): + u = mda.Universe(topology, trajectory) + + if guest_atoms is None: + if guest_selection is None: + raise ValueError("one of guest_atoms or guest_selections must be defined") + guest_ag = u.select_atoms(guest_selection) + else: + + + if host_atoms is None: + if host_selection is None: + raise ValueError("one of host_atoms or host_selection must be defined") + + host_ag = u.select_atoms(host_selection) diff --git a/openfe/protocols/openmm_utils/restraints/openmm/__init__.py b/openfe/protocols/openmm_utils/restraints/openmm/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/openfe/protocols/openmm_utils/restraints/omm_forces.py b/openfe/protocols/openmm_utils/restraints/openmm/omm_forces.py similarity index 100% rename from openfe/protocols/openmm_utils/restraints/omm_forces.py rename to openfe/protocols/openmm_utils/restraints/openmm/omm_forces.py diff --git a/openfe/protocols/openmm_utils/restraints/omm_restraints.py b/openfe/protocols/openmm_utils/restraints/openmm/omm_restraints.py similarity index 100% rename from openfe/protocols/openmm_utils/restraints/omm_restraints.py rename to openfe/protocols/openmm_utils/restraints/openmm/omm_restraints.py diff --git a/openfe/protocols/openmm_utils/restraints/search.py b/openfe/protocols/openmm_utils/restraints/search.py new file mode 100644 index 000000000..6b3d94eb7 --- /dev/null +++ b/openfe/protocols/openmm_utils/restraints/search.py @@ -0,0 +1,360 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +Search methods for generating Geometry objects + +TODO +---- +* Add relevant duecredit entries. +""" +import abc +from pydantic.v1 import BaseModel, validator + +from openff.toolkit import Molecule as OFFMol +from openff.units import unit +import networkx as nx +import MDAnalysis as mda +from MDAnalysis.analysis.base import AnalysisBase +from MDAnalysis.lib.distances import calc_bonds, calc_angles + + +def _get_aromatic_atom_idxs(rdmol) -> list[int]: + """ + Helper method to get aromatic atoms idxs + in a RDKit Molecule + + Parameters + ---------- + rdmol : ??? + RDKit Molecule + + Returns + ------- + list[int] + A list of the aromatic atom idxs + """ + idxs = [ + at.GetIdx() for at in rdmol.GetAtoms() + if at.GetIsAromatic() + ] + return idxs + + +def _get_heavy_atom_idxs(rdmol) -> list[int]: + """ + Get idxs of heavy atoms in an RDKit Molecule + + Parameters + ---------- + rmdol : ??? + + Returns + ------- + list[int] + A list of heavy atom idxs + """ + idxs = [ + at.GetIdx() for at in rdmol.GetAtoms() + if at.GetAtomicNum() > 1 + ] + return idxs + + +def _get_central_atom_idx(rdmol) -> int: + offmol = OFFMol(rdmol, allow_undefined_stereo=True) + # We take the zero-th entry if there are multiple center + # atoms (e.g. equal likelihood centers) + center = nx.center(offmol.to_networkx())[0] + return center + + +def _sort_by_distance_from_target(rdmol, target_idx: int, atom_idxs: list[int]) -> list[int]: + """ + Sort a list of atoms by their distance from a target atom. + + Parameters + ---------- + target_idx : int + The idx of the target atom. + atom_idxs : list[int] + The idx values of the atoms to sort. + rdmol : ??? + RDKit Molecule the atoms belong to + + Returns + ------- + list[int] + The input atom idxs sorted by their distance from the target atom. + """ + distances = [] + + conformer = rdmol.GetConformer() + # Get the target atom position + target_pos = conformer.GetAtomPosition(target_idx) + + for idx in atom_idxs: + pos = conformer.GetAtomPosition(idx) + distances.append(((target_pos - pos).Length(), idx)) + + return [i[1] for i in sorted(distances)] + + +def _get_bonded_angles_from_pool(rdmol, atom_idx, atom_pool): + angles = [] + + # Get the base atom and its neighbors + at1 = rdmol.GetAtomWithIdx(atom_idx) + at1_neighbors = [at.GetIdx() for at in at1.GetNeighbors()] + + # We loop at2 and at3 through the sorted atom_pool in order to get + # a list of angles in the branch that are sorted by how close the atoms + # are from the central atom + for at2 in atom_pool: + if at2 in at1_neighbors: + at2_neighbors = [ + at.GetIdx() + for at in rdmol.GetAtomWithIdx(at2).GetNeighbors() + ] + for at3 in atom_pool: + if at3 != atom_idx and at3 in at2_neighbors: + angles.append((atom_idx, at2, at3)) + return angles + + +def get_ligand_anchor_atoms(rdmol) -> list[tuple[int, int, int]]: + """ + Get a list of ligand anchor atoms (e.g. l1, l2, and l3 of an orientational restraint). + + Parameters + ---------- + rdmol : ??? + Molecule object for the ligand to apply a restraint to. + + Returns + ------- + angles : list[tuple[int, int, int]] + A list of ligand atom triples denoting the possible l1, l2, and l3 + restraint atoms. Ordered by likelihood of restraint-ability. + """ + # Find the central atom + center = _get_central_atom_idx(rdmol) + + # Get a pool of potential anchor atoms looking for aromatic atoms + anchor_pool = _get_aromatic_atoms(rdmol) + + # If there are not enough aromatic atoms, then default to heavy atoms + if len(anchor_pool) < 3: + anchor_pool = _get_heavy_atoms(rdmol) + + # Raise an error if we have less than 3 anchors + if len(anchor_pool) < 3: + errmsg = f"Too few potential ligand anchor atoms, {len(anchor_pool)}" + raise ValueError(errmsg) + + # Sort the pool of anchor atoms by their distance from the central atom + sorted_anchor_pool = _sort_by_distance_from_target(rdmol, center, anchor_pool) + + # Get a list of ligand anchor angle atoms + angles = [] + for atom in sorted_anchor_pool: + angles.extend( + _get_bonded_angles_from_pool(rdmol, atom, sorted_anchor_pool) + ) + + +def get_host_anchors(positions, topology, exclude_resids: list[int], lig_anchor_idx: int, selection: str): + """ + Get a list of host anchor atomss sorted by their distance from a ligand anchor atom. + + Parameters + ---------- + positions : openmm.unit.Quantity + Positions of the input system + topology : openmm.app.Topology + OpenMM Topology for input system + exclude_resids : list[int] + List of residue numbers to exclude from host selection + lig_anchor_idx : int + The index of the l1 ligand anchor. + selection : str + Selection string for the host atoms. + """ + # Create an mdtraj trajectory to manipulate + # First fetch the box vectors and pass them as lengths and angles + vectors = from_openmm(topology.getPeriodicBoxVectors()) + a, b, c, alpha, beta, gamma = mdt.utils.box_vectors_to_lengths_and_angles(vectors[0].m, vectors[1].m, vectors[2].m) + + traj = mdt.Trajectory( + positions[np.newaxis, ...], + mdt.Topology.from_openmm(topology) + ) + + # Get all the potential protein atoms matching the selection + host_sel = traj.topology.select(selection) + + # Get residues to exclude from the selection + exclude_sel = np.array([ + at.index for at in + chain(*[traj.topology.residue(i).atoms for i in exclude_resids]) + ]) + + # Remove exclusion + anchors = host_sel[np.isin(host_sel, exclude_sel, invert=True)] + + # Compute distanecs from ligand l1 anchor atom + pairs = np.vstack((anchors, np.array([lig_anchor_idx for _ in range(len(anchors))]))).T + + distances = mdt.compute_distances(traj, pairs, periodic=True) + + return np.array([pairs[i][0] for i in np.argsort(distances[0])]) + + +def is_collinear(positions, atoms, threshold=0.9): + """ + Check whether any sequential vectors in a sequence of atoms are collinear. + + Parameters + ---------- + positions : openmm.unit.Quantity + System positions. + atoms : list[int] + The indices of the atoms to test. + threshold : float + Atoms are not collinear if their sequential vector separation dot + products are less than ``threshold``. Default 0.9. + + Returns + ------- + result : bool + Returns True if any sequential pair of vectors is collinear; False otherwise. + + Notes + ----- + Originally from Yank, with modifications from Separated Topologies + """ + results = False + for i in range(len(atoms) - 2): + v1 = positions[atoms[i + 1], :] - positions[atoms[i], :] + v2 = positions[atoms[i + 2], :] - positions[atoms[i + 1], :] + normalized_inner_product = np.dot(v1, v2) / np.sqrt(np.dot(v1, v1) * np.dot(v2, v2)) + result = result or (np.abs(normalized_inner_product) > threshold) + return result + + +def check_angle(angle, force_constant=83.68): + """ + Check whether the chosen angle is less than 10 kT from 0 or 180 + + Parameters + ---------- + angle : float + The angle to check in degrees. + force_constant : float + Force constant of the angle. + + Note + ---- + We assume the temperature to be 298.15 Kelvin. + """ + # TODO: convert this to unit.Quantity so we don't end up with + # conversion errors + RT = 8.31445985 * 0.001 * 298.15 + # check if angle is <10kT from 0 or 180 + check1 = 0.5 * force_constant * np.power((angle - 0.0) / 180.0 * np.pi, 2) + check2 = 0.5 * force_constant * np.power((angle - 180.0) / 180.0 * np.pi, 2) + ang_check_1 = check1 / RT + ang_check_2 = check2 / RT + if ang_check_1 < 10.0 or ang_check_2 < 10.0: + return False + return True + + + + +class FindHostAtoms(AnalysisBase): + """ + Class filter host atoms based on their distance + from a set of guest atoms. + + Parameters + ---------- + host_atoms : MDAnalysis.AtomGroup + Initial selection of host atoms to filter from. + guest_atoms : MDANalysis.AtomGroup + Selection of guest atoms to search around. + search_distance: unit.Quantity + Distance to filter atoms within. + """ + _analysis_algorithm_is_parallelizable = False + + def __init__(self, host_atoms, guest_atoms, search_distance, **kwargs): + super().__init__(host_atoms.universe.trajectory, **kwargs) + + self.host_ag = host_atoms + self.guest_ag = guest_atoms + self.cutoff = search_distance.to('angstrom').m + + def _prepare(self): + self.results.host_idxs = set() + + def _single_frame(self): + pairs = capped_distance( + reference=self.host_ag.positions, + configuration=self.guest_ag.positions, + max_cutoff=self.cutoff, + min_cutoff=None + box=self.guest_ag.universe.dimensions, + return_distances=False) + + host_idxs = [self.guest_ag.atoms[p].index for p in pairs[:, 1]] + self.results.host_idxs.update(set(host_idxs)) + + def _conclude(self): + pass + + +def find_host_atoms(topology, trajectory, host_selection, guest_selection, cutoff) -> mda.AtomGroup: + """ + Get an AtomGroup of the host atoms based on their distances from the guest atoms. + """ + u = mda.Universe(topology, trajectory) + + def _get_selection(selection): + """ + If it's a str, call select_atoms, if not a list of atom idxs + """ + if isinstance(selection, str): + ag = u.select_atoms(host_selection) + else: + ag = u.atoms[host_ag] + return ag + + host_ag = _get_selection(host_selection) + guest_ag = _get_selection(guest_selection) + + finder = FindHostAtoms(host_ag, guest_ag, cutoff) + finder.run() + + return u.atoms[list(finder.results.host_idxs)] + +def get_molecule_center_idx(atomgroup): + offmol = Molecule(atomgroup.convert_to("RDKIT"), allow_undefined_stereo=True) + # Check if the molecule is whole, otherwise throw an error. + nx = offmol.to_networkx() + + +def get_distance_restraint(topology, trajectory, host_atoms, guest_atoms, host_selection, guest_selection): + u = mda.Universe(topology, trajectory) + + if guest_atoms is None: + if guest_selection is None: + raise ValueError("one of guest_atoms or guest_selections must be defined") + guest_ag = u.select_atoms(guest_selection) + else: + + + if host_atoms is None: + if host_selection is None: + raise ValueError("one of host_atoms or host_selection must be defined") + + host_ag = u.select_atoms(host_selection) From 20dd1dcce9f6608d0fe5b3d37fcd257709f4a109 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Thu, 12 Dec 2024 14:02:05 +0000 Subject: [PATCH 14/29] add some angle checks --- .../openmm_utils/restraints/geometry/utils.py | 132 +++++++++++++++++- 1 file changed, 126 insertions(+), 6 deletions(-) diff --git a/openfe/protocols/openmm_utils/restraints/geometry/utils.py b/openfe/protocols/openmm_utils/restraints/geometry/utils.py index 6b3d94eb7..80b7c3372 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/utils.py +++ b/openfe/protocols/openmm_utils/restraints/geometry/utils.py @@ -12,20 +12,22 @@ from openff.toolkit import Molecule as OFFMol from openff.units import unit +from openff.units.types import FloatQuantity import networkx as nx +from rdkit import Chem import MDAnalysis as mda from MDAnalysis.analysis.base import AnalysisBase from MDAnalysis.lib.distances import calc_bonds, calc_angles -def _get_aromatic_atom_idxs(rdmol) -> list[int]: +def get_aromatic_atom_idxs(rdmol: Chem.Mol) -> list[int]: """ Helper method to get aromatic atoms idxs in a RDKit Molecule Parameters ---------- - rdmol : ??? + rdmol : Chem.Mol RDKit Molecule Returns @@ -40,13 +42,13 @@ def _get_aromatic_atom_idxs(rdmol) -> list[int]: return idxs -def _get_heavy_atom_idxs(rdmol) -> list[int]: +def get_heavy_atom_idxs(rdmol: Chem.Mol) -> list[int]: """ Get idxs of heavy atoms in an RDKit Molecule Parameters ---------- - rmdol : ??? + rmdol : Chem.Mol Returns ------- @@ -60,14 +62,132 @@ def _get_heavy_atom_idxs(rdmol) -> list[int]: return idxs -def _get_central_atom_idx(rdmol) -> int: +def get_central_atom_idx(rdmol: Chem.Mol) -> int: + """ + Get the central atom in an rdkit Molecule. + + Parameters + ---------- + rdmol : Chem.Mol + RDKit Molcule to query + + Returns + ------- + center : int + Index of central atom in Molecule + + Note + ---- + If there are equal likelihood centers, will return + the first entry. + """ + # TODO: switch to a manual conversion to avoid an OpenFF dependency offmol = OFFMol(rdmol, allow_undefined_stereo=True) + nx_mol = offmol.to_networkx() + if not nx.is_weakly_connected(nx_mol): + errmsg = "A disconnected molecule was passed, cannot find the center" + raise ValueError(errmsg) + # We take the zero-th entry if there are multiple center # atoms (e.g. equal likelihood centers) - center = nx.center(offmol.to_networkx())[0] + center = nx.center(nx_mol)[0] return center +def is_collinear(positions, atoms, threshold=0.9): + """ + Check whether any sequential vectors in a sequence of atoms are collinear. + + Parameters + ---------- + positions : openmm.unit.Quantity + System positions. + atoms : list[int] + The indices of the atoms to test. + threshold : float + Atoms are not collinear if their sequential vector separation dot + products are less than ``threshold``. Default 0.9. + + Returns + ------- + result : bool + Returns True if any sequential pair of vectors is collinear; False otherwise. + + Notes + ----- + Originally from Yank, with modifications from Separated Topologies + """ + results = False + for i in range(len(atoms) - 2): + v1 = positions[atoms[i + 1], :] - positions[atoms[i], :] + v2 = positions[atoms[i + 2], :] - positions[atoms[i + 1], :] + normalized_inner_product = np.dot(v1, v2) / np.sqrt(np.dot(v1, v1) * np.dot(v2, v2)) + result = result or (np.abs(normalized_inner_product) > threshold) + return result + + +def check_angle_energy( + angle: FloatQuantity['radians'], + force_constant: FloatQuantity['unit.kilojoule_per_mole / unit.radians**2'] = 83.68 * unit.kilojoule_per_mole / unit.radians**2, + temperature: FloatQuantity['kelvin'] = 298.15 * unit.kelvin +) -> bool: + """ + Check whether the chosen angle is less than 10 kT from 0 or 180 + + Parameters + ---------- + angle : unit.Quantity + The angle to check in units compatible with radians. + force_constant : unit.Quantity + Force constant of the angle in units compatible with kilojoule_per_mole / radians ** 2. + temperature: unit.Quantity + The system temperature in units compatible with Kelvin. + + Note + ---- + We assume the temperature to be 298.15 Kelvin. + """ + # Convert things + angle_rads = angle.to('radians') + frc_const = force_constant.to('unit.kilojoule_per_mole / unit.radians**2') + temp_kelvin = temperature.to('kelvin') + RT = 8.31445985 * 0.001 * temp_kelvin + + # check if angle is <10kT from 0 or 180 + check1 = 0.5 * frc_const * np.power((angle - 0.0), 2) + check2 = 0.5 * frc_const * np.power((angle - np.pi), 2) + ang_check_1 = check1 / RT + ang_check_2 = check2 / RT + if ang_check_1 < 10.0 or ang_check_2 < 10.0: + return False + return True + + +def check_dihedral_bounds( + dihedral: FloatQuantity['radians'] + lower_cutoff: FloatQuantity['radians'] = 2.618 * unit.radians, + upper_cutoff: FloatQuantity['radians'] = -2.6.18 * unit.radians, +): + """ + Check that a dihedral does not exceed the bounds set by + lower_cutoff and upper_cutoff. + + Parameters + ---------- + dihedral : unit.Quantity + Dihedral in units compatible with radians. + lower_cutoff : unit.Quantity + Dihedral lower cutoff in units compatible with radians. + upper_cutoff : unit.Quantity + Dihedral upper cutoff in units compatible with radians. + """ + if (dihedral < lower_cutoff) or (dihedral > upper_cutoff): + return False + return True + + + + def _sort_by_distance_from_target(rdmol, target_idx: int, atom_idxs: list[int]) -> list[int]: """ Sort a list of atoms by their distance from a target atom. From 9ab74a8225d7efc511fac79975abd5a2f795b5de Mon Sep 17 00:00:00 2001 From: IAlibay Date: Thu, 12 Dec 2024 14:19:39 +0000 Subject: [PATCH 15/29] only construct with settings --- .../restraints/openmm/omm_restraints.py | 111 ++++++++++-------- 1 file changed, 61 insertions(+), 50 deletions(-) diff --git a/openfe/protocols/openmm_utils/restraints/openmm/omm_restraints.py b/openfe/protocols/openmm_utils/restraints/openmm/omm_restraints.py index ab5f4e821..e53e828d5 100644 --- a/openfe/protocols/openmm_utils/restraints/openmm/omm_restraints.py +++ b/openfe/protocols/openmm_utils/restraints/openmm/omm_restraints.py @@ -87,40 +87,42 @@ class BaseHostGuestRestraints(abc.ABC): def __init__( self, restraint_settings: SettingsBaseModel, - restraint_geometry: BaseRestraintGeometry, controlling_parameter_name: str = "lambda_restraints", ): self.settings = restraint_settings - self.geometry = restraint_geometry - self._verify_input() + self._verify_settings() @abc.abstractmethod - def _verify_inputs(self): + def _verify_settings(self): pass @abc.abstractmethod - def add_force(self, thermodynamic_state: ThermodynamicState): + def _verify_geometry(self, geometry): pass @abc.abstractmethod - def get_standard_state_correction(self, thermodynamic_state: ThermodynamicState): + def add_force(self, thermodynamic_state: ThermodynamicState, geometry: BaseRestraintGeometry): pass @abc.abstractmethod - def _get_force(self): + def get_standard_state_correction(self, thermodynamic_state: ThermodynamicState, geometry: BaseRestraintGeometry): + pass + + @abc.abstractmethod + def _get_force(self, geometry: BaseRestraintGeometry): pass class SingleBondMixin: - def _verify_input(self): - if len(self.geometry.host_atoms) != 1 or len(self.geometry.guest_atoms) != 1: + def _verify_geometry(self, geometry: BaseRestraintGeometry): + if len(geometry.host_atoms) != 1 or len(geometry.guest_atoms) != 1: errmsg = ( "host_atoms and guest_atoms must only include a single index " f"each, got {len(host_atoms)} and " f"{len(guest_atoms)} respectively." ) raise ValueError(errmsg) - super()._verify_inputs() + super()._verify_geometry(geometry) class BaseRadialllySymmetricRestraintForce(BaseHostGuestRestraints): @@ -128,12 +130,15 @@ def _verify_inputs(self) -> None: if not isinstance(self.settings, BaseDistanceRestraintSettings): errmsg = f"Incorrect settings type {self.settings} passed through" raise ValueError(errmsg) - if not isinstance(self.geometry, DistanceRestraintGeometry): - errmsg = f"Incorrect geometry type {self.geometry} passed through" + + def _verify_geometry(self, geometry: DistanceRestraintGeometry) + if not isinstance(geometry, DistanceRestraintGeometry): + errmsg = f"Incorrect geometry class type {geometry} passed through" raise ValueError(errmsg) - def add_force(self, thermodynamic_state: ThermodynamicState) -> None: - force = self._get_force() + def add_force(self, thermodynamic_state: ThermodynamicState, geometry: DistanceRestraintGeometry) -> None: + self._verify_geometry(geometry) + force = self._get_force(geometry) force.setUsesPeriodicBoundaryConditions(thermodynamic_state.is_periodic) # Note .system is a call to get_system() so it's returning a copy system = thermodynamic_state.system @@ -141,87 +146,92 @@ def add_force(self, thermodynamic_state: ThermodynamicState) -> None: thermodynamic_state.system = system def get_standard_state_correction( - self, thermodynamic_state: ThermodynamicState + self, + thermodynamic_state: ThermodynamicState, + geometry: DistanceRestraintGeometry, ) -> unit.Quantity: - force = self._get_force() + self._verify_geometry(geometry) + force = self._get_force(geometry) corr = force.compute_standard_state_correction( thermodynamic_state, volume="system" ) dg = corr * thermodynamic_state.kT return from_openmm(dg).to('kilojoule_per_mole') - def _get_force(self): + def _get_force(self, geometry: DistanceRestraintGeometry): raise NotImplementedError("only implemented in child classes") class HarmonicBondRestraint(BaseRadialllySymmetricRestraintForce, SingleBondMixin): - def _get_force(self) -> openmm.Force: + def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: spring_constant = to_openmm(self.settings.spring_constant).value_in_unit_system(omm_unit.md_unit_system) return HarmonicRestraintBondForce( spring_constant=spring_constant, - restrained_atom_index1=self.geometry.host_atoms[0], - restrained_atom_index2=self.geometry.guest_atoms[0], + restrained_atom_index1=geometry.host_atoms[0], + restrained_atom_index2=geometry.guest_atoms[0], controlling_parameter_name=self.controlling_parameter_name, ) class FlatBottomBondRestraint(BaseRadialllySymmetricRestraintForce, SingleBondMixin): - def _get_force(self) -> openmm.Force: + def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: spring_constant = to_openmm(self.settings.spring_constant).value_in_unit_system(omm_unit.md_unit_system) - well_radius = to_openmm(self.geometry.well_radius).value_in_unit_system(omm_unit.md_unit_system) + well_radius = to_openmm(geometry.well_radius).value_in_unit_system(omm_unit.md_unit_system) return FlatBottomRestraintBondForce( spring_constant=spring_constant, well_radius=well_radius, - restrained_atom_index1=self.geometry.host_atoms[0], - restrained_atom_index2=self.geometry.guest_atoms[0], + restrained_atom_index1=geometry.host_atoms[0], + restrained_atom_index2=geometry.guest_atoms[0], controlling_parameter_name=self.controlling_parameter_name, ) class CentroidHarmonicRestraint(BaseRadialllySymmetricRestraintForce): - def _get_force(self) -> openmm.Force: + def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: spring_constant = to_openmm(self.settings.spring_constant).value_in_unit_system(omm_unit.md_unit_system) return HarmonicRestraintForce( spring_constant=spring_constant, - restrained_atom_index1=self.geometry.host_atoms, - restrained_atom_index2=self.geometry.guest_atoms, + restrained_atom_index1=geometry.host_atoms, + restrained_atom_index2=geometry.guest_atoms, controlling_parameter_name=self.controlling_parameter_name, ) class CentroidFlatBottomRestraint(BaseRadialllySymmetricRestraintForce): - def _get_force(self) -> openmm.Force: + def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: spring_constant = to_openmm(self.settings.spring_constant).value_in_unit_system(omm_unit.md_unit_system) - well_radius = to_openmm(self.geometry.well_radius).value_in_unit_system(omm_unit.md_unit_system) + well_radius = to_openmm(geometry.well_radius).value_in_unit_system(omm_unit.md_unit_system) return FlatBottomRestraintBondForce( spring_constant=spring_constant, well_radius=well_radius, - restrained_atom_index1=self.geometry.host_atoms, - restrained_atom_index2=self.geometry.guest_atoms, + restrained_atom_index1=geometry.host_atoms, + restrained_atom_index2=geometry.guest_atoms, controlling_parameter_name=self.controlling_parameter_name, ) class BoreschRestraint(BaseHostGuestRestraints): - _EFUNC_METHOD: Callable = get_boresch_energy_function - def _verify_inputs(self) -> None: + def _verify_settings(self) -> None: if not isinstance(self.settings, BoreschRestraintSettings): errmsg = f"Incorrect settings type {self.settings} passed through" raise ValueError(errmsg) - if not isinstance(self.geometry, BoreschRestraintGeometry): - errmsg = f"Incorrect geometry type {self.geometry} passed through" + + def _verify_geometry(self, geometry: BoreschRestraintGeometry): + if not isinstance(geometry, BoreschRestraintGeometry): + errmsg = f"Incorrect geometry class type {geometry} passed through" raise ValueError(errmsg) - def add_force(self, thermodynamic_state: ThermodynamicState) -> None: - force = self._get_force() + def add_force(self, thermodynamic_state: ThermodynamicState, geometry: BoreschRestraintGeometry) -> None: + _verify_geometry(geometry) + force = self._get_force(geometry) force.setUsesPeriodicBoundaryConditions(thermodynamic_state.is_periodic) # Note .system is a call to get_system() so it's returning a copy system = thermodynamic_state.system add_force_in_separate_group(system, force) thermodynamic_state.system = system - def _get_force(self) -> openmm.Force: - efunc = _EFUNC_METHOD( + def _get_force(self, geometry: BoreschRestraintGeometry) -> openmm.Force: + efunc = get_boresch_energy_function( self.controlling_parameter_name, ) @@ -233,37 +243,38 @@ def _get_force(self) -> openmm.Force: parameter_dict = { 'K_r': self.settings.K_r, - 'r_aA0': self.geometry.r_aA0, + 'r_aA0': geometry.r_aA0, 'K_thetaA': self.settings.K_thetaA, - 'theta_A0': self.geometry.theta_A0, + 'theta_A0': geometry.theta_A0, 'K_thetaB': self.settings.K_thetaB, - 'theta_B0': self.geometry.theta_B0, + 'theta_B0': geometry.theta_B0, 'K_phiA': self.settings.K_phiA, - 'phi_A0': self.geometry.phi_A0, + 'phi_A0': geometry.phi_A0, 'K_phiB': self.settings.K_phiB, - 'phi_B0': self.geometry.phi_B0, + 'phi_B0': geometry.phi_B0, 'K_phiC': self.settings.K_phiC, - 'phi_C0': self.geometry.phi_C0, + 'phi_C0': geometry.phi_C0, } for key, val in parameter_dict.items(): param_values.append(to_openmm(val).value_in_unit_system(omm_unit.md_unit_system)) force.addPerBondParameter(key) force.addGlobalParameter(self.controlling_parameter_name, 1.0) - force.addBond(self.geometry.host_atoms + self.geometry.guest_atoms, param_values) + force.addBond(geometry.host_atoms + geometry.guest_atoms, param_values) return force def get_standard_state_correction( - self, thermodynamic_state: ThermodynamicState + self, thermodynamic_state: ThermodynamicState, geometry: BoreschRestraintGeometry ) -> unit.Quantity: + self._verify_geometry(geometry) StandardV = 1.66053928 * unit.nanometer**3 kt = from_openmm(thermodynamic_state.kT) # distances - r_aA0 = self.geometry.r_aA0.to('nm') - sin_thetaA0 = np.sin(self.geometry.theta_A0.to('radians')) - sin_thetaB0 = np.sin(self.geometry.theta_B0.to('radians')) + r_aA0 = geometry.r_aA0.to('nm') + sin_thetaA0 = np.sin(geometry.theta_A0.to('radians')) + sin_thetaB0 = np.sin(geometry.theta_B0.to('radians')) # restraint energies K_r = self.settings.K_r.to('kilojoule_per_mole / nm ** 2') From 8f2e1e03dd613caf1f91df2d07d61e9751a03c3d Mon Sep 17 00:00:00 2001 From: IAlibay Date: Thu, 12 Dec 2024 14:35:33 +0000 Subject: [PATCH 16/29] Add more checks to utilities --- .../openmm_utils/restraints/geometry/utils.py | 45 +++++++++++++++++-- 1 file changed, 42 insertions(+), 3 deletions(-) diff --git a/openfe/protocols/openmm_utils/restraints/geometry/utils.py b/openfe/protocols/openmm_utils/restraints/geometry/utils.py index 80b7c3372..30e81123f 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/utils.py +++ b/openfe/protocols/openmm_utils/restraints/geometry/utils.py @@ -10,9 +10,12 @@ import abc from pydantic.v1 import BaseModel, validator +import numpy as np +from scipy.stats import circvar, circmean, circstd + from openff.toolkit import Molecule as OFFMol from openff.units import unit -from openff.units.types import FloatQuantity +from openff.models.types import FloatQuantity, ArrayQuantity import networkx as nx from rdkit import Chem import MDAnalysis as mda @@ -132,7 +135,7 @@ def check_angle_energy( temperature: FloatQuantity['kelvin'] = 298.15 * unit.kelvin ) -> bool: """ - Check whether the chosen angle is less than 10 kT from 0 or 180 + Check whether the chosen angle is less than 10 kT from 0 or pi radians Parameters ---------- @@ -143,6 +146,12 @@ def check_angle_energy( temperature: unit.Quantity The system temperature in units compatible with Kelvin. + + Returns + ------- + bool + If the angle is less than 10 kT from 0 or pi radians + Note ---- We assume the temperature to be 298.15 Kelvin. @@ -167,7 +176,7 @@ def check_dihedral_bounds( dihedral: FloatQuantity['radians'] lower_cutoff: FloatQuantity['radians'] = 2.618 * unit.radians, upper_cutoff: FloatQuantity['radians'] = -2.6.18 * unit.radians, -): +) -> bool: """ Check that a dihedral does not exceed the bounds set by lower_cutoff and upper_cutoff. @@ -180,12 +189,42 @@ def check_dihedral_bounds( Dihedral lower cutoff in units compatible with radians. upper_cutoff : unit.Quantity Dihedral upper cutoff in units compatible with radians. + + Returns + ------- + bool + ``True`` if the dihedral is within the upper and lower + cutoff bounds. """ if (dihedral < lower_cutoff) or (dihedral > upper_cutoff): return False return True +def check_angular_variance( + angles: ArrayQuantity['radians'] + width: FloatQuantity['radians'] +) -> bool: + """ + Check that the variance of a list of ``angles`` does not exceed + a given ``width`` + + Parameters + ---------- + angles : ArrayLike[unit.Quantity] + An array of angles in units compatible with radians. + width : unit.Quantity + The width to check the variance against, in units compatible with radians. + + Returns + ------- + bool + ``True`` if the variance of the angles is less than the width. + + """ + array = angles.to('radians').m + variance = circvar(array) + return not (variance * unit.radians > width) def _sort_by_distance_from_target(rdmol, target_idx: int, atom_idxs: list[int]) -> list[int]: From 0a480aa0771f7112b9c92647877c7f02c5bc6a50 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Thu, 12 Dec 2024 21:14:59 +0000 Subject: [PATCH 17/29] host finding code --- .../restraints/geometry/boresch.py | 213 +++++++++++++++- .../openmm_utils/restraints/geometry/utils.py | 240 ++++++++++-------- 2 files changed, 337 insertions(+), 116 deletions(-) diff --git a/openfe/protocols/openmm_utils/restraints/geometry/boresch.py b/openfe/protocols/openmm_utils/restraints/geometry/boresch.py index 822382b9c..d6241f3d7 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/boresch.py +++ b/openfe/protocols/openmm_utils/restraints/geometry/boresch.py @@ -10,6 +10,8 @@ import abc from pydantic.v1 import BaseModel, validator +from rdkit import Chem + from openff.units import unit import MDAnalysis as mda from MDAnalysis.lib.distances import calc_bonds, calc_angles @@ -31,6 +33,7 @@ class BoreschRestraintGeometry(HostGuestRestraintGeometry): Where HX represents the X index of ``host_atoms`` and GX the X index of ``guest_atoms``. """ + def get_bond_distance(self, topology, coordinates) -> unit.Quantity: u = mda.Universe(topology, coordinates) at1 = u.atoms[host_atoms[2]] @@ -46,8 +49,12 @@ def get_angles(self, topology, coordinates) -> unit.Quantity: at3 = u.atoms[guest_atoms[0]] at4 = u.atoms[guest_atoms[1]] - angleA = calc_angles(at1.position, at2.position, at3.position, u.atoms.dimensions) - angleB = calc_angles(at2.position, at3.position, at4.position, u.atoms.dimensions) + angleA = calc_angles( + at1.position, at2.position, at3.position, u.atoms.dimensions + ) + angleB = calc_angles( + at2.position, at3.position, at4.position, u.atoms.dimensions + ) return angleA, angleB def get_dihedrals(self, topology, coordinates) -> unit.Quantity: @@ -59,8 +66,204 @@ def get_dihedrals(self, topology, coordinates) -> unit.Quantity: at5 = u.atoms[guest_atoms[1]] at6 = u.atoms[guest_atoms[2]] - dihA = calc_dihedrals(at1.position, at2.position, at3.position, at4.position, u.atoms.dimensions) - dihB = calc_dihedrals(at2.position, at3.position, at4.position, at5.position, u.atoms.dimensions) - dihC = calc_dihedrals(at3.position, at4.position, at5.position, at6.position, u.atoms.dimensions) + dihA = calc_dihedrals( + at1.position, at2.position, at3.position, at4.position, u.atoms.dimensions + ) + dihB = calc_dihedrals( + at2.position, at3.position, at4.position, at5.position, u.atoms.dimensions + ) + dihC = calc_dihedrals( + at3.position, at4.position, at5.position, at6.position, u.atoms.dimensions + ) return dihA, dihB, dihC + + +def _sort_by_distance_from_atom( + rdmol: Chem.Mol, target_idx: int, atom_idxs: Iterable[int] +) -> list[int]: + """ + Sort a list of RDMol atoms by their distance from a target atom. + + Parameters + ---------- + target_idx : int + The idx of the atom to measure from. + atom_idxs : list[int] + The idx values of the atoms to sort. + rdmol : Chem.Mol + RDKit Molecule the atoms belong to + + Returns + ------- + list[int] + The input atom idxs sorted by their distance from the target atom. + """ + distances = [] + + conformer = rdmol.GetConformer() + # Get the target atom position + target_pos = conformer.GetAtomPosition(target_idx) + + for idx in atom_idxs: + pos = conformer.GetAtomPosition(idx) + distances.append(((target_pos - pos).Length(), idx)) + + return [i[1] for i in sorted(distances)] + + +def _get_bonded_angles_from_pool( + rdmol: Chem.Mol, atom_idx: int, atom_pool: list[int] +) -> list[tuple[int, int, int]]: + """ + Get all bonded angles starting from ``atom_idx`` from a pool of atoms. + + Parameters + ---------- + rdmol : Chem.Mol + The RDKit Molecule + atom_idx : int + The index of the atom to search angles from. + atom_pool : list[int] + The list of indices to pick possible angle partners from. + + Returns + ------- + list[tuple[int, int, int]] + A list of tuples containing all the angles. + """ + angles = [] + + # Get the base atom and its neighbors + at1 = rdmol.GetAtomWithIdx(atom_idx) + at1_neighbors = [at.GetIdx() for at in at1.GetNeighbors()] + + # We loop at2 and at3 through the sorted atom_pool in order to get + # a list of angles in the branch that are sorted by how close the atoms + # are from the central atom + for at2 in atom_pool: + if at2 in at1_neighbors: + at2_neighbors = [ + at.GetIdx() for at in rdmol.GetAtomWithIdx(at2).GetNeighbors() + ] + for at3 in atom_pool: + if at3 != atom_idx and at3 in at2_neighbors: + angles.append((atom_idx, at2, at3)) + return angles + + +def get_small_molecule_atom_candidates( + topology: Union[str, openmm.app.Topology], + trajectory: Union[str, pathlib.Path], + rdmol: Chem.Mol, + ligand_idxs: list[int], + rmsf_cutoff: unit.Quantity = 1 * unit.angstrom, + angle_force_constant=83.68 * unit.kilojoule_per_mole / unit.radians**2, +): + """ + Get a list of potential ligand atom choices for a Boresch restraint + being applied to a given small molecule. + + TODO: remember to update the RDMol with the last frame positions + """ + if isinstance(topology, openmm.app.Topology): + topology_format = "OPENMMTOPOLOGY" + else: + topology_format = None + + u = mda.Universe(topology, trajectory, topology_format=topology_format) + ligand_ag = u.atoms[ligand_idxs] + + # 0. Get the ligand RMSF + rmsf = get_local_rmsf(ligand_ag) + u.trajectory[-1] # forward to the last frame + + # 1. Get the pool of atoms to work with + # TODO: move to a helper function to make it easier to test + # Get a list of all the aromatic rings + # Note: no need to keep track of rings because we'll filter by + # bonded terms after, so if we only keep rings then all the bonded + # atoms should be within the same ring system. + atom_pool = set() + for ring in get_aromatic_rings(rdmol): + max_rmsf = rmsf[list(ring)].max() + if max_rmsf < rmsf_cutoff: + atom_pool.update(ring) + + # if we don't have enough atoms just get all the heavy atoms + if len(atom_pool) < 3: + heavy_atoms = get_heavy_atom_idxs(rdmol) + atom_pool = set(heavy_atoms[rmsf[heavy_atoms] < rmsf_cutoff]) + if len(atom_pool) < 3: + errmsg = ( + "No suitable ligand atoms for " "the boresch restraint could be found" + ) + raise ValueError(errmsg) + + # 2. Get the central atom + center = get_central_atom_idx(rdmol) + + # 3. Sort the atom pool based on their distance from the center + sorted_anchor_pool = _sort_by_distance_from_atom(rdmol, center, anchor_pool) + + # 4. Get a list of probable angles + angles_list = [] + for atom in sorted_anchor_pool: + angles = _get_bonded_angles_from_pool(rdmol, atom, sorted_anchor_pool) + for angle in _angles: + angle_ag = ligand_ag.atoms[angle] + collinear = is_collinear(ligand_ag.positions, angle) + angle_value = ( + calc_angle( + angle_ag.atoms[0].position, + angle_ag.atoms[1].position, + angle_ag.atoms[2].position, + box=angle_ag.universe.dimensions, + ) + * unit.radians + ) + energy = check_angle_energy( + angle_value, angle_force_constant, 298.15 * unit.kelvin + ) + if not collinear and energy: + angles_list.append(angle) + + return angles_list + + +def get_host_atom_candidates( + topology: Union[str, openmm.app.Topology], + trajectory: Union[str, pathlib.Path], + host_idxs: list[int], + l1_idx: int, + host_selection: str, + dssp_filter: bool = False, + rmsf_cutoff: unit.Quantity = 0.1 * unit.nanometer, + min_distance: unit.Quantity = 10 * unit.nanometer, + max_distance: unit.Quantity = 30 * unit.nanometer, + angle_force_constant=83.68 * unit.kilojoule_per_mole / unit.radians**2, +): + if isinstance(topology, openmm.app.Topology): + topology_format = "OPENMMTOPOLOGY" + else: + topology_format = None + + u = mda.Universe(topology, trajectory, topology_format=topology_format) + protein_ag1 = u.atoms[host_idxs] + protein_ag2 = protein_ag.select_atoms(protein_selection) + + # 0. TODO: implement DSSP filter + # Should be able to just call MDA's DSSP method, but will need to catch an exception + if dssp_filter: + raise NotImplementedError("DSSP filtering is not currently implemented") + + # 1. Get the RMSF & filter + rmsf = get_local_rmsf(sub_protein_ag) + protein_ag3 = sub_protein_ag.atoms[rmsf[heavy_atoms] < rmsf_cutoff] + + # 2. Search of atoms within the min/max cutoff + atom_finder = FindHostAtoms( + protein_ag3, u.atoms[l1_idx], min_search_distance, max_search_distance + ) + atom_finder.run() + return atom_finder.results.host_idxs diff --git a/openfe/protocols/openmm_utils/restraints/geometry/utils.py b/openfe/protocols/openmm_utils/restraints/geometry/utils.py index 30e81123f..c8226af0d 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/utils.py +++ b/openfe/protocols/openmm_utils/restraints/geometry/utils.py @@ -20,8 +20,37 @@ from rdkit import Chem import MDAnalysis as mda from MDAnalysis.analysis.base import AnalysisBase +from MDAnalysis.analysis.rmsf import RMSF from MDAnalysis.lib.distances import calc_bonds, calc_angles +from openfe_analysis.transformations import Aligner, NoJump + + +def get_aromatic_rings(rdmol: Chem.Mol) -> list[tuple[int, ...]]: + """ + Get a list of tuples with the indices for each ring in an rdkit Molecule. + + Parameters + ---------- + rdmol : Chem.Mol + RDKit Molecule + + Returns + ------- + list[tuple[int]] + List of tuples for each ring. + """ + ringinfo = rdmol.GetRingInfo() + arom_idxs = get_aromatic_atom_idxs(rdmol) + + aromatic_rings = [] + + for ring in ringinfo.AtomRings(): + if all(a in aroms for a in ring): + aromatic_rings.append(ring) + + return aromatic_rings + def get_aromatic_atom_idxs(rdmol: Chem.Mol) -> list[int]: """ @@ -38,10 +67,7 @@ def get_aromatic_atom_idxs(rdmol: Chem.Mol) -> list[int]: list[int] A list of the aromatic atom idxs """ - idxs = [ - at.GetIdx() for at in rdmol.GetAtoms() - if at.GetIsAromatic() - ] + idxs = [at.GetIdx() for at in rdmol.GetAtoms() if at.GetIsAromatic()] return idxs @@ -58,10 +84,7 @@ def get_heavy_atom_idxs(rdmol: Chem.Mol) -> list[int]: list[int] A list of heavy atom idxs """ - idxs = [ - at.GetIdx() for at in rdmol.GetAtoms() - if at.GetAtomicNum() > 1 - ] + idxs = [at.GetIdx() for at in rdmol.GetAtoms() if at.GetAtomicNum() > 1] return idxs @@ -124,15 +147,19 @@ def is_collinear(positions, atoms, threshold=0.9): for i in range(len(atoms) - 2): v1 = positions[atoms[i + 1], :] - positions[atoms[i], :] v2 = positions[atoms[i + 2], :] - positions[atoms[i + 1], :] - normalized_inner_product = np.dot(v1, v2) / np.sqrt(np.dot(v1, v1) * np.dot(v2, v2)) + normalized_inner_product = np.dot(v1, v2) / np.sqrt( + np.dot(v1, v1) * np.dot(v2, v2) + ) result = result or (np.abs(normalized_inner_product) > threshold) return result def check_angle_energy( - angle: FloatQuantity['radians'], - force_constant: FloatQuantity['unit.kilojoule_per_mole / unit.radians**2'] = 83.68 * unit.kilojoule_per_mole / unit.radians**2, - temperature: FloatQuantity['kelvin'] = 298.15 * unit.kelvin + angle: FloatQuantity["radians"], + force_constant: FloatQuantity["unit.kilojoule_per_mole / unit.radians**2"] = 83.68 + * unit.kilojoule_per_mole + / unit.radians**2, + temperature: FloatQuantity["kelvin"] = 298.15 * unit.kelvin, ) -> bool: """ Check whether the chosen angle is less than 10 kT from 0 or pi radians @@ -157,9 +184,9 @@ def check_angle_energy( We assume the temperature to be 298.15 Kelvin. """ # Convert things - angle_rads = angle.to('radians') - frc_const = force_constant.to('unit.kilojoule_per_mole / unit.radians**2') - temp_kelvin = temperature.to('kelvin') + angle_rads = angle.to("radians") + frc_const = force_constant.to("unit.kilojoule_per_mole / unit.radians**2") + temp_kelvin = temperature.to("kelvin") RT = 8.31445985 * 0.001 * temp_kelvin # check if angle is <10kT from 0 or 180 @@ -167,15 +194,15 @@ def check_angle_energy( check2 = 0.5 * frc_const * np.power((angle - np.pi), 2) ang_check_1 = check1 / RT ang_check_2 = check2 / RT - if ang_check_1 < 10.0 or ang_check_2 < 10.0: + if ang_check_1 < 10.0 or ang_check_2 < 10.0: return False return True def check_dihedral_bounds( - dihedral: FloatQuantity['radians'] - lower_cutoff: FloatQuantity['radians'] = 2.618 * unit.radians, - upper_cutoff: FloatQuantity['radians'] = -2.6.18 * unit.radians, + dihedral: FloatQuantity["radians"], + lower_cutoff: FloatQuantity["radians"] = 2.618 * unit.radians, + upper_cutoff: FloatQuantity["radians"] = -2.618 * unit.radians, ) -> bool: """ Check that a dihedral does not exceed the bounds set by @@ -202,8 +229,7 @@ def check_dihedral_bounds( def check_angular_variance( - angles: ArrayQuantity['radians'] - width: FloatQuantity['radians'] + angles: ArrayQuantity["radians"], width: FloatQuantity["radians"] ) -> bool: """ Check that the variance of a list of ``angles`` does not exceed @@ -222,45 +248,14 @@ def check_angular_variance( ``True`` if the variance of the angles is less than the width. """ - array = angles.to('radians').m + array = angles.to("radians").m variance = circvar(array) return not (variance * unit.radians > width) -def _sort_by_distance_from_target(rdmol, target_idx: int, atom_idxs: list[int]) -> list[int]: - """ - Sort a list of atoms by their distance from a target atom. - - Parameters - ---------- - target_idx : int - The idx of the target atom. - atom_idxs : list[int] - The idx values of the atoms to sort. - rdmol : ??? - RDKit Molecule the atoms belong to - - Returns - ------- - list[int] - The input atom idxs sorted by their distance from the target atom. - """ - distances = [] - - conformer = rdmol.GetConformer() - # Get the target atom position - target_pos = conformer.GetAtomPosition(target_idx) - - for idx in atom_idxs: - pos = conformer.GetAtomPosition(idx) - distances.append(((target_pos - pos).Length(), idx)) - - return [i[1] for i in sorted(distances)] - - def _get_bonded_angles_from_pool(rdmol, atom_idx, atom_pool): angles = [] - + # Get the base atom and its neighbors at1 = rdmol.GetAtomWithIdx(atom_idx) at1_neighbors = [at.GetIdx() for at in at1.GetNeighbors()] @@ -271,8 +266,7 @@ def _get_bonded_angles_from_pool(rdmol, atom_idx, atom_pool): for at2 in atom_pool: if at2 in at1_neighbors: at2_neighbors = [ - at.GetIdx() - for at in rdmol.GetAtomWithIdx(at2).GetNeighbors() + at.GetIdx() for at in rdmol.GetAtomWithIdx(at2).GetNeighbors() ] for at3 in atom_pool: if at3 != atom_idx and at3 in at2_neighbors: @@ -283,12 +277,12 @@ def _get_bonded_angles_from_pool(rdmol, atom_idx, atom_pool): def get_ligand_anchor_atoms(rdmol) -> list[tuple[int, int, int]]: """ Get a list of ligand anchor atoms (e.g. l1, l2, and l3 of an orientational restraint). - + Parameters ---------- rdmol : ??? Molecule object for the ligand to apply a restraint to. - + Returns ------- angles : list[tuple[int, int, int]] @@ -304,7 +298,7 @@ def get_ligand_anchor_atoms(rdmol) -> list[tuple[int, int, int]]: # If there are not enough aromatic atoms, then default to heavy atoms if len(anchor_pool) < 3: anchor_pool = _get_heavy_atoms(rdmol) - + # Raise an error if we have less than 3 anchors if len(anchor_pool) < 3: errmsg = f"Too few potential ligand anchor atoms, {len(anchor_pool)}" @@ -316,15 +310,15 @@ def get_ligand_anchor_atoms(rdmol) -> list[tuple[int, int, int]]: # Get a list of ligand anchor angle atoms angles = [] for atom in sorted_anchor_pool: - angles.extend( - _get_bonded_angles_from_pool(rdmol, atom, sorted_anchor_pool) - ) + angles.extend(_get_bonded_angles_from_pool(rdmol, atom, sorted_anchor_pool)) -def get_host_anchors(positions, topology, exclude_resids: list[int], lig_anchor_idx: int, selection: str): +def get_host_anchors( + positions, topology, exclude_resids: list[int], lig_anchor_idx: int, selection: str +): """ Get a list of host anchor atomss sorted by their distance from a ligand anchor atom. - + Parameters ---------- positions : openmm.unit.Quantity @@ -341,30 +335,35 @@ def get_host_anchors(positions, topology, exclude_resids: list[int], lig_anchor_ # Create an mdtraj trajectory to manipulate # First fetch the box vectors and pass them as lengths and angles vectors = from_openmm(topology.getPeriodicBoxVectors()) - a, b, c, alpha, beta, gamma = mdt.utils.box_vectors_to_lengths_and_angles(vectors[0].m, vectors[1].m, vectors[2].m) - + a, b, c, alpha, beta, gamma = mdt.utils.box_vectors_to_lengths_and_angles( + vectors[0].m, vectors[1].m, vectors[2].m + ) + traj = mdt.Trajectory( - positions[np.newaxis, ...], - mdt.Topology.from_openmm(topology) + positions[np.newaxis, ...], mdt.Topology.from_openmm(topology) ) - + # Get all the potential protein atoms matching the selection host_sel = traj.topology.select(selection) - + # Get residues to exclude from the selection - exclude_sel = np.array([ - at.index for at in - chain(*[traj.topology.residue(i).atoms for i in exclude_resids]) - ]) - + exclude_sel = np.array( + [ + at.index + for at in chain(*[traj.topology.residue(i).atoms for i in exclude_resids]) + ] + ) + # Remove exclusion anchors = host_sel[np.isin(host_sel, exclude_sel, invert=True)] - + # Compute distanecs from ligand l1 anchor atom - pairs = np.vstack((anchors, np.array([lig_anchor_idx for _ in range(len(anchors))]))).T - + pairs = np.vstack( + (anchors, np.array([lig_anchor_idx for _ in range(len(anchors))])) + ).T + distances = mdt.compute_distances(traj, pairs, periodic=True) - + return np.array([pairs[i][0] for i in np.argsort(distances[0])]) @@ -395,7 +394,9 @@ def is_collinear(positions, atoms, threshold=0.9): for i in range(len(atoms) - 2): v1 = positions[atoms[i + 1], :] - positions[atoms[i], :] v2 = positions[atoms[i + 2], :] - positions[atoms[i + 1], :] - normalized_inner_product = np.dot(v1, v2) / np.sqrt(np.dot(v1, v1) * np.dot(v2, v2)) + normalized_inner_product = np.dot(v1, v2) / np.sqrt( + np.dot(v1, v1) * np.dot(v2, v2) + ) result = result or (np.abs(normalized_inner_product) > threshold) return result @@ -403,7 +404,7 @@ def is_collinear(positions, atoms, threshold=0.9): def check_angle(angle, force_constant=83.68): """ Check whether the chosen angle is less than 10 kT from 0 or 180 - + Parameters ---------- angle : float @@ -417,19 +418,17 @@ def check_angle(angle, force_constant=83.68): """ # TODO: convert this to unit.Quantity so we don't end up with # conversion errors - RT = 8.31445985 * 0.001 * 298.15 + RT = 8.31445985 * 0.001 * 298.15 # check if angle is <10kT from 0 or 180 check1 = 0.5 * force_constant * np.power((angle - 0.0) / 180.0 * np.pi, 2) check2 = 0.5 * force_constant * np.power((angle - 180.0) / 180.0 * np.pi, 2) ang_check_1 = check1 / RT ang_check_2 = check2 / RT - if ang_check_1 < 10.0 or ang_check_2 < 10.0: + if ang_check_1 < 10.0 or ang_check_2 < 10.0: return False return True - - class FindHostAtoms(AnalysisBase): """ Class filter host atoms based on their distance @@ -441,17 +440,28 @@ class FindHostAtoms(AnalysisBase): Initial selection of host atoms to filter from. guest_atoms : MDANalysis.AtomGroup Selection of guest atoms to search around. - search_distance: unit.Quantity - Distance to filter atoms within. + min_search_distance: unit.Quantity + Minimum distance to filter atoms within. + max_search_distance: unit.Quantity + Maximum distance to filter atoms within. """ + _analysis_algorithm_is_parallelizable = False - def __init__(self, host_atoms, guest_atoms, search_distance, **kwargs): + def __init__( + self, + host_atoms, + guest_atoms, + min_search_distance, + max_search_distance, + **kwargs, + ): super().__init__(host_atoms.universe.trajectory, **kwargs) self.host_ag = host_atoms self.guest_ag = guest_atoms - self.cutoff = search_distance.to('angstrom').m + self.min_cutoff = min_search_distance.to("angstrom").m + self.max_cutoff = max_search_distance.to("angstrom").m def _prepare(self): self.results.host_idxs = set() @@ -460,19 +470,22 @@ def _single_frame(self): pairs = capped_distance( reference=self.host_ag.positions, configuration=self.guest_ag.positions, - max_cutoff=self.cutoff, - min_cutoff=None + max_cutoff=self.max_cutoff, + min_cutoff=self.min_cutoff, box=self.guest_ag.universe.dimensions, - return_distances=False) + return_distances=False, + ) - host_idxs = [self.guest_ag.atoms[p].index for p in pairs[:, 1]] + host_idxs = [self.guest_ag.atoms[p].ix for p in pairs[:, 1]] self.results.host_idxs.update(set(host_idxs)) def _conclude(self): - pass + self.results.host_idxs = np.array(self.results.host_idxs) -def find_host_atoms(topology, trajectory, host_selection, guest_selection, cutoff) -> mda.AtomGroup: +def find_host_atoms( + topology, trajectory, host_selection, guest_selection, cutoff +) -> mda.AtomGroup: """ Get an AtomGroup of the host atoms based on their distances from the guest atoms. """ @@ -487,7 +500,7 @@ def _get_selection(selection): else: ag = u.atoms[host_ag] return ag - + host_ag = _get_selection(host_selection) guest_ag = _get_selection(guest_selection) @@ -496,24 +509,29 @@ def _get_selection(selection): return u.atoms[list(finder.results.host_idxs)] -def get_molecule_center_idx(atomgroup): - offmol = Molecule(atomgroup.convert_to("RDKIT"), allow_undefined_stereo=True) - # Check if the molecule is whole, otherwise throw an error. - nx = offmol.to_networkx() +def get_local_rmsf(atomgroup: mda.AtomGroup): + """ + Get the RMSF of an AtomGroup when aligned upon itself. -def get_distance_restraint(topology, trajectory, host_atoms, guest_atoms, host_selection, guest_selection): - u = mda.Universe(topology, trajectory) + Parameters + ---------- + atomgroup : MDAnalysis.AtomGroup - if guest_atoms is None: - if guest_selection is None: - raise ValueError("one of guest_atoms or guest_selections must be defined") - guest_ag = u.select_atoms(guest_selection) - else: + Return + ------ + rmsf + ArrayQuantity of RMSF values. + """ + # First let's copy our Universe + copy_u = atomgroup.universe.copy() + ag = copy_u.atoms[atomgroup.atoms.ix] + nojump = NoJump(ag) + align = Aligner(ag) - if host_atoms is None: - if host_selection is None: - raise ValueError("one of host_atoms or host_selection must be defined") + copy_u.trajectory.add_transformations(nojump, align) - host_ag = u.select_atoms(host_selection) + rmsf = RMSF(ag) + rmsf.run() + return rmsf.results.rmsf * unit.angstrom From 7a7be903a62ca03ecfa76326e621636e78cb9537 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Thu, 12 Dec 2024 21:21:30 +0000 Subject: [PATCH 18/29] fix up weird black wrapping --- .../openmm_utils/restraints/geometry/utils.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/openfe/protocols/openmm_utils/restraints/geometry/utils.py b/openfe/protocols/openmm_utils/restraints/geometry/utils.py index c8226af0d..82e7e621f 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/utils.py +++ b/openfe/protocols/openmm_utils/restraints/geometry/utils.py @@ -26,6 +26,10 @@ from openfe_analysis.transformations import Aligner, NoJump +DEFAULT_ANGLE_FRC_CONSTANT = 83.68 * unit.kilojoule_per_mole / unit.radians**2 +ANGLE_FRC_CONSTANT_TYPE = FloatQuantity["unit.kilojoule_per_mole / unit.radians**2"] + + def get_aromatic_rings(rdmol: Chem.Mol) -> list[tuple[int, ...]]: """ Get a list of tuples with the indices for each ring in an rdkit Molecule. @@ -156,9 +160,7 @@ def is_collinear(positions, atoms, threshold=0.9): def check_angle_energy( angle: FloatQuantity["radians"], - force_constant: FloatQuantity["unit.kilojoule_per_mole / unit.radians**2"] = 83.68 - * unit.kilojoule_per_mole - / unit.radians**2, + force_constant: ANGLE_FRC_CONSTANT_TYPE = DEFAULT_ANGLE_FRC_CONSTANT, temperature: FloatQuantity["kelvin"] = 298.15 * unit.kelvin, ) -> bool: """ @@ -170,10 +172,9 @@ def check_angle_energy( The angle to check in units compatible with radians. force_constant : unit.Quantity Force constant of the angle in units compatible with kilojoule_per_mole / radians ** 2. - temperature: unit.Quantity + temperature : unit.Quantity The system temperature in units compatible with Kelvin. - Returns ------- bool From 733f3b3c681a1112d081e043f7866cb198fbf3aa Mon Sep 17 00:00:00 2001 From: IAlibay Date: Thu, 12 Dec 2024 21:47:56 +0000 Subject: [PATCH 19/29] remove old search file, add more changes to boresch search --- .../restraints/geometry/boresch.py | 111 ++++-- .../openmm_utils/restraints/search.py | 360 ------------------ 2 files changed, 83 insertions(+), 388 deletions(-) delete mode 100644 openfe/protocols/openmm_utils/restraints/search.py diff --git a/openfe/protocols/openmm_utils/restraints/geometry/boresch.py b/openfe/protocols/openmm_utils/restraints/geometry/boresch.py index d6241f3d7..a35c4f5a9 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/boresch.py +++ b/openfe/protocols/openmm_utils/restraints/geometry/boresch.py @@ -15,6 +15,8 @@ from openff.units import unit import MDAnalysis as mda from MDAnalysis.lib.distances import calc_bonds, calc_angles +import numpy as np +import numpy.typing as npt from .base import HostGuestRestraintGeometry @@ -152,19 +154,84 @@ def _get_bonded_angles_from_pool( return angles -def get_small_molecule_atom_candidates( +def _get_atom_pool(rdmol: Chem.Mol, rmsf: npt.NDArray) -> Optional[set[int]]: + """ + Filter atoms based on rmsf & rings, defaulting to heavy atoms if + there are not enough. + + Parameters + ---------- + rdmol : Chem.Mol + The RDKit Molecule to search through + rmsf : npt.NDArray + A 1-D array of RMSF values for each atom. + + Returns + ------- + atom_pool : Optional[set[int]] + """ + # Get a list of all the aromatic rings + # Note: no need to keep track of rings because we'll filter by + # bonded terms after, so if we only keep rings then all the bonded + # atoms should be within the same ring system. + atom_pool = set() + for ring in get_aromatic_rings(rdmol): + max_rmsf = rmsf[list(ring)].max() + if max_rmsf < rmsf_cutoff: + atom_pool.update(ring) + + # if we don't have enough atoms just get all the heavy atoms + if len(atom_pool) < 3: + heavy_atoms = get_heavy_atom_idxs(rdmol) + atom_pool = set(heavy_atoms[rmsf[heavy_atoms] < rmsf_cutoff]) + if len(atom_pool) < 3: + return None + + return atom_pool + + +def get_small_molecule_guest_atom_candidates( topology: Union[str, openmm.app.Topology], trajectory: Union[str, pathlib.Path], rdmol: Chem.Mol, ligand_idxs: list[int], rmsf_cutoff: unit.Quantity = 1 * unit.angstrom, - angle_force_constant=83.68 * unit.kilojoule_per_mole / unit.radians**2, -): + angle_force_constant: unit.Quantity = 83.68 * unit.kilojoule_per_mole / unit.radians**2, +) -> list[tuple[int]]: """ Get a list of potential ligand atom choices for a Boresch restraint being applied to a given small molecule. - TODO: remember to update the RDMol with the last frame positions + Parameters + ---------- + topology : Union[str, openmm.app.Topology] + The topology of the system. + trajectory : Union[str, pathlib.Path] + A path to the system's coordinate trajectory. + rdmol : Chem.Mol + An RDKit Molecule representing the small molecule ordered in + the same way as it is listed in the topology. + ligand_idxs : list[int] + The ligand indices in the topology. + rmsf_cutoff : unit.Quantity + The RMSF filter cut-off. + angle_force_constant : unit.Quantity + The force constant for the l1-l2-l3 atom angle. + + Returns + ------- + angle_list : list[tuple[int]] + A list of tuples for each valid l1, l2, l3 angle. If ``None``, no + angles could be found. + + Raises + ------ + ValueError + If no suitable ligand atoms could be found. + + TODO + ---- + Remember to update the RDMol with the last frame positions. """ if isinstance(topology, openmm.app.Topology): topology_format = "OPENMMTOPOLOGY" @@ -179,26 +246,12 @@ def get_small_molecule_atom_candidates( u.trajectory[-1] # forward to the last frame # 1. Get the pool of atoms to work with - # TODO: move to a helper function to make it easier to test - # Get a list of all the aromatic rings - # Note: no need to keep track of rings because we'll filter by - # bonded terms after, so if we only keep rings then all the bonded - # atoms should be within the same ring system. - atom_pool = set() - for ring in get_aromatic_rings(rdmol): - max_rmsf = rmsf[list(ring)].max() - if max_rmsf < rmsf_cutoff: - atom_pool.update(ring) + atom_pool = _get_atom_pool(rdmol: Chem.Mol, rmsf: npt.NDArray) - # if we don't have enough atoms just get all the heavy atoms - if len(atom_pool) < 3: - heavy_atoms = get_heavy_atom_idxs(rdmol) - atom_pool = set(heavy_atoms[rmsf[heavy_atoms] < rmsf_cutoff]) - if len(atom_pool) < 3: - errmsg = ( - "No suitable ligand atoms for " "the boresch restraint could be found" - ) - raise ValueError(errmsg) + if atom_pool is None: + # We don't have enough atoms so we raise an error + errmsg = "No suitable ligand atoms were found for the restraint" + raise ValueError(errmsg) # 2. Get the central atom center = get_central_atom_idx(rdmol) @@ -211,7 +264,7 @@ def get_small_molecule_atom_candidates( for atom in sorted_anchor_pool: angles = _get_bonded_angles_from_pool(rdmol, atom, sorted_anchor_pool) for angle in _angles: - angle_ag = ligand_ag.atoms[angle] + angle_ag = ligand_ag.atoms[list(angle)] collinear = is_collinear(ligand_ag.positions, angle) angle_value = ( calc_angle( @@ -219,8 +272,7 @@ def get_small_molecule_atom_candidates( angle_ag.atoms[1].position, angle_ag.atoms[2].position, box=angle_ag.universe.dimensions, - ) - * unit.radians + ) * unit.radians ) energy = check_angle_energy( angle_value, angle_force_constant, 298.15 * unit.kelvin @@ -239,10 +291,13 @@ def get_host_atom_candidates( host_selection: str, dssp_filter: bool = False, rmsf_cutoff: unit.Quantity = 0.1 * unit.nanometer, - min_distance: unit.Quantity = 10 * unit.nanometer, - max_distance: unit.Quantity = 30 * unit.nanometer, + min_distance: unit.Quantity = 1 * unit.nanometer, + max_distance: unit.Quantity = 3 * unit.nanometer, angle_force_constant=83.68 * unit.kilojoule_per_mole / unit.radians**2, ): + """ + + """ if isinstance(topology, openmm.app.Topology): topology_format = "OPENMMTOPOLOGY" else: diff --git a/openfe/protocols/openmm_utils/restraints/search.py b/openfe/protocols/openmm_utils/restraints/search.py deleted file mode 100644 index 6b3d94eb7..000000000 --- a/openfe/protocols/openmm_utils/restraints/search.py +++ /dev/null @@ -1,360 +0,0 @@ -# This code is part of OpenFE and is licensed under the MIT license. -# For details, see https://github.com/OpenFreeEnergy/openfe -""" -Search methods for generating Geometry objects - -TODO ----- -* Add relevant duecredit entries. -""" -import abc -from pydantic.v1 import BaseModel, validator - -from openff.toolkit import Molecule as OFFMol -from openff.units import unit -import networkx as nx -import MDAnalysis as mda -from MDAnalysis.analysis.base import AnalysisBase -from MDAnalysis.lib.distances import calc_bonds, calc_angles - - -def _get_aromatic_atom_idxs(rdmol) -> list[int]: - """ - Helper method to get aromatic atoms idxs - in a RDKit Molecule - - Parameters - ---------- - rdmol : ??? - RDKit Molecule - - Returns - ------- - list[int] - A list of the aromatic atom idxs - """ - idxs = [ - at.GetIdx() for at in rdmol.GetAtoms() - if at.GetIsAromatic() - ] - return idxs - - -def _get_heavy_atom_idxs(rdmol) -> list[int]: - """ - Get idxs of heavy atoms in an RDKit Molecule - - Parameters - ---------- - rmdol : ??? - - Returns - ------- - list[int] - A list of heavy atom idxs - """ - idxs = [ - at.GetIdx() for at in rdmol.GetAtoms() - if at.GetAtomicNum() > 1 - ] - return idxs - - -def _get_central_atom_idx(rdmol) -> int: - offmol = OFFMol(rdmol, allow_undefined_stereo=True) - # We take the zero-th entry if there are multiple center - # atoms (e.g. equal likelihood centers) - center = nx.center(offmol.to_networkx())[0] - return center - - -def _sort_by_distance_from_target(rdmol, target_idx: int, atom_idxs: list[int]) -> list[int]: - """ - Sort a list of atoms by their distance from a target atom. - - Parameters - ---------- - target_idx : int - The idx of the target atom. - atom_idxs : list[int] - The idx values of the atoms to sort. - rdmol : ??? - RDKit Molecule the atoms belong to - - Returns - ------- - list[int] - The input atom idxs sorted by their distance from the target atom. - """ - distances = [] - - conformer = rdmol.GetConformer() - # Get the target atom position - target_pos = conformer.GetAtomPosition(target_idx) - - for idx in atom_idxs: - pos = conformer.GetAtomPosition(idx) - distances.append(((target_pos - pos).Length(), idx)) - - return [i[1] for i in sorted(distances)] - - -def _get_bonded_angles_from_pool(rdmol, atom_idx, atom_pool): - angles = [] - - # Get the base atom and its neighbors - at1 = rdmol.GetAtomWithIdx(atom_idx) - at1_neighbors = [at.GetIdx() for at in at1.GetNeighbors()] - - # We loop at2 and at3 through the sorted atom_pool in order to get - # a list of angles in the branch that are sorted by how close the atoms - # are from the central atom - for at2 in atom_pool: - if at2 in at1_neighbors: - at2_neighbors = [ - at.GetIdx() - for at in rdmol.GetAtomWithIdx(at2).GetNeighbors() - ] - for at3 in atom_pool: - if at3 != atom_idx and at3 in at2_neighbors: - angles.append((atom_idx, at2, at3)) - return angles - - -def get_ligand_anchor_atoms(rdmol) -> list[tuple[int, int, int]]: - """ - Get a list of ligand anchor atoms (e.g. l1, l2, and l3 of an orientational restraint). - - Parameters - ---------- - rdmol : ??? - Molecule object for the ligand to apply a restraint to. - - Returns - ------- - angles : list[tuple[int, int, int]] - A list of ligand atom triples denoting the possible l1, l2, and l3 - restraint atoms. Ordered by likelihood of restraint-ability. - """ - # Find the central atom - center = _get_central_atom_idx(rdmol) - - # Get a pool of potential anchor atoms looking for aromatic atoms - anchor_pool = _get_aromatic_atoms(rdmol) - - # If there are not enough aromatic atoms, then default to heavy atoms - if len(anchor_pool) < 3: - anchor_pool = _get_heavy_atoms(rdmol) - - # Raise an error if we have less than 3 anchors - if len(anchor_pool) < 3: - errmsg = f"Too few potential ligand anchor atoms, {len(anchor_pool)}" - raise ValueError(errmsg) - - # Sort the pool of anchor atoms by their distance from the central atom - sorted_anchor_pool = _sort_by_distance_from_target(rdmol, center, anchor_pool) - - # Get a list of ligand anchor angle atoms - angles = [] - for atom in sorted_anchor_pool: - angles.extend( - _get_bonded_angles_from_pool(rdmol, atom, sorted_anchor_pool) - ) - - -def get_host_anchors(positions, topology, exclude_resids: list[int], lig_anchor_idx: int, selection: str): - """ - Get a list of host anchor atomss sorted by their distance from a ligand anchor atom. - - Parameters - ---------- - positions : openmm.unit.Quantity - Positions of the input system - topology : openmm.app.Topology - OpenMM Topology for input system - exclude_resids : list[int] - List of residue numbers to exclude from host selection - lig_anchor_idx : int - The index of the l1 ligand anchor. - selection : str - Selection string for the host atoms. - """ - # Create an mdtraj trajectory to manipulate - # First fetch the box vectors and pass them as lengths and angles - vectors = from_openmm(topology.getPeriodicBoxVectors()) - a, b, c, alpha, beta, gamma = mdt.utils.box_vectors_to_lengths_and_angles(vectors[0].m, vectors[1].m, vectors[2].m) - - traj = mdt.Trajectory( - positions[np.newaxis, ...], - mdt.Topology.from_openmm(topology) - ) - - # Get all the potential protein atoms matching the selection - host_sel = traj.topology.select(selection) - - # Get residues to exclude from the selection - exclude_sel = np.array([ - at.index for at in - chain(*[traj.topology.residue(i).atoms for i in exclude_resids]) - ]) - - # Remove exclusion - anchors = host_sel[np.isin(host_sel, exclude_sel, invert=True)] - - # Compute distanecs from ligand l1 anchor atom - pairs = np.vstack((anchors, np.array([lig_anchor_idx for _ in range(len(anchors))]))).T - - distances = mdt.compute_distances(traj, pairs, periodic=True) - - return np.array([pairs[i][0] for i in np.argsort(distances[0])]) - - -def is_collinear(positions, atoms, threshold=0.9): - """ - Check whether any sequential vectors in a sequence of atoms are collinear. - - Parameters - ---------- - positions : openmm.unit.Quantity - System positions. - atoms : list[int] - The indices of the atoms to test. - threshold : float - Atoms are not collinear if their sequential vector separation dot - products are less than ``threshold``. Default 0.9. - - Returns - ------- - result : bool - Returns True if any sequential pair of vectors is collinear; False otherwise. - - Notes - ----- - Originally from Yank, with modifications from Separated Topologies - """ - results = False - for i in range(len(atoms) - 2): - v1 = positions[atoms[i + 1], :] - positions[atoms[i], :] - v2 = positions[atoms[i + 2], :] - positions[atoms[i + 1], :] - normalized_inner_product = np.dot(v1, v2) / np.sqrt(np.dot(v1, v1) * np.dot(v2, v2)) - result = result or (np.abs(normalized_inner_product) > threshold) - return result - - -def check_angle(angle, force_constant=83.68): - """ - Check whether the chosen angle is less than 10 kT from 0 or 180 - - Parameters - ---------- - angle : float - The angle to check in degrees. - force_constant : float - Force constant of the angle. - - Note - ---- - We assume the temperature to be 298.15 Kelvin. - """ - # TODO: convert this to unit.Quantity so we don't end up with - # conversion errors - RT = 8.31445985 * 0.001 * 298.15 - # check if angle is <10kT from 0 or 180 - check1 = 0.5 * force_constant * np.power((angle - 0.0) / 180.0 * np.pi, 2) - check2 = 0.5 * force_constant * np.power((angle - 180.0) / 180.0 * np.pi, 2) - ang_check_1 = check1 / RT - ang_check_2 = check2 / RT - if ang_check_1 < 10.0 or ang_check_2 < 10.0: - return False - return True - - - - -class FindHostAtoms(AnalysisBase): - """ - Class filter host atoms based on their distance - from a set of guest atoms. - - Parameters - ---------- - host_atoms : MDAnalysis.AtomGroup - Initial selection of host atoms to filter from. - guest_atoms : MDANalysis.AtomGroup - Selection of guest atoms to search around. - search_distance: unit.Quantity - Distance to filter atoms within. - """ - _analysis_algorithm_is_parallelizable = False - - def __init__(self, host_atoms, guest_atoms, search_distance, **kwargs): - super().__init__(host_atoms.universe.trajectory, **kwargs) - - self.host_ag = host_atoms - self.guest_ag = guest_atoms - self.cutoff = search_distance.to('angstrom').m - - def _prepare(self): - self.results.host_idxs = set() - - def _single_frame(self): - pairs = capped_distance( - reference=self.host_ag.positions, - configuration=self.guest_ag.positions, - max_cutoff=self.cutoff, - min_cutoff=None - box=self.guest_ag.universe.dimensions, - return_distances=False) - - host_idxs = [self.guest_ag.atoms[p].index for p in pairs[:, 1]] - self.results.host_idxs.update(set(host_idxs)) - - def _conclude(self): - pass - - -def find_host_atoms(topology, trajectory, host_selection, guest_selection, cutoff) -> mda.AtomGroup: - """ - Get an AtomGroup of the host atoms based on their distances from the guest atoms. - """ - u = mda.Universe(topology, trajectory) - - def _get_selection(selection): - """ - If it's a str, call select_atoms, if not a list of atom idxs - """ - if isinstance(selection, str): - ag = u.select_atoms(host_selection) - else: - ag = u.atoms[host_ag] - return ag - - host_ag = _get_selection(host_selection) - guest_ag = _get_selection(guest_selection) - - finder = FindHostAtoms(host_ag, guest_ag, cutoff) - finder.run() - - return u.atoms[list(finder.results.host_idxs)] - -def get_molecule_center_idx(atomgroup): - offmol = Molecule(atomgroup.convert_to("RDKIT"), allow_undefined_stereo=True) - # Check if the molecule is whole, otherwise throw an error. - nx = offmol.to_networkx() - - -def get_distance_restraint(topology, trajectory, host_atoms, guest_atoms, host_selection, guest_selection): - u = mda.Universe(topology, trajectory) - - if guest_atoms is None: - if guest_selection is None: - raise ValueError("one of guest_atoms or guest_selections must be defined") - guest_ag = u.select_atoms(guest_selection) - else: - - - if host_atoms is None: - if host_selection is None: - raise ValueError("one of host_atoms or host_selection must be defined") - - host_ag = u.select_atoms(host_selection) From 96decfffe7036ab5cff19ad4453517957a3291dd Mon Sep 17 00:00:00 2001 From: IAlibay Date: Thu, 12 Dec 2024 22:19:34 +0000 Subject: [PATCH 20/29] Remove duplicate methods --- .../restraints/geometry/boresch.py | 34 ++- .../openmm_utils/restraints/geometry/utils.py | 218 ++++-------------- 2 files changed, 75 insertions(+), 177 deletions(-) diff --git a/openfe/protocols/openmm_utils/restraints/geometry/boresch.py b/openfe/protocols/openmm_utils/restraints/geometry/boresch.py index a35c4f5a9..cf14b73aa 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/boresch.py +++ b/openfe/protocols/openmm_utils/restraints/geometry/boresch.py @@ -37,6 +37,13 @@ class BoreschRestraintGeometry(HostGuestRestraintGeometry): """ def get_bond_distance(self, topology, coordinates) -> unit.Quantity: + """ + Get the H2 - G0 distance + + Parameters + ---------- + topology : + """ u = mda.Universe(topology, coordinates) at1 = u.atoms[host_atoms[2]] at2 = u.atoms[guest_atoms[0]] @@ -293,10 +300,30 @@ def get_host_atom_candidates( rmsf_cutoff: unit.Quantity = 0.1 * unit.nanometer, min_distance: unit.Quantity = 1 * unit.nanometer, max_distance: unit.Quantity = 3 * unit.nanometer, - angle_force_constant=83.68 * unit.kilojoule_per_mole / unit.radians**2, ): """ + Get a list of suitable host atoms. + Parameters + ---------- + topology : Union[str, openmm.app.Topology] + The topology of the system. + trajectory : Union[str, pathlib.Path] + A path to the system's coordinate trajectory. + host_idxs : list[int] + A list of the host indices in the system topology. + l1_idx : int + The index of the proposed l1 binding atom. + host_selection : str + An MDAnalysis selection string to fileter the host by. + dssp_filter : bool + Whether or not to apply a DSSP filter on the host selection. + rmsf_cutoff : uni.Quantity + The maximum RMSF value allowwed for any candidate host atom. + min_distance : unit.Quantity + The minimum search distance around l1 for suitable candidate atoms. + max_distance : unit.Quantity + The maximum search distance around l1 for suitable candidate atoms. """ if isinstance(topology, openmm.app.Topology): topology_format = "OPENMMTOPOLOGY" @@ -322,3 +349,8 @@ def get_host_atom_candidates( ) atom_finder.run() return atom_finder.results.host_idxs + + +def select_boresch_atoms( + +): diff --git a/openfe/protocols/openmm_utils/restraints/geometry/utils.py b/openfe/protocols/openmm_utils/restraints/geometry/utils.py index 82e7e621f..a74e83ed3 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/utils.py +++ b/openfe/protocols/openmm_utils/restraints/geometry/utils.py @@ -11,6 +11,7 @@ from pydantic.v1 import BaseModel, validator import numpy as np +import numpy.typing as npt from scipy.stats import circvar, circmean, circstd from openff.toolkit import Molecule as OFFMol @@ -22,6 +23,7 @@ from MDAnalysis.analysis.base import AnalysisBase from MDAnalysis.analysis.rmsf import RMSF from MDAnalysis.lib.distances import calc_bonds, calc_angles +from MDAnalysis.coordinates.memory import MemoryReader from openfe_analysis.transformations import Aligner, NoJump @@ -30,6 +32,46 @@ ANGLE_FRC_CONSTANT_TYPE = FloatQuantity["unit.kilojoule_per_mole / unit.radians**2"] +def _get_mda_coord_format(coordinates: Union[str, npt.NDArray]) -> Optional[MemoryReader]: + """ + Helper to set the coordinate format to MemoryReader + if the coordinates are an NDArray. + + Parameters + ---------- + coordinates : Union[str, npt.NDArray] + + Returns + ------- + Optional[MemoryReader] + Either the MemoryReader class or None. + """ + if isinstance(coordinates, npt.NDArray): + return MemoryReader + else: + return None + +def _get_mda_topology_format(topology: Union[str, openmm.app.Topology]) -> Optional[str]: + """ + Helper to set the topology format to OPENMMTOPOLOGY + if the topology is an openmm.app.Topology. + + Parameters + ---------- + topology : Union[str, openmm.app.Topology] + + + Returns + ------- + Optional[str] + The string `OPENMMTOPOLOGY` or None. + """ + if isinstance(topology, openmm.app.Topology): + return "OPENMMTOPOLOGY" + else: + return None + + def get_aromatic_rings(rdmol: Chem.Mol) -> list[tuple[int, ...]]: """ Get a list of tuples with the indices for each ring in an rdkit Molecule. @@ -254,182 +296,6 @@ def check_angular_variance( return not (variance * unit.radians > width) -def _get_bonded_angles_from_pool(rdmol, atom_idx, atom_pool): - angles = [] - - # Get the base atom and its neighbors - at1 = rdmol.GetAtomWithIdx(atom_idx) - at1_neighbors = [at.GetIdx() for at in at1.GetNeighbors()] - - # We loop at2 and at3 through the sorted atom_pool in order to get - # a list of angles in the branch that are sorted by how close the atoms - # are from the central atom - for at2 in atom_pool: - if at2 in at1_neighbors: - at2_neighbors = [ - at.GetIdx() for at in rdmol.GetAtomWithIdx(at2).GetNeighbors() - ] - for at3 in atom_pool: - if at3 != atom_idx and at3 in at2_neighbors: - angles.append((atom_idx, at2, at3)) - return angles - - -def get_ligand_anchor_atoms(rdmol) -> list[tuple[int, int, int]]: - """ - Get a list of ligand anchor atoms (e.g. l1, l2, and l3 of an orientational restraint). - - Parameters - ---------- - rdmol : ??? - Molecule object for the ligand to apply a restraint to. - - Returns - ------- - angles : list[tuple[int, int, int]] - A list of ligand atom triples denoting the possible l1, l2, and l3 - restraint atoms. Ordered by likelihood of restraint-ability. - """ - # Find the central atom - center = _get_central_atom_idx(rdmol) - - # Get a pool of potential anchor atoms looking for aromatic atoms - anchor_pool = _get_aromatic_atoms(rdmol) - - # If there are not enough aromatic atoms, then default to heavy atoms - if len(anchor_pool) < 3: - anchor_pool = _get_heavy_atoms(rdmol) - - # Raise an error if we have less than 3 anchors - if len(anchor_pool) < 3: - errmsg = f"Too few potential ligand anchor atoms, {len(anchor_pool)}" - raise ValueError(errmsg) - - # Sort the pool of anchor atoms by their distance from the central atom - sorted_anchor_pool = _sort_by_distance_from_target(rdmol, center, anchor_pool) - - # Get a list of ligand anchor angle atoms - angles = [] - for atom in sorted_anchor_pool: - angles.extend(_get_bonded_angles_from_pool(rdmol, atom, sorted_anchor_pool)) - - -def get_host_anchors( - positions, topology, exclude_resids: list[int], lig_anchor_idx: int, selection: str -): - """ - Get a list of host anchor atomss sorted by their distance from a ligand anchor atom. - - Parameters - ---------- - positions : openmm.unit.Quantity - Positions of the input system - topology : openmm.app.Topology - OpenMM Topology for input system - exclude_resids : list[int] - List of residue numbers to exclude from host selection - lig_anchor_idx : int - The index of the l1 ligand anchor. - selection : str - Selection string for the host atoms. - """ - # Create an mdtraj trajectory to manipulate - # First fetch the box vectors and pass them as lengths and angles - vectors = from_openmm(topology.getPeriodicBoxVectors()) - a, b, c, alpha, beta, gamma = mdt.utils.box_vectors_to_lengths_and_angles( - vectors[0].m, vectors[1].m, vectors[2].m - ) - - traj = mdt.Trajectory( - positions[np.newaxis, ...], mdt.Topology.from_openmm(topology) - ) - - # Get all the potential protein atoms matching the selection - host_sel = traj.topology.select(selection) - - # Get residues to exclude from the selection - exclude_sel = np.array( - [ - at.index - for at in chain(*[traj.topology.residue(i).atoms for i in exclude_resids]) - ] - ) - - # Remove exclusion - anchors = host_sel[np.isin(host_sel, exclude_sel, invert=True)] - - # Compute distanecs from ligand l1 anchor atom - pairs = np.vstack( - (anchors, np.array([lig_anchor_idx for _ in range(len(anchors))])) - ).T - - distances = mdt.compute_distances(traj, pairs, periodic=True) - - return np.array([pairs[i][0] for i in np.argsort(distances[0])]) - - -def is_collinear(positions, atoms, threshold=0.9): - """ - Check whether any sequential vectors in a sequence of atoms are collinear. - - Parameters - ---------- - positions : openmm.unit.Quantity - System positions. - atoms : list[int] - The indices of the atoms to test. - threshold : float - Atoms are not collinear if their sequential vector separation dot - products are less than ``threshold``. Default 0.9. - - Returns - ------- - result : bool - Returns True if any sequential pair of vectors is collinear; False otherwise. - - Notes - ----- - Originally from Yank, with modifications from Separated Topologies - """ - results = False - for i in range(len(atoms) - 2): - v1 = positions[atoms[i + 1], :] - positions[atoms[i], :] - v2 = positions[atoms[i + 2], :] - positions[atoms[i + 1], :] - normalized_inner_product = np.dot(v1, v2) / np.sqrt( - np.dot(v1, v1) * np.dot(v2, v2) - ) - result = result or (np.abs(normalized_inner_product) > threshold) - return result - - -def check_angle(angle, force_constant=83.68): - """ - Check whether the chosen angle is less than 10 kT from 0 or 180 - - Parameters - ---------- - angle : float - The angle to check in degrees. - force_constant : float - Force constant of the angle. - - Note - ---- - We assume the temperature to be 298.15 Kelvin. - """ - # TODO: convert this to unit.Quantity so we don't end up with - # conversion errors - RT = 8.31445985 * 0.001 * 298.15 - # check if angle is <10kT from 0 or 180 - check1 = 0.5 * force_constant * np.power((angle - 0.0) / 180.0 * np.pi, 2) - check2 = 0.5 * force_constant * np.power((angle - 180.0) / 180.0 * np.pi, 2) - ang_check_1 = check1 / RT - ang_check_2 = check2 / RT - if ang_check_1 < 10.0 or ang_check_2 < 10.0: - return False - return True - - class FindHostAtoms(AnalysisBase): """ Class filter host atoms based on their distance From 2d97de82dc762f5f294e503282efcbe7bb0f03bf Mon Sep 17 00:00:00 2001 From: Irfan Alibay Date: Thu, 12 Dec 2024 22:21:50 +0000 Subject: [PATCH 21/29] Apply suggestions from code review --- .../openmm_utils/restraints/openmm/omm_restraints.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/openfe/protocols/openmm_utils/restraints/openmm/omm_restraints.py b/openfe/protocols/openmm_utils/restraints/openmm/omm_restraints.py index e53e828d5..a3fe777d3 100644 --- a/openfe/protocols/openmm_utils/restraints/openmm/omm_restraints.py +++ b/openfe/protocols/openmm_utils/restraints/openmm/omm_restraints.py @@ -125,7 +125,7 @@ def _verify_geometry(self, geometry: BaseRestraintGeometry): super()._verify_geometry(geometry) -class BaseRadialllySymmetricRestraintForce(BaseHostGuestRestraints): +class BaseRadiallySymmetricRestraintForce(BaseHostGuestRestraints): def _verify_inputs(self) -> None: if not isinstance(self.settings, BaseDistanceRestraintSettings): errmsg = f"Incorrect settings type {self.settings} passed through" @@ -162,7 +162,7 @@ def _get_force(self, geometry: DistanceRestraintGeometry): raise NotImplementedError("only implemented in child classes") -class HarmonicBondRestraint(BaseRadialllySymmetricRestraintForce, SingleBondMixin): +class HarmonicBondRestraint(BaseRadiallySymmetricRestraintForce, SingleBondMixin): def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: spring_constant = to_openmm(self.settings.spring_constant).value_in_unit_system(omm_unit.md_unit_system) return HarmonicRestraintBondForce( @@ -173,7 +173,7 @@ def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: ) -class FlatBottomBondRestraint(BaseRadialllySymmetricRestraintForce, SingleBondMixin): +class FlatBottomBondRestraint(BaseRadiallySymmetricRestraintForce, SingleBondMixin): def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: spring_constant = to_openmm(self.settings.spring_constant).value_in_unit_system(omm_unit.md_unit_system) well_radius = to_openmm(geometry.well_radius).value_in_unit_system(omm_unit.md_unit_system) @@ -186,7 +186,7 @@ def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: ) -class CentroidHarmonicRestraint(BaseRadialllySymmetricRestraintForce): +class CentroidHarmonicRestraint(BaseRadiallySymmetricRestraintForce): def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: spring_constant = to_openmm(self.settings.spring_constant).value_in_unit_system(omm_unit.md_unit_system) return HarmonicRestraintForce( @@ -197,7 +197,7 @@ def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: ) -class CentroidFlatBottomRestraint(BaseRadialllySymmetricRestraintForce): +class CentroidFlatBottomRestraint(BaseRadiallySymmetricRestraintForce): def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: spring_constant = to_openmm(self.settings.spring_constant).value_in_unit_system(omm_unit.md_unit_system) well_radius = to_openmm(geometry.well_radius).value_in_unit_system(omm_unit.md_unit_system) From 116ba64b0cb29acf69124a7743fc74c2d24815c7 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Thu, 12 Dec 2024 22:33:09 +0000 Subject: [PATCH 22/29] Add some more docstring --- .../restraints/geometry/boresch.py | 72 ++++++++++++------- 1 file changed, 48 insertions(+), 24 deletions(-) diff --git a/openfe/protocols/openmm_utils/restraints/geometry/boresch.py b/openfe/protocols/openmm_utils/restraints/geometry/boresch.py index cf14b73aa..cfaf443dc 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/boresch.py +++ b/openfe/protocols/openmm_utils/restraints/geometry/boresch.py @@ -35,23 +35,48 @@ class BoreschRestraintGeometry(HostGuestRestraintGeometry): Where HX represents the X index of ``host_atoms`` and GX the X index of ``guest_atoms``. """ - - def get_bond_distance(self, topology, coordinates) -> unit.Quantity: + def get_bond_distance( + self, + topology: Union[str, openmm.app.Topology], + coordinates: Union[str, npt.NDArray], + ) -> unit.Quantity: """ - Get the H2 - G0 distance + Get the H2 - G0 distance. Parameters ---------- - topology : + topology : Union[str, openmm.app.Topology] + coordinates : Union[str, npt.NDArray] + A coordinate file or NDArray in frame-atom-coordinate + order in Angstrom. """ - u = mda.Universe(topology, coordinates) + u = mda.Universe( + topology, + coordinates, + format=_get_mda_coord_format(coordinates), + topology_format=_get_mda_topology_format(topology) + ) at1 = u.atoms[host_atoms[2]] at2 = u.atoms[guest_atoms[0]] bond = calc_bonds(at1.position, at2.position, u.atoms.dimensions) # convert to float so we avoid having a np.float64 return float(bond) * unit.angstrom - def get_angles(self, topology, coordinates) -> unit.Quantity: + def get_angles( + self, + topology: Union[str, openmm.app.Topology], + coordinates: Union[str, npt.NDArray], + ) -> unit.Quantity: + """ + Get the H1-H2-G0, and H2-G0-G1 angles. + + Parameters + ---------- + topology : Union[str, openmm.app.Topology] + coordinates : Union[str, npt.NDArray] + A coordinate file or NDArray in frame-atom-coordinate + order in Angstrom. + """ u = mda.Universe(topology, coordinates) at1 = u.atoms[host_atoms[1]] at2 = u.atoms[host_atoms[2]] @@ -66,7 +91,21 @@ def get_angles(self, topology, coordinates) -> unit.Quantity: ) return angleA, angleB - def get_dihedrals(self, topology, coordinates) -> unit.Quantity: + def get_dihedrals( + self, + topology: Union[str, openmm.app.Topology], + coordinates: Union[str, npt.NDArray], + ) -> unit.Quantity: + """ + Get the H0-H1-H2-G0, H1-H2-G0-G1, and H2-G0-G1-G2 dihedrals. + + Parameters + ---------- + topology : Union[str, openmm.app.Topology] + coordinates : Union[str, npt.NDArray] + A coordinate file or NDArray in frame-atom-coordinate + order in Angstrom. + """ u = mda.Universe(topology, coordinates) at1 = u.atoms[host_atoms[0]] at2 = u.atoms[host_atoms[1]] @@ -84,7 +123,6 @@ def get_dihedrals(self, topology, coordinates) -> unit.Quantity: dihC = calc_dihedrals( at3.position, at4.position, at5.position, at6.position, u.atoms.dimensions ) - return dihA, dihB, dihC @@ -203,7 +241,6 @@ def get_small_molecule_guest_atom_candidates( rdmol: Chem.Mol, ligand_idxs: list[int], rmsf_cutoff: unit.Quantity = 1 * unit.angstrom, - angle_force_constant: unit.Quantity = 83.68 * unit.kilojoule_per_mole / unit.radians**2, ) -> list[tuple[int]]: """ Get a list of potential ligand atom choices for a Boresch restraint @@ -222,8 +259,6 @@ def get_small_molecule_guest_atom_candidates( The ligand indices in the topology. rmsf_cutoff : unit.Quantity The RMSF filter cut-off. - angle_force_constant : unit.Quantity - The force constant for the l1-l2-l3 atom angle. Returns ------- @@ -271,20 +306,9 @@ def get_small_molecule_guest_atom_candidates( for atom in sorted_anchor_pool: angles = _get_bonded_angles_from_pool(rdmol, atom, sorted_anchor_pool) for angle in _angles: + # Check that the angle is at least not collinear angle_ag = ligand_ag.atoms[list(angle)] - collinear = is_collinear(ligand_ag.positions, angle) - angle_value = ( - calc_angle( - angle_ag.atoms[0].position, - angle_ag.atoms[1].position, - angle_ag.atoms[2].position, - box=angle_ag.universe.dimensions, - ) * unit.radians - ) - energy = check_angle_energy( - angle_value, angle_force_constant, 298.15 * unit.kelvin - ) - if not collinear and energy: + if not is_collinear(ligand_ag.positions, angle): angles_list.append(angle) return angles_list From 9ae60da278e4c54eaf2f98feaf2387bef09a76a7 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Fri, 13 Dec 2024 01:20:40 +0000 Subject: [PATCH 23/29] Add minimized vectors on the collinear checks --- .../restraints/geometry/boresch.py | 181 +++++++++++++++--- .../openmm_utils/restraints/geometry/utils.py | 26 ++- 2 files changed, 168 insertions(+), 39 deletions(-) diff --git a/openfe/protocols/openmm_utils/restraints/geometry/boresch.py b/openfe/protocols/openmm_utils/restraints/geometry/boresch.py index cfaf443dc..363e22c6b 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/boresch.py +++ b/openfe/protocols/openmm_utils/restraints/geometry/boresch.py @@ -8,13 +8,15 @@ * Add relevant duecredit entries. """ import abc +import pathlib from pydantic.v1 import BaseModel, validator from rdkit import Chem from openff.units import unit import MDAnalysis as mda -from MDAnalysis.lib.distances import calc_bonds, calc_angles +from MDANalysis.analysis.base import AnalysisBase +from MDAnalysis.lib.distances import calc_bonds, calc_angles, calc_dihedrals import numpy as np import numpy.typing as npt @@ -37,8 +39,8 @@ class BoreschRestraintGeometry(HostGuestRestraintGeometry): """ def get_bond_distance( self, - topology: Union[str, openmm.app.Topology], - coordinates: Union[str, npt.NDArray], + topology: Union[str, pathlib.Path, openmm.app.Topology], + coordinates: Union[str, pathlib.Path, npt.NDArray], ) -> unit.Quantity: """ Get the H2 - G0 distance. @@ -64,8 +66,8 @@ def get_bond_distance( def get_angles( self, - topology: Union[str, openmm.app.Topology], - coordinates: Union[str, npt.NDArray], + topology: Union[str, pathlib.Path, openmm.app.Topology], + coordinates: Union[str, pathlib.Path, npt.NDArray], ) -> unit.Quantity: """ Get the H1-H2-G0, and H2-G0-G1 angles. @@ -77,7 +79,12 @@ def get_angles( A coordinate file or NDArray in frame-atom-coordinate order in Angstrom. """ - u = mda.Universe(topology, coordinates) + u = mda.Universe( + topology, + coordinates, + format=_get_mda_coord_format(coordinates), + topology_format=_get_mda_topology_format(topology) + ) at1 = u.atoms[host_atoms[1]] at2 = u.atoms[host_atoms[2]] at3 = u.atoms[guest_atoms[0]] @@ -93,8 +100,8 @@ def get_angles( def get_dihedrals( self, - topology: Union[str, openmm.app.Topology], - coordinates: Union[str, npt.NDArray], + topology: Union[str, pathlib.Path, openmm.app.Topology], + coordinates: Union[str, pathlib.Path, npt.NDArray], ) -> unit.Quantity: """ Get the H0-H1-H2-G0, H1-H2-G0-G1, and H2-G0-G1-G2 dihedrals. @@ -106,7 +113,12 @@ def get_dihedrals( A coordinate file or NDArray in frame-atom-coordinate order in Angstrom. """ - u = mda.Universe(topology, coordinates) + u = mda.Universe( + topology, + coordinates, + format=_get_mda_coord_format(coordinates), + topology_format=_get_mda_topology_format(topology) + ) at1 = u.atoms[host_atoms[0]] at2 = u.atoms[host_atoms[1]] at3 = u.atoms[host_atoms[2]] @@ -235,12 +247,12 @@ def _get_atom_pool(rdmol: Chem.Mol, rmsf: npt.NDArray) -> Optional[set[int]]: return atom_pool -def get_small_molecule_guest_atom_candidates( - topology: Union[str, openmm.app.Topology], +def get_guest_atom_candidates( + topology: Union[str, pathlib.Path, openmm.app.Topology], trajectory: Union[str, pathlib.Path], rdmol: Chem.Mol, - ligand_idxs: list[int], - rmsf_cutoff: unit.Quantity = 1 * unit.angstrom, + guest_idxs: list[int], + rmsf_cutoff: unit.Quantity = 1 * unit.nanometer, ) -> list[tuple[int]]: """ Get a list of potential ligand atom choices for a Boresch restraint @@ -255,7 +267,7 @@ def get_small_molecule_guest_atom_candidates( rdmol : Chem.Mol An RDKit Molecule representing the small molecule ordered in the same way as it is listed in the topology. - ligand_idxs : list[int] + guest_idxs : list[int] The ligand indices in the topology. rmsf_cutoff : unit.Quantity The RMSF filter cut-off. @@ -275,13 +287,14 @@ def get_small_molecule_guest_atom_candidates( ---- Remember to update the RDMol with the last frame positions. """ - if isinstance(topology, openmm.app.Topology): - topology_format = "OPENMMTOPOLOGY" - else: - topology_format = None + u = mda.Universe( + topology, + coordinates, + format=_get_mda_coord_format(coordinates), + topology_format=_get_mda_topology_format(topology) + ) - u = mda.Universe(topology, trajectory, topology_format=topology_format) - ligand_ag = u.atoms[ligand_idxs] + ligand_ag = u.atoms[guest_idxs] # 0. Get the ligand RMSF rmsf = get_local_rmsf(ligand_ag) @@ -308,14 +321,20 @@ def get_small_molecule_guest_atom_candidates( for angle in _angles: # Check that the angle is at least not collinear angle_ag = ligand_ag.atoms[list(angle)] - if not is_collinear(ligand_ag.positions, angle): - angles_list.append(angle) + if not is_collinear(ligand_ag.positions, angle, u.dimensions): + angles_list.append( + ( + angle_ag.atoms[0].ix, + angle_ag.atoms[1].ix, + angle_ag.atoms[2].ix + ) + ) return angles_list def get_host_atom_candidates( - topology: Union[str, openmm.app.Topology], + topology: Union[str, pathlib.Path, openmm.app.Topology], trajectory: Union[str, pathlib.Path], host_idxs: list[int], l1_idx: int, @@ -349,12 +368,13 @@ def get_host_atom_candidates( max_distance : unit.Quantity The maximum search distance around l1 for suitable candidate atoms. """ - if isinstance(topology, openmm.app.Topology): - topology_format = "OPENMMTOPOLOGY" - else: - topology_format = None + u = mda.Universe( + topology, + coordinates, + format=_get_mda_coord_format(coordinates), + topology_format=_get_mda_topology_format(topology) + ) - u = mda.Universe(topology, trajectory, topology_format=topology_format) protein_ag1 = u.atoms[host_idxs] protein_ag2 = protein_ag.select_atoms(protein_selection) @@ -375,6 +395,107 @@ def get_host_atom_candidates( return atom_finder.results.host_idxs -def select_boresch_atoms( +class EvaluateH2Atoms(AnalysisBase): + """ + Class to evaluate the suitability of a set of host atoms + as a H2 atom (i.e. bonded to the guest G0 atom). + + Parameters + ---------- + guest_atoms: MDAnalysis.AtomGroup + The guest atoms representing G0-G1-G2. + host_atom_pool: MDAnalysis.AtomGroup + The pool of atoms to pick a H2 from. + angle_force_constant : unit.Quantity + The force constant for the H2-G0-G1 angle. + """ + + +def find_boresch_restraint( + topology: Union[str, pathlib.Path, openmm.app.Topology], + trajectory: Union[str, pathlib.Path], + guest_rdmol: Chem.Mol, + guest_idxs: list[int], + host_idxs: list[int], + guest_restraint_atom_idxs: Optional[list[int]] = None, + host_restraint_atoms_idxs Optional[list[int]] = None, + host_selection: str = 'all', + dssp_filter: bool = False, + rmsf_custoff: unit.Quantity = 0.1 * unit.nanometer, + host_min_distance: unit.Quantity = 1 * unit.nanometer, + host_max_distance: unit.Quantity = 3 * unit.nanometer, +) -> BoreschRestraintGeometry: + """ + Find suitable Boresch-style restraints between a host and guest entity. + + Parameters + ---------- + ... + + Returns + ------- + ... + """ + u = mda.Universe( + topology, + coordinates, + format=_get_mda_coord_format(coordinates), + topology_format=_get_mda_topology_format(topology) + ) + u.trajectory[-1] # Work with the final frame + + if (guest_restraint_atoms_idxs is not None) and (host_restraint_atoms_idxs is not None): + # In this case assume the picked atoms were intentional / representative + # of the input and go with it + guest_ag = u.select_atoms[guest_idxs] + guest_angle = (at.ix for at in guest_ag.atoms[guest_restraint_atom_idxs]) + host_ag = u.select_atoms[host_idxs] + host_angle = (at.ix for at in host_ag.atoms[host_restraint_atoms_idxs]) + # TODO sort out the return on this + return BoreschRestraintGeometry(...) + + if (guest_restraint_atoms_idxs is not None) ^ (host_restraint_atoms_idxs is not None): + # This is not an intended outcome, crash out here + errmsg = ( + "both ``guest_restraints_atoms_idxs`` and ``host_restraint_atoms_idxs`` " + "must be set or both must be None. " + f"Got {guest_restraint_atoms_idxs} and {host_atoms_restraint_atoms_idxs}" + ) + raise ValueError(errmsg) + + # Fetch the guest angles + guest_angles = get_guest_atom_candidates( + topology=topology, + trajectory=trajectory, + rdmol=guest_rdmol, + guest_idxs=guest_idxs, + rmsf_cutoff=rmsf_cutoff, + ) + + guest_angle = guest_angles[0] + + # Fetch the host atom pool + host_pool = get_host_atom_candidates( + topology=topology, + trajectory=trajectory, + host_idxs=host_idxs, + l1_idx=guest_angle[0], + host_selection=host_selection, + dssp_filter=dssp_filter, + rmsf_cutoff=rmsf_custoff, + min_distance=host_min_distance, + max_distance=host_max_distance, + ) + + # Get the guest angle atomgroup + guest_ag = u.atoms[list(guest_angle)] + + # Find all suitable H2 idxs + h2_idxs = [] + for i in host_pool: + host2_at = u.atoms[i] + pos = np.vstack((at.position, guest_ag.positions)) + angle = calc_angles(pos[0], pos[1], pos[2], box=u.dimensions) * unit.radians + dihed = calc_dihedrals(pos[0], pos[1], pos[2], pos[3], box=u.dimensions) * unit.radians + collinear = is_collinear(positions, [0, 1, 2, 3]) -): diff --git a/openfe/protocols/openmm_utils/restraints/geometry/utils.py b/openfe/protocols/openmm_utils/restraints/geometry/utils.py index a74e83ed3..a1f983621 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/utils.py +++ b/openfe/protocols/openmm_utils/restraints/geometry/utils.py @@ -22,7 +22,7 @@ import MDAnalysis as mda from MDAnalysis.analysis.base import AnalysisBase from MDAnalysis.analysis.rmsf import RMSF -from MDAnalysis.lib.distances import calc_bonds, calc_angles +from MDAnalysis.lib.distances import calc_bonds, calc_angles, minimize_vectors from MDAnalysis.coordinates.memory import MemoryReader from openfe_analysis.transformations import Aligner, NoJump @@ -166,19 +166,21 @@ def get_central_atom_idx(rdmol: Chem.Mol) -> int: return center -def is_collinear(positions, atoms, threshold=0.9): +def is_collinear(positions, atoms, dimensions=None, threshold=0.9): """ Check whether any sequential vectors in a sequence of atoms are collinear. Parameters ---------- positions : openmm.unit.Quantity - System positions. + System positions. atoms : list[int] - The indices of the atoms to test. + The indices of the atoms to test. + dimensions : Optional[npt.NDArray] + The dimensions of the system to minimize vectors. threshold : float - Atoms are not collinear if their sequential vector separation dot - products are less than ``threshold``. Default 0.9. + Atoms are not collinear if their sequential vector separation dot + products are less than ``threshold``. Default 0.9. Returns ------- @@ -187,12 +189,18 @@ def is_collinear(positions, atoms, threshold=0.9): Notes ----- - Originally from Yank, with modifications from Separated Topologies + Originally from Yank. """ results = False for i in range(len(atoms) - 2): - v1 = positions[atoms[i + 1], :] - positions[atoms[i], :] - v2 = positions[atoms[i + 2], :] - positions[atoms[i + 1], :] + v1 = minimize_vectors( + positions[atoms[i + 1], :] - positions[atoms[i], :], + box=dimensions, + ) + v2 = minimize_vectors( + positions[atoms[i + 2], :] - positions[atoms[i + 1], :], + box=dimensions, + ) normalized_inner_product = np.dot(v1, v2) / np.sqrt( np.dot(v1, v1) * np.dot(v2, v2) ) From 3cce308a8548df90305d7ce11c03f128dee6cc0d Mon Sep 17 00:00:00 2001 From: IAlibay Date: Sat, 14 Dec 2024 00:07:26 +0000 Subject: [PATCH 24/29] add host atom finding routine --- .../restraints/geometry/boresch.py | 277 +++++++++++++++--- .../openmm_utils/restraints/geometry/utils.py | 20 +- 2 files changed, 259 insertions(+), 38 deletions(-) diff --git a/openfe/protocols/openmm_utils/restraints/geometry/boresch.py b/openfe/protocols/openmm_utils/restraints/geometry/boresch.py index 363e22c6b..41bd3b3b7 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/boresch.py +++ b/openfe/protocols/openmm_utils/restraints/geometry/boresch.py @@ -19,6 +19,7 @@ from MDAnalysis.lib.distances import calc_bonds, calc_angles, calc_dihedrals import numpy as np import numpy.typing as npt +from scipy.stats import circmean from .base import HostGuestRestraintGeometry @@ -29,10 +30,10 @@ class BoreschRestraintGeometry(HostGuestRestraintGeometry): The restraint is defined by the following: - H0 G2 + H2 G2 - - - - - H1 - - H2 -- G0 - - G1 + H1 - - H0 -- G0 - - G1 Where HX represents the X index of ``host_atoms`` and GX the X index of ``guest_atoms``. @@ -43,7 +44,7 @@ def get_bond_distance( coordinates: Union[str, pathlib.Path, npt.NDArray], ) -> unit.Quantity: """ - Get the H2 - G0 distance. + Get the H0 - G0 distance. Parameters ---------- @@ -58,7 +59,7 @@ def get_bond_distance( format=_get_mda_coord_format(coordinates), topology_format=_get_mda_topology_format(topology) ) - at1 = u.atoms[host_atoms[2]] + at1 = u.atoms[host_atoms[0]] at2 = u.atoms[guest_atoms[0]] bond = calc_bonds(at1.position, at2.position, u.atoms.dimensions) # convert to float so we avoid having a np.float64 @@ -70,7 +71,7 @@ def get_angles( coordinates: Union[str, pathlib.Path, npt.NDArray], ) -> unit.Quantity: """ - Get the H1-H2-G0, and H2-G0-G1 angles. + Get the H1-H0-G0, and H0-G0-G1 angles. Parameters ---------- @@ -86,7 +87,7 @@ def get_angles( topology_format=_get_mda_topology_format(topology) ) at1 = u.atoms[host_atoms[1]] - at2 = u.atoms[host_atoms[2]] + at2 = u.atoms[host_atoms[0]] at3 = u.atoms[guest_atoms[0]] at4 = u.atoms[guest_atoms[1]] @@ -104,7 +105,7 @@ def get_dihedrals( coordinates: Union[str, pathlib.Path, npt.NDArray], ) -> unit.Quantity: """ - Get the H0-H1-H2-G0, H1-H2-G0-G1, and H2-G0-G1-G2 dihedrals. + Get the H2-H1-H0-G0, H1-H0-G0-G1, and H0-G0-G1-G2 dihedrals. Parameters ---------- @@ -119,9 +120,9 @@ def get_dihedrals( format=_get_mda_coord_format(coordinates), topology_format=_get_mda_topology_format(topology) ) - at1 = u.atoms[host_atoms[0]] + at1 = u.atoms[host_atoms[2]] at2 = u.atoms[host_atoms[1]] - at3 = u.atoms[host_atoms[2]] + at3 = u.atoms[host_atoms[0]] at4 = u.atoms[guest_atoms[0]] at5 = u.atoms[guest_atoms[1]] at6 = u.atoms[guest_atoms[2]] @@ -275,7 +276,7 @@ def get_guest_atom_candidates( Returns ------- angle_list : list[tuple[int]] - A list of tuples for each valid l1, l2, l3 angle. If ``None``, no + A list of tuples for each valid G0, G1, G2 angle. If ``None``, no angles could be found. Raises @@ -343,7 +344,7 @@ def get_host_atom_candidates( rmsf_cutoff: unit.Quantity = 0.1 * unit.nanometer, min_distance: unit.Quantity = 1 * unit.nanometer, max_distance: unit.Quantity = 3 * unit.nanometer, -): +) -> npt.NDArray: """ Get a list of suitable host atoms. @@ -367,6 +368,11 @@ def get_host_atom_candidates( The minimum search distance around l1 for suitable candidate atoms. max_distance : unit.Quantity The maximum search distance around l1 for suitable candidate atoms. + + Return + ------ + NDArray + Array of host atom indexes """ u = mda.Universe( topology, @@ -395,20 +401,212 @@ def get_host_atom_candidates( return atom_finder.results.host_idxs -class EvaluateH2Atoms(AnalysisBase): +class EvaluateHostAtoms1(AnalysisBase): """ Class to evaluate the suitability of a set of host atoms - as a H2 atom (i.e. bonded to the guest G0 atom). + as H1 atoms (i.e. the second host atom). Parameters ---------- - guest_atoms: MDAnalysis.AtomGroup - The guest atoms representing G0-G1-G2. - host_atom_pool: MDAnalysis.AtomGroup - The pool of atoms to pick a H2 from. + reference : MDAnalysis.AtomGroup + The reference preceeding three atoms. + host_atom_pool : MDAnalysis.AtomGroup + The pool of atoms to pick an atom from. + minimum_distance : unit.Quantity + The minimum distance from the bound reference atom. angle_force_constant : unit.Quantity - The force constant for the H2-G0-G1 angle. + The force constant for the angle. + temperature : unit.Quantity + The system temperature in Kelvin """ + def __init__( + self, + reference, + host_atom_pool, + minimum_distance, + angle_force_constant, + temperature, + **kwargs + ): + super().__init__(reference.universe.trajectory, **kwargs) + + if len(reference) != 3: + errmsg = "Incorrect number of reference atoms passed" + raise ValueError(errmsg) + + self.reference = reference + self.host_atom_pool = host_atom_pool + self.minimum_distance = minimum_distance.to('angstrom').m + self.angle_force_constant = angle_force_constant + self.temperature = temperature + + def _prepare(self): + self.results.distances = np.zeros( + (len(self.host_atom_pool), self.n_frames) + ) + self.results.angles = np.zeros( + (len(self.host_atom_pool), self.n_frames) + ) + self.results.dihedrals = np.zeros( + (len(self.host_atom_pool), self.n_frames) + ) + self.results.collinear = np.empty( + (len(self.host_atom_pool), self.n_frames), + dtype=bool, + ) + self.results.valid = np.empty( + len(self.host_atom_pool), + dtype=bool, + ) + + def _single_frame(self): + for i, at in enumerate(self.host_atom_pool): + distance = calc_bonds( + at.position, + self.reference.atoms[0].position, + box=self.reference.dimensions, + ) + angle = calc_angles( + at.position, + self.reference.atoms[0].position, + self.reference.atoms[1].position, + box=self.reference.dimensions, + ) + dihedral = calc_dihedrals( + at.position, + self.reference.atoms[0].position, + self.reference.atoms[1].position, + self.reference.atoms[2].position, + box=self.reference.dimensions + ) + collinear = is_collinear( + positions=np.vstack((at.position, self.reference.positions)), + dimensions=self.reference.dimensions, + ) + self.results.distances[i][self._frame_index] = distance + self.results.angles[i][self._frame_index] = angle + self.results.dihedrals[i][self._frame_index] = dihedral + self.results.collinear[i][self._frame_index] = collinear + + def _conclude(self): + for i, at in enumerate(self.host_atom_pool): + distance_bounds = all( + self.results.distances[i] > self.minimum_distance + ) + mean_angle = circmean(self.results.angles[i], high=np.pi, low=0) + angle_bounds = check_angle_not_flat( + angle=mean_angle * unit.radians, + force_constant=self.angle_force_constant, + temperature=self.temperature, + ) + angle_variance = check_angular_variance( + self.results.angles[i] * unit.radians, + upper_bound=np.pi * unit.radians, + lower_bound=0 * unit.radians, + width=1.745 * unit.radians, + ) + mean_dihed = circmean(self.results.dihedrals[i], high=np.pi, low=-np.pi) + dihed_bounds = check_dihedral_bounds(mean_dihed) + dihed_variance = check_angular_variance( + self.results.dihedrals[i] * unit.radians, + upper_bound=np.pi * unit.radians, + lower_bound=-np.pi * unit.radians, + width=5.23 * unit.radians, + ) + not_collinear = not all(self.results.collinear[i]) + if all([distance_bounds, angle_bounds, angle_variance, dihed_bounds, dihed_variance, not_collinear]): + self.results.valid[i] = True + + +class EvaluateHostAtoms2(EvaluateH21Atoms): + def _prepare(self): + self.results.distances1 = np.zeros( + (len(self.host_atom_pool), self.n_frames) + ) + self.results.ditances2 = np.zeros( + (len(self.host_atom_pool), self.n_frames) + ) + self.results.dihedrals = np.zeros( + (len(self.host_atom_pool), self.n_frames) + ) + self.results.collinear = np.empty( + (len(self.host_atom_pool), self.n_frames), + dtype=bool, + ) + self.results.valid = np.empty( + len(self.host_atom_pool), + dtype=bool, + ) + + def _single_frame(self): + for i, at in enumerate(self.host_atom_pool): + distance1 = calc_bonds( + at.position, + self.reference.atoms[0].position, + box=self.reference.dimensions, + ) + distance2 = calc_bonds( + at.position, + self.reference.atoms[1].position, + box=self.reference.dimensions, + ) + dihedral = calc_dihedrals( + at.position, + self.reference.atoms[0].position, + self.reference.atoms[1].position, + self.reference.atoms[2].position, + box=self.reference.dimensions + ) + collinear = is_collinear( + positions=np.vstack((at.position, self.reference.positions)), + dimensions=self.reference.dimensions, + ) + self.results.distances1[i][self._frame_index] = distance + self.results.distances2[i][self._frame_index] = angle + self.results.dihedrals[i][self._frame_index] = dihedral + self.results.collinear[i][self._frame_index] = collinear + + def _conclude(self): + for i, at in enumerate(self.host_atom_pool): + distance1_bounds = all( + self.results.distances1[i] > self.minimum_distance + ) + distance2_bounds = all( + self.results.distances2[i] > self.minimum_distance + ) + mean_dihed = circmean(self.results.dihedrals[i], high=np.pi, low=-np.pi) + dihed_bounds = check_dihedral_bounds(mean_dihed) + dihed_variance = check_angular_variance( + self.results.dihedrals[i] * unit.radians, + upper_bound=np.pi * unit.radians, + lower_bound=-np.pi * unit.radians, + width=5.23 * unit.radians, + ) + not_collinear = not all(self.results.collinear[i]) + if all([distance1_bounds, distance2_bounds, dihed_bounds, dihed_variance, not_collinear]): + self.results.valid[i] = True + + +def _find_host_angle(g0g1g2_atoms, host_atom_pool, minimum_distance, angle_force_constant, temperature): + h0_eval = EvaluateHAtoms1(g0g1g2_atoms, host_atom_pool, minimum_distance, angle_force_constant, temperature) + h0_eval.run() + + for i, valid_h0 in enumerate(h0_eval.results.valid): + if valid_h0: + g1g2h0_atoms = g0g1g2_atoms.atoms[1:] + host_atom_pool.atoms[i] + h1_eval = EvaluateHAtoms1(g1g2h0_atoms, host_atom_pool, minimum_distance, angle_force_constant, temperature) + for j, valid_h1 in enumerate(h1_eval.results.valid): + g2h0h1_atoms = g1g2h0_atoms.atoms[1:] + host_atom_pool.atoms[j] + h2_eval = EvaluateHAtoms2(g2h0h1_atoms, host_atom_pool, minimum_distance, angle_force_constant, temperature) + + if any(h2_eval.ressults.valid): + d1_avgs = [d.mean() for d in h2_eval.results.distances1] + d2_avgs = [d.mean() for d in h2_eval.results.distances2] + dsum_avgs = d1_avgs + d2_avgs + k = dsum_avgs.argmin() + + return host_atom_pool.atoms[[i, j, k]].ix + return None def find_boresch_restraint( @@ -424,6 +622,8 @@ def find_boresch_restraint( rmsf_custoff: unit.Quantity = 0.1 * unit.nanometer, host_min_distance: unit.Quantity = 1 * unit.nanometer, host_max_distance: unit.Quantity = 3 * unit.nanometer, + angle_force_constant: unit.Quantity = 83.68 * unit.kilojoule_per_mole / unit.radians**2, + temperature: unit.Quantity = 298.15 * unit.kelvin, ) -> BoreschRestraintGeometry: """ Find suitable Boresch-style restraints between a host and guest entity. @@ -448,11 +648,11 @@ def find_boresch_restraint( # In this case assume the picked atoms were intentional / representative # of the input and go with it guest_ag = u.select_atoms[guest_idxs] - guest_angle = (at.ix for at in guest_ag.atoms[guest_restraint_atom_idxs]) + guest_angle = [at.ix for at in guest_ag.atoms[guest_restraint_atom_idxs]] host_ag = u.select_atoms[host_idxs] - host_angle = (at.ix for at in host_ag.atoms[host_restraint_atoms_idxs]) + host_angle = [at.ix for at in host_ag.atoms[host_restraint_atoms_idxs]] # TODO sort out the return on this - return BoreschRestraintGeometry(...) + return BoreschRestraintGeometry(host_atoms=host_angle, guest_atoms=guest_angle) if (guest_restraint_atoms_idxs is not None) ^ (host_restraint_atoms_idxs is not None): # This is not an intended outcome, crash out here @@ -463,7 +663,7 @@ def find_boresch_restraint( ) raise ValueError(errmsg) - # Fetch the guest angles + # 1. Fetch the guest angles guest_angles = get_guest_atom_candidates( topology=topology, trajectory=trajectory, @@ -472,9 +672,14 @@ def find_boresch_restraint( rmsf_cutoff=rmsf_cutoff, ) + if len(guest_angles) != 0: + errmsg = "No suitable ligand atoms found for the restraint." + raise ValueError(errmsg) + + # We pick the first angle / ligand atom set as the one to use guest_angle = guest_angles[0] - # Fetch the host atom pool + # 2. We next fetch the host atom pool host_pool = get_host_atom_candidates( topology=topology, trajectory=trajectory, @@ -487,15 +692,21 @@ def find_boresch_restraint( max_distance=host_max_distance, ) - # Get the guest angle atomgroup - guest_ag = u.atoms[list(guest_angle)] + # 3. We then loop through the guest angles to find suitable host atoms + for guest_angle in guest_angles: + host_angle = _find_host_angle( + g0g1g2_atoms=u.atoms[list(guest_angle)], + host_atom_pool=u.atoms[host_pool], + minimum_distance=0.5 * unit.nanometer, + angle_force_constant=angle_force_constant, + temperature=temperature, + ) + # continue if it's empty, otherwise stop + if host_angle is not None: + break - # Find all suitable H2 idxs - h2_idxs = [] - for i in host_pool: - host2_at = u.atoms[i] - pos = np.vstack((at.position, guest_ag.positions)) - angle = calc_angles(pos[0], pos[1], pos[2], box=u.dimensions) * unit.radians - dihed = calc_dihedrals(pos[0], pos[1], pos[2], pos[3], box=u.dimensions) * unit.radians - collinear = is_collinear(positions, [0, 1, 2, 3]) + if host_angle is None: + errmsg = "No suitable host atoms could be found" + raise ValueError(errmsg) + return BoreschRestraintGeometry(host_atoms=host_angle, guest_atoms=guest_angle) diff --git a/openfe/protocols/openmm_utils/restraints/geometry/utils.py b/openfe/protocols/openmm_utils/restraints/geometry/utils.py index a1f983621..96a665ee5 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/utils.py +++ b/openfe/protocols/openmm_utils/restraints/geometry/utils.py @@ -208,7 +208,7 @@ def is_collinear(positions, atoms, dimensions=None, threshold=0.9): return result -def check_angle_energy( +def check_angle_not_flat( angle: FloatQuantity["radians"], force_constant: ANGLE_FRC_CONSTANT_TYPE = DEFAULT_ANGLE_FRC_CONSTANT, temperature: FloatQuantity["kelvin"] = 298.15 * unit.kelvin, @@ -228,7 +228,7 @@ def check_angle_energy( Returns ------- bool - If the angle is less than 10 kT from 0 or pi radians + False if the angle is less than 10 kT from 0 or pi radians Note ---- @@ -280,7 +280,10 @@ def check_dihedral_bounds( def check_angular_variance( - angles: ArrayQuantity["radians"], width: FloatQuantity["radians"] + angles: ArrayQuantity["radians"], width: FloatQuantity["radians"], + upper_bound: FloatQuantity['radians'], + lower_bound: FloatQuantity['radians'], + width: FloatQuantity['radians'], ) -> bool: """ Check that the variance of a list of ``angles`` does not exceed @@ -290,6 +293,10 @@ def check_angular_variance( ---------- angles : ArrayLike[unit.Quantity] An array of angles in units compatible with radians. + upper_bound: FloatQuantity['radians'] + The upper bound in the angle range. + lower_bound: FloatQuantity['radians'] + The lower bound in the angle range. width : unit.Quantity The width to check the variance against, in units compatible with radians. @@ -299,8 +306,11 @@ def check_angular_variance( ``True`` if the variance of the angles is less than the width. """ - array = angles.to("radians").m - variance = circvar(array) + variance = circvar( + angles.to("radians").m, + high=upper_bound.to("radians").m, + low=lower_bound.to("radians").m + ) return not (variance * unit.radians > width) From 9171d3992e2418f1453e0888cf48ec2661c2bc25 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Sat, 14 Dec 2024 00:12:14 +0000 Subject: [PATCH 25/29] autoformatting --- .../restraints/geometry/boresch.py | 133 ++++++++++-------- .../openmm_utils/restraints/geometry/utils.py | 24 ++-- 2 files changed, 93 insertions(+), 64 deletions(-) diff --git a/openfe/protocols/openmm_utils/restraints/geometry/boresch.py b/openfe/protocols/openmm_utils/restraints/geometry/boresch.py index 41bd3b3b7..844817f52 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/boresch.py +++ b/openfe/protocols/openmm_utils/restraints/geometry/boresch.py @@ -38,9 +38,10 @@ class BoreschRestraintGeometry(HostGuestRestraintGeometry): Where HX represents the X index of ``host_atoms`` and GX the X index of ``guest_atoms``. """ + def get_bond_distance( self, - topology: Union[str, pathlib.Path, openmm.app.Topology], + topology: Union[str, pathlib.Path, openmm.app.Topology], coordinates: Union[str, pathlib.Path, npt.NDArray], ) -> unit.Quantity: """ @@ -57,7 +58,7 @@ def get_bond_distance( topology, coordinates, format=_get_mda_coord_format(coordinates), - topology_format=_get_mda_topology_format(topology) + topology_format=_get_mda_topology_format(topology), ) at1 = u.atoms[host_atoms[0]] at2 = u.atoms[guest_atoms[0]] @@ -84,7 +85,7 @@ def get_angles( topology, coordinates, format=_get_mda_coord_format(coordinates), - topology_format=_get_mda_topology_format(topology) + topology_format=_get_mda_topology_format(topology), ) at1 = u.atoms[host_atoms[1]] at2 = u.atoms[host_atoms[0]] @@ -118,7 +119,7 @@ def get_dihedrals( topology, coordinates, format=_get_mda_coord_format(coordinates), - topology_format=_get_mda_topology_format(topology) + topology_format=_get_mda_topology_format(topology), ) at1 = u.atoms[host_atoms[2]] at2 = u.atoms[host_atoms[1]] @@ -292,7 +293,7 @@ def get_guest_atom_candidates( topology, coordinates, format=_get_mda_coord_format(coordinates), - topology_format=_get_mda_topology_format(topology) + topology_format=_get_mda_topology_format(topology), ) ligand_ag = u.atoms[guest_idxs] @@ -302,7 +303,7 @@ def get_guest_atom_candidates( u.trajectory[-1] # forward to the last frame # 1. Get the pool of atoms to work with - atom_pool = _get_atom_pool(rdmol: Chem.Mol, rmsf: npt.NDArray) + atom_pool = _get_atom_pool(rdmol, rmsf) if atom_pool is None: # We don't have enough atoms so we raise an error @@ -324,11 +325,7 @@ def get_guest_atom_candidates( angle_ag = ligand_ag.atoms[list(angle)] if not is_collinear(ligand_ag.positions, angle, u.dimensions): angles_list.append( - ( - angle_ag.atoms[0].ix, - angle_ag.atoms[1].ix, - angle_ag.atoms[2].ix - ) + (angle_ag.atoms[0].ix, angle_ag.atoms[1].ix, angle_ag.atoms[2].ix) ) return angles_list @@ -378,7 +375,7 @@ def get_host_atom_candidates( topology, coordinates, format=_get_mda_coord_format(coordinates), - topology_format=_get_mda_topology_format(topology) + topology_format=_get_mda_topology_format(topology), ) protein_ag1 = u.atoms[host_idxs] @@ -419,6 +416,7 @@ class EvaluateHostAtoms1(AnalysisBase): temperature : unit.Quantity The system temperature in Kelvin """ + def __init__( self, reference, @@ -426,7 +424,7 @@ def __init__( minimum_distance, angle_force_constant, temperature, - **kwargs + **kwargs, ): super().__init__(reference.universe.trajectory, **kwargs) @@ -436,20 +434,14 @@ def __init__( self.reference = reference self.host_atom_pool = host_atom_pool - self.minimum_distance = minimum_distance.to('angstrom').m + self.minimum_distance = minimum_distance.to("angstrom").m self.angle_force_constant = angle_force_constant self.temperature = temperature def _prepare(self): - self.results.distances = np.zeros( - (len(self.host_atom_pool), self.n_frames) - ) - self.results.angles = np.zeros( - (len(self.host_atom_pool), self.n_frames) - ) - self.results.dihedrals = np.zeros( - (len(self.host_atom_pool), self.n_frames) - ) + self.results.distances = np.zeros((len(self.host_atom_pool), self.n_frames)) + self.results.angles = np.zeros((len(self.host_atom_pool), self.n_frames)) + self.results.dihedrals = np.zeros((len(self.host_atom_pool), self.n_frames)) self.results.collinear = np.empty( (len(self.host_atom_pool), self.n_frames), dtype=bool, @@ -477,7 +469,7 @@ def _single_frame(self): self.reference.atoms[0].position, self.reference.atoms[1].position, self.reference.atoms[2].position, - box=self.reference.dimensions + box=self.reference.dimensions, ) collinear = is_collinear( positions=np.vstack((at.position, self.reference.positions)), @@ -490,9 +482,7 @@ def _single_frame(self): def _conclude(self): for i, at in enumerate(self.host_atom_pool): - distance_bounds = all( - self.results.distances[i] > self.minimum_distance - ) + distance_bounds = all(self.results.distances[i] > self.minimum_distance) mean_angle = circmean(self.results.angles[i], high=np.pi, low=0) angle_bounds = check_angle_not_flat( angle=mean_angle * unit.radians, @@ -514,21 +504,24 @@ def _conclude(self): width=5.23 * unit.radians, ) not_collinear = not all(self.results.collinear[i]) - if all([distance_bounds, angle_bounds, angle_variance, dihed_bounds, dihed_variance, not_collinear]): + if all( + [ + distance_bounds, + angle_bounds, + angle_variance, + dihed_bounds, + dihed_variance, + not_collinear, + ] + ): self.results.valid[i] = True class EvaluateHostAtoms2(EvaluateH21Atoms): def _prepare(self): - self.results.distances1 = np.zeros( - (len(self.host_atom_pool), self.n_frames) - ) - self.results.ditances2 = np.zeros( - (len(self.host_atom_pool), self.n_frames) - ) - self.results.dihedrals = np.zeros( - (len(self.host_atom_pool), self.n_frames) - ) + self.results.distances1 = np.zeros((len(self.host_atom_pool), self.n_frames)) + self.results.ditances2 = np.zeros((len(self.host_atom_pool), self.n_frames)) + self.results.dihedrals = np.zeros((len(self.host_atom_pool), self.n_frames)) self.results.collinear = np.empty( (len(self.host_atom_pool), self.n_frames), dtype=bool, @@ -555,7 +548,7 @@ def _single_frame(self): self.reference.atoms[0].position, self.reference.atoms[1].position, self.reference.atoms[2].position, - box=self.reference.dimensions + box=self.reference.dimensions, ) collinear = is_collinear( positions=np.vstack((at.position, self.reference.positions)), @@ -568,12 +561,8 @@ def _single_frame(self): def _conclude(self): for i, at in enumerate(self.host_atom_pool): - distance1_bounds = all( - self.results.distances1[i] > self.minimum_distance - ) - distance2_bounds = all( - self.results.distances2[i] > self.minimum_distance - ) + distance1_bounds = all(self.results.distances1[i] > self.minimum_distance) + distance2_bounds = all(self.results.distances2[i] > self.minimum_distance) mean_dihed = circmean(self.results.dihedrals[i], high=np.pi, low=-np.pi) dihed_bounds = check_dihedral_bounds(mean_dihed) dihed_variance = check_angular_variance( @@ -583,21 +572,49 @@ def _conclude(self): width=5.23 * unit.radians, ) not_collinear = not all(self.results.collinear[i]) - if all([distance1_bounds, distance2_bounds, dihed_bounds, dihed_variance, not_collinear]): + if all( + [ + distance1_bounds, + distance2_bounds, + dihed_bounds, + dihed_variance, + not_collinear, + ] + ): self.results.valid[i] = True -def _find_host_angle(g0g1g2_atoms, host_atom_pool, minimum_distance, angle_force_constant, temperature): - h0_eval = EvaluateHAtoms1(g0g1g2_atoms, host_atom_pool, minimum_distance, angle_force_constant, temperature) +def _find_host_angle( + g0g1g2_atoms, host_atom_pool, minimum_distance, angle_force_constant, temperature +): + h0_eval = EvaluateHAtoms1( + g0g1g2_atoms, + host_atom_pool, + minimum_distance, + angle_force_constant, + temperature, + ) h0_eval.run() for i, valid_h0 in enumerate(h0_eval.results.valid): if valid_h0: g1g2h0_atoms = g0g1g2_atoms.atoms[1:] + host_atom_pool.atoms[i] - h1_eval = EvaluateHAtoms1(g1g2h0_atoms, host_atom_pool, minimum_distance, angle_force_constant, temperature) + h1_eval = EvaluateHAtoms1( + g1g2h0_atoms, + host_atom_pool, + minimum_distance, + angle_force_constant, + temperature, + ) for j, valid_h1 in enumerate(h1_eval.results.valid): g2h0h1_atoms = g1g2h0_atoms.atoms[1:] + host_atom_pool.atoms[j] - h2_eval = EvaluateHAtoms2(g2h0h1_atoms, host_atom_pool, minimum_distance, angle_force_constant, temperature) + h2_eval = EvaluateHAtoms2( + g2h0h1_atoms, + host_atom_pool, + minimum_distance, + angle_force_constant, + temperature, + ) if any(h2_eval.ressults.valid): d1_avgs = [d.mean() for d in h2_eval.results.distances1] @@ -616,13 +633,15 @@ def find_boresch_restraint( guest_idxs: list[int], host_idxs: list[int], guest_restraint_atom_idxs: Optional[list[int]] = None, - host_restraint_atoms_idxs Optional[list[int]] = None, - host_selection: str = 'all', + host_restraint_atoms_idxs: Optional[list[int]] = None, + host_selection: str = "all", dssp_filter: bool = False, rmsf_custoff: unit.Quantity = 0.1 * unit.nanometer, host_min_distance: unit.Quantity = 1 * unit.nanometer, host_max_distance: unit.Quantity = 3 * unit.nanometer, - angle_force_constant: unit.Quantity = 83.68 * unit.kilojoule_per_mole / unit.radians**2, + angle_force_constant: unit.Quantity = ( + 83.68 * unit.kilojoule_per_mole / unit.radians**2 + ), temperature: unit.Quantity = 298.15 * unit.kelvin, ) -> BoreschRestraintGeometry: """ @@ -640,11 +659,13 @@ def find_boresch_restraint( topology, coordinates, format=_get_mda_coord_format(coordinates), - topology_format=_get_mda_topology_format(topology) + topology_format=_get_mda_topology_format(topology), ) u.trajectory[-1] # Work with the final frame - if (guest_restraint_atoms_idxs is not None) and (host_restraint_atoms_idxs is not None): + if (guest_restraint_atoms_idxs is not None) and ( + host_restraint_atoms_idxs is not None + ): # In this case assume the picked atoms were intentional / representative # of the input and go with it guest_ag = u.select_atoms[guest_idxs] @@ -654,7 +675,9 @@ def find_boresch_restraint( # TODO sort out the return on this return BoreschRestraintGeometry(host_atoms=host_angle, guest_atoms=guest_angle) - if (guest_restraint_atoms_idxs is not None) ^ (host_restraint_atoms_idxs is not None): + if (guest_restraint_atoms_idxs is not None) ^ ( + host_restraint_atoms_idxs is not None + ): # This is not an intended outcome, crash out here errmsg = ( "both ``guest_restraints_atoms_idxs`` and ``host_restraint_atoms_idxs`` " diff --git a/openfe/protocols/openmm_utils/restraints/geometry/utils.py b/openfe/protocols/openmm_utils/restraints/geometry/utils.py index 96a665ee5..91ce61e8b 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/utils.py +++ b/openfe/protocols/openmm_utils/restraints/geometry/utils.py @@ -32,7 +32,9 @@ ANGLE_FRC_CONSTANT_TYPE = FloatQuantity["unit.kilojoule_per_mole / unit.radians**2"] -def _get_mda_coord_format(coordinates: Union[str, npt.NDArray]) -> Optional[MemoryReader]: +def _get_mda_coord_format( + coordinates: Union[str, npt.NDArray] +) -> Optional[MemoryReader]: """ Helper to set the coordinate format to MemoryReader if the coordinates are an NDArray. @@ -51,7 +53,10 @@ def _get_mda_coord_format(coordinates: Union[str, npt.NDArray]) -> Optional[Memo else: return None -def _get_mda_topology_format(topology: Union[str, openmm.app.Topology]) -> Optional[str]: + +def _get_mda_topology_format( + topology: Union[str, openmm.app.Topology] +) -> Optional[str]: """ Helper to set the topology format to OPENMMTOPOLOGY if the topology is an openmm.app.Topology. @@ -59,7 +64,7 @@ def _get_mda_topology_format(topology: Union[str, openmm.app.Topology]) -> Optio Parameters ---------- topology : Union[str, openmm.app.Topology] - + Returns ------- @@ -177,7 +182,7 @@ def is_collinear(positions, atoms, dimensions=None, threshold=0.9): atoms : list[int] The indices of the atoms to test. dimensions : Optional[npt.NDArray] - The dimensions of the system to minimize vectors. + The dimensions of the system to minimize vectors. threshold : float Atoms are not collinear if their sequential vector separation dot products are less than ``threshold``. Default 0.9. @@ -280,10 +285,11 @@ def check_dihedral_bounds( def check_angular_variance( - angles: ArrayQuantity["radians"], width: FloatQuantity["radians"], - upper_bound: FloatQuantity['radians'], - lower_bound: FloatQuantity['radians'], - width: FloatQuantity['radians'], + angles: ArrayQuantity["radians"], + width: FloatQuantity["radians"], + upper_bound: FloatQuantity["radians"], + lower_bound: FloatQuantity["radians"], + width: FloatQuantity["radians"], ) -> bool: """ Check that the variance of a list of ``angles`` does not exceed @@ -309,7 +315,7 @@ def check_angular_variance( variance = circvar( angles.to("radians").m, high=upper_bound.to("radians").m, - low=lower_bound.to("radians").m + low=lower_bound.to("radians").m, ) return not (variance * unit.radians > width) From 033a1e44aadfbf330531b0443842b6f21095e6ec Mon Sep 17 00:00:00 2001 From: IAlibay Date: Sat, 14 Dec 2024 01:21:39 +0000 Subject: [PATCH 26/29] various fixes --- .../restraints/geometry/__init__.py | 4 + .../openmm_utils/restraints/geometry/base.py | 5 - .../restraints/geometry/boresch.py | 170 +++++++++++------- .../restraints/geometry/flatbottom.py | 16 +- .../restraints/geometry/harmonic.py | 33 ++-- .../openmm_utils/restraints/geometry/utils.py | 56 +++--- .../restraints/openmm/omm_forces.py | 2 +- .../restraints/openmm/omm_restraints.py | 108 +++++++---- 8 files changed, 236 insertions(+), 158 deletions(-) diff --git a/openfe/protocols/openmm_utils/restraints/geometry/__init__.py b/openfe/protocols/openmm_utils/restraints/geometry/__init__.py index e69de29bb..1c1b4c56a 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/__init__.py +++ b/openfe/protocols/openmm_utils/restraints/geometry/__init__.py @@ -0,0 +1,4 @@ +from .base import BaseRestraintGeometry +from .harmonic import DistanceRestraintGeometry +from .flatbottom import FlatBottomDistanceGeometry +from .boresch import BoreschRestraintGeometry diff --git a/openfe/protocols/openmm_utils/restraints/geometry/base.py b/openfe/protocols/openmm_utils/restraints/geometry/base.py index 21a714cde..5db9225ac 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/base.py +++ b/openfe/protocols/openmm_utils/restraints/geometry/base.py @@ -10,10 +10,6 @@ import abc from pydantic.v1 import BaseModel, validator -from openff.units import unit -import MDAnalysis as mda -from MDAnalysis.lib.distances import calc_bonds, calc_angles - class BaseRestraintGeometry(BaseModel, abc.ABC): class Config: @@ -47,4 +43,3 @@ def positive_idxs(cls, v): errmsg = "negative indices passed" raise ValueError(errmsg) return v - diff --git a/openfe/protocols/openmm_utils/restraints/geometry/boresch.py b/openfe/protocols/openmm_utils/restraints/geometry/boresch.py index 844817f52..0d6806611 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/boresch.py +++ b/openfe/protocols/openmm_utils/restraints/geometry/boresch.py @@ -7,21 +7,34 @@ ---- * Add relevant duecredit entries. """ -import abc import pathlib -from pydantic.v1 import BaseModel, validator +from typing import Union, Optional, Iterable from rdkit import Chem +import openmm from openff.units import unit import MDAnalysis as mda -from MDANalysis.analysis.base import AnalysisBase +from MDAnalysis.analysis.base import AnalysisBase from MDAnalysis.lib.distances import calc_bonds, calc_angles, calc_dihedrals import numpy as np import numpy.typing as npt from scipy.stats import circmean from .base import HostGuestRestraintGeometry +from .utils import ( + _get_mda_coord_format, + _get_mda_topology_format, + get_aromatic_rings, + get_heavy_atom_idxs, + get_central_atom_idx, + is_collinear, + check_angular_variance, + check_dihedral_bounds, + check_angle_not_flat, + FindHostAtoms, + get_local_rmsf +) class BoreschRestraintGeometry(HostGuestRestraintGeometry): @@ -60,8 +73,8 @@ def get_bond_distance( format=_get_mda_coord_format(coordinates), topology_format=_get_mda_topology_format(topology), ) - at1 = u.atoms[host_atoms[0]] - at2 = u.atoms[guest_atoms[0]] + at1 = u.atoms[self.host_atoms[0]] + at2 = u.atoms[self.guest_atoms[0]] bond = calc_bonds(at1.position, at2.position, u.atoms.dimensions) # convert to float so we avoid having a np.float64 return float(bond) * unit.angstrom @@ -87,10 +100,10 @@ def get_angles( format=_get_mda_coord_format(coordinates), topology_format=_get_mda_topology_format(topology), ) - at1 = u.atoms[host_atoms[1]] - at2 = u.atoms[host_atoms[0]] - at3 = u.atoms[guest_atoms[0]] - at4 = u.atoms[guest_atoms[1]] + at1 = u.atoms[self.host_atoms[1]] + at2 = u.atoms[self.host_atoms[0]] + at3 = u.atoms[self.guest_atoms[0]] + at4 = u.atoms[self.guest_atoms[1]] angleA = calc_angles( at1.position, at2.position, at3.position, u.atoms.dimensions @@ -121,21 +134,24 @@ def get_dihedrals( format=_get_mda_coord_format(coordinates), topology_format=_get_mda_topology_format(topology), ) - at1 = u.atoms[host_atoms[2]] - at2 = u.atoms[host_atoms[1]] - at3 = u.atoms[host_atoms[0]] - at4 = u.atoms[guest_atoms[0]] - at5 = u.atoms[guest_atoms[1]] - at6 = u.atoms[guest_atoms[2]] + at1 = u.atoms[self.host_atoms[2]] + at2 = u.atoms[self.host_atoms[1]] + at3 = u.atoms[self.host_atoms[0]] + at4 = u.atoms[self.guest_atoms[0]] + at5 = u.atoms[self.guest_atoms[1]] + at6 = u.atoms[self.guest_atoms[2]] dihA = calc_dihedrals( - at1.position, at2.position, at3.position, at4.position, u.atoms.dimensions + at1.position, at2.position, at3.position, at4.position, + box=u.dimensions ) dihB = calc_dihedrals( - at2.position, at3.position, at4.position, at5.position, u.atoms.dimensions + at2.position, at3.position, at4.position, at5.position, + box=u.dimensions ) dihC = calc_dihedrals( - at3.position, at4.position, at5.position, at6.position, u.atoms.dimensions + at3.position, at4.position, at5.position, at6.position, + box=u.dimensions ) return dihA, dihB, dihC @@ -213,7 +229,11 @@ def _get_bonded_angles_from_pool( return angles -def _get_atom_pool(rdmol: Chem.Mol, rmsf: npt.NDArray) -> Optional[set[int]]: +def _get_atom_pool( + rdmol: Chem.Mol, + rmsf: npt.NDArray, + rmsf_cutoff: unit.Quantity +) -> Optional[set[int]]: """ Filter atoms based on rmsf & rings, defaulting to heavy atoms if there are not enough. @@ -291,8 +311,8 @@ def get_guest_atom_candidates( """ u = mda.Universe( topology, - coordinates, - format=_get_mda_coord_format(coordinates), + trajectory, + format=_get_mda_coord_format(trajectory), topology_format=_get_mda_topology_format(topology), ) @@ -314,18 +334,22 @@ def get_guest_atom_candidates( center = get_central_atom_idx(rdmol) # 3. Sort the atom pool based on their distance from the center - sorted_anchor_pool = _sort_by_distance_from_atom(rdmol, center, anchor_pool) + sorted_atom_pool = _sort_by_distance_from_atom(rdmol, center, atom_pool) # 4. Get a list of probable angles angles_list = [] - for atom in sorted_anchor_pool: - angles = _get_bonded_angles_from_pool(rdmol, atom, sorted_anchor_pool) - for angle in _angles: + for atom in sorted_atom_pool: + angles = _get_bonded_angles_from_pool(rdmol, atom, sorted_atom_pool) + for angle in angles: # Check that the angle is at least not collinear angle_ag = ligand_ag.atoms[list(angle)] if not is_collinear(ligand_ag.positions, angle, u.dimensions): angles_list.append( - (angle_ag.atoms[0].ix, angle_ag.atoms[1].ix, angle_ag.atoms[2].ix) + ( + angle_ag.atoms[0].ix, + angle_ag.atoms[1].ix, + angle_ag.atoms[2].ix + ) ) return angles_list @@ -373,26 +397,29 @@ def get_host_atom_candidates( """ u = mda.Universe( topology, - coordinates, - format=_get_mda_coord_format(coordinates), + trajectory, + format=_get_mda_coord_format(trajectory), topology_format=_get_mda_topology_format(topology), ) - protein_ag1 = u.atoms[host_idxs] - protein_ag2 = protein_ag.select_atoms(protein_selection) + host_ag1 = u.atoms[host_idxs] + host_ag2 = host_ag1.select_atoms(host_selection) # 0. TODO: implement DSSP filter - # Should be able to just call MDA's DSSP method, but will need to catch an exception + # Should be able to just call MDA's DSSP method + # but will need to catch an exception if dssp_filter: - raise NotImplementedError("DSSP filtering is not currently implemented") + raise NotImplementedError( + "DSSP filtering is not currently implemented" + ) # 1. Get the RMSF & filter - rmsf = get_local_rmsf(sub_protein_ag) - protein_ag3 = sub_protein_ag.atoms[rmsf[heavy_atoms] < rmsf_cutoff] + rmsf = get_local_rmsf(host_ag2) + protein_ag3 = host_ag2.atoms[rmsf < rmsf_cutoff] # 2. Search of atoms within the min/max cutoff atom_finder = FindHostAtoms( - protein_ag3, u.atoms[l1_idx], min_search_distance, max_search_distance + protein_ag3, u.atoms[l1_idx], min_distance, max_distance ) atom_finder.run() return atom_finder.results.host_idxs @@ -439,9 +466,15 @@ def __init__( self.temperature = temperature def _prepare(self): - self.results.distances = np.zeros((len(self.host_atom_pool), self.n_frames)) - self.results.angles = np.zeros((len(self.host_atom_pool), self.n_frames)) - self.results.dihedrals = np.zeros((len(self.host_atom_pool), self.n_frames)) + self.results.distances = np.zeros( + (len(self.host_atom_pool), self.n_frames) + ) + self.results.angles = np.zeros( + (len(self.host_atom_pool), self.n_frames) + ) + self.results.dihedrals = np.zeros( + (len(self.host_atom_pool), self.n_frames) + ) self.results.collinear = np.empty( (len(self.host_atom_pool), self.n_frames), dtype=bool, @@ -517,7 +550,7 @@ def _conclude(self): self.results.valid[i] = True -class EvaluateHostAtoms2(EvaluateH21Atoms): +class EvaluateHostAtoms2(EvaluateHostAtoms1): def _prepare(self): self.results.distances1 = np.zeros((len(self.host_atom_pool), self.n_frames)) self.results.ditances2 = np.zeros((len(self.host_atom_pool), self.n_frames)) @@ -554,8 +587,8 @@ def _single_frame(self): positions=np.vstack((at.position, self.reference.positions)), dimensions=self.reference.dimensions, ) - self.results.distances1[i][self._frame_index] = distance - self.results.distances2[i][self._frame_index] = angle + self.results.distances1[i][self._frame_index] = distance1 + self.results.distances2[i][self._frame_index] = distance2 self.results.dihedrals[i][self._frame_index] = dihedral self.results.collinear[i][self._frame_index] = collinear @@ -585,9 +618,13 @@ def _conclude(self): def _find_host_angle( - g0g1g2_atoms, host_atom_pool, minimum_distance, angle_force_constant, temperature + g0g1g2_atoms, + host_atom_pool, + minimum_distance, + angle_force_constant, + temperature ): - h0_eval = EvaluateHAtoms1( + h0_eval = EvaluateHostAtoms1( g0g1g2_atoms, host_atom_pool, minimum_distance, @@ -599,7 +636,7 @@ def _find_host_angle( for i, valid_h0 in enumerate(h0_eval.results.valid): if valid_h0: g1g2h0_atoms = g0g1g2_atoms.atoms[1:] + host_atom_pool.atoms[i] - h1_eval = EvaluateHAtoms1( + h1_eval = EvaluateHostAtoms1( g1g2h0_atoms, host_atom_pool, minimum_distance, @@ -608,7 +645,7 @@ def _find_host_angle( ) for j, valid_h1 in enumerate(h1_eval.results.valid): g2h0h1_atoms = g1g2h0_atoms.atoms[1:] + host_atom_pool.atoms[j] - h2_eval = EvaluateHAtoms2( + h2_eval = EvaluateHostAtoms2( g2h0h1_atoms, host_atom_pool, minimum_distance, @@ -632,11 +669,11 @@ def find_boresch_restraint( guest_rdmol: Chem.Mol, guest_idxs: list[int], host_idxs: list[int], - guest_restraint_atom_idxs: Optional[list[int]] = None, + guest_restraint_atoms_idxs: Optional[list[int]] = None, host_restraint_atoms_idxs: Optional[list[int]] = None, host_selection: str = "all", dssp_filter: bool = False, - rmsf_custoff: unit.Quantity = 0.1 * unit.nanometer, + rmsf_cutoff: unit.Quantity = 0.1 * unit.nanometer, host_min_distance: unit.Quantity = 1 * unit.nanometer, host_max_distance: unit.Quantity = 3 * unit.nanometer, angle_force_constant: unit.Quantity = ( @@ -657,32 +694,35 @@ def find_boresch_restraint( """ u = mda.Universe( topology, - coordinates, - format=_get_mda_coord_format(coordinates), + trajectory, + format=_get_mda_coord_format(trajectory), topology_format=_get_mda_topology_format(topology), ) u.trajectory[-1] # Work with the final frame - if (guest_restraint_atoms_idxs is not None) and ( - host_restraint_atoms_idxs is not None - ): - # In this case assume the picked atoms were intentional / representative - # of the input and go with it + if (guest_restraint_atoms_idxs is not None) and (host_restraint_atoms_idxs is not None): # fmt: skip + # In this case assume the picked atoms were intentional / + # representative of the input and go with it guest_ag = u.select_atoms[guest_idxs] - guest_angle = [at.ix for at in guest_ag.atoms[guest_restraint_atom_idxs]] + guest_angle = [ + at.ix for at in guest_ag.atoms[guest_restraint_atoms_idxs] + ] host_ag = u.select_atoms[host_idxs] - host_angle = [at.ix for at in host_ag.atoms[host_restraint_atoms_idxs]] + host_angle = [ + at.ix for at in host_ag.atoms[host_restraint_atoms_idxs] + ] # TODO sort out the return on this - return BoreschRestraintGeometry(host_atoms=host_angle, guest_atoms=guest_angle) + return BoreschRestraintGeometry( + host_atoms=host_angle, guest_atoms=guest_angle + ) - if (guest_restraint_atoms_idxs is not None) ^ ( - host_restraint_atoms_idxs is not None - ): + if (guest_restraint_atoms_idxs is not None) ^ (host_restraint_atoms_idxs is not None): # fmt: skip # This is not an intended outcome, crash out here errmsg = ( - "both ``guest_restraints_atoms_idxs`` and ``host_restraint_atoms_idxs`` " + "both ``guest_restraints_atoms_idxs`` and " + "``host_restraint_atoms_idxs`` " "must be set or both must be None. " - f"Got {guest_restraint_atoms_idxs} and {host_atoms_restraint_atoms_idxs}" + f"Got {guest_restraint_atoms_idxs} and {host_restraint_atoms_idxs}" ) raise ValueError(errmsg) @@ -710,7 +750,7 @@ def find_boresch_restraint( l1_idx=guest_angle[0], host_selection=host_selection, dssp_filter=dssp_filter, - rmsf_cutoff=rmsf_custoff, + rmsf_cutoff=rmsf_cutoff, min_distance=host_min_distance, max_distance=host_max_distance, ) @@ -732,4 +772,6 @@ def find_boresch_restraint( errmsg = "No suitable host atoms could be found" raise ValueError(errmsg) - return BoreschRestraintGeometry(host_atoms=host_angle, guest_atoms=guest_angle) + return BoreschRestraintGeometry( + host_atoms=host_angle, guest_atoms=guest_angle + ) diff --git a/openfe/protocols/openmm_utils/restraints/geometry/flatbottom.py b/openfe/protocols/openmm_utils/restraints/geometry/flatbottom.py index c7e987736..c9007dd59 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/flatbottom.py +++ b/openfe/protocols/openmm_utils/restraints/geometry/flatbottom.py @@ -7,14 +7,15 @@ ---- * Add relevant duecredit entries. """ -import abc -from pydantic.v1 import BaseModel, validator - +import pathlib +from typing import Union, Optional import numpy as np +from openmm import app from openff.units import unit +from openff.models.types import FloatQuantity import MDAnalysis as mda from MDAnalysis.analysis.base import AnalysisBase -from MDAnalysis.lib.distances import calc_bonds, calc_angles +from MDAnalysis.lib.distances import calc_bonds from .harmonic import ( DistanceRestraintGeometry, @@ -27,7 +28,6 @@ class FlatBottomDistanceGeometry(DistanceRestraintGeometry): A geometry class for a flat bottom distance restraint between two groups of atoms. """ - well_radius: FloatQuantity["nanometer"] @@ -45,8 +45,8 @@ class COMDistanceAnalysis(AnalysisBase): _analysis_algorithm_is_parallelizable = False - def __init__(self, host_atoms, guest_atoms, search_distance, **kwargs): - super().__init__(host_atoms.universe.trajectory, **kwargs) + def __init__(self, group1, group2, **kwargs): + super().__init__(group1.universe.trajectory, **kwargs) self.ag1 = group1 self.ag2 = group2 @@ -67,7 +67,7 @@ def _conclude(self): def get_flatbottom_distance_restraint( - topology: Union[str, openmm.app.Topology], + topology: Union[str, app.Topology], trajectory: pathlib.Path, topology_format: Optional[str] = None, host_atoms: Optional[list[int]] = None, diff --git a/openfe/protocols/openmm_utils/restraints/geometry/harmonic.py b/openfe/protocols/openmm_utils/restraints/geometry/harmonic.py index 36e7a61a7..770f86bcb 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/harmonic.py +++ b/openfe/protocols/openmm_utils/restraints/geometry/harmonic.py @@ -7,12 +7,12 @@ ---- * Add relevant duecredit entries. """ -import abc -from pydantic.v1 import BaseModel, validator - +import pathlib +from typing import Union, Optional +from openmm import app from openff.units import unit import MDAnalysis as mda -from MDAnalysis.lib.distances import calc_bonds, calc_angles +from MDAnalysis.lib.distances import calc_bonds from rdkit import Chem from .base import HostGuestRestraintGeometry @@ -50,7 +50,7 @@ def _get_selection(universe, atom_list, selection): def get_distance_restraint( - topology: Union[str, openmm.app.Topology], + topology: Union[str, app.Topology], trajectory: pathlib.Path, topology_format: Optional[str] = None, host_atoms: Optional[list[int]] = None, @@ -61,34 +61,25 @@ def get_distance_restraint( u = mda.Universe(topology, trajectory, topology_format=topology_format) guest_ag = _get_selection(u, guest_atoms, guest_selection) + guest_atoms = [a.ix for a in guest_ag] host_ag = _get_selection(u, host_atoms, host_selection) + host_atoms = [a.ix for a in host_ag] - return DistanceRestraintGeometry(guest_atoms=guest_atoms, host_atoms=host_atoms) + return DistanceRestraintGeometry( + guest_atoms=guest_atoms, host_atoms=host_atoms + ) def get_molecule_centers_restraint( - topology: Union[str, openmm.app.Topology], - trajectory: pathlib.Path, molA_rdmol: Chem.Mol, molB_rdmol: Chem.Mol, molA_idxs: list[int], molB_idxs: list[int], - topology_format: Optional[str] = None, ): # We assume that the mol idxs are ordered centerA = molA_idxs[_get_central_atom_idx(molA_rdmol)] centerB = molB_idxs[_get_central_atom_idx(molB_rdmol)] - u = mda.Universe(topology, trajectory, topology_format=topology_format) - guest_ag = _get_selection( - u, - [centerA], - None, - ) - guest_ag = _get_selection( - u, - [centerB], - None, + return DistanceRestraintGeometry( + guest_atoms=[centerA], host_atoms=[centerB] ) - - return DistsanceRestraintGeometry(guest_atoms=guest_atoms, host_atoms=host_atoms) diff --git a/openfe/protocols/openmm_utils/restraints/geometry/utils.py b/openfe/protocols/openmm_utils/restraints/geometry/utils.py index 91ce61e8b..7d6906650 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/utils.py +++ b/openfe/protocols/openmm_utils/restraints/geometry/utils.py @@ -7,29 +7,26 @@ ---- * Add relevant duecredit entries. """ -import abc -from pydantic.v1 import BaseModel, validator - +from typing import Union, Optional import numpy as np import numpy.typing as npt -from scipy.stats import circvar, circmean, circstd +from scipy.stats import circvar +import openmm from openff.toolkit import Molecule as OFFMol from openff.units import unit -from openff.models.types import FloatQuantity, ArrayQuantity import networkx as nx from rdkit import Chem import MDAnalysis as mda from MDAnalysis.analysis.base import AnalysisBase -from MDAnalysis.analysis.rmsf import RMSF -from MDAnalysis.lib.distances import calc_bonds, calc_angles, minimize_vectors +from MDAnalysis.analysis.rms import RMSF +from MDAnalysis.lib.distances import minimize_vectors, capped_distance from MDAnalysis.coordinates.memory import MemoryReader from openfe_analysis.transformations import Aligner, NoJump DEFAULT_ANGLE_FRC_CONSTANT = 83.68 * unit.kilojoule_per_mole / unit.radians**2 -ANGLE_FRC_CONSTANT_TYPE = FloatQuantity["unit.kilojoule_per_mole / unit.radians**2"] def _get_mda_coord_format( @@ -97,7 +94,7 @@ def get_aromatic_rings(rdmol: Chem.Mol) -> list[tuple[int, ...]]: aromatic_rings = [] for ring in ringinfo.AtomRings(): - if all(a in aroms for a in ring): + if all(a in arom_idxs for a in ring): aromatic_rings.append(ring) return aromatic_rings @@ -190,13 +187,14 @@ def is_collinear(positions, atoms, dimensions=None, threshold=0.9): Returns ------- result : bool - Returns True if any sequential pair of vectors is collinear; False otherwise. + Returns True if any sequential pair of vectors is collinear; + False otherwise. Notes ----- Originally from Yank. """ - results = False + result = False for i in range(len(atoms) - 2): v1 = minimize_vectors( positions[atoms[i + 1], :] - positions[atoms[i], :], @@ -214,9 +212,9 @@ def is_collinear(positions, atoms, dimensions=None, threshold=0.9): def check_angle_not_flat( - angle: FloatQuantity["radians"], - force_constant: ANGLE_FRC_CONSTANT_TYPE = DEFAULT_ANGLE_FRC_CONSTANT, - temperature: FloatQuantity["kelvin"] = 298.15 * unit.kelvin, + angle: unit.Quantity, + force_constant: unit.Quantity = DEFAULT_ANGLE_FRC_CONSTANT, + temperature: unit.Quantity = 298.15 * unit.kelvin, ) -> bool: """ Check whether the chosen angle is less than 10 kT from 0 or pi radians @@ -246,8 +244,8 @@ def check_angle_not_flat( RT = 8.31445985 * 0.001 * temp_kelvin # check if angle is <10kT from 0 or 180 - check1 = 0.5 * frc_const * np.power((angle - 0.0), 2) - check2 = 0.5 * frc_const * np.power((angle - np.pi), 2) + check1 = 0.5 * frc_const * np.power((angle_rads - 0.0), 2) + check2 = 0.5 * frc_const * np.power((angle_rads - np.pi), 2) ang_check_1 = check1 / RT ang_check_2 = check2 / RT if ang_check_1 < 10.0 or ang_check_2 < 10.0: @@ -256,9 +254,9 @@ def check_angle_not_flat( def check_dihedral_bounds( - dihedral: FloatQuantity["radians"], - lower_cutoff: FloatQuantity["radians"] = 2.618 * unit.radians, - upper_cutoff: FloatQuantity["radians"] = -2.618 * unit.radians, + dihedral: unit.Quantity, + lower_cutoff: unit.Quantity = 2.618 * unit.radians, + upper_cutoff: unit.Quantity = -2.618 * unit.radians, ) -> bool: """ Check that a dihedral does not exceed the bounds set by @@ -285,11 +283,10 @@ def check_dihedral_bounds( def check_angular_variance( - angles: ArrayQuantity["radians"], - width: FloatQuantity["radians"], - upper_bound: FloatQuantity["radians"], - lower_bound: FloatQuantity["radians"], - width: FloatQuantity["radians"], + angles: unit.Quantity, + upper_bound: unit.Quantity, + lower_bound: unit.Quantity, + width: unit.Quantity, ) -> bool: """ Check that the variance of a list of ``angles`` does not exceed @@ -299,12 +296,13 @@ def check_angular_variance( ---------- angles : ArrayLike[unit.Quantity] An array of angles in units compatible with radians. - upper_bound: FloatQuantity['radians'] - The upper bound in the angle range. - lower_bound: FloatQuantity['radians'] - The lower bound in the angle range. + upper_bound: unit.Quantity + The upper bound in the angle range in radians compatible units. + lower_bound: unit.Quantity + The lower bound in the angle range in radians compatible units. width : unit.Quantity - The width to check the variance against, in units compatible with radians. + The width to check the variance against, in units compatible with + radians. Returns ------- diff --git a/openfe/protocols/openmm_utils/restraints/openmm/omm_forces.py b/openfe/protocols/openmm_utils/restraints/openmm/omm_forces.py index 3ad9d0aa6..9c288515d 100644 --- a/openfe/protocols/openmm_utils/restraints/openmm/omm_forces.py +++ b/openfe/protocols/openmm_utils/restraints/openmm/omm_forces.py @@ -44,7 +44,7 @@ def get_periodic_boresch_energy_function( def get_custom_compound_bond_force( - n_particles: int = 6, energy_function: str = BORESCH_ENERGY_FUNCTION + energy_function: str, n_particles: int = 6, ): """ Return an OpenMM CustomCompoundForce diff --git a/openfe/protocols/openmm_utils/restraints/openmm/omm_restraints.py b/openfe/protocols/openmm_utils/restraints/openmm/omm_restraints.py index a3fe777d3..2b8898a22 100644 --- a/openfe/protocols/openmm_utils/restraints/openmm/omm_restraints.py +++ b/openfe/protocols/openmm_utils/restraints/openmm/omm_restraints.py @@ -15,7 +15,6 @@ * Add Periodic Torsion Boresch class """ import abc -from typing import Optional, Union, Callable import numpy as np import openmm @@ -31,11 +30,16 @@ from openff.units import unit from gufe.settings.models import SettingsBaseModel -from openfe.protocols.openmm_utils.omm_forces import ( + +from openfe.protocols.openmm_utils.restraints.geometry import ( + BaseRestraintGeometry, + DistanceRestraintGeometry, + BoreschRestraintGeometry +) +from .omm_forces import ( get_custom_compound_bond_force, add_force_in_separate_group, get_boresch_energy_function, - get_periodic_boresch_energy_function, ) @@ -49,7 +53,8 @@ class RestraintParameterState(GlobalParameterState): ---------- parameters_name_suffix : Optional[str] If specified, the state will control a modified version of the parameter - ``lambda_restraints_{parameters_name_suffix}` instead of just ``lambda_restraints``. + ``lambda_restraints_{parameters_name_suffix}` instead of just + ``lambda_restraints``. lambda_restraints : Optional[float] The strength of the restraint. If defined, must be between 0 and 1. @@ -66,7 +71,8 @@ class RestraintParameterState(GlobalParameterState): def lambda_restraints(self, instance, new_value): if new_value is not None and not (0.0 <= new_value <= 1.0): errmsg = ( - "lambda_restraints must be between 0.0 and 1.0, " f"got {new_value}" + "lambda_restraints must be between 0.0 and 1.0 " + f"and got {new_value}" ) raise ValueError(errmsg) # Not crashing out on None to match upstream behaviour @@ -101,11 +107,19 @@ def _verify_geometry(self, geometry): pass @abc.abstractmethod - def add_force(self, thermodynamic_state: ThermodynamicState, geometry: BaseRestraintGeometry): + def add_force( + self, + thermodynamic_state: ThermodynamicState, + geometry: BaseRestraintGeometry + ): pass @abc.abstractmethod - def get_standard_state_correction(self, thermodynamic_state: ThermodynamicState, geometry: BaseRestraintGeometry): + def get_standard_state_correction( + self, + thermodynamic_state: ThermodynamicState, + geometry: BaseRestraintGeometry + ): pass @abc.abstractmethod @@ -118,8 +132,8 @@ def _verify_geometry(self, geometry: BaseRestraintGeometry): if len(geometry.host_atoms) != 1 or len(geometry.guest_atoms) != 1: errmsg = ( "host_atoms and guest_atoms must only include a single index " - f"each, got {len(host_atoms)} and " - f"{len(guest_atoms)} respectively." + f"each, got {len(geometry.host_atoms)} and " + f"{len(geometry.guest_atoms)} respectively." ) raise ValueError(errmsg) super()._verify_geometry(geometry) @@ -127,19 +141,25 @@ def _verify_geometry(self, geometry: BaseRestraintGeometry): class BaseRadiallySymmetricRestraintForce(BaseHostGuestRestraints): def _verify_inputs(self) -> None: - if not isinstance(self.settings, BaseDistanceRestraintSettings): + if not isinstance(self.settings, DistanceRestraintSettings): errmsg = f"Incorrect settings type {self.settings} passed through" raise ValueError(errmsg) - def _verify_geometry(self, geometry: DistanceRestraintGeometry) + def _verify_geometry(self, geometry: DistanceRestraintGeometry): if not isinstance(geometry, DistanceRestraintGeometry): errmsg = f"Incorrect geometry class type {geometry} passed through" raise ValueError(errmsg) - def add_force(self, thermodynamic_state: ThermodynamicState, geometry: DistanceRestraintGeometry) -> None: + def add_force( + self, + thermodynamic_state: ThermodynamicState, + geometry: DistanceRestraintGeometry + ) -> None: self._verify_geometry(geometry) force = self._get_force(geometry) - force.setUsesPeriodicBoundaryConditions(thermodynamic_state.is_periodic) + force.setUsesPeriodicBoundaryConditions( + thermodynamic_state.is_periodic + ) # Note .system is a call to get_system() so it's returning a copy system = thermodynamic_state.system add_force_in_separate_group(system, force) @@ -162,9 +182,13 @@ def _get_force(self, geometry: DistanceRestraintGeometry): raise NotImplementedError("only implemented in child classes") -class HarmonicBondRestraint(BaseRadiallySymmetricRestraintForce, SingleBondMixin): +class HarmonicBondRestraint( + BaseRadiallySymmetricRestraintForce, SingleBondMixin +): def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: - spring_constant = to_openmm(self.settings.spring_constant).value_in_unit_system(omm_unit.md_unit_system) + spring_constant = to_openmm( + self.settings.spring_constant + ).value_in_unit_system(omm_unit.md_unit_system) return HarmonicRestraintBondForce( spring_constant=spring_constant, restrained_atom_index1=geometry.host_atoms[0], @@ -173,10 +197,16 @@ def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: ) -class FlatBottomBondRestraint(BaseRadiallySymmetricRestraintForce, SingleBondMixin): +class FlatBottomBondRestraint( + BaseRadiallySymmetricRestraintForce, SingleBondMixin +): def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: - spring_constant = to_openmm(self.settings.spring_constant).value_in_unit_system(omm_unit.md_unit_system) - well_radius = to_openmm(geometry.well_radius).value_in_unit_system(omm_unit.md_unit_system) + spring_constant = to_openmm( + self.settings.spring_constant + ).value_in_unit_system(omm_unit.md_unit_system) + well_radius = to_openmm( + geometry.well_radius + ).value_in_unit_system(omm_unit.md_unit_system) return FlatBottomRestraintBondForce( spring_constant=spring_constant, well_radius=well_radius, @@ -188,7 +218,9 @@ def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: class CentroidHarmonicRestraint(BaseRadiallySymmetricRestraintForce): def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: - spring_constant = to_openmm(self.settings.spring_constant).value_in_unit_system(omm_unit.md_unit_system) + spring_constant = to_openmm( + self.settings.spring_constant + ).value_in_unit_system(omm_unit.md_unit_system) return HarmonicRestraintForce( spring_constant=spring_constant, restrained_atom_index1=geometry.host_atoms, @@ -199,9 +231,13 @@ def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: class CentroidFlatBottomRestraint(BaseRadiallySymmetricRestraintForce): def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: - spring_constant = to_openmm(self.settings.spring_constant).value_in_unit_system(omm_unit.md_unit_system) - well_radius = to_openmm(geometry.well_radius).value_in_unit_system(omm_unit.md_unit_system) - return FlatBottomRestraintBondForce( + spring_constant = to_openmm( + self.settings.spring_constant + ).value_in_unit_system(omm_unit.md_unit_system) + well_radius = to_openmm( + geometry.well_radius + ).value_in_unit_system(omm_unit.md_unit_system) + return FlatBottomRestraintForce( spring_constant=spring_constant, well_radius=well_radius, restrained_atom_index1=geometry.host_atoms, @@ -221,10 +257,16 @@ def _verify_geometry(self, geometry: BoreschRestraintGeometry): errmsg = f"Incorrect geometry class type {geometry} passed through" raise ValueError(errmsg) - def add_force(self, thermodynamic_state: ThermodynamicState, geometry: BoreschRestraintGeometry) -> None: - _verify_geometry(geometry) + def add_force( + self, + thermodynamic_state: ThermodynamicState, + geometry: BoreschRestraintGeometry + ) -> None: + self._verify_geometry(geometry) force = self._get_force(geometry) - force.setUsesPeriodicBoundaryConditions(thermodynamic_state.is_periodic) + force.setUsesPeriodicBoundaryConditions( + thermodynamic_state.is_periodic + ) # Note .system is a call to get_system() so it's returning a copy system = thermodynamic_state.system add_force_in_separate_group(system, force) @@ -236,7 +278,7 @@ def _get_force(self, geometry: BoreschRestraintGeometry) -> openmm.Force: ) force = get_custom_compound_bond_force( - n_particles=6, energy_function=efunc + energy_function=efunc, n_particles=6, ) param_values = [] @@ -256,7 +298,9 @@ def _get_force(self, geometry: BoreschRestraintGeometry) -> openmm.Force: 'phi_C0': geometry.phi_C0, } for key, val in parameter_dict.items(): - param_values.append(to_openmm(val).value_in_unit_system(omm_unit.md_unit_system)) + param_values.append( + to_openmm(val).value_in_unit_system(omm_unit.md_unit_system) + ) force.addPerBondParameter(key) force.addGlobalParameter(self.controlling_parameter_name, 1.0) @@ -264,7 +308,9 @@ def _get_force(self, geometry: BoreschRestraintGeometry) -> openmm.Force: return force def get_standard_state_correction( - self, thermodynamic_state: ThermodynamicState, geometry: BoreschRestraintGeometry + self, + thermodynamic_state: ThermodynamicState, + geometry: BoreschRestraintGeometry ) -> unit.Quantity: self._verify_geometry(geometry) @@ -279,14 +325,16 @@ def get_standard_state_correction( # restraint energies K_r = self.settings.K_r.to('kilojoule_per_mole / nm ** 2') K_thetaA = self.settings.K_thetaA.to('kilojoule_per_mole / radians ** 2') - k_thetaB = self.settings.K_thetaB.to('kilojoule_per_mole / radians ** 2') + K_thetaB = self.settings.K_thetaB.to('kilojoule_per_mole / radians ** 2') K_phiA = self.settings.K_phiA.to('kilojoule_per_mole / radians ** 2') K_phiB = self.settings.K_phiB.to('kilojoule_per_mole / radians ** 2') K_phiC = self.settings.K_phiC.to('kilojoule_per_mole / radians ** 2') numerator1 = 8.0 * (np.pi**2) * StandardV denum1 = (r_aA0**2) * sin_thetaA0 * sin_thetaB0 - numerator2 = np.sqrt(K_r * K_thetaA * K_thetaB * K_phiA * K_phiB * K_phiC) + numerator2 = np.sqrt( + K_r * K_thetaA * K_thetaB * K_phiA * K_phiB * K_phiC + ) denum2 = (2.0 * np.pi * kt)**3 dG = -kt * np.log((numerator1/denum1) * (numerator2/denum2)) From fe1308ee4beff476e6d5e0cb967991fa538cb1c1 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Sat, 14 Dec 2024 02:00:19 +0000 Subject: [PATCH 27/29] docstring drive --- .../restraints/openmm/omm_forces.py | 13 +++++ .../restraints/openmm/omm_restraints.py | 48 ++++++++++++++++++- 2 files changed, 59 insertions(+), 2 deletions(-) diff --git a/openfe/protocols/openmm_utils/restraints/openmm/omm_forces.py b/openfe/protocols/openmm_utils/restraints/openmm/omm_forces.py index 9c288515d..52cbbec98 100644 --- a/openfe/protocols/openmm_utils/restraints/openmm/omm_forces.py +++ b/openfe/protocols/openmm_utils/restraints/openmm/omm_forces.py @@ -14,6 +14,19 @@ def get_boresch_energy_function( control_parameter: str, ) -> str: + """ + Return a Boresch-style energy function for a CustomCompoundForce. + + Parameters + ---------- + control_parameter : str + A string for the lambda scaling control parameter + + Returns + ------- + str + The energy function string. + """ energy_function = ( f"{control_parameter} * E; " "E = (K_r/2)*(distance(p3,p4) - r_aA0)^2 " diff --git a/openfe/protocols/openmm_utils/restraints/openmm/omm_restraints.py b/openfe/protocols/openmm_utils/restraints/openmm/omm_restraints.py index 2b8898a22..18f9f2f34 100644 --- a/openfe/protocols/openmm_utils/restraints/openmm/omm_restraints.py +++ b/openfe/protocols/openmm_utils/restraints/openmm/omm_restraints.py @@ -96,14 +96,21 @@ def __init__( controlling_parameter_name: str = "lambda_restraints", ): self.settings = restraint_settings + self.controlling_parameter_name = controlling_parameter_name self._verify_settings() @abc.abstractmethod def _verify_settings(self): + """ + Method for validating the settings passed on object construction. + """ pass @abc.abstractmethod def _verify_geometry(self, geometry): + """ + Method for validating that the geometry object passed is correct. + """ pass @abc.abstractmethod @@ -112,6 +119,18 @@ def add_force( thermodynamic_state: ThermodynamicState, geometry: BaseRestraintGeometry ): + """ + Method for in-place adding a force to the System of a + ThermodynamicState. + + Parameters + ---------- + thermodymamic_state : ThermodynamicState + The ThermodynamicState with a System to inplace modify with the + new force. + geometry : BaseRestraintGeometry + A geometry object defining the restraint parameters. + """ pass @abc.abstractmethod @@ -119,7 +138,24 @@ def get_standard_state_correction( self, thermodynamic_state: ThermodynamicState, geometry: BaseRestraintGeometry - ): + ) -> unit.Quantity: + """ + Get the standard state correction for the Force. + + Parameters + ---------- + thermodymamic_state : ThermodynamicState + The ThermodynamicState with a System to inplace modify with the + new force. + geometry : BaseRestraintGeometry + A geometry object defining the restraint parameters. + + Returns + ------- + correction : unit.Quantity + The standard state correction free energy in units compatible + with kilojoule per mole. + """ pass @abc.abstractmethod @@ -304,7 +340,15 @@ def _get_force(self, geometry: BoreschRestraintGeometry) -> openmm.Force: force.addPerBondParameter(key) force.addGlobalParameter(self.controlling_parameter_name, 1.0) - force.addBond(geometry.host_atoms + geometry.guest_atoms, param_values) + atoms = [ + geometry.host_atoms[2], + geometry.host_atoms[1], + geometry.host_atoms[0], + geometry.guest_atoms[0], + geometry.guest_atoms[1], + geometry.guest_atoms[2], + ] + force.addBond(atoms, param_values) return force def get_standard_state_correction( From d71b9616055f95c4211bc460afead1446430115e Mon Sep 17 00:00:00 2001 From: IAlibay Date: Sun, 15 Dec 2024 22:59:44 +0000 Subject: [PATCH 28/29] Migrate to restraint_utils --- .../restraints/geometry/harmonic.py | 85 ---- .../__init__.py | 0 .../geometry/__init__.py | 0 .../geometry/base.py | 3 + .../geometry/boresch.py | 289 +++++++++++--- .../geometry/flatbottom.py | 51 ++- .../restraint_utils/geometry/harmonic.py | 144 +++++++ .../geometry/utils.py | 83 ++-- .../openmm/__init__.py | 0 .../openmm/omm_forces.py | 22 +- .../openmm/omm_restraints.py | 371 ++++++++++++++++-- openfe/tests/protocols/restraints/__init__.py | 0 .../restraints/test_geometry_base.py | 25 ++ .../restraints/test_omm_restraints.py | 31 ++ .../restraints/test_openmm_forces.py | 115 ++++++ 15 files changed, 1000 insertions(+), 219 deletions(-) delete mode 100644 openfe/protocols/openmm_utils/restraints/geometry/harmonic.py rename openfe/protocols/{openmm_utils/restraints => restraint_utils}/__init__.py (100%) rename openfe/protocols/{openmm_utils/restraints => restraint_utils}/geometry/__init__.py (100%) rename openfe/protocols/{openmm_utils/restraints => restraint_utils}/geometry/base.py (94%) rename openfe/protocols/{openmm_utils/restraints => restraint_utils}/geometry/boresch.py (74%) rename openfe/protocols/{openmm_utils/restraints => restraint_utils}/geometry/flatbottom.py (58%) create mode 100644 openfe/protocols/restraint_utils/geometry/harmonic.py rename openfe/protocols/{openmm_utils/restraints => restraint_utils}/geometry/utils.py (88%) rename openfe/protocols/{openmm_utils/restraints => restraint_utils}/openmm/__init__.py (100%) rename openfe/protocols/{openmm_utils/restraints => restraint_utils}/openmm/omm_forces.py (85%) rename openfe/protocols/{openmm_utils/restraints => restraint_utils}/openmm/omm_restraints.py (50%) create mode 100644 openfe/tests/protocols/restraints/__init__.py create mode 100644 openfe/tests/protocols/restraints/test_geometry_base.py create mode 100644 openfe/tests/protocols/restraints/test_omm_restraints.py create mode 100644 openfe/tests/protocols/restraints/test_openmm_forces.py diff --git a/openfe/protocols/openmm_utils/restraints/geometry/harmonic.py b/openfe/protocols/openmm_utils/restraints/geometry/harmonic.py deleted file mode 100644 index 770f86bcb..000000000 --- a/openfe/protocols/openmm_utils/restraints/geometry/harmonic.py +++ /dev/null @@ -1,85 +0,0 @@ -# This code is part of OpenFE and is licensed under the MIT license. -# For details, see https://github.com/OpenFreeEnergy/openfe -""" -Restraint Geometry classes - -TODO ----- -* Add relevant duecredit entries. -""" -import pathlib -from typing import Union, Optional -from openmm import app -from openff.units import unit -import MDAnalysis as mda -from MDAnalysis.lib.distances import calc_bonds -from rdkit import Chem - -from .base import HostGuestRestraintGeometry -from .utils import _get_central_atom_idx - - -class DistanceRestraintGeometry(HostGuestRestraintGeometry): - """ - A geometry class for a distance restraint between two groups of atoms. - """ - - def get_distance(self, topology, coordinates) -> unit.Quantity: - u = mda.Universe(topology, coordinates) - ag1 = u.atoms[self.host_atoms] - ag2 = u.atoms[self.guest_atoms] - bond = calc_bonds( - ag1.center_of_mass(), ag2.center_of_mass(), box=u.atoms.dimensions - ) - # convert to float so we avoid having a np.float64 - return float(bond) * unit.angstrom - - -def _get_selection(universe, atom_list, selection): - if atom_list is None: - if selection is None: - raise ValueError( - "one of either the atom lists or selections must be defined" - ) - - ag = universe.select_atoms(selection) - else: - ag = universe.atoms[atom_list] - - return ag - - -def get_distance_restraint( - topology: Union[str, app.Topology], - trajectory: pathlib.Path, - topology_format: Optional[str] = None, - host_atoms: Optional[list[int]] = None, - guest_atoms: Optional[list[int]] = None, - host_selection: Optional[str] = None, - guest_selection: Optional[str] = None, -) -> DistanceRestraintGeometry: - u = mda.Universe(topology, trajectory, topology_format=topology_format) - - guest_ag = _get_selection(u, guest_atoms, guest_selection) - guest_atoms = [a.ix for a in guest_ag] - host_ag = _get_selection(u, host_atoms, host_selection) - host_atoms = [a.ix for a in host_ag] - - return DistanceRestraintGeometry( - guest_atoms=guest_atoms, host_atoms=host_atoms - ) - - -def get_molecule_centers_restraint( - molA_rdmol: Chem.Mol, - molB_rdmol: Chem.Mol, - molA_idxs: list[int], - molB_idxs: list[int], -): - # We assume that the mol idxs are ordered - centerA = molA_idxs[_get_central_atom_idx(molA_rdmol)] - centerB = molB_idxs[_get_central_atom_idx(molB_rdmol)] - - return DistanceRestraintGeometry( - guest_atoms=[centerA], host_atoms=[centerB] - ) diff --git a/openfe/protocols/openmm_utils/restraints/__init__.py b/openfe/protocols/restraint_utils/__init__.py similarity index 100% rename from openfe/protocols/openmm_utils/restraints/__init__.py rename to openfe/protocols/restraint_utils/__init__.py diff --git a/openfe/protocols/openmm_utils/restraints/geometry/__init__.py b/openfe/protocols/restraint_utils/geometry/__init__.py similarity index 100% rename from openfe/protocols/openmm_utils/restraints/geometry/__init__.py rename to openfe/protocols/restraint_utils/geometry/__init__.py diff --git a/openfe/protocols/openmm_utils/restraints/geometry/base.py b/openfe/protocols/restraint_utils/geometry/base.py similarity index 94% rename from openfe/protocols/openmm_utils/restraints/geometry/base.py rename to openfe/protocols/restraint_utils/geometry/base.py index 5db9225ac..0ca6ae200 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/base.py +++ b/openfe/protocols/restraint_utils/geometry/base.py @@ -12,6 +12,9 @@ class BaseRestraintGeometry(BaseModel, abc.ABC): + """ + A base class for a restraint geometry. + """ class Config: arbitrary_types_allowed = True diff --git a/openfe/protocols/openmm_utils/restraints/geometry/boresch.py b/openfe/protocols/restraint_utils/geometry/boresch.py similarity index 74% rename from openfe/protocols/openmm_utils/restraints/geometry/boresch.py rename to openfe/protocols/restraint_utils/geometry/boresch.py index 0d6806611..6e740f48d 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/boresch.py +++ b/openfe/protocols/restraint_utils/geometry/boresch.py @@ -14,6 +14,7 @@ import openmm from openff.units import unit +from openff.models.types import FloatQuantity import MDAnalysis as mda from MDAnalysis.analysis.base import AnalysisBase from MDAnalysis.lib.distances import calc_bonds, calc_angles, calc_dihedrals @@ -51,107 +52,137 @@ class BoreschRestraintGeometry(HostGuestRestraintGeometry): Where HX represents the X index of ``host_atoms`` and GX the X index of ``guest_atoms``. """ + r_aA0: FloatQuantity['nanometer'] + """ + The equilibrium distance between H0 and G0. + """ + theta_A0: FloatQuantity['radians'] + """ + The equilibrium angle value between H1, H0, and G0. + """ + theta_B0: FloatQuantity['radians'] + """ + The equilibrium angle value between H0, G0, and G1. + """ + phi_A0: FloatQuantity['radians'] + """ + The equilibrium dihedral value between H2, H1, H0, and G0. + """ + phi_B0: FloatQuantity['radians'] + + """ + The equilibrium dihedral value between H1, H0, G0, and G1. + """ + phi_C0: FloatQuantity['radians'] + + """ + The equilibrium dihedral value between H0, G0, G1, and G2. + """ def get_bond_distance( self, - topology: Union[str, pathlib.Path, openmm.app.Topology], - coordinates: Union[str, pathlib.Path, npt.NDArray], + universe: mda.Universe, ) -> unit.Quantity: """ Get the H0 - G0 distance. Parameters ---------- - topology : Union[str, openmm.app.Topology] - coordinates : Union[str, npt.NDArray] - A coordinate file or NDArray in frame-atom-coordinate - order in Angstrom. + universe : mda.Universe + A Universe representing the system of interest. + + Returns + ------- + bond : unit.Quantity + The H0-G0 distance. """ - u = mda.Universe( - topology, - coordinates, - format=_get_mda_coord_format(coordinates), - topology_format=_get_mda_topology_format(topology), + at1 = universe.atoms[self.host_atoms[0]] + at2 = universe.atoms[self.guest_atoms[0]] + bond = calc_bonds( + at1.position, + at2.position, + box=universe.atoms.dimensions ) - at1 = u.atoms[self.host_atoms[0]] - at2 = u.atoms[self.guest_atoms[0]] - bond = calc_bonds(at1.position, at2.position, u.atoms.dimensions) # convert to float so we avoid having a np.float64 return float(bond) * unit.angstrom def get_angles( self, - topology: Union[str, pathlib.Path, openmm.app.Topology], - coordinates: Union[str, pathlib.Path, npt.NDArray], - ) -> unit.Quantity: + universe: mda.Universe, + ) -> tuple[unit.Quantity, unit.Quantity]: """ Get the H1-H0-G0, and H0-G0-G1 angles. Parameters ---------- - topology : Union[str, openmm.app.Topology] - coordinates : Union[str, npt.NDArray] - A coordinate file or NDArray in frame-atom-coordinate - order in Angstrom. + universe : mda.Universe + A Universe representing the system of interest. + + Returns + ------- + angleA : unit.Quantity + The H1-H0-G0 angle. + angleB : unit.Quantity + The H0-G0-G1 angle. """ - u = mda.Universe( - topology, - coordinates, - format=_get_mda_coord_format(coordinates), - topology_format=_get_mda_topology_format(topology), - ) - at1 = u.atoms[self.host_atoms[1]] - at2 = u.atoms[self.host_atoms[0]] - at3 = u.atoms[self.guest_atoms[0]] - at4 = u.atoms[self.guest_atoms[1]] + at1 = universe.atoms[self.host_atoms[1]] + at2 = universe.atoms[self.host_atoms[0]] + at3 = universe.atoms[self.guest_atoms[0]] + at4 = universe.atoms[self.guest_atoms[1]] angleA = calc_angles( - at1.position, at2.position, at3.position, u.atoms.dimensions + at1.position, + at2.position, + at3.position, + box=universe.atoms.dimensions ) angleB = calc_angles( - at2.position, at3.position, at4.position, u.atoms.dimensions + at2.position, + at3.position, + at4.position, + box=universe.atoms.dimensions ) return angleA, angleB def get_dihedrals( self, - topology: Union[str, pathlib.Path, openmm.app.Topology], - coordinates: Union[str, pathlib.Path, npt.NDArray], - ) -> unit.Quantity: + universe: mda.Universe, + ) -> tuple[unit.Quantity, unit.Quantity, unit.Quantity]: """ Get the H2-H1-H0-G0, H1-H0-G0-G1, and H0-G0-G1-G2 dihedrals. Parameters ---------- - topology : Union[str, openmm.app.Topology] - coordinates : Union[str, npt.NDArray] - A coordinate file or NDArray in frame-atom-coordinate - order in Angstrom. + universe : mda.Universe + A Universe representing the system of interest. + + Returns + ------- + dihA : unit.Quantity + The H2-H1-H0-G0 angle. + dihB : unit.Quantity + The H1-H0-G0-G1 angle. + dihC : unit.Quantity + The H0-G0-G1-G2 angle. """ - u = mda.Universe( - topology, - coordinates, - format=_get_mda_coord_format(coordinates), - topology_format=_get_mda_topology_format(topology), - ) - at1 = u.atoms[self.host_atoms[2]] - at2 = u.atoms[self.host_atoms[1]] - at3 = u.atoms[self.host_atoms[0]] - at4 = u.atoms[self.guest_atoms[0]] - at5 = u.atoms[self.guest_atoms[1]] - at6 = u.atoms[self.guest_atoms[2]] + at1 = universe.atoms[self.host_atoms[2]] + at2 = universe.atoms[self.host_atoms[1]] + at3 = universe.atoms[self.host_atoms[0]] + at4 = universe.atoms[self.guest_atoms[0]] + at5 = universe.atoms[self.guest_atoms[1]] + at6 = universe.atoms[self.guest_atoms[2]] dihA = calc_dihedrals( at1.position, at2.position, at3.position, at4.position, - box=u.dimensions + box=universe.dimensions ) dihB = calc_dihedrals( at2.position, at3.position, at4.position, at5.position, - box=u.dimensions + box=universe.dimensions ) dihC = calc_dihedrals( at3.position, at4.position, at5.position, at6.position, - box=u.dimensions + box=universe.dimensions ) return dihA, dihB, dihC @@ -307,7 +338,7 @@ def get_guest_atom_candidates( TODO ---- - Remember to update the RDMol with the last frame positions. + Should the RDMol have a specific frame position? """ u = mda.Universe( topology, @@ -663,6 +694,66 @@ def _find_host_angle( return None +def _get_restraint_distances( + atomgroup: mda.AtomGroup +) -> tuple[unit.Quantity]: + """ + Get the bond, angle, and dihedral distances for an input atomgroup + defining the six atoms for a Boresch-like restraint. + + The atoms must be in the order of H0, H1, H2, G0, G1, G2. + + Parameters + ---------- + atomgroup : mda.AtomGroup + An AtomGroup defining the restrained atoms in order. + + Returns + ------- + bond : unit.Quantity + The H0-G0 bond value. + angle1 : unit.Quantity + The H1-H0-G0 angle value. + angle2 : unit.Quantity + The H0-G0-G1 angle value. + dihed1 : unit.Quantity + The H2-H1-H0-G0 dihedral value. + dihed2 : unit.Quantity + The H1-H0-G0-G1 dihedral value. + dihed3 : unit.Quantity + The H0-G0-G1-G2 dihedral value. + """ + + bond = calc_bonds( + atomgroup.atoms[0].position, + atomgroup.atoms[3], + box=atomgroup.dimensions + ) + + angles = [] + for idx_set in [[1, 0, 3], [0, 3, 4]]: + angle = calc_angles( + atomgroup.atoms[idx_set[0]].position, + atomgroup.atoms[idx_set[1]].position, + atomgroup.atoms[idx_set[2]].position, + box=atomgroup.dimensions, + ) + angles.append(angle * unit.radians) + + dihedrals = [] + for idx_set in [[2, 1, 0, 3], [1, 0, 3, 4], [0, 3, 4, 5]]: + dihed = calc_dihedrals( + atomgroup.atoms[idx_set[0]].position, + atomgroup.atoms[idx_set[1]].position, + atomgroup.atoms[idx_set[2]].position, + atomgroup.atoms[idx_set[3]].position, + box=atomgroup.dimensions, + ) + dihedrals.append(dihed * unit.radians) + + return bond, angles[0], angles[1], dihedrals[0], dihedrals[1], dihedrals[2] + + def find_boresch_restraint( topology: Union[str, pathlib.Path, openmm.app.Topology], trajectory: Union[str, pathlib.Path], @@ -682,15 +773,60 @@ def find_boresch_restraint( temperature: unit.Quantity = 298.15 * unit.kelvin, ) -> BoreschRestraintGeometry: """ - Find suitable Boresch-style restraints between a host and guest entity. + Find suitable Boresch-style restraints between a host and guest entity + based on the approach of Baumann et al. [1] with some modifications. Parameters ---------- - ... + topology : Union[str, pathlib.Path, openmm.app.Topology] + A topology of the system. + trajectory : Union[str, pathlib.Path] + A path to a coordinate trajectory file. + guest_rdmol : Chem.Mol + An RDKit Mol for the guest molecule. + guest_idxs : list[int] + Indices in the topology for the guest molecule. + host_idxs : list[int] + Indices in the topology for the host molecule. + guest_restraint_atoms_idxs : Optional[list[int]] + User selected indices of the guest molecule itself (i.e. indexed + starting a 0 for the guest molecule). This overrides the + restraint search and a restraint using these indices will + be retruned. Must be defined alongside ``host_restraint_atoms_idxs``. + host_restraint_atoms_idxs : Optional[list[int]] + User selected indices of the host molecule itself (i.e. indexed + starting a 0 for the hosts molecule). This overrides the + restraint search and a restraint using these indices will + be returnned. Must be defined alongside ``guest_restraint_atoms_idxs``. + host_selection : str + An MDAnalysis selection string to sub-select the host atoms. + dssp_filter : bool + Whether or not to filter the host atoms by their secondary structure. + rmsf_cutoff : unit.Quantity + The cutoff value for atom root mean square fluction. Atoms with RMSF + values above this cutoff will be disregarded. + Must be in units compatible with nanometer. + host_min_distance : unit.Quantity + The minimum distance between any host atom and the guest G0 atom. + Must be in units compatible with nanometer. + host_max_distance : unit.Quantity + The maximum distance between any host atom and the guest G0 atom. + Must be in units compatible with nanometer. + angle_force_constant : unit.Quantity + The force constant for the G1-G0-H0 and G0-H0-H1 angles. Must be + in units compatible with kilojoule / mole / radians ** 2. + temperature : unit.Quantity + The system temperature in units compatible with Kelvin. Returns ------- - ... + BoreschRestraintGeometry + An object defining the parameters of the Boresch-like restraint. + + References + ---------- + [1] Baumann, Hannah M., et al. "Broadening the scope of binding free energy + calculations using a Separated Topologies approach." (2023). """ u = mda.Universe( topology, @@ -698,7 +834,6 @@ def find_boresch_restraint( format=_get_mda_coord_format(trajectory), topology_format=_get_mda_topology_format(topology), ) - u.trajectory[-1] # Work with the final frame if (guest_restraint_atoms_idxs is not None) and (host_restraint_atoms_idxs is not None): # fmt: skip # In this case assume the picked atoms were intentional / @@ -711,9 +846,23 @@ def find_boresch_restraint( host_angle = [ at.ix for at in host_ag.atoms[host_restraint_atoms_idxs] ] - # TODO sort out the return on this + + # Set the equilibrium values as those of the final frame + u.trajectory[-1] + atomgroup = u.atoms[host_angle + guest_angle] + bond, ang1, ang2, dih1, dih2, dih3 = _get_restraint_distances( + atomgroup + ) + return BoreschRestraintGeometry( - host_atoms=host_angle, guest_atoms=guest_angle + host_atoms=host_angle, + guest_atoms=guest_angle, + r_aA0=bond, + theta_A0=ang1, + theta_B0=ang2, + phi_A0=dih1, + phi_B0=dih2, + phi_C0=dih3 ) if (guest_restraint_atoms_idxs is not None) ^ (host_restraint_atoms_idxs is not None): # fmt: skip @@ -772,6 +921,20 @@ def find_boresch_restraint( errmsg = "No suitable host atoms could be found" raise ValueError(errmsg) + # Set the equilibrium values as those of the final frame + u.trajectory[-1] + atomgroup = u.atoms[host_angle + guest_angle] + bond, ang1, ang2, dih1, dih2, dih3 = _get_restraint_distances( + atomgroup + ) + return BoreschRestraintGeometry( - host_atoms=host_angle, guest_atoms=guest_angle + host_atoms=host_angle, + guest_atoms=guest_angle, + r_aA0=bond, + theta_A0=ang1, + theta_B0=ang2, + phi_A0=dih1, + phi_B0=dih2, + phi_C0=dih3 ) diff --git a/openfe/protocols/openmm_utils/restraints/geometry/flatbottom.py b/openfe/protocols/restraint_utils/geometry/flatbottom.py similarity index 58% rename from openfe/protocols/openmm_utils/restraints/geometry/flatbottom.py rename to openfe/protocols/restraint_utils/geometry/flatbottom.py index c9007dd59..3b4599f56 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/flatbottom.py +++ b/openfe/protocols/restraint_utils/geometry/flatbottom.py @@ -19,9 +19,10 @@ from .harmonic import ( DistanceRestraintGeometry, - _get_selection, ) +from .utils import _get_mda_topology_format, _get_mda_selection + class FlatBottomDistanceGeometry(DistanceRestraintGeometry): """ @@ -42,7 +43,6 @@ class COMDistanceAnalysis(AnalysisBase): group2 : MDANalysis.AtomGroup Atoms defining the second centroid. """ - _analysis_algorithm_is_parallelizable = False def __init__(self, group1, group2, **kwargs): @@ -68,18 +68,55 @@ def _conclude(self): def get_flatbottom_distance_restraint( topology: Union[str, app.Topology], - trajectory: pathlib.Path, - topology_format: Optional[str] = None, + trajectory: Union[str, pathlib.Path], host_atoms: Optional[list[int]] = None, guest_atoms: Optional[list[int]] = None, host_selection: Optional[str] = None, guest_selection: Optional[str] = None, padding: unit.Quantity = 0.5 * unit.nanometer, ) -> FlatBottomDistanceGeometry: - u = mda.Universe(topology, trajectory, topology_format=topology_format) + """ + Get a FlatBottomDistanceGeometry by analyzing the COM distance + change between two sets of atoms. + + The ``well_radius`` is defined as the maximum COM distance plus + ``padding``. + + Parameters + ---------- + topology : Union[str, app.Topology] + A topology defining the system. + trajectory : Union[str, pathlib.Path] + A coordinate trajectory for the system. + host_atoms : Optional[list[int]] + A list of host atoms indices. Either ``host_atoms`` or + ``host_selection`` must be defined. + guest_atoms : Optional[list[int]] + A list of guest atoms indices. Either ``guest_atoms`` or + ``guest_selection`` must be defined. + host_selection : Optional[str] + An MDAnalysis selection string to define the host atoms. + Either ``host_atoms`` or ``host_selection`` must be defined. + guest_selection : Optional[str] + An MDAnalysis selection string to define the guest atoms. + Either ``guest_atoms`` or ``guest_selection`` must be defined. + padding : unit.Quantity + A padding value to add to the ``well_radius`` definition. + Must be in units compatible with nanometers. + + Returns + ------- + FlatBottomDistanceGeometry + An object defining a flat bottom restraint geometry. + """ + u = mda.Universe( + topology, + trajectory, + topology_format=_get_mda_topology_format(topology) + ) - guest_ag = _get_selection(u, guest_atoms, guest_selection) - host_ag = _get_selection(u, host_atoms, host_selection) + guest_ag = _get_mda_selection(u, guest_atoms, guest_selection) + host_ag = _get_mda_selection(u, host_atoms, host_selection) com_dists = COMDistanceAnalysis(guest_ag, host_ag) com_dists.run() diff --git a/openfe/protocols/restraint_utils/geometry/harmonic.py b/openfe/protocols/restraint_utils/geometry/harmonic.py new file mode 100644 index 000000000..197a8bc44 --- /dev/null +++ b/openfe/protocols/restraint_utils/geometry/harmonic.py @@ -0,0 +1,144 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +Restraint Geometry classes + +TODO +---- +* Add relevant duecredit entries. +""" +import pathlib +from typing import Union, Optional +from openmm import app +from openff.units import unit +import MDAnalysis as mda +from MDAnalysis.lib.distances import calc_bonds +from rdkit import Chem + +from .base import HostGuestRestraintGeometry +from .utils import ( + get_central_atom_idx, + _get_mda_selection, + _get_mda_topology_format, +) + + +class DistanceRestraintGeometry(HostGuestRestraintGeometry): + """ + A geometry class for a distance restraint between two groups of atoms. + """ + + def get_distance(self, universe: mda.Universe) -> unit.Quantity: + """ + Get the center of mass distance between the host and guest atoms. + + Parameters + ---------- + universe : mda.Universe + A Universe representing the system of interest. + + Returns + ------- + bond : unit.Quantity + The center of mass distance between the two groups of atoms. + """ + ag1 = universe.atoms[self.host_atoms] + ag2 = universe.atoms[self.guest_atoms] + bond = calc_bonds( + ag1.center_of_mass(), + ag2.center_of_mass(), + box=universe.atoms.dimensions + ) + # convert to float so we avoid having a np.float64 + return float(bond) * unit.angstrom + + +def get_distance_restraint( + topology: Union[str, pathlib.Path, app.Topology], + trajectory: Union[str, pathlib.Path], + host_atoms: Optional[list[int]] = None, + guest_atoms: Optional[list[int]] = None, + host_selection: Optional[str] = None, + guest_selection: Optional[str] = None, +) -> DistanceRestraintGeometry: + """ + Get a DistanceRestraintGeometry between two groups of atoms. + + You can either select the groups by passing through a set of indices + or an MDAnalysis selection. + + Parameters + ---------- + topology : Union[str, pathlib.Path, app.Topology] + A path or object defining the system topology. + trajectory : Union[str, pathlib.Path] + Coordinates for the system. + host_atoms : Optional[list[int]] + A list of host atoms indices. Either ``host_atoms`` or + ``host_selection`` must be defined. + guest_atoms : Optional[list[int]] + A list of guest atoms indices. Either ``guest_atoms`` or + ``guest_selection`` must be defined. + host_selection : Optional[str] + An MDAnalysis selection string to define the host atoms. + Either ``host_atoms`` or ``host_selection`` must be defined. + guest_selection : Optional[str] + An MDAnalysis selection string to define the guest atoms. + Either ``guest_atoms`` or ``guest_selection`` must be defined. + + Returns + ------- + DistanceRestraintGeometry + An object that defines a distance restraint geometry. + """ + u = mda.Universe( + topology, + trajectory, + topology_format=_get_mda_topology_format(topology) + ) + + guest_ag = _get_mda_selection(u, guest_atoms, guest_selection) + guest_atoms = [a.ix for a in guest_ag] + host_ag = _get_mda_selection(u, host_atoms, host_selection) + host_atoms = [a.ix for a in host_ag] + + return DistanceRestraintGeometry( + guest_atoms=guest_atoms, host_atoms=host_atoms + ) + + +def get_molecule_centers_restraint( + molA_rdmol: Chem.Mol, + molB_rdmol: Chem.Mol, + molA_idxs: list[int], + molB_idxs: list[int], +): + """ + Get a DistanceRestraintGeometry between the central atoms of + two molecules. + + Parameters + ---------- + molA_rdmol : Chem.Mol + An RDKit Molecule for the first molecule. + molB_rdmol : Chem.Mol + An RDKit Molecule for the first molecule. + molA_idxs : list[int] + The indices of the first molecule in the system. Note we assume these + to be sorted in the same order as the input rdmol. + molB_idxs : list[int] + The indices of the first molecule in the system. Note we assume these + to be sorted in the same order as the input rdmol. + + Returns + ------- + DistanceRestraintGeometry + An object that defines a distance restraint geometry. + """ + # We assume that the mol idxs are ordered + centerA = molA_idxs[get_central_atom_idx(molA_rdmol)] + centerB = molB_idxs[get_central_atom_idx(molB_rdmol)] + + return DistanceRestraintGeometry( + guest_atoms=[centerA], host_atoms=[centerB] + ) diff --git a/openfe/protocols/openmm_utils/restraints/geometry/utils.py b/openfe/protocols/restraint_utils/geometry/utils.py similarity index 88% rename from openfe/protocols/openmm_utils/restraints/geometry/utils.py rename to openfe/protocols/restraint_utils/geometry/utils.py index 7d6906650..4b734b410 100644 --- a/openfe/protocols/openmm_utils/restraints/geometry/utils.py +++ b/openfe/protocols/restraint_utils/geometry/utils.py @@ -22,13 +22,59 @@ from MDAnalysis.analysis.rms import RMSF from MDAnalysis.lib.distances import minimize_vectors, capped_distance from MDAnalysis.coordinates.memory import MemoryReader +from MDAnalysis.transformations.nojump import NoJump -from openfe_analysis.transformations import Aligner, NoJump +from openfe_analysis.transformations import Aligner DEFAULT_ANGLE_FRC_CONSTANT = 83.68 * unit.kilojoule_per_mole / unit.radians**2 +def _get_mda_selection( + universe: mda.Universe, + atom_list: Optional[list[int]], + selection: Optional[str] +) -> mda.AtomGroup: + """ + Return an AtomGroup based on either a list of atom indices or an + mdanalysis string selection. + + Parameters + ---------- + universe : mda.Universe + The MDAnalysis Universe to get the AtomGroup from. + atom_list : Optional[list[int]] + A list of atom indices. + selection : Optional[str] + An MDAnalysis selection string. + + Returns + ------- + ag : mda.AtomGroup + An atom group selected from the inputs. + + Raises + ------ + ValueError + If both ``atom_list`` and ``selection`` are ``None`` + or are defined. + """ + if atom_list is None: + if selection is None: + raise ValueError( + "one of either the atom lists or selections must be defined" + ) + + ag = universe.select_atoms(selection) + else: + if selection is not None: + raise ValueError( + "both atom_list and selection cannot be defined together" + ) + ag = universe.atoms[atom_list] + return ag + + def _get_mda_coord_format( coordinates: Union[str, npt.NDArray] ) -> Optional[MemoryReader]: @@ -224,7 +270,8 @@ def check_angle_not_flat( angle : unit.Quantity The angle to check in units compatible with radians. force_constant : unit.Quantity - Force constant of the angle in units compatible with kilojoule_per_mole / radians ** 2. + Force constant of the angle in units compatible with + kilojoule_per_mole / radians ** 2. temperature : unit.Quantity The system temperature in units compatible with Kelvin. @@ -334,7 +381,6 @@ class FindHostAtoms(AnalysisBase): max_search_distance: unit.Quantity Maximum distance to filter atoms within. """ - _analysis_algorithm_is_parallelizable = False def __init__( @@ -372,34 +418,7 @@ def _conclude(self): self.results.host_idxs = np.array(self.results.host_idxs) -def find_host_atoms( - topology, trajectory, host_selection, guest_selection, cutoff -) -> mda.AtomGroup: - """ - Get an AtomGroup of the host atoms based on their distances from the guest atoms. - """ - u = mda.Universe(topology, trajectory) - - def _get_selection(selection): - """ - If it's a str, call select_atoms, if not a list of atom idxs - """ - if isinstance(selection, str): - ag = u.select_atoms(host_selection) - else: - ag = u.atoms[host_ag] - return ag - - host_ag = _get_selection(host_selection) - guest_ag = _get_selection(guest_selection) - - finder = FindHostAtoms(host_ag, guest_ag, cutoff) - finder.run() - - return u.atoms[list(finder.results.host_idxs)] - - -def get_local_rmsf(atomgroup: mda.AtomGroup): +def get_local_rmsf(atomgroup: mda.AtomGroup) -> unit.Quantity: """ Get the RMSF of an AtomGroup when aligned upon itself. @@ -416,7 +435,7 @@ def get_local_rmsf(atomgroup: mda.AtomGroup): copy_u = atomgroup.universe.copy() ag = copy_u.atoms[atomgroup.atoms.ix] - nojump = NoJump(ag) + nojump = NoJump() align = Aligner(ag) copy_u.trajectory.add_transformations(nojump, align) diff --git a/openfe/protocols/openmm_utils/restraints/openmm/__init__.py b/openfe/protocols/restraint_utils/openmm/__init__.py similarity index 100% rename from openfe/protocols/openmm_utils/restraints/openmm/__init__.py rename to openfe/protocols/restraint_utils/openmm/__init__.py diff --git a/openfe/protocols/openmm_utils/restraints/openmm/omm_forces.py b/openfe/protocols/restraint_utils/openmm/omm_forces.py similarity index 85% rename from openfe/protocols/openmm_utils/restraints/openmm/omm_forces.py rename to openfe/protocols/restraint_utils/openmm/omm_forces.py index 52cbbec98..2947c8e03 100644 --- a/openfe/protocols/openmm_utils/restraints/openmm/omm_forces.py +++ b/openfe/protocols/restraint_utils/openmm/omm_forces.py @@ -43,6 +43,20 @@ def get_boresch_energy_function( def get_periodic_boresch_energy_function( control_parameter: str, ) -> str: + """ + Return a Boresch-style energy function with a periodic torsion for a + CustomCompoundForce. + + Parameters + ---------- + control_parameter : str + A string for the lambda scaling control parameter + + Returns + ------- + str + The energy function string. + """ energy_function = ( f"{control_parameter} * E; " "E = (K_r/2)*(distance(p3,p4) - r_aA0)^2 " @@ -104,8 +118,12 @@ def add_force_in_separate_group( Mostly reproduced from `Yank `_. """ available_force_groups = set(range(32)) - for force in system.getForces(): - available_force_groups.discard(force.getForceGroup()) + for existing_force in system.getForces(): + available_force_groups.discard(existing_force.getForceGroup()) + + if len(available_force_groups) == 0: + errmsg = "No available force groups could be found" + raise ValueError(errmsg) force.setForceGroup(min(available_force_groups)) system.addForce(force) diff --git a/openfe/protocols/openmm_utils/restraints/openmm/omm_restraints.py b/openfe/protocols/restraint_utils/openmm/omm_restraints.py similarity index 50% rename from openfe/protocols/openmm_utils/restraints/openmm/omm_restraints.py rename to openfe/protocols/restraint_utils/openmm/omm_restraints.py index 18f9f2f34..c77b1cd0b 100644 --- a/openfe/protocols/openmm_utils/restraints/openmm/omm_restraints.py +++ b/openfe/protocols/restraint_utils/openmm/omm_restraints.py @@ -31,7 +31,7 @@ from gufe.settings.models import SettingsBaseModel -from openfe.protocols.openmm_utils.restraints.geometry import ( +from openfe.protocols.restraint_utils.geometry import ( BaseRestraintGeometry, DistanceRestraintGeometry, BoreschRestraintGeometry @@ -87,22 +87,20 @@ class BaseHostGuestRestraints(abc.ABC): TODO ---- - Add some examples here. + Add some developer examples here. """ def __init__( self, restraint_settings: SettingsBaseModel, - controlling_parameter_name: str = "lambda_restraints", ): self.settings = restraint_settings - self.controlling_parameter_name = controlling_parameter_name - self._verify_settings() + self._verify_inputs() @abc.abstractmethod - def _verify_settings(self): + def _verify_inputs(self): """ - Method for validating the settings passed on object construction. + Method for validating that the inputs to the class are correct. """ pass @@ -117,10 +115,11 @@ def _verify_geometry(self, geometry): def add_force( self, thermodynamic_state: ThermodynamicState, - geometry: BaseRestraintGeometry + geometry: BaseRestraintGeometry, + controlling_parameter_name: str, ): """ - Method for in-place adding a force to the System of a + Method for in-place adding the Force to the System of a ThermodynamicState. Parameters @@ -130,6 +129,8 @@ def add_force( new force. geometry : BaseRestraintGeometry A geometry object defining the restraint parameters. + controlling_parameter_name : str + The name of the controlling parameter for the Force. """ pass @@ -140,7 +141,8 @@ def get_standard_state_correction( geometry: BaseRestraintGeometry ) -> unit.Quantity: """ - Get the standard state correction for the Force. + Get the standard state correction for the Force when + applied to the input ThermodynamicState. Parameters ---------- @@ -159,11 +161,23 @@ def get_standard_state_correction( pass @abc.abstractmethod - def _get_force(self, geometry: BaseRestraintGeometry): + def _get_force( + self, + geometry: BaseRestraintGeometry, + controlling_parameter_name: str, + ): + """ + Helper method to get the relevant OpenMM Force for this + class, given an input geometry. + """ pass class SingleBondMixin: + """ + A mixin to extend geometry checks for Forces that can only hold + a single atom. + """ def _verify_geometry(self, geometry: BaseRestraintGeometry): if len(geometry.host_atoms) != 1 or len(geometry.guest_atoms) != 1: errmsg = ( @@ -176,6 +190,12 @@ def _verify_geometry(self, geometry: BaseRestraintGeometry): class BaseRadiallySymmetricRestraintForce(BaseHostGuestRestraints): + """ + A base class for all radially symmetic Forces acting between + two sets of atoms. + + Must be subclassed. + """ def _verify_inputs(self) -> None: if not isinstance(self.settings, DistanceRestraintSettings): errmsg = f"Incorrect settings type {self.settings} passed through" @@ -189,10 +209,25 @@ def _verify_geometry(self, geometry: DistanceRestraintGeometry): def add_force( self, thermodynamic_state: ThermodynamicState, - geometry: DistanceRestraintGeometry + geometry: DistanceRestraintGeometry, + controlling_parameter_name: str = "lambda_restraints", ) -> None: + """ + Method for in-place adding the Force to the System of the + given ThermodynamicState. + + Parameters + ---------- + thermodymamic_state : ThermodynamicState + The ThermodynamicState with a System to inplace modify with the + new force. + geometry : BaseRestraintGeometry + A geometry object defining the restraint parameters. + controlling_parameter_name : str + The name of the controlling parameter for the Force. + """ self._verify_geometry(geometry) - force = self._get_force(geometry) + force = self._get_force(geometry, controlling_parameter_name) force.setUsesPeriodicBoundaryConditions( thermodynamic_state.is_periodic ) @@ -206,6 +241,24 @@ def get_standard_state_correction( thermodynamic_state: ThermodynamicState, geometry: DistanceRestraintGeometry, ) -> unit.Quantity: + """ + Get the standard state correction for the Force when + applied to the input ThermodynamicState. + + Parameters + ---------- + thermodymamic_state : ThermodynamicState + The ThermodynamicState with a System to inplace modify with the + new force. + geometry : BaseRestraintGeometry + A geometry object defining the restraint parameters. + + Returns + ------- + correction : unit.Quantity + The standard state correction free energy in units compatible + with kilojoule per mole. + """ self._verify_geometry(geometry) force = self._get_force(geometry) corr = force.compute_standard_state_correction( @@ -214,14 +267,50 @@ def get_standard_state_correction( dg = corr * thermodynamic_state.kT return from_openmm(dg).to('kilojoule_per_mole') - def _get_force(self, geometry: DistanceRestraintGeometry): + def _get_force( + self, + geometry: DistanceRestraintGeometry, + controlling_parameter_name: str + ): raise NotImplementedError("only implemented in child classes") class HarmonicBondRestraint( BaseRadiallySymmetricRestraintForce, SingleBondMixin ): - def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: + """ + A class to add a harmonic restraint between two atoms + in an OpenMM system. + + The restraint is defined as a + :class:`openmmtools.forces.HarmonicRestraintBondForce`. + + Notes + ----- + * Settings must contain a ``spring_constant`` for the + Force in units compatible with kilojoule/mole. + """ + def _get_force( + self, + geometry: DistanceRestraintGeometry, + controlling_parameter_name: str, + ) -> openmm.Force: + """ + Get the HarmonicRestraintBondForce given an input geometry. + + Parameters + ---------- + geometry : DistanceRestraintGeometry + A geometry class that defines how the Force is applied. + controlling_parameter_name : str + The name of the controlling parameter for the Force. + + Returns + ------- + HarmonicRestraintBondForce + An OpenMM Force that applies a harmonic restraint between + two atoms. + """ spring_constant = to_openmm( self.settings.spring_constant ).value_in_unit_system(omm_unit.md_unit_system) @@ -229,14 +318,46 @@ def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: spring_constant=spring_constant, restrained_atom_index1=geometry.host_atoms[0], restrained_atom_index2=geometry.guest_atoms[0], - controlling_parameter_name=self.controlling_parameter_name, + controlling_parameter_name=controlling_parameter_name, ) class FlatBottomBondRestraint( BaseRadiallySymmetricRestraintForce, SingleBondMixin ): - def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: + """ + A class to add a flat bottom restraint between two atoms + in an OpenMM system. + + The restraint is defined as a + :class:`openmmtools.forces.FlatBottomRestraintBondForce`. + + Notes + ----- + * Settings must contain a ``spring_constant`` for the + Force in units compatible with kilojoule/mole. + """ + def _get_force( + self, + geometry: DistanceRestraintGeometry, + controlling_parameter_name: str, + ) -> openmm.Force: + """ + Get the FlatBottomRestraintBondForce given an input geometry. + + Parameters + ---------- + geometry : DistanceRestraintGeometry + A geometry class that defines how the Force is applied. + controlling_parameter_name : str + The name of the controlling parameter for the Force. + + Returns + ------- + FlatBottomRestraintBondForce + An OpenMM Force that applies a flat bottom restraint between + two atoms. + """ spring_constant = to_openmm( self.settings.spring_constant ).value_in_unit_system(omm_unit.md_unit_system) @@ -248,12 +369,44 @@ def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: well_radius=well_radius, restrained_atom_index1=geometry.host_atoms[0], restrained_atom_index2=geometry.guest_atoms[0], - controlling_parameter_name=self.controlling_parameter_name, + controlling_parameter_name=controlling_parameter_name, ) class CentroidHarmonicRestraint(BaseRadiallySymmetricRestraintForce): - def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: + """ + A class to add a harmonic restraint between the centroid of + two sets of atoms in an OpenMM system. + + The restraint is defined as a + :class:`openmmtools.forces.HarmonicRestraintForce`. + + Notes + ----- + * Settings must contain a ``spring_constant`` for the + Force in units compatible with kilojoule/mole. + """ + def _get_force( + self, + geometry: DistanceRestraintGeometry, + controlling_parameter_name: str, + ) -> openmm.Force: + """ + Get the HarmonicRestraintForce given an input geometry. + + Parameters + ---------- + geometry : DistanceRestraintGeometry + A geometry class that defines how the Force is applied. + controlling_parameter_name : str + The name of the controlling parameter for the Force. + + Returns + ------- + HarmonicRestraintForce + An OpenMM Force that applies a harmonic restraint between + the centroid of two sets of atoms. + """ spring_constant = to_openmm( self.settings.spring_constant ).value_in_unit_system(omm_unit.md_unit_system) @@ -261,12 +414,44 @@ def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: spring_constant=spring_constant, restrained_atom_index1=geometry.host_atoms, restrained_atom_index2=geometry.guest_atoms, - controlling_parameter_name=self.controlling_parameter_name, + controlling_parameter_name=controlling_parameter_name, ) class CentroidFlatBottomRestraint(BaseRadiallySymmetricRestraintForce): - def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: + """ + A class to add a flat bottom restraint between the centroid + of two sets of atoms in an OpenMM system. + + The restraint is defined as a + :class:`openmmtools.forces.FlatBottomRestraintForce`. + + Notes + ----- + * Settings must contain a ``spring_constant`` for the + Force in units compatible with kilojoule/mole. + """ + def _get_force( + self, + geometry: DistanceRestraintGeometry, + controlling_parameter_name: str, + ) -> openmm.Force: + """ + Get the FlatBottomRestraintForce given an input geometry. + + Parameters + ---------- + geometry : DistanceRestraintGeometry + A geometry class that defines how the Force is applied. + controlling_parameter_name : str + The name of the controlling parameter for the Force. + + Returns + ------- + FlatBottomRestraintForce + An OpenMM Force that applies a flat bottom restraint between + the centroid of two sets of atoms. + """ spring_constant = to_openmm( self.settings.spring_constant ).value_in_unit_system(omm_unit.md_unit_system) @@ -278,17 +463,80 @@ def _get_force(self, geometry: DistanceRestraintGeometry) -> openmm.Force: well_radius=well_radius, restrained_atom_index1=geometry.host_atoms, restrained_atom_index2=geometry.guest_atoms, - controlling_parameter_name=self.controlling_parameter_name, + controlling_parameter_name=controlling_parameter_name, ) class BoreschRestraint(BaseHostGuestRestraints): - def _verify_settings(self) -> None: + """ + A class to add a Boresch-like restraint between six atoms, + + The restraint is defined as a + :class:`openmmtools.forces.CustomCompoundForce` with the + following energy function: + + lambda_control_parameter * E; + E = (K_r/2)*(distance(p3,p4) - r_aA0)^2 + + (K_thetaA/2)*(angle(p2,p3,p4)-theta_A0)^2 + + (K_thetaB/2)*(angle(p3,p4,p5)-theta_B0)^2 + + (K_phiA/2)*dphi_A^2 + (K_phiB/2)*dphi_B^2 + + (K_phiC/2)*dphi_C^2; + dphi_A = dA - floor(dA/(2.0*pi)+0.5)*(2.0*pi); + dA = dihedral(p1,p2,p3,p4) - phi_A0; + dphi_B = dB - floor(dB/(2.0*pi)+0.5)*(2.0*pi); + dB = dihedral(p2,p3,p4,p5) - phi_B0; + dphi_C = dC - floor(dC/(2.0*pi)+0.5)*(2.0*pi); + dC = dihedral(p3,p4,p5,p6) - phi_C0; + + Where p1, p2, p3, p4, p5, p6 represent host atoms 2, 1, 0, + and guest atoms 0, 1, 2 respectively. + + ``lambda_control_parameter`` is a control parameter for + scaling the Force. + + ``K_r`` is defined as the bond spring constant between + p3 and p4 and must be provided in the settings in units + compatible with kilojoule / mole. + + ``r_aA0`` is the equilibrium distance of the bond between + p3 and p4. This must be provided by the Geometry class in + units compatiblle with nanometer. + + ``K_thetaA`` and ``K_thetaB`` are the spring constants for the angles + formed by (p2, p3, p4) and (p3, p4, p5). They must be provided in the + settings in units compatible with kilojoule / mole / radians**2. + + ``theta_A0`` and ``theta_B0`` are the equilibrium values for angles + (p2, p3, p4) and (p3, p4, p5). They must be provided by the + Geometry class in units compatible with radians. + + ``phi_A0``, ``phi_B0``, and ``phi_C0`` are the equilibrium constants + for the dihedrals formed by (p1, p2, p3, p4), (p2, p3, p4, p5), and + (p3, p4, p5, p6). They must be provided in the settings in units + compatible with kilojoule / mole / radians ** 2. + + ``phi_A0``, ``phi_B0``, and ``phi_C0`` are the equilibrium values + for the dihedrals formed by (p1, p2, p3, p4), (p2, p3, p4, p5), and + (p3, p4, p5, p6). They must be provided in the Geometry class in + units compatible with radians. + + + Notes + ----- + * Settings must define the ``K_r`` (d) + """ + def _verify_inputs(self) -> None: + """ + Method for validating that the geometry object is correct. + """ if not isinstance(self.settings, BoreschRestraintSettings): errmsg = f"Incorrect settings type {self.settings} passed through" raise ValueError(errmsg) def _verify_geometry(self, geometry: BoreschRestraintGeometry): + """ + Method for validating that the geometry object is correct. + """ if not isinstance(geometry, BoreschRestraintGeometry): errmsg = f"Incorrect geometry class type {geometry} passed through" raise ValueError(errmsg) @@ -296,10 +544,28 @@ def _verify_geometry(self, geometry: BoreschRestraintGeometry): def add_force( self, thermodynamic_state: ThermodynamicState, - geometry: BoreschRestraintGeometry + geometry: BoreschRestraintGeometry, + controlling_parameter_name: str, ) -> None: + """ + Method for in-place adding the Boresch CustomCompoundForce + to the System of the given ThermodynamicState. + + Parameters + ---------- + thermodymamic_state : ThermodynamicState + The ThermodynamicState with a System to inplace modify with the + new force. + geometry : BaseRestraintGeometry + A geometry object defining the restraint parameters. + controlling_parameter_name : str + The name of the controlling parameter for the Force. + """ self._verify_geometry(geometry) - force = self._get_force(geometry) + force = self._get_force( + geometry, + controlling_parameter_name, + ) force.setUsesPeriodicBoundaryConditions( thermodynamic_state.is_periodic ) @@ -308,10 +574,29 @@ def add_force( add_force_in_separate_group(system, force) thermodynamic_state.system = system - def _get_force(self, geometry: BoreschRestraintGeometry) -> openmm.Force: - efunc = get_boresch_energy_function( - self.controlling_parameter_name, - ) + def _get_force( + self, + geometry: BoreschRestraintGeometry, + controlling_parameter_name: str + ) -> openmm.CustomCompoundBondForce: + """ + Get the CustomCompoundForce with a Boresch-like energy function + given an input geometry. + + Parameters + ---------- + geometry : DistanceRestraintGeometry + A geometry class that defines how the Force is applied. + controlling_parameter_name : str + The name of the controlling parameter for the Force. + + Returns + ------- + CustomCompoundForce + An OpenMM CustomCompoundForce that applies a Boresch-like + restraint between 6 atoms. + """ + efunc = get_boresch_energy_function(controlling_parameter_name) force = get_custom_compound_bond_force( energy_function=efunc, n_particles=6, @@ -339,7 +624,7 @@ def _get_force(self, geometry: BoreschRestraintGeometry) -> openmm.Force: ) force.addPerBondParameter(key) - force.addGlobalParameter(self.controlling_parameter_name, 1.0) + force.addGlobalParameter(controlling_parameter_name, 1.0) atoms = [ geometry.host_atoms[2], geometry.host_atoms[1], @@ -356,6 +641,32 @@ def get_standard_state_correction( thermodynamic_state: ThermodynamicState, geometry: BoreschRestraintGeometry ) -> unit.Quantity: + """ + Get the standard state correction for the Boresch-like + restraint when applied to the input ThermodynamicState. + + The correction is calculated using the analytical method + as defined by Boresch et al. [1] + + Parameters + ---------- + thermodymamic_state : ThermodynamicState + The ThermodynamicState with a System to inplace modify with the + new force. + geometry : BaseRestraintGeometry + A geometry object defining the restraint parameters. + + Returns + ------- + correction : unit.Quantity + The standard state correction free energy in units compatible + with kilojoule per mole. + + References + ---------- + [1] Boresch S, Tettinger F, Leitgeb M, Karplus M. J Phys Chem B. 107:9535, 2003. + http://dx.doi.org/10.1021/jp0217839 + """ self._verify_geometry(geometry) StandardV = 1.66053928 * unit.nanometer**3 diff --git a/openfe/tests/protocols/restraints/__init__.py b/openfe/tests/protocols/restraints/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/openfe/tests/protocols/restraints/test_geometry_base.py b/openfe/tests/protocols/restraints/test_geometry_base.py new file mode 100644 index 000000000..139c57dc5 --- /dev/null +++ b/openfe/tests/protocols/restraints/test_geometry_base.py @@ -0,0 +1,25 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe + +import pytest + +from openfe.protocols.restraint_utils.geometry.base import ( + HostGuestRestraintGeometry +) + + +def test_hostguest_geometry(): + """ + A very basic will it build test. + """ + geom = HostGuestRestraintGeometry(guest_atoms=[1, 2, 3], host_atoms=[4]) + + assert isinstance(geom, HostGuestRestraintGeometry) + + +def test_hostguest_positiveidxs_validator(): + """ + Check that the validator is working as intended. + """ + with pytest.raises(ValueError, match="negative indices passed"): + geom = HostGuestRestraintGeometry(guest_atoms=[-1, 1], host_atoms=[0]) diff --git a/openfe/tests/protocols/restraints/test_omm_restraints.py b/openfe/tests/protocols/restraints/test_omm_restraints.py new file mode 100644 index 000000000..0e346f9c5 --- /dev/null +++ b/openfe/tests/protocols/restraints/test_omm_restraints.py @@ -0,0 +1,31 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe + +import pytest + +from openfe.protocols.restraint_utils.openmm.omm_restraints import ( + RestraintParameterState, +) + + +def test_parameter_state_default(): + param_state = RestraintParameterState() + assert param_state.lambda_restraints is None + + +@pytest.mark.parametrize('suffix', [None, 'foo']) +@pytest.mark.parametrize('lambda_var', [0, 0.5, 1.0]) +def test_parameter_state_suffix(suffix, lambda_var): + param_state = RestraintParameterState( + parameters_name_suffix=suffix, lambda_restraints=lambda_var + ) + + if suffix is not None: + param_name = f'lambda_restraints_{suffix}' + else: + param_name = 'lambda_restraints' + + assert getattr(param_state, param_name) == lambda_var + assert len(param_state._parameters.keys()) == 1 + assert param_state._parameters[param_name] == lambda_var + assert param_state._parameters_name_suffix == suffix diff --git a/openfe/tests/protocols/restraints/test_openmm_forces.py b/openfe/tests/protocols/restraints/test_openmm_forces.py new file mode 100644 index 000000000..cd2a7f21e --- /dev/null +++ b/openfe/tests/protocols/restraints/test_openmm_forces.py @@ -0,0 +1,115 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe + +import pytest +import numpy as np +import openmm +from openfe.protocols.restraint_utils.openmm.omm_forces import ( + get_boresch_energy_function, + get_periodic_boresch_energy_function, + get_custom_compound_bond_force, + add_force_in_separate_group, +) + + +@pytest.mark.parametrize('param', ['foo', 'bar']) +def test_boresch_energy_function(param): + """ + Base regression test for the energy function + """ + fn = get_boresch_energy_function(param) + assert fn == ( + f"{param} * E; " + "E = (K_r/2)*(distance(p3,p4) - r_aA0)^2 " + "+ (K_thetaA/2)*(angle(p2,p3,p4)-theta_A0)^2 + (K_thetaB/2)*(angle(p3,p4,p5)-theta_B0)^2 " + "+ (K_phiA/2)*dphi_A^2 + (K_phiB/2)*dphi_B^2 + (K_phiC/2)*dphi_C^2; " + "dphi_A = dA - floor(dA/(2.0*pi)+0.5)*(2.0*pi); dA = dihedral(p1,p2,p3,p4) - phi_A0; " + "dphi_B = dB - floor(dB/(2.0*pi)+0.5)*(2.0*pi); dB = dihedral(p2,p3,p4,p5) - phi_B0; " + "dphi_C = dC - floor(dC/(2.0*pi)+0.5)*(2.0*pi); dC = dihedral(p3,p4,p5,p6) - phi_C0; " + f"pi = {np.pi}; " + ) + + +@pytest.mark.parametrize('param', ['foo', 'bar']) +def test_periodic_boresch_energy_function(param): + """ + Base regression test for the energy function + """ + fn = get_periodic_boresch_energy_function(param) + assert fn == ( + f"{param} * E; " + "E = (K_r/2)*(distance(p3,p4) - r_aA0)^2 " + "+ (K_thetaA/2)*(angle(p2,p3,p4)-theta_A0)^2 + (K_thetaB/2)*(angle(p3,p4,p5)-theta_B0)^2 " + "+ (K_phiA/2)*uphi_A + (K_phiB/2)*uphi_B + (K_phiC/2)*uphi_C; " + "uphi_A = (1-cos(dA)); dA = dihedral(p1,p2,p3,p4) - phi_A0; " + "uphi_B = (1-cos(dB)); dB = dihedral(p2,p3,p4,p5) - phi_B0; " + "uphi_C = (1-cos(dC)); dC = dihedral(p3,p4,p5,p6) - phi_C0; " + f"pi = {np.pi}; " + ) + + +@pytest.mark.parametrize('num_atoms', [6, 20]) +def test_custom_compound_force(num_atoms): + fn = get_boresch_energy_function('lambda_restraints') + force = get_custom_compound_bond_force(fn, num_atoms) + + # Check we have the right object + assert isinstance(force, openmm.CustomCompoundBondForce) + + # Check the energy function + assert force.getEnergyFunction() == fn + + # Check the number of particles + assert force.getNumParticlesPerBond() == num_atoms + + +@pytest.mark.parametrize('groups, expected', [ + [[0, 1, 2, 3, 4], 5], + [[1, 2, 3, 4, 5], 0], +]) +def test_add_force_in_separate_group(groups, expected): + # Create an empty system + system = openmm.System() + + # Create some forces with some force groups + base_forces = [ + openmm.NonbondedForce(), + openmm.HarmonicBondForce(), + openmm.HarmonicAngleForce(), + openmm.PeriodicTorsionForce(), + openmm.CMMotionRemover(), + ] + + for force, group in zip(base_forces, groups): + force.setForceGroup(group) + + [system.addForce(force) for force in base_forces] + + # Get your CustomCompoundBondForce + fn = get_boresch_energy_function('lambda_restraints') + new_force = get_custom_compound_bond_force(fn, 6) + # new_force.setForceGroup(5) + # system.addForce(new_force) + add_force_in_separate_group(system=system, force=new_force) + + # Loop through and check that we go assigned the expected force group + for force in system.getForces(): + if isinstance(force, openmm.CustomCompoundBondForce): + assert force.getForceGroup() == expected + + +def test_add_too_many_force_groups(): + # Create a system + system = openmm.System() + + # Fill it upu with 32 forces with different groups + for i in range(32): + f = openmm.HarmonicBondForce() + f.setForceGroup(i) + system.addForce(f) + + # Now try to add another force + with pytest.raises(ValueError, match="No available force group"): + add_force_in_separate_group( + system=system, force=openmm.HarmonicBondForce() + ) \ No newline at end of file From c914b18c63026524de550303f8662819c354bcd0 Mon Sep 17 00:00:00 2001 From: IAlibay Date: Mon, 16 Dec 2024 09:22:31 +0000 Subject: [PATCH 29/29] base for restraint settings --- openfe/protocols/restraint_utils/settings.py | 23 ++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 openfe/protocols/restraint_utils/settings.py diff --git a/openfe/protocols/restraint_utils/settings.py b/openfe/protocols/restraint_utils/settings.py new file mode 100644 index 000000000..0c12aef17 --- /dev/null +++ b/openfe/protocols/restraint_utils/settings.py @@ -0,0 +1,23 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +Settings for adding restraints. +""" +from typing import Optional, Literal +from openff.units import unit +from openff.models.types import FloatQuantity, ArrayQuantity + +from gufe.settings import ( + SettingsBaseModel, +) + + +from pydantic.v1 import validator + + +class BaseRestraintSettings(SettingsBaseModel): + """ + Base class for RestraintSettings objects. + """ + class Config: + arbitrary_types_allowed = True