Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] ABFE protocol #1045

Open
wants to merge 34 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
3e848a6
first pass at abstract classes
IAlibay Dec 4, 2024
ef050e5
A start at restraints and forces
IAlibay Dec 5, 2024
420f3e5
Add boresch restraint class
IAlibay Dec 6, 2024
2e2b52e
Fix units
IAlibay Dec 6, 2024
76c5fcf
Fix correction return in kj/mole
IAlibay Dec 9, 2024
e9cd918
Add more restraint API bits
IAlibay Dec 9, 2024
f1bbd8a
move some things around
IAlibay Dec 9, 2024
4f4d58e
Merge branch 'main' into omm-restraints
IAlibay Dec 9, 2024
ac452e9
Some changes
IAlibay Dec 11, 2024
536a76e
Merge branch 'omm-restraints' of github.com:OpenFreeEnergy/openfe int…
IAlibay Dec 11, 2024
5d0b683
Towards ABFE protocol
IAlibay Dec 11, 2024
22fbc37
Add base code for settings
IAlibay Dec 11, 2024
e1ef18c
Fix up the units a bit
IAlibay Dec 11, 2024
158ce40
Deal with the restraint addition
IAlibay Dec 11, 2024
a19c86c
refactor restraints
IAlibay Dec 12, 2024
20dd1dc
add some angle checks
IAlibay Dec 12, 2024
9ab74a8
only construct with settings
IAlibay Dec 12, 2024
8f2e1e0
Add more checks to utilities
IAlibay Dec 12, 2024
0a480aa
host finding code
IAlibay Dec 12, 2024
7a7be90
fix up weird black wrapping
IAlibay Dec 12, 2024
733f3b3
remove old search file, add more changes to boresch search
IAlibay Dec 12, 2024
bbef017
Merge branch 'main' into omm-restraints
IAlibay Dec 12, 2024
96decff
Remove duplicate methods
IAlibay Dec 12, 2024
2d97de8
Apply suggestions from code review
IAlibay Dec 12, 2024
116ba64
Add some more docstring
IAlibay Dec 12, 2024
9ae60da
Add minimized vectors on the collinear checks
IAlibay Dec 13, 2024
3cce308
add host atom finding routine
IAlibay Dec 14, 2024
9171d39
autoformatting
IAlibay Dec 14, 2024
033a1e4
various fixes
IAlibay Dec 14, 2024
fe1308e
docstring drive
IAlibay Dec 14, 2024
deac61d
Merge branch 'omm-restraints' into abfe_protocol
IAlibay Dec 14, 2024
d71b961
Migrate to restraint_utils
IAlibay Dec 15, 2024
c914b18
base for restraint settings
IAlibay Dec 16, 2024
2dd883a
Merge branch 'omm-restraints' into abfe_protocol
IAlibay Dec 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 97 additions & 52 deletions openfe/protocols/openmm_afe/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from openmmtools import multistate
from openmmtools.states import (SamplerState,
ThermodynamicState,
GlobalParameterState,
create_thermodynamic_state_protocol,)
from openmmtools.alchemy import (AlchemicalRegion, AbsoluteAlchemicalFactory,
AlchemicalState,)
Expand Down Expand Up @@ -469,45 +470,70 @@ def _get_modeller(

def _get_omm_objects(
self,
system_modeller: app.Modeller,
system_generator: SystemGenerator,
smc_components: list[OFFMolecule],
) -> tuple[app.Topology, openmm.unit.Quantity, openmm.System]:
settings: dict[str, SettingsBaseModel],
protein_component: Optional[ProteinComponent],
solvent_component: Optional[SolventComponent],
smc_components: dict[SmallMoleculeComponent, OFFMolecule],
) -> tuple[
app.Topology,
openmm.System,
openmm.unit.Quantity,
dict[str, npt.NDArray],
]:
"""
Get the OpenMM Topology, Positions and System of the
parameterised system.

Parameters
----------
system_modeller : app.Modeller
OpenMM Modeller object representing the system to be
parametrized.
system_generator : SystemGenerator
SystemGenerator object to create a System with.
smc_components : list[openff.toolkit.Molecule]
A list of openff Molecules to add to the system.
settings : dict[str, SettingsBaseModel]
Protocol settings
protein_component : Optional[ProteinComponent]
Protein component for the system.
solvent_component : Optional[SolventComponent]
Solvent component for the system.
smc_components : dict[str, OFFMolecule]
SmallMoleculeComponents defining ligands to be added to the system

Returns
-------
topology : app.Topology
Topology object describing the parameterized system
OpenMM Topology object describing the parameterized system.
system : openmm.System
An OpenMM System of the alchemical system.
positionns : openmm.unit.Quantity
An non-alchemical OpenMM System of the simulated system.
positions : openmm.unit.Quantity
Positions of the system.
comp_resids : dict[str, npt.NDArray]
A dictionary of residues for each component in the System.
"""
topology = system_modeller.getTopology()
if self.verbose:
self.logger.info("Parameterizing system")

