diff --git a/Yank/analyze.py b/Yank/analyze.py index cc482a3f..4897fb47 100644 --- a/Yank/analyze.py +++ b/Yank/analyze.py @@ -22,12 +22,17 @@ import os import abc import yaml -import mdtraj import logging +import itertools + import numpy as np -import simtk.unit as units +import mdtraj as md + +import simtk.unit as unit import openmmtools as mmtools -from pymbar import timeseries +from pymbar import timeseries, MBAR +from msmbuilder.cluster import RegularSpatial + from . import multistate logger = logging.getLogger(__name__) @@ -253,8 +258,8 @@ def analyze_directory(source_directory, **analyzer_kwargs): # Print energies logger.info('') logger.info('Free energy{:<13}: {:9.3f} +- {:.3f} kT ({:.3f} +- {:.3f} kcal/mol)'.format( - calculation_type, DeltaF, dDeltaF, DeltaF * kT / units.kilocalories_per_mole, - dDeltaF * kT / units.kilocalories_per_mole)) + calculation_type, DeltaF, dDeltaF, DeltaF * kT / unit.kilocalories_per_mole, + dDeltaF * kT / unit.kilocalories_per_mole)) logger.info('') for phase in phase_names: @@ -265,8 +270,8 @@ def analyze_directory(source_directory, **analyzer_kwargs): data[phase]['DeltaF_standard_state_correction'])) logger.info('') logger.info('Enthalpy{:<16}: {:9.3f} +- {:.3f} kT ({:.3f} +- {:.3f} kcal/mol)'.format( - calculation_type, DeltaH, dDeltaH, DeltaH * kT / units.kilocalories_per_mole, - dDeltaH * kT / units.kilocalories_per_mole)) + calculation_type, DeltaH, dDeltaH, DeltaH * kT / unit.kilocalories_per_mole, + dDeltaH * kT / unit.kilocalories_per_mole)) # ========================================== @@ -479,7 +484,7 @@ def extract_trajectory(nc_path, nc_checkpoint_file=None, state_index=None, repli # Create trajectory object logger.info('Creating trajectory object...') - trajectory = mdtraj.Trajectory(positions, topology) + trajectory = md.Trajectory(positions, topology) if is_periodic: trajectory.unitcell_vectors = box_vectors @@ -496,3 +501,256 @@ def extract_trajectory(nc_path, nc_checkpoint_file=None, state_index=None, repli logger.warning('The molecules will not be imaged because the system is non-periodic.') return trajectory + +# ============================================================================== +# Cluster ligand conformations and estimate populations in fully-interacting state +# ============================================================================== + +# TODO: This is a preliminary draft. This can be heavily refactored after generalizing the analysis code in the MultiStateAnalyzer + +def cluster(reference_pdb_filename, netcdf_filename, output_prefix='cluster', nsnapshots_per_cluster=5, + receptor_dsl_selection='protein and name CA', ligand_dsl_selection='not protein and (mass > 1.5)', + fully_interacting_state=0, ligand_rmsd_cutoff=0.3, ligand_filter_cutoff=0.3, + cluster_filter_threshold=0.95): + """ + Cluster ligand conformations and estimate populations in fully-interacting state + + Parameters + ---------- + reference_pdb_filename : str + The name of the PDB file for the solvated complex + netcdf_filename : str + The complex NetCDF file to read + output_prefix : str + String to prepend to cluster PDB files and populations written + nsnapshots_per_cluster : int, optional, default=5 + The number of snapshots per state to write + receptor_dsl_selection : str, optional, default='protein and name CA' + MDTraj DSL to use for selecting receptor atoms for alignment + ligand_dsl_selection : str, optional, default='not protein and (mass > 1.5)' + MDTraj DSL to use for selectinf ligand atoms to cluster + fully_interacting_state : int, optional, default=0 + 0 specifies the fully-interacting state + 1 species the first alchemical state + ligand_rmsd_cutoff : float, optional, default=0.3 + RMSD cutoff to use for ligand custering (in nanometers) + ligand_filter_cutoff : float, optional, default=0.3 + Snapshots where ligand atoms are greater than this cutoff from the receptor are filtered out + cluster_filter_threshold : float, optional, default=0.95 + Only the most populous clusters that add up to more than this threshold in population are written. + + The algorithm + ------------- + * Compute per-snapshot weights (using MBAR) representing the relative weight of each snapshot in the fully interacting state + * Cluster the remaining snapshots + * Assign relative populations to the clusters + * Sort clusters by population, writing only most populous clusters + * Sample representative snapshots from the clusters proportional to their weights, writing out PDB files + * Write out cluster populations + + """ + # mdtraj works in nanometers + ligand_rmsd_cutoff /= unit.nanometers + ligand_filter_cutoff /= unit.nanometers + + topology = md.load(reference_pdb_filename) + solute_indices = topology.top.select('not water') + logger.info('There are {:d} non-water atoms'.format(len(solute_indices))) + topology = topology.atom_slice(solute_indices) # check that this is the same as w + + from netCDF4 import Dataset + ncfile = Dataset(netcdf_filename, 'r') + + # TODO: Extend this to handle more than one replica + replica_index = 0 + + # Extract energy trajectories + sampled_energy_matrix = np.transpose(np.array(ncfile.variables['energies'][:,replica_index:(replica_index+1),:], np.float32), axes=[1,2,0]) + unsampled_energy_matrix = np.transpose(np.array(ncfile.variables['unsampled_energies'][:,replica_index:(replica_index+1),:], np.float32), axes=[1,2,0]) + + # Initialize the MBAR matrices in ln form. + n_replicas, n_sampled_states, n_iterations = sampled_energy_matrix.shape + _, n_unsampled_states, _ = unsampled_energy_matrix.shape + logger.info('n_replicas: {:d}'.format(n_replicas)) + logger.info('n_sampled_states: {:d}'.format(n_sampled_states)) + logger.info('n_iterations: {:d}'.format(n_iterations)) + + # Remove some frames + # TODO: Change this to instead extract the "good" portion of the trajectory that isn't corrupted + ntrim = 100 + logger.info('Trimming {:d} frames from either end'.format(ntrim)) + retained_snapshot_indices = list(range(ntrim, (n_iterations-ntrim))) + n_iterations = len(retained_snapshot_indices) + sampled_energy_matrix = sampled_energy_matrix[:,:,retained_snapshot_indices] + unsampled_energy_matrix = unsampled_energy_matrix[:,:,retained_snapshot_indices] + + # Extract thermodynamic state indices + replicas_state_indices = np.transpose(np.array(ncfile.variables['states'][retained_snapshot_indices,replica_index:(replica_index+1)], np.int64), axes=[1,0]) + + # TODO: Pre-filter all states remote from fully-interacting state + + # TODO: We could detect the equilibration time and discard data to equilibration here + #[t0, g, Neff_max] = timeseries.detectEquilibration(replicas_state_indices[replica_index,:], nskip=100) + + # Compute snapshot weights with MBAR + + # + # Note: This comes from multistateanalyzer.py L1445-1479. + # That section could be refactored to be more general to avoid code duplication + # + + logger.info('Reformatting energies...') + n_total_states = n_sampled_states + n_unsampled_states + energy_matrix = np.zeros([n_total_states, n_iterations*n_replicas]) + samples_per_state = np.zeros([n_total_states], dtype=int) + # Compute shift index for how many unsampled states there were. + # This assume that we set an equal number of unsampled states at the end points. + first_sampled_state = int(n_unsampled_states/2.0) + last_sampled_state = n_total_states - first_sampled_state + # Cast the sampled energy matrix from kln' to ln form. + energy_matrix[first_sampled_state:last_sampled_state, :] = multistate.MultiStateSamplerAnalyzer.reformat_energies_for_mbar(sampled_energy_matrix) + # Determine how many samples and which states they were drawn from. + unique_sampled_states, counts = np.unique(replicas_state_indices, return_counts=True) + # Assign those counts to the correct range of states. + samples_per_state[first_sampled_state:last_sampled_state][unique_sampled_states] = counts + # Add energies of unsampled states to the end points. + if n_unsampled_states > 0: + energy_matrix[[0, -1], :] = multistate.MultiStateSamplerAnalyzer.reformat_energies_for_mbar(unsampled_energy_matrix) + + # TODO: Should we instead only run MBAR *after* we have already filtered out problematic snapshots? + logger.info('Estimating weights...') + mbar = MBAR(energy_matrix, samples_per_state) + # Extract weights + w_n = mbar.W_nk[:,fully_interacting_state] + + # Extract unitcell lengths and angles + # TODO: Make this more general for non-rectilinear boxes + x = np.array(ncfile.variables['box_vectors'][retained_snapshot_indices,replica_index,:,:]) + unitcell_lengths = x[:,[0,1,2],[0,1,2]] + unitcell_angles = 90.0 * np.ones(unitcell_lengths.shape, np.float32) + + # Extract solute trajectory as MDTraj Trajectory + # NOTE: Only retained snapshots are extracted to speed things up + # NOTE: This will store the whole trajectory in memory + traj = md.Trajectory(ncfile.variables['positions'][retained_snapshot_indices,replica_index,solute_indices,:], topology.top, unitcell_lengths=unitcell_lengths, unitcell_angles=unitcell_angles) + + + # Remove counterions + # TODO: Is there a better way to eliminate everything but receptor and ligand? + logger.info(traj) + ion_dsl_selection = 'not (resname "Na+" or resname "Cl-")' + indices = traj.top.select(ion_dsl_selection) + traj = traj.atom_slice(indices) + logger.info(traj) + + # Check snapshot weights are small + logger.info('Maximum weight from any given snapshot (SHOULD BE SMALL!): {:f}'.format(w_n.max())) + indices = np.argsort(-w_n) + MAX_SNAPSHOT_WEIGHT = 0.01 + if w_n.max() > MAX_SNAPSHOT_WEIGHT: + filename = '%s-outlier.pdb' % (output_prefix) + logger.warning('WARNING: One snapshot is dominating the weights so clusters and populations will be unreliable') + logger.warning('Writing outlier to {}'.format(filename)) + snapshot_index = w_n.argmax() + logger.warning('snaphot {:d} has weight {:f}'.format(snapshot_index, w_n.max())) + traj[snapshot_index].save(filename) + + # Image molecules into periodic box, ensuring ligand is in closest image to receptor + traj.image_molecules(inplace=True) + + # Compute minimum heavy atom distance from ligand to protein + residues = [residue for residue in traj.top.residues] + protein_residues = [residue.index for residue in traj.top.residues if residue.is_protein] + logger.info('There are {:d} protein residues'.format(len(protein_residues))) + ligand_residues = [residue.index for residue in traj.top.residues if not residue.is_protein] + logger.info('There are {:d} ligand residues'.format(len(ligand_residues))) + + pairs = list(itertools.product(ligand_residues, protein_residues)) + distances, contacts = md.compute_contacts(traj, contacts=pairs, scheme='closest-heavy', ignore_nonprotein=False) + min_distances = distances.min(1) + logger.info('Maximum ligand heavy atom distance from protein: {:f} nm'.format(min_distances.max())) + logger.info('Minimum ligand heavy atom distance from protein: {:f} nm'.format(min_distances.min())) + + # Filter out snapshots where ligand is too far from the protein + filtered_snapshot_indices = np.where(min_distances <= ligand_filter_cutoff)[0] + logger.info('Retaining {:d} of {:d} snapshots where ligand heavy atoms are less than {:f} nm from protein'.format(len(filtered_snapshot_indices), len(traj), ligand_filter_cutoff)) + + # Filter out snapshots where ligand is too far from the protein + filtered_traj = traj[filtered_snapshot_indices] + filtered_w_n = np.array(w_n[filtered_snapshot_indices]) + filtered_w_n /= filtered_w_n.sum() # renormalize + + # Align receptor to first frame + atoms_to_align = filtered_traj.top.select(receptor_dsl_selection) + if (len(atoms_to_align) == 0): + raise Exception("Please check receptor_dsl_selection since no atoms were found in selection!") + logger.info('aligning on {:d} atoms from receptor_dsl_selection'.format(len(atoms_to_align))) + aligned_traj = filtered_traj.superpose(filtered_traj[0], frame=0, atom_indices=atoms_to_align) + + # Extract ligand trajectory + ligand_atom_indices = aligned_traj.topology.select(ligand_dsl_selection) + ligand_trajectory = aligned_traj.atom_slice(ligand_atom_indices) + logger.info('{:d} atoms in ligand trajectory'.format(len(ligand_atom_indices))) + + # Perform regular spatial clustering on ligand trajectory + nsnapshots, natoms, _ = ligand_trajectory.xyz.shape + x = np.array(ligand_trajectory.xyz).reshape([nsnapshots, natoms*3], order='C') + reg_space = RegularSpatial(d_min=3*natoms*ligand_rmsd_cutoff**2, metric='sqeuclidean').fit([x]) + cluster_assignments = reg_space.fit_predict([x])[0] + nclusters = cluster_assignments.max() + 1 + logger.info('There are {:d} clusters'.format(nclusters)) + + # Sort clusters by probability + cluster_probabilities = np.zeros([nclusters], np.float64) + for cluster_index in range(nclusters): + snapshot_indices = np.where(cluster_assignments == cluster_index)[0] + cluster_probabilities[cluster_index] = filtered_w_n[snapshot_indices].sum() + # Permute clusters + sorted_indices = np.argsort(-cluster_probabilities) + new_cluster_assignments = np.array(cluster_assignments) + for cluster_index in range(nclusters): + indices = np.where(cluster_assignments == sorted_indices[cluster_index])[0] + new_cluster_assignments[indices] = cluster_index + cluster_assignments = new_cluster_assignments + cluster_probabilities = cluster_probabilities[sorted_indices] + + # Write cluster populations + for cluster_index in range(nclusters): + logger.info('Cluster {:5d} : {:12.8f}'.format(cluster_index, cluster_probabilities[cluster_index])) + + cumsum = np.cumsum(cluster_probabilities) + cutoff_index = np.where(cumsum > cluster_filter_threshold)[0][0] # first index where weight is below threshold + nclusters = max(cutoff_index, 1) + logger.info('There are {:d} clusters after retaining only those where the cumulative weight exceeds {:f}'.format(nclusters, cluster_filter_threshold)) + + # Write reference protein conformation + receptor_atom_indices = aligned_traj.topology.select('protein') + filename = '%s-reference.pdb' % (output_prefix) + logger.info('Writing reference coordinates to {}'.format(filename)) + aligned_traj[0].atom_slice(receptor_atom_indices).save(filename) + + # Write aligned frames + for cluster_index in range(nclusters): + indices = np.where(cluster_assignments == cluster_index)[0] + # Remove indices with zero probability + indices = indices[filtered_w_n[indices] > 0.0] + + nsnapshots = len(indices) + logger.info('Cluster {:5d} : pop {:12.8f} : {:8d} members'.format(cluster_index, cluster_probabilities[cluster_index], nsnapshots)) + + # Sample frames + filename = '%s-cluster%03d.pdb' % (output_prefix, cluster_index) + logger.info(' writing {}'.format(filename)) + if nsnapshots <= nsnapshots_per_cluster: + aligned_traj[indices].save(filename) + else: + p = filtered_w_n[indices] / filtered_w_n[indices].sum() + sampled_indices = np.random.choice(indices, size=nsnapshots_per_cluster, p=p, replace=False) + aligned_traj[sampled_indices].save('%s-cluster%03d.pdb' % (output_prefix, cluster_index)) + # Write cluster populations to a file + filename = '%s-populations.txt' % (output_prefix) + logger.info('Writing populations to {}'.format(filename)) + outfile = open(filename, 'w') + for cluster_index in range(nclusters): + outfile.write('%05d %12.8f\n' % (cluster_index, cluster_probabilities[cluster_index])) + outfile.close() diff --git a/Yank/commands/analyze.py b/Yank/commands/analyze.py index 8e62f7bb..86097de7 100644 --- a/Yank/commands/analyze.py +++ b/Yank/commands/analyze.py @@ -33,6 +33,7 @@ yank analyze (-s STORE | --store=STORE) [--skipunbiasing] [--distcutoff=DISTANCE] [--energycutoff=ENERGY] [-v | --verbose] [--fulltraj] yank analyze report (-s STORE | --store=STORE) (-o REPORT | --output=REPORT) [-e | --serial] [--skipunbiasing] [--distcutoff=DISTANCE] [--energycutoff=ENERGY] [--fulltraj] yank analyze extract-trajectory --netcdf=FILEPATH [--checkpoint=FILEPATH ] (--state=STATE | --replica=REPLICA) --trajectory=FILEPATH [--start=START_FRAME] [--skip=SKIP_FRAME] [--end=END_FRAME] [--nosolvent] [--discardequil] [--imagemol] [-v | --verbose] + yank analyze cluster --refpdb=REFPDB --complexnetcdf=FILEPATH [--prefix=PREFIX] [--filter=FILTERDIST] [--cutoff=CUTOFFDIST] [--nsnapshots=NSNAPSHOTS] [--threshold=THRESHOLD] [-v | --verbose] Description: Analyze the data to compute Free Energies OR extract the trajectory from the NetCDF file into a common fortmat. @@ -80,11 +81,20 @@ --discardequil Detect and discard equilibration frames --imagemol Reprocess trajectory to enforce periodic boundary conditions to molecules positions +Cluster Required Arguments: + --refpdb=REFPDB Reference PDB filename for solvated complex + --complexnetcdf=FILEPATH Path to the complex analysis NetCDF file + --prefix=PREFIX Prefix to use for output cluster PDB files and populations (default: cluster) + --filter=FILTERDIST Discard snapshots where the ligand is farther than this minimum heavy atom distance from the protein, in nanometers (default: 0.3) + --cutoff=CUTOFFDIST Heavy-atom RMSD separation between clusters, in nanometers (default: 0.3) + --nsnapshots=NUM_SNAPSHOTS Number of snapshots per cluster to write (default: 5) + --cluster_filter_threshold=THRESHOLD Threshold to use for which clusters to include (default: 0.95) + General Options: -v, --verbose Print verbose output - --fulltraj Force ALL analysis run from this command to rely on the full trajectory and not do any - automatic equilibration detection or decorrelation subsampling. Although the - equilibration and correlation times will still be computed, no calculation depending on + --fulltraj Force ALL analysis run from this command to rely on the full trajectory and not do any + automatic equilibration detection or decorrelation subsampling. Although the + equilibration and correlation times will still be computed, no calculation depending on them will use this information. """ @@ -103,6 +113,9 @@ def dispatch(args): if args['extract-trajectory']: return dispatch_extract_trajectory(args) + if args['cluster']: + return dispatch_cluster(args) + # Configure analyzer keyword arguments. analyzer_kwargs = extract_analyzer_kwargs(args) analyze.analyze_directory(args['--store'], **analyzer_kwargs) @@ -167,6 +180,26 @@ def dispatch_extract_trajectory(args): return True +def dispatch_cluster(args): + """ + Cluster ligand conformations and estimate populations in fully-interacting state. + """ + refpdb_filename = args['--refpdb'] + nc_path = args['--complexnetcdf'] + prefix = args['--prefix'] if (args['--prefix'] is not None) else 'cluster' + filter = args['--filter'] if (args['--filter'] is not None) else (0.3 * unit.nanometers) + cutoff = args['--cutoff'] if (args['--cutoff'] is not None) else (0.3 * unit.nanometers) + nsnapshots_per_cluster = args['--nsnapshots'] if (args['--nsnapshots'] is not None) else 5 + cluster_filter_threshold = args['--threshold'] if (args['--threshold'] is not None) else 0.95 + + # TODO: Refine this API + from yank.analyze import cluster + cluster(reference_pdb_filename=refpdb_filename, netcdf_filename=nc_path, output_prefix=prefix, + nsnapshots_per_cluster=nsnapshots_per_cluster, cluster_filter_threshold=cluster_filter_threshold, + receptor_dsl_selection = 'protein and name CA', ligand_dsl_selection = 'not protein and (mass > 1.5)', + fully_interacting_state=0, ligand_rmsd_cutoff=cutoff, ligand_filter_cutoff=filter) + + return True def dispatch_report(args): # Check modules for render diff --git a/devtools/conda-recipe/meta.yaml b/devtools/conda-recipe/meta.yaml index fbae8be4..271dc0ed 100644 --- a/devtools/conda-recipe/meta.yaml +++ b/devtools/conda-recipe/meta.yaml @@ -39,6 +39,7 @@ requirements: - jupyter - pdbfixer - libnetcdf >=4.6.0 + - msmbuilder test: requires: