Skip to content

Commit

Permalink
Merge pull request #18 from choderalab/dev-mw
Browse files Browse the repository at this point in the history
updating scripts to run sampling for tautomers given two input smiles
  • Loading branch information
wiederm authored Nov 2, 2020
2 parents be8ddce + d8bc9d7 commit f7e8def
Show file tree
Hide file tree
Showing 14 changed files with 707 additions and 199 deletions.
Binary file added data/test_data/exp_results.pickle
Binary file not shown.
152 changes: 128 additions & 24 deletions neutromeratio/analysis.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,46 @@
import copy
import logging
import pickle
from glob import glob

import matplotlib.pyplot as plt
import mdtraj as md
import networkx as nx
import numpy as np
import pandas as pd
import pkg_resources
import scipy.stats as scs
import seaborn as sns
import torch
import torchani
from rdkit import Chem, Geometry
from rdkit.Chem import AllChem
from rdkit.Chem import AllChem, rdFMCS
from scipy.special import logsumexp
from simtk import unit
import mdtraj as md
import torchani
import torch
from rdkit.Chem import rdFMCS
import pkg_resources
import scipy.stats as scs

from .constants import (
num_threads,
kT,
from neutromeratio.ani import (
ANI,
AlchemicalANI1ccx,
AlchemicalANI1x,
AlchemicalANI2x,
ANI1_force_and_energy,
)
from neutromeratio.constants import (
device,
exclude_set_ANI,
gas_constant,
temperature,
kT,
mols_with_charge,
exclude_set_ANI,
multiple_stereobonds,
device,
num_threads,
temperature,
)
from .tautomers import Tautomer
from .parameter_gradients import FreeEnergyCalculator
from .utils import generate_tautomer_class_stereobond_aware
from .ani import (
ANI1_force_and_energy,
AlchemicalANI2x,
AlchemicalANI1ccx,
ANI,
AlchemicalANI1x,
from neutromeratio.parameter_gradients import FreeEnergyCalculator
from neutromeratio.tautomers import Tautomer
from neutromeratio.utils import (
generate_new_tautomer_pair,
generate_tautomer_class_stereobond_aware,
)
from glob import glob

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -306,7 +308,7 @@ def compare_confomer_generator_and_trajectory_minimum_structures(
mol.RemoveAllConformers()

# generate energy function, use atom symbols of rdkti mol
from .ani import ANI1ccx, ANI1_force_and_energy
from .ani import ANI1_force_and_energy, ANI1ccx

model = ANI1ccx()
energy_function = ANI1_force_and_energy(
Expand Down Expand Up @@ -557,6 +559,108 @@ def setup_alchemical_system_and_energy_function(
return energy_function, tautomer, flipped


def setup_new_alchemical_system_and_energy_function(
name: str,
t1_smiles: str,
t2_smiles: str,
env: str,
ANImodel: ANI,
base_path: str = None,
diameter: int = -1,
checkpoint_file: str = "",
):

import os

if not (
issubclass(ANImodel, (AlchemicalANI2x, AlchemicalANI1ccx, AlchemicalANI1x))
):
raise RuntimeError("Only Alchemical ANI objects allowed! Aborting.")

#######################

####################
# Set up the system, set the restraints
tautomer = generate_new_tautomer_pair(name, t1_smiles, t2_smiles)
tautomer.perform_tautomer_transformation()

# if base_path is defined write out the topology
if base_path:
base_path = os.path.abspath(base_path)
logger.debug(base_path)
if not os.path.exists(base_path):
os.makedirs(base_path)

if env == "droplet":
if diameter == -1:
raise RuntimeError("Droplet is not specified. Aborting.")
# for droplet topology is written in every case
m = tautomer.add_droplet(
tautomer.hybrid_topology,
tautomer.get_hybrid_coordinates(),
diameter=diameter * unit.angstrom,
restrain_hydrogen_bonds=True,
restrain_hydrogen_angles=False,
top_file=f"{base_path}/{name}_in_droplet.pdb",
)
else:
if base_path:
# for vacuum only if base_path is defined
pdb_filepath = f"{base_path}/{name}.pdb"
try:
traj = md.load(pdb_filepath)
except OSError:
coordinates = tautomer.get_hybrid_coordinates()
traj = md.Trajectory(
coordinates.value_in_unit(unit.nanometer), tautomer.hybrid_topology
)
traj.save_pdb(pdb_filepath)
tautomer.set_hybrid_coordinates(traj.xyz[0] * unit.nanometer)

# define the alchemical atoms
alchemical_atoms = [
tautomer.hybrid_hydrogen_idx_at_lambda_1,
tautomer.hybrid_hydrogen_idx_at_lambda_0,
]

model = ANImodel(alchemical_atoms=alchemical_atoms).to(device)
# if specified, load nn parameters
if checkpoint_file:
logger.debug("Loading nn parameters ...")
model.load_nn_parameters(checkpoint_file)

# setup energy function
if env == "vacuum":
energy_function = ANI1_force_and_energy(
model=model,
atoms=tautomer.hybrid_atoms,
mol=None,
)
else:
energy_function = ANI1_force_and_energy(
model=model,
atoms=tautomer.ligand_in_water_atoms,
mol=None,
)

# add restraints
for r in tautomer.ligand_restraints:
energy_function.add_restraint_to_lambda_protocol(r)
for r in tautomer.hybrid_ligand_restraints:
energy_function.add_restraint_to_lambda_protocol(r)

if env == "droplet":
tautomer.add_COM_for_hybrid_ligand(
np.array([diameter / 2, diameter / 2, diameter / 2]) * unit.angstrom
)
for r in tautomer.solvent_restraints:
energy_function.add_restraint_to_lambda_protocol(r)
for r in tautomer.com_restraints:
energy_function.add_restraint_to_lambda_protocol(r)

return energy_function, tautomer


def _error(x: np.ndarray, y: np.ndarray):
""" Simple error """
return x - y
Expand Down
2 changes: 1 addition & 1 deletion neutromeratio/ani.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def load_nn_parameters(
else:
self.tweaked_neural_network.load_state_dict(parameters)
else:
logger.info(f"Parameter file {parameters} does not exist.")
logger.info(f"Parameter file {parameter_path} does not exist.")

def _from_neurochem_resources(self, info_file_path, periodic_table_index):
(
Expand Down
114 changes: 108 additions & 6 deletions neutromeratio/parameter_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def __init__(
bulk_energy_calculation: bool,
potential_energy_trajs: list,
lambdas: list,
n_atoms: int,
max_snapshots_per_window: int = 200,
pickle_path: str = "",
):
Expand Down Expand Up @@ -105,9 +104,9 @@ def get_mix(lambda0, lambda1, lam=0.0):
f"There are {len(snapshots)} snapshots per lambda state (max: {max_snapshots_per_window}). Aborting."
)

# test that we have not less than 80% of max_snapshots_per_window
# test that we have not less than 60% of max_snapshots_per_window
if max_snapshots_per_window != -1 and len(snapshots) < (
int(max_snapshots_per_window * 0.8)
int(max_snapshots_per_window * 0.6)
):
raise RuntimeError(
f"There are only {len(snapshots)} snapshots per lambda state. Aborting."
Expand All @@ -126,7 +125,7 @@ def get_mix(lambda0, lambda1, lam=0.0):
snapshots.extend(ani_trajs[lam])
logger.debug(f"Snapshots per lambda {lam}: {len(ani_trajs[lam])}")

if len(snapshots) < 100:
if len(snapshots) < 300:
logger.critical(
f"Total number of snapshots is {len(snapshots)} -- is this enough?"
)
Expand Down Expand Up @@ -1048,17 +1047,120 @@ def parse_lambda_from_dcd_filename(dcd_filename):
assert len(lambdas) == len(energies)
assert len(lambdas) == len(md_trajs)

if env == "vacuum":
pickle_path = f"{data_path}/{name}/{name}_{ANImodel.name}_{max_snapshots_per_window}_{len(tautomer.hybrid_atoms)}_atoms.pickle"
else:
pickle_path = f"{data_path}/{name}/{name}_{ANImodel.name}_{max_snapshots_per_window}_{diameter}A_{len(tautomer.ligand_in_water_atoms)}_atoms.pickle"

# calculate free energy in kT
fec = FreeEnergyCalculator(
ani_model=energy_function,
md_trajs=md_trajs,
potential_energy_trajs=energies,
lambdas=lambdas,
pickle_path=f"{data_path}/{name}/{name}_{ANImodel.name}_{max_snapshots_per_window}.pickle",
n_atoms=len(tautomer.hybrid_atoms),
pickle_path=pickle_path,
bulk_energy_calculation=bulk_energy_calculation,
max_snapshots_per_window=max_snapshots_per_window,
)

fec.flipped = flipped
return fec


def setup_mbar_for_new_tautomer_pairs(
name: str,
t1_smiles: str,
t2_smiles: str,
max_snapshots_per_window: int,
ANImodel: ANI,
bulk_energy_calculation: bool,
env: str = "vacuum",
checkpoint_file: str = "",
data_path: str = "../data/",
diameter: int = -1,
):

from neutromeratio.analysis import setup_new_alchemical_system_and_energy_function
import os

if not (env == "vacuum" or env == "droplet"):
raise RuntimeError("Only keyword vacuum or droplet are allowed as environment.")
if env == "droplet" and diameter == -1:
raise RuntimeError("Something went wrong.")

def parse_lambda_from_dcd_filename(dcd_filename):
"""parsed the dcd filename
Arguments:
dcd_filename {str} -- how is the dcd file called?
Returns:
[float] -- lambda value
"""
l = dcd_filename[: dcd_filename.find(f"_energy_in_{env}")].split("_")
lam = l[-3]
return float(lam)

data_path = os.path.abspath(data_path)
if not os.path.exists(data_path):
raise RuntimeError(f"{data_path} does not exist!")

#######################
(energy_function, tautomer,) = setup_new_alchemical_system_and_energy_function(
name=name,
t1_smiles=t1_smiles,
t2_smiles=t2_smiles,
ANImodel=ANImodel,
checkpoint_file=checkpoint_file,
env=env,
diameter=diameter,
base_path=f"{data_path}/{name}/",
)
# and lambda values in list
dcds = glob(f"{data_path}/{name}/*.dcd")

lambdas = []
md_trajs = []
energies = []

# read in all the frames from the trajectories
if env == "vacuum":
top = tautomer.hybrid_topology
else:
top = f"{data_path}/{name}/{name}_in_droplet.pdb"

for dcd_filename in dcds:
lam = parse_lambda_from_dcd_filename(dcd_filename)
lambdas.append(lam)
traj = md.load_dcd(dcd_filename, top=top)
logger.debug(f"Nr of frames in trajectory: {len(traj)}")
md_trajs.append(traj)
f = open(
f"{data_path}/{name}/{name}_lambda_{lam:0.4f}_energy_in_{env}.csv", "r"
)
energies.append(np.array([float(e) * kT for e in f]))
f.close()

if len(lambdas) < 5:
raise RuntimeError(f"Below 5 lambda states for {name}")

assert len(lambdas) == len(energies)
assert len(lambdas) == len(md_trajs)

if env == "vacuum":
pickle_path = f"{data_path}/{name}/{name}_{ANImodel.name}_{max_snapshots_per_window}_{len(tautomer.hybrid_atoms)}_atoms.pickle"
else:
pickle_path = f"{data_path}/{name}/{name}_{ANImodel.name}_{max_snapshots_per_window}_{diameter}A_{len(tautomer.ligand_in_water_atoms)}_atoms.pickle"

# calculate free energy in kT
fec = FreeEnergyCalculator(
ani_model=energy_function,
md_trajs=md_trajs,
potential_energy_trajs=energies,
lambdas=lambdas,
pickle_path=pickle_path,
bulk_energy_calculation=bulk_energy_calculation,
max_snapshots_per_window=max_snapshots_per_window,
)

return fec
6 changes: 4 additions & 2 deletions neutromeratio/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,13 @@ def plot_correlation_analysis(
Y,
yerr=error,
mfc="blue",
ms=3,
mec="blue",
ms=4,
fmt="o",
capthick=2,
capsize=2,
alpha=0.6,
ecolor="r",
ecolor="red",
)

else:
Expand Down
Loading

0 comments on commit f7e8def

Please sign in to comment.