diff --git a/Yank/experiment.py b/Yank/experiment.py index 1186ec68..f87cfbf2 100644 --- a/Yank/experiment.py +++ b/Yank/experiment.py @@ -651,7 +651,6 @@ def parse(self, script): yaml_content = script.copy() self._raw_yaml = yaml_content.copy() - # Check that YAML loading was successful if yaml_content is None: raise YamlParseError('The YAML file is empty!') @@ -1678,7 +1677,7 @@ def _is_pipeline_solvent_with_receptor(field, solvent_id, error): # DSL Schema ligand_dsl: required: no - type: string + type: [string, list] dependencies: [phase1_path, phase2_path] solvent_dsl: required: no @@ -1746,7 +1745,7 @@ def _is_pipeline_solvent_with_receptor(field, solvent_id, error): check_with: REGION_CLASH_DETERMINED_AT_RUNTIME_WITH_LIGAND ligand: required: no - type: string + type: [string, list] dependencies: [receptor, solvent] allowed: MOLECULE_IDS_POPULATED_AT_RUNTIME excludes: [solute, phase1_path, phase2_path] @@ -2617,16 +2616,16 @@ def _generate_auto_state_parameters(phase_factory): state_parameters.append(('lambda_restraints', [0.0, 1.0])) # We support only lambda sterics and electrostatics for now. if is_vacuum and not phase_factory.alchemical_regions.annihilate_electrostatics: - state_parameters.append(('lambda_electrostatics', [1.0, 1.0])) + state_parameters.append(('lambda_electrostatics', [1.0, 1.0])) else: - state_parameters.append(('lambda_electrostatics', [1.0, 0.0])) + state_parameters.append(('lambda_electrostatics', [1.0, 0.0])) if is_vacuum and not phase_factory.alchemical_regions.annihilate_sterics: - state_parameters.append(('lambda_sterics', [1.0, 1.0])) + state_parameters.append(('lambda_sterics', [1.0, 1.0])) else: - state_parameters.append(('lambda_sterics', [1.0, 0.0])) + state_parameters.append(('lambda_sterics', [1.0, 0.0])) # Turn the RMSD restraints off slowly at the end if isinstance(phase_factory.restraint, restraints.RMSD): - state_parameters.append(('lambda_restraints', [1.0, 0.0])) + state_parameters.append(('lambda_restraints', [1.0, 0.0])) return state_parameters @@ -3089,6 +3088,7 @@ def _build_experiment(self, experiment_path, experiment, use_dummy_protocol=Fals assert isinstance(protocol, collections.OrderedDict) phase_names = list(protocol.keys()) phase_paths = self._get_nc_file_paths(experiment_path, experiment) + is_relative = False for phase_idx, (phase_name, phase_path) in enumerate(zip(phase_names, phase_paths)): # Check if we need to resume a phase. If the phase has been # already created, Experiment will resume from the storage. @@ -3112,8 +3112,22 @@ def _build_experiment(self, experiment_path, experiment, use_dummy_protocol=Fals # Identify system components. There is a ligand only in the complex phase. if phase_idx == 0: ligand_atoms = ligand_dsl + if len(ligand_atoms) == 2 and isinstance(ligand_atoms[0], list) and isinstance(ligand_atoms[1], list): + is_relative = True else: ligand_atoms = None + topography = Topography(topology, ligand_atoms=ligand_atoms, + solvent_atoms=solvent_dsl) + + solute_atoms = getattr(topography, 'solute_atoms') + topo = topography.topology.subset(solute_atoms) + resnames = [] + for chain in topo._chains: + for residue in chain._residues: + resnames.append(f'resname {residue.name}') + if len(set(resnames)) == 2: + is_relative = True + topography = Topography(topology, ligand_atoms=ligand_atoms, solvent_atoms=solvent_dsl) @@ -3134,8 +3148,8 @@ def _build_experiment(self, experiment_path, experiment, use_dummy_protocol=Fals # and modified it according to the user options. phase_protocol = protocol[phase_name]['alchemical_path'] alchemical_region = AlchemicalPhase._build_default_alchemical_region(system, topography, - phase_protocol) - alchemical_region = alchemical_region._replace(**alchemical_region_opts) + phase_protocol, is_relative) + alchemical_region = [region._replace(**alchemical_region_opts) for region in alchemical_region] # Apply restraint only if this is the first phase. AlchemicalPhase # will take care of raising an error if the phase type does not support it. diff --git a/Yank/pipeline.py b/Yank/pipeline.py index 45a404a8..bfe91dd8 100644 --- a/Yank/pipeline.py +++ b/Yank/pipeline.py @@ -234,7 +234,7 @@ def compute_net_charge(system, atom_indices): return net_charge -def find_alchemical_counterions(system, topography, region_name): +def find_alchemical_counterions(system, topography, region_name, is_relative): """Return the atom indices of the ligand or solute counter ions. In periodic systems, the solvation box needs to be neutral, and @@ -273,32 +273,92 @@ def find_alchemical_counterions(system, topography, region_name): if len(atom_indices) == 0: raise ValueError("Cannot find counterions for region {}. " "The region has no atoms.") - - # If the net charge of alchemical atoms is 0, we don't need counterions. - mol_net_charge = compute_net_charge(system, atom_indices) - logger.debug('{} net charge: {}'.format(region_name, mol_net_charge)) - if mol_net_charge == 0: - return [] - # Find net charge of all ions in the system. ions_net_charges = [(ion_id, compute_net_charge(system, [ion_id])) - for ion_id in topography.ions_atoms] + for ion_id in topography.ions_atoms] topology = topography.topology - ions_names_charges = [(topology.atom(ion_id).residue.name, ion_net_charge) - for ion_id, ion_net_charge in ions_net_charges] - logger.debug('Ions net charges: {}'.format(ions_names_charges)) + if not is_relative: + # If the net charge of alchemical atoms is 0, we don't need counterions. + mol_net_charge = compute_net_charge(system, atom_indices) + logger.debug('{} net charge: {}'.format(region_name, mol_net_charge)) + if mol_net_charge == 0: + return [] + + ions_names_charges = [(topology.atom(ion_id).residue.name, ion_net_charge) + for ion_id, ion_net_charge in ions_net_charges] + logger.debug('Ions net charges: {}'.format(ions_names_charges)) + + # Find minimal subset of counterions whose charges sums to -mol_net_charge. + return ions_subset(ions_net_charges, -mol_net_charge) + + # We couldn't find any subset of counterions neutralizing the region. + raise ValueError('Impossible to find a solution for region {}. ' + 'Net charge: {}, system ions: {}.'.format( + region_name, mol_net_charge, ions_names_charges)) + + else: + if region_name == 'solute_atoms': + topo = topology.subset(atom_indices) + resnames = [] + for chain in topo._chains: + for residue in chain._residues: + resnames.append(f'resname {residue.name}') + atom_indices = tuple([topography.select(r) for r in resnames]) + + ions_net_charges = [(ion_id, compute_net_charge(system, [ion_id])) + for ion_id in topography.ions_atoms] + ignore = [ions[0] for ions in ions_net_charges] + atom_indices[0] + atom_indices[1] + indices = [atom.index for atom in topology.atoms if atom.index not in ignore] + net_charge = compute_net_charge(system, indices) + charge_init = net_charge + compute_net_charge(system, atom_indices[0]) + charge_end = net_charge + compute_net_charge(system, atom_indices[1]) + if (charge_init*charge_end) > 0: + if abs(charge_init) > abs(charge_end): + ions_to_dummies = int(charge_end - charge_init) + dummies_to_ions = None + elif abs(charge_init) < abs(charge_end): + ions_to_dummies = None + dummies_to_ions = int(-(charge_end - charge_init)) + else: + ions_to_dummies = None + dummies_to_ions = None + else: + ions_to_dummies = int(-charge_init) + dummies_to_ions = int(-charge_end) + + counterions = ions_subset(ions_net_charges, -charge_init) + avail_ions = [(id, c) for (id, c) in ions_net_charges if id not in counterions] + ions_to_dummies_idx = None + dummies_to_ions_idx = None + if ions_to_dummies: + ions_to_dummies_idx = counterions[:abs(ions_to_dummies)] + if dummies_to_ions: + dummies_to_ions_idx = ions_subset(avail_ions, dummies_to_ions) + + return ions_to_dummies_idx, dummies_to_ions_idx + +def ions_subset(ions_net_charges, counterions): + """ + Finds minimal subset of ion indexes whose charge sums to counterions + Parameters + ---------- + ions_net_charges : list of tuples + (index, charge) of ions in system. + counterions : int + Total charge. + + Returns + ------- + counterions_indices : list of int + Indices of ions. + + """ - # Find minimal subset of counterions whose charges sums to -mol_net_charge. for n_ions in range(1, len(ions_net_charges) + 1): for ion_subset in itertools.combinations(ions_net_charges, n_ions): counterions_indices, counterions_charges = zip(*ion_subset) - if sum(counterions_charges) == -mol_net_charge: - return counterions_indices - - # We couldn't find any subset of counterions neutralizing the region. - raise ValueError('Impossible to find a solution for region {}. ' - 'Net charge: {}, system ions: {}.'.format( - region_name, mol_net_charge, ions_names_charges)) + if sum(counterions_charges) == counterions: + return(counterions_indices) # See Amber manual Table 4.1 http://ambermd.org/doc12/Amber15.pdf diff --git a/Yank/restraints.py b/Yank/restraints.py index 13599f3b..bc7398c6 100644 --- a/Yank/restraints.py +++ b/Yank/restraints.py @@ -251,7 +251,11 @@ def compute_atom_intersect(self, input_atoms, topography_key: str, *additional_s """ topography = self.topography - topography_set = set(getattr(topography, topography_key)) + if isinstance(getattr(topography, topography_key)[0], list) and isinstance(getattr(topography, topography_key)[1], list): + topography_set = set([index for sublist in getattr(topography, topography_key) for + index in sublist]) + else: + topography_set = set(getattr(topography, topography_key)) # Ensure additions are sets additional_sets = [set(additional_set) for additional_set in additional_sets] if len(additional_sets) == 0: @@ -259,7 +263,6 @@ def compute_atom_intersect(self, input_atoms, topography_key: str, *additional_s additional_intersect = topography_set else: additional_intersect = set.intersection(*additional_sets) - @functools.singledispatch def compute_atom_set(passed_atoms): """Helper function for doing set operations on heavy ligand atoms of all other types""" @@ -412,6 +415,8 @@ def _parameters(self): for parameter_name in parameter_names} return parameters + def _is_ligand_ligand_restraint(self, restrained_ligand_atoms): + return len(restrained_ligand_atoms) == 2 and isinstance(restrained_ligand_atoms[0], list) and isinstance(restrained_ligand_atoms[1], list) class _RestrainedAtomsProperty(object): """ @@ -433,8 +438,11 @@ def __get__(self, instance, owner_class=None): def __set__(self, instance, new_restrained_atoms): # If we set the restrained attributes to None, no reason to check things. - if new_restrained_atoms is not None: - new_restrained_atoms = self._validate_atoms(new_restrained_atoms) + if (new_restrained_atoms): + if instance._is_ligand_ligand_restraint(new_restrained_atoms): + new_restrained_atoms = [self._validate_atoms(atoms) for atoms in new_restrained_atoms] + else: + new_restrained_atoms = self._validate_atoms(new_restrained_atoms) setattr(instance, self._attribute_name, new_restrained_atoms) @methoddispatch @@ -446,7 +454,6 @@ def _validate_atoms(self, restrained_atoms): restrained_atoms = list(restrained_atoms) return restrained_atoms - # ============================================================================== # Base class for radially-symmetric receptor-ligand restraints. # ============================================================================== @@ -528,7 +535,7 @@ class _RadiallySymmetricRestrainedAtomsProperty(_RestrainedAtomsProperty): def _validate_atoms(self, restrained_atoms): restrained_atoms = super()._validate_atoms(restrained_atoms) if len(restrained_atoms) > 1: - logger.debug(self._CENTROID_COMPUTE_STRING.format("more than one", self._atoms_type)) + logger.debug(self._CENTROID_COMPUTE_STRING.format("more than one", self._atoms_type)) return restrained_atoms @_validate_atoms.register(str) @@ -571,9 +578,12 @@ def restrain_state(self, thermodynamic_state): raise RestraintParameterError('Restraint {}: Undefined restrained ' 'atoms.'.format(self.__class__.__name__)) - # Create restraint force. - restraint_force = self._get_restraint_force(self.restrained_receptor_atoms, - self.restrained_ligand_atoms) + # Create restraint force + if self._is_ligand_ligand_restraint(self.restrained_ligand_atoms): + restraint_force = self._get_restraint_force(*self.restrained_ligand_atoms) + else: + restraint_force = self._get_restraint_force(self.restrained_receptor_atoms, + self.restrained_ligand_atoms) # Set periodic conditions on the force if necessary. restraint_force.setUsesPeriodicBoundaryConditions(thermodynamic_state.is_periodic) @@ -700,10 +710,17 @@ def _determine_restraint_parameters(self, thermodynamic_state, sampler_state, to @property def _are_restrained_atoms_defined(self): """Check if the restrained atoms are defined well enough to make a restraint""" - for atoms in [self.restrained_receptor_atoms, self.restrained_ligand_atoms]: - # Atoms should be a list or None at this point due to the _RestrainedAtomsProperty class - if atoms is None or not (isinstance(atoms, list) and len(atoms) > 0): - return False + if self.restrained_ligand_atoms: + if self._is_ligand_ligand_restraint(self.restrained_ligand_atoms): + for atoms in self.restrained_ligand_atoms: + # Atoms should be a list or None at this point due to the _RestrainedAtomsProperty class + if atoms is None or any(isinstance(atom, str) for atom in atoms) or not (isinstance(atoms, list) and len(atoms) > 0): + return False + else: + for atoms in [self.restrained_receptor_atoms, self.restrained_ligand_atoms]: + if atoms is None or not (isinstance(atoms, list) and len(atoms) > 0): + return False + return True def _get_restraint_force(self, particles1, particles2): @@ -743,7 +760,9 @@ def _determine_restrained_atoms(self, sampler_state, topography): # If receptor and ligand atoms are explicitly provided, use those. restrained_ligand_atoms = self.restrained_ligand_atoms - restrained_receptor_atoms = self.restrained_receptor_atoms + if not self._is_ligand_ligand_restraint(restrained_ligand_atoms): + restrained_receptor_atoms = self.restrained_receptor_atoms + @functools.singledispatch def compute_atom_set(input_atoms, topography_key, mapping_function): @@ -760,7 +779,7 @@ def compute_atom_set(input_atoms, topography_key, mapping_function): "Atoms not part of {0} will be ignored.".format(topography_key)) final_atoms = list(intersect_set) else: - final_atoms = list(input_atoms) + final_atoms = list(input_atoms_set) return final_atoms @compute_atom_set.register(type(None)) @@ -775,18 +794,30 @@ def compute_atom_none(_, topography_key, mapping_function): def compute_atom_str(input_string, topography_key, _): """Helper for string parsing""" selection = topography.select(input_string, as_set=True) - selection_with_top = selection & set(getattr(topography, topography_key)) + if isinstance(getattr(topography, topography_key)[0], list) and isinstance(getattr(topography, topography_key)[1], list): + selection_with_top = selection & set([index for sublist in getattr(topography, topography_key) for + index in sublist]) + else: + selection_with_top = selection & set(getattr(topography, topography_key)) # Force output to be a normal int, dont need to worry about floats at this point, there should not be any # If they come out as np.int64's, OpenMM complains return [*map(int, selection_with_top)] - - self.restrained_ligand_atoms = compute_atom_set(restrained_ligand_atoms, + + if self._is_ligand_ligand_restraint(restrained_ligand_atoms): + self.restrained_ligand_atoms = [] + for element in restrained_ligand_atoms: + for value in element: + self.restrained_ligand_atoms.append(compute_atom_set(value, 'ligand_atoms', - self._closest_atom_to_centroid) - self.restrained_receptor_atoms = compute_atom_set(restrained_receptor_atoms, - 'receptor_atoms', - self._closest_atom_to_centroid) + self._closest_atom_to_centroid)) + else: + self.restrained_ligand_atoms = compute_atom_set(restrained_ligand_atoms, + 'ligand_atoms', + self._closest_atom_to_centroid) + self.restrained_receptor_atoms = compute_atom_set(restrained_receptor_atoms, + 'receptor_atoms', + self._closest_atom_to_centroid) @staticmethod def _closest_atom_to_centroid(positions, indices=None, masses=None): """ @@ -962,6 +993,7 @@ def _create_restraint_force(self, particles1, particles2): The created restraint force. """ + # Create bond force and lambda_restraints parameter to control it. if len(particles1) == 1 and len(particles2) == 1: # CustomCentroidBondForce works only on 64bit platforms. When the @@ -1172,13 +1204,13 @@ def _determine_restraint_parameters(self, thermodynamic_state, sampler_state, to The topography with labeled receptor and ligand atoms. """ - # Determine number of atoms. - n_atoms = len(topography.receptor_atoms) - - # Check that restrained receptor atoms are in expected range. - if any(atom_id >= n_atoms for atom_id in self.restrained_receptor_atoms): - raise ValueError('Receptor atoms {} were selected for restraint, but system ' - 'only has {} atoms.'.format(self.restrained_receptor_atoms, n_atoms)) + if not self._is_ligand_ligand_restraint(self.restrained_ligand_atoms): + # Determine number of atoms. + n_atoms = len(topography.receptor_atoms) + # Check that restrained receptor atoms are in expected range. + if any(atom_id >= n_atoms for atom_id in self.restrained_receptor_atoms): + raise ValueError('Receptor atoms {} were selected for restraint, but system ' + 'only has {} atoms.'.format(self.restrained_receptor_atoms, n_atoms)) # Compute well radius if the user hasn't specified it in the constructor. if self.well_radius is None: @@ -1471,7 +1503,11 @@ def restrain_state(self, thermodynamic_state): n_particles = 6 # number of particles involved in restraint: p1 ... p6 restraint_force = openmm.CustomCompoundBondForce(n_particles, energy_function) restraint_force.addGlobalParameter('lambda_restraints', 1.0) - restraint_force.addBond(self.restrained_receptor_atoms + self.restrained_ligand_atoms, []) + + if self._is_ligand_ligand_restraint(self.restrained_ligand_atoms): + restraint_force.addBond(self.restrained_ligand_atoms[0] + self.restrained_ligand_atoms[1], []) + else: + restraint_force.addBond(self.restrained_receptor_atoms + self.restrained_ligand_atoms, []) restraint_force.setUsesPeriodicBoundaryConditions(thermodynamic_state.is_periodic) # Get a copy of the system of the ThermodynamicState, modify it and set it back. @@ -1705,10 +1741,17 @@ def _check_parameters_defined(self): @property def _are_restrained_atoms_defined(self): """Check if the restrained atoms are defined well enough to make a restraint""" - for atoms in [self.restrained_receptor_atoms, self.restrained_ligand_atoms]: - # Atoms should be a list or None at this point due to the _RestrainedAtomsProperty class - if atoms is None or not (isinstance(atoms, list) and len(atoms) == 3): - return False + if self.restrained_ligand_atoms: + if self._is_ligand_ligand_restraint(self.restrained_ligand_atoms): + for atoms in self.restrained_ligand_atoms: + # Atoms should be a list or None at this point due to the _RestrainedAtomsProperty class + if atoms is None or any(isinstance(atom, str) for atom in atoms) or not (isinstance(atoms, list) and len(atoms) > 0): + return False + else: + for atoms in [self.restrained_receptor_atoms, self.restrained_ligand_atoms]: + # Atoms should be a list or None at this point due to the _RestrainedAtomsProperty class + if atoms is None or not (isinstance(atoms, list) and len(atoms) == 3): + return False return True @staticmethod @@ -1912,7 +1955,10 @@ def _assign_if_undefined(attr_name, attr_value): setattr(self, attr_name, attr_value) # Merge receptor and ligand atoms in a single array for easy manipulation. - restrained_atoms = self.restrained_receptor_atoms + self.restrained_ligand_atoms + if self._is_ligand_ligand_restraint(self.restrained_ligand_atoms): + restrained_atoms = self.restrained_ligand_atoms[0] + self.restrained_ligand_atoms[1] + else: + restrained_atoms = self.restrained_receptor_atoms + self.restrained_ligand_atoms # Set spring constants uniformly, as in Ref [1] Table 1 caption. _assign_if_undefined('K_r', 20.0 * unit.kilocalories_per_mole / unit.angstrom**2) @@ -2496,7 +2542,7 @@ def __init__(self, atoms_type, allowed_empty=False): def _validate_atoms(self, restrained_atoms): restrained_atoms = super()._validate_atoms(restrained_atoms) # TODO: Determine the minimum number of atoms needed for this restraint (can it be 0?) - if len(restrained_atoms) < 3 and not (len(restrained_atoms) == 0 and self._allowed_empty): + if len(restrained_atoms) < 3 and not (len(restrained_atoms) == 0 and self._allowed_empty) and not (isinstance(element, str) for element in restrained_atoms): raise ValueError('At least three {} atoms are required to impose an ' 'RMSD restraint.'.format(self._atoms_type)) return restrained_atoms @@ -2620,7 +2666,7 @@ def _check_parameters_defined(self): def _are_restrained_atoms_defined(self): """Check if the restrained atoms are defined well enough to make a restraint""" for atoms in [self.restrained_receptor_atoms, self.restrained_ligand_atoms]: - # Atoms should be a list or None at this point due to the _RestrainedAtomsProperty class + # Atoms should be a list or None at this point due to the _RestrainedAtomsProperty class if not self._are_single_atoms_defined(atoms): return False return True @@ -2635,14 +2681,23 @@ def _are_single_atoms_defined(atom_list): def _pick_restrained_atoms(self, topography): """Select the restrained atoms to use for this system""" atom_selector = _AtomSelector(topography) - for atom_word, top_key in zip(["restrained_ligand_atoms", "restrained_receptor_atoms"], - ["ligand_atoms", "receptor_atoms"]): - atoms = getattr(self, atom_word) - if self._are_single_atoms_defined(atoms): - continue - defined_atoms = atom_selector.compute_atom_intersect(atoms, top_key) - setattr(self, atom_word, defined_atoms) - + atoms = getattr(self, "restrained_receptor_atoms") + if not self._are_single_atoms_defined(atoms): + defined_atoms = atom_selector.compute_atom_intersect(atoms, "receptor_atoms") + setattr(self, "restrained_receptor_atoms", defined_atoms) + atoms = getattr(self, "restrained_ligand_atoms") + if self._is_ligand_ligand_restraint(self.restrained_ligand_atoms): + defined_atoms = [] + for element in self.restrained_ligand_atoms: + if isinstance(element[0], str): + defined_atoms += atom_selector.compute_atom_intersect(element[0], "ligand_atoms") + elif isinstance(element, list) and len(element) > 0: + defined_atoms += element + setattr(self, "restrained_ligand_atoms", defined_atoms) + else: + if not self._are_single_atoms_defined(atoms): + defined_atoms = atom_selector.compute_atom_intersect(atoms, "ligand_atoms") + setattr(self, "restrained_ligand_atoms", defined_atoms) if __name__ == '__main__': import doctest diff --git a/Yank/schema/validator.py b/Yank/schema/validator.py index 83332bbd..72c5dac2 100644 --- a/Yank/schema/validator.py +++ b/Yank/schema/validator.py @@ -187,8 +187,11 @@ def _check_with_only_with_no_cutoff(self, field, value): def _check_with_specify_lambda_electrostatics_and_sterics(self, field, value): """Check that the keys of a dictionary contain both lambda_electrostatics and lambda_sterics.""" + if ((isinstance(value, dict) or isinstance(value, collections.OrderedDict)) and - not ('lambda_sterics' in value and 'lambda_electrostatics' in value)): + not (('lambda_sterics' in value and 'lambda_electrostatics' in value) or + ('lambda_sterics_0' in value and 'lambda_electrostatics_0' in value and + 'lambda_sterics_1' in value and 'lambda_electrostatics_1' in value))): self._error(field, "Missing required keys lambda_sterics and/or lambda_electrostatics") def _check_with_math_expressions_variables_are_given(self, field, value): diff --git a/Yank/yank.py b/Yank/yank.py index b117365a..9275e22e 100644 --- a/Yank/yank.py +++ b/Yank/yank.py @@ -120,8 +120,15 @@ def ligand_atoms(self): @ligand_atoms.setter def ligand_atoms(self, value): - self._ligand_atoms = self.select(value) - + relative_system = False + if len(value) == 2: + for v in value: + if isinstance(v, list): + relative_system = True + if (relative_system): + self._ligand_atoms = tuple([self.select(v[0]) for v in value]) + else: + self._ligand_atoms = self.select(value) # Safety check: with a ligand there should always be a receptor. if len(self._ligand_atoms) > 0 and len(self.receptor_atoms) == 0: raise ValueError('Specified ligand but cannot find ' @@ -140,8 +147,11 @@ def receptor_atoms(self): if len(self._ligand_atoms) == 0: return [] - # Create a set for fast searching. - ligand_atomset = frozenset(self._ligand_atoms) + # Create a set for fast searching. + if isinstance(self._ligand_atoms, tuple): + ligand_atomset = frozenset(self._ligand_atoms[0] + self._ligand_atoms[1]) + else: + ligand_atomset = frozenset(self._ligand_atoms) # Receptor atoms are all solute atoms that are not ligand. return [i for i in self.solute_atoms if i not in ligand_atomset] @@ -1010,8 +1020,9 @@ def create(self, thermodynamic_state, sampler_states, topography, protocol, reference_system, topography, protocol) # Check that we have atoms to alchemically modify. - if len(alchemical_regions.alchemical_atoms) == 0: - raise ValueError("Couldn't find atoms to alchemically modify.") + for region in alchemical_regions: + if region.alchemical_atoms == 0: + raise ValueError("Couldn't find atoms to alchemically modify.") # Create alchemically-modified system using alchemical factory. logger.debug("Creating alchemically-modified states...") @@ -1022,11 +1033,16 @@ def create(self, thermodynamic_state, sampler_states, topography, protocol, # Create compound alchemically modified state to pass to sampler. thermodynamic_state.system = alchemical_system - alchemical_state = mmtools.alchemy.AlchemicalState.from_system(alchemical_system) - if restraint_state is not None: - composable_states = [alchemical_state, restraint_state] + if len(alchemical_regions) > 1: + alchemical_state = [mmtools.alchemy.AlchemicalState.from_system(alchemical_system, parameters_name_suffix = '0'), + mmtools.alchemy.AlchemicalState.from_system(alchemical_system, parameters_name_suffix = '1')] else: - composable_states = [alchemical_state] + alchemical_state = [mmtools.alchemy.AlchemicalState.from_system(alchemical_system)] + + composable_states = [state for state in alchemical_state] + if restraint_state is not None: + composable_states += [restraint_state] + compound_state = mmtools.states.CompoundThermodynamicState( thermodynamic_state=thermodynamic_state, composable_states=composable_states) @@ -1331,43 +1347,70 @@ def _expand_state_cutoff(thermodynamic_state, expanded_cutoff_distance, return thermodynamic_state @staticmethod - def _build_default_alchemical_region(system, topography, protocol): + def _build_default_alchemical_region(system, topography, protocol, is_relative): """Create a default AlchemicalRegion if the user hasn't provided one.""" # TODO: we should probably have a second region that annihilate sterics of counterions. - alchemical_region_kwargs = {} - + alchemical_region_0_kwargs = {} + if is_relative: + alchemical_region_1_kwargs = {} # Modify ligand if this is a receptor-ligand phase, or # solute if this is a transfer free energy calculation. if len(topography.ligand_atoms) > 0: alchemical_region_name = 'ligand_atoms' + alchemical_atoms = getattr(topography, alchemical_region_name) else: alchemical_region_name = 'solute_atoms' - alchemical_atoms = getattr(topography, alchemical_region_name) - + solute_atoms = getattr(topography, alchemical_region_name) + topo = topography.topology.subset(solute_atoms) + resnames = [] + for chain in topo._chains: + for residue in chain._residues: + resnames.append(f'resname {residue.name}') + alchemical_atoms = tuple([topography.select(r) for r in resnames]) # In periodic systems, we alchemically modify the ligand/solute # counterions to make sure that the solvation box is always neutral. - if system.usesPeriodicBoundaryConditions(): + if system.usesPeriodicBoundaryConditions(): alchemical_counterions = mpiplus.run_single_node(0, pipeline.find_alchemical_counterions, system, topography, alchemical_region_name, - broadcast_result=True) - alchemical_atoms += alchemical_counterions - - # Sort them by index for safety. We don't want to - # accidentally exchange two atoms' positions. - alchemical_atoms = sorted(alchemical_atoms) - - alchemical_region_kwargs['alchemical_atoms'] = alchemical_atoms + is_relative, broadcast_result=True) + + if is_relative: + ions_to_dummies, dummies_to_ions = mpiplus.run_single_node(0, pipeline.find_alchemical_counterions, + system, topography, alchemical_region_name, + is_relative, broadcast_result=True) + if ions_to_dummies: + alchemical_atoms[0] += ions_to_dummies + if dummies_to_ions: + alchemical_atoms[1] += dummies_to_ions + # Sort them by index for safety. We don't want to + # accidentally exchange two atoms' positions. + alchemical_atoms = [sorted(atoms) for atoms in alchemical_atoms] + alchemical_region_0_kwargs['alchemical_atoms'] = alchemical_atoms[0] + alchemical_region_1_kwargs['alchemical_atoms'] = alchemical_atoms[1] + else: + alchemical_atoms += alchemical_counterions + alchemical_atoms = sorted(alchemical_atoms) + alchemical_region_0_kwargs[0]['alchemical_atoms'] = alchemical_atoms + # Check if we need to modify bonds/angles/torsions. for element_type in ['bonds', 'angles', 'torsions']: if 'lambda_' + element_type in protocol: modify_it = True else: modify_it = None - alchemical_region_kwargs['alchemical_' + element_type] = modify_it - + if is_relative: + alchemical_region_0_kwargs['alchemical_' + element_type] = modify_it + alchemical_region_1_kwargs['alchemical_' + element_type] = modify_it + else: + alchemical_region_0_kwargs['alchemical_' + element_type] = modify_it # Create alchemical region. - alchemical_region = mmtools.alchemy.AlchemicalRegion(**alchemical_region_kwargs) + if is_relative: + alchemical_region = [mmtools.alchemy.AlchemicalRegion(name='0', **alchemical_region_0_kwargs), + mmtools.alchemy.AlchemicalRegion(name='1', **alchemical_region_1_kwargs)] + + else: + alchemical_region = [mmtools.alchemy.AlchemicalRegion(**alchemical_region_0_kwargs)] return alchemical_region