system_generator = self._get_system_generator(
settings, solvent_component
)

modeller, comp_resids = self._get_modeller(
protein_component,
solvent_component,
smc_components,
system_generator,
settings['charge_settings'],
settings['solvation_settings']
)

topology = modeller.getTopology()
# roundtrip positions to remove vec3 issues
positions = to_openmm(from_openmm(system_modeller.getPositions()))
positions = to_openmm(from_openmm(modeller.getPositions()))

# Block out oechem backend to avoid any issues with
# smiles roundtripping between rdkit and oechem
with without_oechem_backend():
system = system_generator.create_system(
system_modeller.topology,
modeller.topology,
molecules=smc_components,
)
return topology, system, positions
return topology, system, positions, comp_resids

def _get_lambda_schedule(
self, settings: dict[str, SettingsBaseModel]
Expand All @@ -533,21 +559,24 @@ def _get_lambda_schedule(

lambda_elec = settings['lambda_settings'].lambda_elec
lambda_vdw = settings['lambda_settings'].lambda_vdw
lambda_rest = settings['lambda_settings'].lambda_restraints

# Reverse lambda schedule since in AbsoluteAlchemicalFactory 1
# means fully interacting, not stateB
lambda_elec = [1-x for x in lambda_elec]
lambda_vdw = [1-x for x in lambda_vdw]
lambdas['lambda_electrostatics'] = lambda_elec
lambdas['lambda_sterics'] = lambda_vdw
for name, schedule in [
('lambda_electrostatics', lambda_elec),
('lambda_sterics', lambda_vdw),
('lambda_restraints', lambda_rest),
]:
lambdas[name] = [1-x for x in schedule]

return lambdas

def _add_restraints(self, system, topology, settings):
"""
Placeholder method to add restraints if necessary
"""
return
return None, None, system

def _get_alchemical_system(
self,
Expand Down Expand Up @@ -607,6 +636,7 @@ def _get_states(
settings: dict[str, SettingsBaseModel],
lambdas: dict[str, npt.NDArray],
solvent_comp: Optional[SolventComponent],
restraint_state: Optional[GlobalParameterState],
) -> tuple[list[SamplerState], list[ThermodynamicState]]:
"""
Get a list of sampler and thermodynmic states from an
Expand All @@ -624,6 +654,8 @@ def _get_states(
A dictionary of lambda scales.
solvent_comp : Optional[SolventComponent]
The solvent component of the system, if there is one.
restraint_state : Optional[GlobalParameterState]
The restraint parameter control state, if there is one.

Returns
-------
Expand All @@ -641,9 +673,14 @@ def _get_states(
if solvent_comp is not None:
constants['pressure'] = ensure_quantity(pressure, 'openmm')

if restraint_state is not None:
composable_states = [alchemical_state, restraint_state]
else:
composable_states = [alchemical_state,]

cmp_states = create_thermodynamic_state_protocol(
alchemical_system, protocol=lambdas,
constants=constants, composable_states=[alchemical_state],
constants=constants, composable_states=composable_states,
)

sampler_state = SamplerState(positions=positions)
Expand Down Expand Up @@ -873,6 +910,7 @@ def _run_simulation(
sampler: multistate.MultiStateSampler,
reporter: multistate.MultiStateReporter,
settings: dict[str, SettingsBaseModel],
standard_state_corr: Optional[unit.Quantity]
dry: bool
):
"""
Expand All @@ -886,6 +924,8 @@ def _run_simulation(
The reporter associated with the sampler.
settings : dict[str, SettingsBaseModel]
The dictionary of settings for the protocol.
standard_state_corr : Optional[unit.Quantity]
The standard state correction, if available.
dry : bool
Whether or not to dry run the simulation

Expand Down Expand Up @@ -944,7 +984,12 @@ def _run_simulation(
analyzer.plot(filepath=self.shared_basepath, filename_prefix="")
analyzer.close()

return analyzer.unit_results_dict
return_dict = analyzer.unit_results_dict

if standard_state_corr is not None:
return_dict['standard_state_correction'] = standard_state_corr

return return_dict

else:
# close reporter when you're done, prevent file handle clashes
Expand Down Expand Up @@ -991,44 +1036,40 @@ def run(self, dry=False, verbose=True,
# 2. Get settings
settings = self._handle_settings()

# 3. Get system generator
system_generator = self._get_system_generator(settings, solv_comp)

# 4. Get modeller
system_modeller, comp_resids = self._get_modeller(
prot_comp, solv_comp, smc_comps, system_generator,
settings['charge_settings'],
settings['solvation_settings'],
# 3. Get OpenMM topology, positions, and system
omm_topology, omm_system, position, comp_resids = self._get_omm_objects(
settings, prot_comps, solv_comps, smc_comps,
)

# 5. Get OpenMM topology, positions and system
omm_topology, omm_system, positions = self._get_omm_objects(
system_modeller, system_generator, list(smc_comps.values())
)

# 6. Pre-equilbrate System (Test + Avoid NaNs + get stable system)
# 4. Pre-equilbrate System (Test + Avoid NaNs + get stable system)
positions = self._pre_equilibrate(
omm_system, omm_topology, positions, settings, dry
)

# 7. Get lambdas
# 5. Get lambdas
lambdas = self._get_lambda_schedule(settings)

# 8. Add restraints
self._add_restraints(omm_system, omm_topology, settings)
# 6. Add restraints
restraint_parameter_state, standard_state_corr, omm_system = self._add_restraints(
omm_system, omm_topology, settings
)

# 9. Get alchemical system
# 7. Get alchemical system
alchem_factory, alchem_system, alchem_indices = self._get_alchemical_system(
omm_topology, omm_system, comp_resids, alchem_comps
)

# 10. Get compound and sampler states
# 7. Get compound and sampler states
sampler_states, cmp_states = self._get_states(
alchem_system, positions, settings,
lambdas, solv_comp
alchem_system,
positions,
settings,
lambdas,
solv_comp,
restraint_parameter_state,
)

# 11. Create the multistate reporter & create PDB
# 9. Create the multistate reporter & create PDB
reporter = self._get_reporter(
omm_topology, positions,
settings['simulation_settings'],
Expand All @@ -1037,29 +1078,33 @@ def run(self, dry=False, verbose=True,

# Wrap in try/finally to avoid memory leak issues
try:
# 12. Get context caches
# 10. Get context caches
energy_ctx_cache, sampler_ctx_cache = self._get_ctx_caches(
settings['forcefield_settings'],
settings['engine_settings']
)

# 13. Get integrator
# 11. Get integrator
integrator = self._get_integrator(
settings['integrator_settings'],
settings['simulation_settings'],
)

# 14. Get sampler
# 12. Get sampler
sampler = self._get_sampler(
integrator, reporter, settings['simulation_settings'],
settings['thermo_settings'],
cmp_states, sampler_states,
energy_ctx_cache, sampler_ctx_cache
)

# 15. Run simulation
# 13. Run simulation
unit_result_dict = self._run_simulation(
sampler, reporter, settings, dry
sampler,
reporter,
settings,
standard_state_corr,
dry
)

finally:
Expand Down
Loading
Loading