From 562e0d305b96082cac337ce2320f5f7835385618 Mon Sep 17 00:00:00 2001 From: Mike Henry <11765982+mikemhenry@users.noreply.github.com> Date: Fri, 16 Aug 2024 14:34:10 -0700 Subject: [PATCH] pyupgrade ruff-ed --- openmmtools/tests/test_alchemy.py | 2036 ++++++++++++----- openmmtools/tests/test_cache.py | 200 +- openmmtools/tests/test_forcefactories.py | 89 +- openmmtools/tests/test_forces.py | 346 ++- openmmtools/tests/test_integrators.py | 693 ++++-- .../tests/test_integrators_and_testsystems.py | 39 +- openmmtools/tests/test_mcmc.py | 280 ++- openmmtools/tests/test_mixing.py | 15 +- openmmtools/tests/test_platforms.py | 6 +- openmmtools/tests/test_sampling.py | 1661 +++++++++----- openmmtools/tests/test_states.py | 1221 ++++++---- openmmtools/tests/test_storage_interface.py | 45 +- openmmtools/tests/test_storage_iodrivers.py | 96 +- openmmtools/tests/test_testsystems.py | 96 +- openmmtools/tests/test_utils.py | 539 +++-- 15 files changed, 5045 insertions(+), 2317 deletions(-) diff --git a/openmmtools/tests/test_alchemy.py b/openmmtools/tests/test_alchemy.py index bb5010f74..46aed8f4f 100644 --- a/openmmtools/tests/test_alchemy.py +++ b/openmmtools/tests/test_alchemy.py @@ -14,8 +14,6 @@ # GLOBAL IMPORTS # ============================================================================= -from __future__ import print_function - import copy import logging import os @@ -34,9 +32,18 @@ from openmm import unit from openmmtools import forces, forcefactories, states, testsystems, utils from openmmtools.constants import kB, ONE_4PI_EPS0 -from openmmtools.alchemy import AlchemicalFunction, AlchemicalState, AbsoluteAlchemicalFactory, \ - AlchemicalRegion, AlchemicalStateError -from openmmtools.multistate.pymbar import subsample_correlated_data, detect_equilibration, _pymbar_exp +from openmmtools.alchemy import ( + AlchemicalFunction, + AlchemicalState, + AbsoluteAlchemicalFactory, + AlchemicalRegion, + AlchemicalStateError, +) +from openmmtools.multistate.pymbar import ( + subsample_correlated_data, + detect_equilibration, + _pymbar_exp, +) logger = logging.getLogger(__name__) @@ -57,6 +64,7 @@ # TESTING UTILITIES # ============================================================================= + def create_context(system, integrator, platform=None): """Create a Context. @@ -93,7 +101,13 @@ def compute_energy(system, positions, platform=None, force_group=-1): return potential -def minimize(system, positions, platform=None, tolerance=1.0*unit.kilocalories_per_mole/unit.angstroms, maxIterations=500): +def minimize( + system, + positions, + platform=None, + tolerance=1.0 * unit.kilocalories_per_mole / unit.angstroms, + maxIterations=500, +): """Minimize the energy of the given system. Parameters @@ -145,13 +159,20 @@ def compute_force_energy(system, positions, force_name): def assert_almost_equal(energy1, energy2, err_msg): delta = energy1 - energy2 - err_msg += ' interactions do not match! Reference {}, alchemical {},' \ - ' difference {}'.format(energy1, energy2, delta) + err_msg += ( + " interactions do not match! Reference {}, alchemical {}," + " difference {}".format(energy1, energy2, delta) + ) assert abs(delta) < MAX_DELTA, err_msg -def turn_off_nonbonded(system, sterics=False, electrostatics=False, - exceptions=False, only_atoms=frozenset()): +def turn_off_nonbonded( + system, + sterics=False, + electrostatics=False, + exceptions=False, + only_atoms=frozenset(), +): """Turn off sterics and/or electrostatics interactions. This affects only NonbondedForce and non-alchemical CustomNonbondedForces. @@ -168,31 +189,56 @@ def turn_off_nonbonded(system, sterics=False, electrostatics=False, charge_coeff = 0.0 if electrostatics else 1.0 if exceptions: # Turn off exceptions - force_idx, nonbonded_force = forces.find_forces(system, openmm.NonbondedForce, only_one=True) + force_idx, nonbonded_force = forces.find_forces( + system, openmm.NonbondedForce, only_one=True + ) # Exceptions. for exception_index in range(nonbonded_force.getNumExceptions()): - iatom, jatom, charge, sigma, epsilon = nonbonded_force.getExceptionParameters(exception_index) + iatom, jatom, charge, sigma, epsilon = ( + nonbonded_force.getExceptionParameters(exception_index) + ) if iatom in only_atoms or jatom in only_atoms: - nonbonded_force.setExceptionParameters(exception_index, iatom, jatom, - charge_coeff*charge, sigma, epsilon_coeff*epsilon) + nonbonded_force.setExceptionParameters( + exception_index, + iatom, + jatom, + charge_coeff * charge, + sigma, + epsilon_coeff * epsilon, + ) # Offset exceptions. for offset_index in range(nonbonded_force.getNumExceptionParameterOffsets()): - (parameter, exception_index, chargeprod_scale, - sigma_scale, epsilon_scale) = nonbonded_force.getExceptionParameterOffset(offset_index) - iatom, jatom, _, _, _ = nonbonded_force.getExceptionParameters(exception_index) + ( + parameter, + exception_index, + chargeprod_scale, + sigma_scale, + epsilon_scale, + ) = nonbonded_force.getExceptionParameterOffset(offset_index) + iatom, jatom, _, _, _ = nonbonded_force.getExceptionParameters( + exception_index + ) if iatom in only_atoms or jatom in only_atoms: - nonbonded_force.setExceptionParameterOffset(offset_index, parameter, exception_index, - charge_coeff*chargeprod_scale, sigma_scale, - epsilon_coeff*epsilon_scale) + nonbonded_force.setExceptionParameterOffset( + offset_index, + parameter, + exception_index, + charge_coeff * chargeprod_scale, + sigma_scale, + epsilon_coeff * epsilon_scale, + ) else: # Turn off particle interactions for force in system.getForces(): # Handle only a Nonbonded and a CustomNonbonded (for RF). - if not (isinstance(force, openmm.CustomNonbondedForce) and 'lambda' not in force.getEnergyFunction() or - isinstance(force, openmm.NonbondedForce)): + if not ( + isinstance(force, openmm.CustomNonbondedForce) + and "lambda" not in force.getEnergyFunction() + or isinstance(force, openmm.NonbondedForce) + ): continue # Particle interactions. @@ -210,15 +256,27 @@ def turn_off_nonbonded(system, sterics=False, electrostatics=False, # Offset particle interactions. if isinstance(force, openmm.NonbondedForce): for offset_index in range(force.getNumParticleParameterOffsets()): - (parameter, particle_index, charge_scale, - sigma_scale, epsilon_scale) = force.getParticleParameterOffset(offset_index) + ( + parameter, + particle_index, + charge_scale, + sigma_scale, + epsilon_scale, + ) = force.getParticleParameterOffset(offset_index) if particle_index in only_atoms: - force.setParticleParameterOffset(offset_index, parameter, particle_index, - charge_coeff*charge_scale, sigma_scale, - epsilon_coeff*epsilon_scale) - - -def dissect_nonbonded_energy(reference_system, positions, alchemical_atoms, other_alchemical_atoms): + force.setParticleParameterOffset( + offset_index, + parameter, + particle_index, + charge_coeff * charge_scale, + sigma_scale, + epsilon_coeff * epsilon_scale, + ) + + +def dissect_nonbonded_energy( + reference_system, positions, alchemical_atoms, other_alchemical_atoms +): """Dissect the nonbonded energy contributions of the reference system by atom group and sterics/electrostatics. @@ -258,7 +316,9 @@ def dissect_nonbonded_energy(reference_system, positions, alchemical_atoms, othe """ all_alchemical_atoms = set(alchemical_atoms).union(other_alchemical_atoms) - nonalchemical_atoms = set(range(reference_system.getNumParticles())).difference(all_alchemical_atoms) + nonalchemical_atoms = set(range(reference_system.getNumParticles())).difference( + all_alchemical_atoms + ) # Remove all forces but NonbondedForce and eventually the # CustomNonbondedForce used to model reaction field. @@ -267,10 +327,14 @@ def dissect_nonbonded_energy(reference_system, positions, alchemical_atoms, othe for force_index, force in enumerate(reference_system.getForces()): force.setForceGroup(0) if isinstance(force, openmm.NonbondedForce): - force.setReciprocalSpaceForceGroup(30) # separate PME reciprocal from direct space + force.setReciprocalSpaceForceGroup( + 30 + ) # separate PME reciprocal from direct space # We keep only CustomNonbondedForces that are not alchemically modified. - elif not (isinstance(force, openmm.CustomNonbondedForce) and - 'lambda' not in force.getEnergyFunction()): + elif not ( + isinstance(force, openmm.CustomNonbondedForce) + and "lambda" not in force.getEnergyFunction() + ): forces_to_remove.append(force_index) for force_index in reversed(forces_to_remove): @@ -281,8 +345,19 @@ def dissect_nonbonded_energy(reference_system, positions, alchemical_atoms, othe # ---------------------------------------------------------------- # Turn off other alchemical regions if len(other_alchemical_atoms) > 0: - turn_off_nonbonded(reference_system, sterics=True, electrostatics=True, only_atoms=other_alchemical_atoms) - turn_off_nonbonded(reference_system, sterics=True, electrostatics=True, exceptions=True, only_atoms=other_alchemical_atoms) + turn_off_nonbonded( + reference_system, + sterics=True, + electrostatics=True, + only_atoms=other_alchemical_atoms, + ) + turn_off_nonbonded( + reference_system, + sterics=True, + electrostatics=True, + exceptions=True, + only_atoms=other_alchemical_atoms, + ) system = copy.deepcopy(reference_system) @@ -300,9 +375,15 @@ def dissect_nonbonded_energy(reference_system, positions, alchemical_atoms, othe tot_energy_no_particle_sterics = compute_energy(system, positions) tot_particle_sterics = tot_energy - tot_energy_no_particle_sterics - nn_particle_sterics = tot_energy_no_alchem_particle_sterics - tot_energy_no_particle_sterics - aa_particle_sterics = tot_energy_no_nonalchem_particle_sterics - tot_energy_no_particle_sterics - na_particle_sterics = tot_particle_sterics - nn_particle_sterics - aa_particle_sterics + nn_particle_sterics = ( + tot_energy_no_alchem_particle_sterics - tot_energy_no_particle_sterics + ) + aa_particle_sterics = ( + tot_energy_no_nonalchem_particle_sterics - tot_energy_no_particle_sterics + ) + na_particle_sterics = ( + tot_particle_sterics - nn_particle_sterics - aa_particle_sterics + ) # Compute contributions from particle electrostatics system = copy.deepcopy(reference_system) # Restore sterics @@ -316,12 +397,20 @@ def dissect_nonbonded_energy(reference_system, positions, alchemical_atoms, othe turn_off_nonbonded(system, electrostatics=True) tot_energy_no_particle_electro = compute_energy(system, positions) - na_reciprocal_energy = tot_reciprocal_energy - nn_reciprocal_energy - aa_reciprocal_energy + na_reciprocal_energy = ( + tot_reciprocal_energy - nn_reciprocal_energy - aa_reciprocal_energy + ) tot_particle_electro = tot_energy - tot_energy_no_particle_electro - nn_particle_electro = tot_energy_no_alchem_particle_electro - tot_energy_no_particle_electro - aa_particle_electro = tot_energy_no_nonalchem_particle_electro - tot_energy_no_particle_electro - na_particle_electro = tot_particle_electro - nn_particle_electro - aa_particle_electro + nn_particle_electro = ( + tot_energy_no_alchem_particle_electro - tot_energy_no_particle_electro + ) + aa_particle_electro = ( + tot_energy_no_nonalchem_particle_electro - tot_energy_no_particle_electro + ) + na_particle_electro = ( + tot_particle_electro - nn_particle_electro - aa_particle_electro + ) nn_particle_electro -= nn_reciprocal_energy aa_particle_electro -= aa_reciprocal_energy na_particle_electro -= na_reciprocal_energy @@ -331,49 +420,102 @@ def dissect_nonbonded_energy(reference_system, positions, alchemical_atoms, othe # Compute contributions from exceptions sterics system = copy.deepcopy(reference_system) # Restore particle interactions - turn_off_nonbonded(system, sterics=True, exceptions=True, only_atoms=alchemical_atoms) + turn_off_nonbonded( + system, sterics=True, exceptions=True, only_atoms=alchemical_atoms + ) tot_energy_no_alchem_exception_sterics = compute_energy(system, positions) system = copy.deepcopy(reference_system) # Restore alchemical sterics - turn_off_nonbonded(system, sterics=True, exceptions=True, only_atoms=nonalchemical_atoms) + turn_off_nonbonded( + system, sterics=True, exceptions=True, only_atoms=nonalchemical_atoms + ) tot_energy_no_nonalchem_exception_sterics = compute_energy(system, positions) turn_off_nonbonded(system, sterics=True, exceptions=True) tot_energy_no_exception_sterics = compute_energy(system, positions) tot_exception_sterics = tot_energy - tot_energy_no_exception_sterics - nn_exception_sterics = tot_energy_no_alchem_exception_sterics - tot_energy_no_exception_sterics - aa_exception_sterics = tot_energy_no_nonalchem_exception_sterics - tot_energy_no_exception_sterics - na_exception_sterics = tot_exception_sterics - nn_exception_sterics - aa_exception_sterics + nn_exception_sterics = ( + tot_energy_no_alchem_exception_sterics - tot_energy_no_exception_sterics + ) + aa_exception_sterics = ( + tot_energy_no_nonalchem_exception_sterics - tot_energy_no_exception_sterics + ) + na_exception_sterics = ( + tot_exception_sterics - nn_exception_sterics - aa_exception_sterics + ) # Compute contributions from exceptions electrostatics system = copy.deepcopy(reference_system) # Restore exceptions sterics - turn_off_nonbonded(system, electrostatics=True, exceptions=True, only_atoms=alchemical_atoms) + turn_off_nonbonded( + system, electrostatics=True, exceptions=True, only_atoms=alchemical_atoms + ) tot_energy_no_alchem_exception_electro = compute_energy(system, positions) system = copy.deepcopy(reference_system) # Restore alchemical electrostatics - turn_off_nonbonded(system, electrostatics=True, exceptions=True, only_atoms=nonalchemical_atoms) + turn_off_nonbonded( + system, electrostatics=True, exceptions=True, only_atoms=nonalchemical_atoms + ) tot_energy_no_nonalchem_exception_electro = compute_energy(system, positions) turn_off_nonbonded(system, electrostatics=True, exceptions=True) tot_energy_no_exception_electro = compute_energy(system, positions) tot_exception_electro = tot_energy - tot_energy_no_exception_electro - nn_exception_electro = tot_energy_no_alchem_exception_electro - tot_energy_no_exception_electro - aa_exception_electro = tot_energy_no_nonalchem_exception_electro - tot_energy_no_exception_electro - na_exception_electro = tot_exception_electro - nn_exception_electro - aa_exception_electro - - assert tot_particle_sterics == nn_particle_sterics + aa_particle_sterics + na_particle_sterics - assert_almost_equal(tot_particle_electro, nn_particle_electro + aa_particle_electro + - na_particle_electro + nn_reciprocal_energy + aa_reciprocal_energy + na_reciprocal_energy, - 'Inconsistency during dissection of nonbonded contributions:') - assert tot_exception_sterics == nn_exception_sterics + aa_exception_sterics + na_exception_sterics - assert tot_exception_electro == nn_exception_electro + aa_exception_electro + na_exception_electro - assert_almost_equal(tot_energy, tot_particle_sterics + tot_particle_electro + - tot_exception_sterics + tot_exception_electro, - 'Inconsistency during dissection of nonbonded contributions:') - - return nn_particle_sterics, aa_particle_sterics, na_particle_sterics,\ - nn_particle_electro, aa_particle_electro, na_particle_electro,\ - nn_exception_sterics, aa_exception_sterics, na_exception_sterics,\ - nn_exception_electro, aa_exception_electro, na_exception_electro,\ - nn_reciprocal_energy, aa_reciprocal_energy, na_reciprocal_energy + nn_exception_electro = ( + tot_energy_no_alchem_exception_electro - tot_energy_no_exception_electro + ) + aa_exception_electro = ( + tot_energy_no_nonalchem_exception_electro - tot_energy_no_exception_electro + ) + na_exception_electro = ( + tot_exception_electro - nn_exception_electro - aa_exception_electro + ) + + assert ( + tot_particle_sterics + == nn_particle_sterics + aa_particle_sterics + na_particle_sterics + ) + assert_almost_equal( + tot_particle_electro, + nn_particle_electro + + aa_particle_electro + + na_particle_electro + + nn_reciprocal_energy + + aa_reciprocal_energy + + na_reciprocal_energy, + "Inconsistency during dissection of nonbonded contributions:", + ) + assert ( + tot_exception_sterics + == nn_exception_sterics + aa_exception_sterics + na_exception_sterics + ) + assert ( + tot_exception_electro + == nn_exception_electro + aa_exception_electro + na_exception_electro + ) + assert_almost_equal( + tot_energy, + tot_particle_sterics + + tot_particle_electro + + tot_exception_sterics + + tot_exception_electro, + "Inconsistency during dissection of nonbonded contributions:", + ) + + return ( + nn_particle_sterics, + aa_particle_sterics, + na_particle_sterics, + nn_particle_electro, + aa_particle_electro, + na_particle_electro, + nn_exception_sterics, + aa_exception_sterics, + na_exception_sterics, + nn_exception_electro, + aa_exception_electro, + na_exception_electro, + nn_reciprocal_energy, + aa_reciprocal_energy, + na_reciprocal_energy, + ) def compute_direct_space_correction(nonbonded_force, alchemical_atoms, positions): @@ -407,7 +549,10 @@ def compute_direct_space_correction(nonbonded_force, alchemical_atoms, positions positions = positions.value_in_unit_system(unit.md_unit_system) # If there is no reciprocal space, the correction is 0.0 - if nonbonded_force.getNonbondedMethod() not in [openmm.NonbondedForce.Ewald, openmm.NonbondedForce.PME]: + if nonbonded_force.getNonbondedMethod() not in [ + openmm.NonbondedForce.Ewald, + openmm.NonbondedForce.PME, + ]: return aa_correction * energy_unit, na_correction * energy_unit # Get alpha ewald parameter @@ -415,7 +560,7 @@ def compute_direct_space_correction(nonbonded_force, alchemical_atoms, positions if alpha_ewald / alpha_ewald.unit == 0.0: cutoff_distance = nonbonded_force.getCutoffDistance() tolerance = nonbonded_force.getEwaldErrorTolerance() - alpha_ewald = (1.0 / cutoff_distance) * np.sqrt(-np.log(2.0*tolerance)) + alpha_ewald = (1.0 / cutoff_distance) * np.sqrt(-np.log(2.0 * tolerance)) alpha_ewald = alpha_ewald.value_in_unit_system(unit.md_unit_system) assert alpha_ewald != 0.0 @@ -428,12 +573,18 @@ def compute_direct_space_correction(nonbonded_force, alchemical_atoms, positions jcharge = jcharge.value_in_unit_system(unit.md_unit_system) # Compute the correction and take care of numerical instabilities - r = np.linalg.norm(positions[iatom] - positions[jatom]) # distance between atoms + r = np.linalg.norm( + positions[iatom] - positions[jatom] + ) # distance between atoms alpha_r = alpha_ewald * r if alpha_r > 1e-6: - correction = ONE_4PI_EPS0 * icharge * jcharge * scipy.special.erf(alpha_r) / r + correction = ( + ONE_4PI_EPS0 * icharge * jcharge * scipy.special.erf(alpha_r) / r + ) else: # for small alpha_r we linearize erf() - correction = ONE_4PI_EPS0 * alpha_ewald * icharge * jcharge * 2.0 / np.sqrt(np.pi) + correction = ( + ONE_4PI_EPS0 * alpha_ewald * icharge * jcharge * 2.0 / np.sqrt(np.pi) + ) # Assign correction to correct group if iatom in alchemical_atoms and jatom in alchemical_atoms: @@ -448,12 +599,13 @@ def is_alchemical_pme_treatment_exact(alchemical_system): """Return True if the given alchemical system models PME exactly.""" # If exact PME is here, the NonbondedForce defines a # lambda_electrostatics variable. - _, nonbonded_force = forces.find_forces(alchemical_system, openmm.NonbondedForce, - only_one=True) + _, nonbonded_force = forces.find_forces( + alchemical_system, openmm.NonbondedForce, only_one=True + ) for parameter_idx in range(nonbonded_force.getNumGlobalParameters()): parameter_name = nonbonded_force.getGlobalParameterName(parameter_idx) # With multiple alchemical regions, lambda_electrostatics might have a suffix. - if parameter_name.startswith('lambda_electrostatics'): + if parameter_name.startswith("lambda_electrostatics"): return True return False @@ -462,7 +614,10 @@ def is_alchemical_pme_treatment_exact(alchemical_system): # SUBROUTINES FOR TESTING # ============================================================================= -def compare_system_energies(reference_system, alchemical_system, alchemical_regions, positions): + +def compare_system_energies( + reference_system, alchemical_system, alchemical_regions, positions +): """Check that the energies of reference and alchemical systems are close. This takes care of ignoring the reciprocal space when the nonbonded @@ -478,10 +633,14 @@ def compare_system_energies(reference_system, alchemical_system, alchemical_regi # Check nonbonded method. Comparing with PME is more complicated # because the alchemical system with direct-space treatment of PME # does not take into account the reciprocal space. - force_idx, nonbonded_force = forces.find_forces(reference_system, openmm.NonbondedForce, only_one=True) + force_idx, nonbonded_force = forces.find_forces( + reference_system, openmm.NonbondedForce, only_one=True + ) nonbonded_method = nonbonded_force.getNonbondedMethod() - is_direct_space_pme = (nonbonded_method in [openmm.NonbondedForce.PME, openmm.NonbondedForce.Ewald] and - not is_alchemical_pme_treatment_exact(alchemical_system)) + is_direct_space_pme = nonbonded_method in [ + openmm.NonbondedForce.PME, + openmm.NonbondedForce.Ewald, + ] and not is_alchemical_pme_treatment_exact(alchemical_system) if is_direct_space_pme: # Separate the reciprocal space force in a different group. @@ -502,13 +661,17 @@ def compare_system_energies(reference_system, alchemical_system, alchemical_regi na_correction = 0.0 * unit.kilojoule_per_mole for region in alchemical_regions: alchemical_atoms = region.alchemical_atoms - aa, na = compute_direct_space_correction(nonbonded_force, alchemical_atoms, positions) + aa, na = compute_direct_space_correction( + nonbonded_force, alchemical_atoms, positions + ) aa_correction += aa na_correction += na # Compute potential of the direct space. - potentials = [compute_energy(system, positions, force_group=force_group) - for system in [reference_system, alchemical_system]] + potentials = [ + compute_energy(system, positions, force_group=force_group) + for system in [reference_system, alchemical_system] + ] # Add the direct space correction. if is_direct_space_pme: @@ -520,14 +683,23 @@ def compare_system_energies(reference_system, alchemical_system, alchemical_regi delta = potentials[1] - potentials[2] - potentials[0] if abs(delta) > MAX_DELTA: print("========") - for description, potential in zip(['reference', 'alchemical', 'PME correction'], potentials): - print("{}: {} ".format(description, potential)) - print("delta : {}".format(delta)) + for description, potential in zip( + ["reference", "alchemical", "PME correction"], potentials + ): + print(f"{description}: {potential} ") + print(f"delta : {delta}") err_msg = "Maximum allowable deviation exceeded (was {:.8f} kcal/mol; allowed {:.8f} kcal/mol)." - raise Exception(err_msg.format(delta / unit.kilocalories_per_mole, MAX_DELTA / unit.kilocalories_per_mole)) + raise Exception( + err_msg.format( + delta / unit.kilocalories_per_mole, + MAX_DELTA / unit.kilocalories_per_mole, + ) + ) -def check_multi_interacting_energy_components(reference_system, alchemical_system, alchemical_regions, positions): +def check_multi_interacting_energy_components( + reference_system, alchemical_system, alchemical_regions, positions +): """wrapper around check_interacting_energy_components for multiple regions Parameters @@ -553,12 +725,23 @@ def check_multi_interacting_energy_components(reference_system, alchemical_syste all_alchemical_atoms.add(atom) for region in alchemical_regions: check_interacting_energy_components( - reference_system, alchemical_system, region, positions, - all_alchemical_atoms, multi_regions=True) - - -def check_interacting_energy_components(reference_system, alchemical_system, alchemical_regions, positions, - all_alchemical_atoms=None, multi_regions=False): + reference_system, + alchemical_system, + region, + positions, + all_alchemical_atoms, + multi_regions=True, + ) + + +def check_interacting_energy_components( + reference_system, + alchemical_system, + alchemical_regions, + positions, + all_alchemical_atoms=None, + multi_regions=False, +): """Compare full and alchemically-modified system energies by energy component. Parameters @@ -581,64 +764,133 @@ def check_interacting_energy_components(reference_system, alchemical_system, alc is_exact_pme = is_alchemical_pme_treatment_exact(alchemical_system) # Find nonbonded method - _, nonbonded_force = forces.find_forces(reference_system, openmm.NonbondedForce, only_one=True) + _, nonbonded_force = forces.find_forces( + reference_system, openmm.NonbondedForce, only_one=True + ) nonbonded_method = nonbonded_force.getNonbondedMethod() # Get energy components of reference system's nonbonded force if multi_regions: - other_alchemical_atoms = all_alchemical_atoms.difference(alchemical_regions.alchemical_atoms) - print("Dissecting reference system's nonbonded force for region {}".format(alchemical_regions.name)) + other_alchemical_atoms = all_alchemical_atoms.difference( + alchemical_regions.alchemical_atoms + ) + print( + f"Dissecting reference system's nonbonded force for region {alchemical_regions.name}" + ) else: other_alchemical_atoms = set() print("Dissecting reference system's nonbonded force") - energy_components = dissect_nonbonded_energy(reference_system, positions, - alchemical_regions.alchemical_atoms, other_alchemical_atoms) - nn_particle_sterics, aa_particle_sterics, na_particle_sterics,\ - nn_particle_electro, aa_particle_electro, na_particle_electro,\ - nn_exception_sterics, aa_exception_sterics, na_exception_sterics,\ - nn_exception_electro, aa_exception_electro, na_exception_electro,\ - nn_reciprocal_energy, aa_reciprocal_energy, na_reciprocal_energy = energy_components + energy_components = dissect_nonbonded_energy( + reference_system, + positions, + alchemical_regions.alchemical_atoms, + other_alchemical_atoms, + ) + ( + nn_particle_sterics, + aa_particle_sterics, + na_particle_sterics, + nn_particle_electro, + aa_particle_electro, + na_particle_electro, + nn_exception_sterics, + aa_exception_sterics, + na_exception_sterics, + nn_exception_electro, + aa_exception_electro, + na_exception_electro, + nn_reciprocal_energy, + aa_reciprocal_energy, + na_reciprocal_energy, + ) = energy_components # Dissect unmodified nonbonded force in alchemical system if multi_regions: - print("Dissecting alchemical system's unmodified nonbonded force for region {}".format(alchemical_regions.name)) + print( + f"Dissecting alchemical system's unmodified nonbonded force for region {alchemical_regions.name}" + ) else: print("Dissecting alchemical system's unmodified nonbonded force") - energy_components = dissect_nonbonded_energy(alchemical_system, positions, - alchemical_regions.alchemical_atoms, other_alchemical_atoms) - unmod_nn_particle_sterics, unmod_aa_particle_sterics, unmod_na_particle_sterics,\ - unmod_nn_particle_electro, unmod_aa_particle_electro, unmod_na_particle_electro,\ - unmod_nn_exception_sterics, unmod_aa_exception_sterics, unmod_na_exception_sterics,\ - unmod_nn_exception_electro, unmod_aa_exception_electro, unmod_na_exception_electro,\ - unmod_nn_reciprocal_energy, unmod_aa_reciprocal_energy, unmod_na_reciprocal_energy = energy_components + energy_components = dissect_nonbonded_energy( + alchemical_system, + positions, + alchemical_regions.alchemical_atoms, + other_alchemical_atoms, + ) + ( + unmod_nn_particle_sterics, + unmod_aa_particle_sterics, + unmod_na_particle_sterics, + unmod_nn_particle_electro, + unmod_aa_particle_electro, + unmod_na_particle_electro, + unmod_nn_exception_sterics, + unmod_aa_exception_sterics, + unmod_na_exception_sterics, + unmod_nn_exception_electro, + unmod_aa_exception_electro, + unmod_na_exception_electro, + unmod_nn_reciprocal_energy, + unmod_aa_reciprocal_energy, + unmod_na_reciprocal_energy, + ) = energy_components # Get alchemically-modified energy components if multi_regions: - print("Computing alchemical system components energies for region {}".format(alchemical_regions.name)) + print( + f"Computing alchemical system components energies for region {alchemical_regions.name}" + ) else: print("Computing alchemical system components energies") - alchemical_state = AlchemicalState.from_system(alchemical_system, parameters_name_suffix=alchemical_regions.name) + alchemical_state = AlchemicalState.from_system( + alchemical_system, parameters_name_suffix=alchemical_regions.name + ) alchemical_state.set_alchemical_parameters(1.0) - energy_components = AbsoluteAlchemicalFactory.get_energy_components(alchemical_system, alchemical_state, - positions, platform=GLOBAL_ALCHEMY_PLATFORM) + energy_components = AbsoluteAlchemicalFactory.get_energy_components( + alchemical_system, alchemical_state, positions, platform=GLOBAL_ALCHEMY_PLATFORM + ) if multi_regions: - region_label = ' for region {}'.format(alchemical_regions.name) + region_label = f" for region {alchemical_regions.name}" else: - region_label = '' + region_label = "" # Sterics particle and exception interactions are always modeled with a custom force. - na_custom_particle_sterics = energy_components['alchemically modified NonbondedForce for non-alchemical/alchemical sterics' + region_label] - aa_custom_particle_sterics = energy_components['alchemically modified NonbondedForce for alchemical/alchemical sterics' + region_label] - na_custom_exception_sterics = energy_components['alchemically modified BondForce for non-alchemical/alchemical sterics exceptions' + region_label] - aa_custom_exception_sterics = energy_components['alchemically modified BondForce for alchemical/alchemical sterics exceptions' + region_label] + na_custom_particle_sterics = energy_components[ + "alchemically modified NonbondedForce for non-alchemical/alchemical sterics" + + region_label + ] + aa_custom_particle_sterics = energy_components[ + "alchemically modified NonbondedForce for alchemical/alchemical sterics" + + region_label + ] + na_custom_exception_sterics = energy_components[ + "alchemically modified BondForce for non-alchemical/alchemical sterics exceptions" + + region_label + ] + aa_custom_exception_sterics = energy_components[ + "alchemically modified BondForce for alchemical/alchemical sterics exceptions" + + region_label + ] # With exact treatment of PME, we use the NonbondedForce offset for electrostatics. try: - na_custom_particle_electro = energy_components['alchemically modified NonbondedForce for non-alchemical/alchemical electrostatics' + region_label] - aa_custom_particle_electro = energy_components['alchemically modified NonbondedForce for alchemical/alchemical electrostatics' + region_label] - na_custom_exception_electro = energy_components['alchemically modified BondForce for non-alchemical/alchemical electrostatics exceptions' + region_label] - aa_custom_exception_electro = energy_components['alchemically modified BondForce for alchemical/alchemical electrostatics exceptions' + region_label] + na_custom_particle_electro = energy_components[ + "alchemically modified NonbondedForce for non-alchemical/alchemical electrostatics" + + region_label + ] + aa_custom_particle_electro = energy_components[ + "alchemically modified NonbondedForce for alchemical/alchemical electrostatics" + + region_label + ] + na_custom_exception_electro = energy_components[ + "alchemically modified BondForce for non-alchemical/alchemical electrostatics exceptions" + + region_label + ] + aa_custom_exception_electro = energy_components[ + "alchemically modified BondForce for alchemical/alchemical electrostatics exceptions" + + region_label + ] except KeyError: assert is_exact_pme @@ -646,7 +898,7 @@ def check_interacting_energy_components(reference_system, alchemical_system, alc # ------------------------------------------------- # All contributions from alchemical atoms in unmodified nonbonded force are turned off - err_msg = 'Non-zero contribution from unmodified NonbondedForce alchemical atoms: ' + err_msg = "Non-zero contribution from unmodified NonbondedForce alchemical atoms: " assert_almost_equal(unmod_aa_particle_sterics, 0.0 * energy_unit, err_msg) assert_almost_equal(unmod_na_particle_sterics, 0.0 * energy_unit, err_msg) assert_almost_equal(unmod_aa_exception_sterics, 0.0 * energy_unit, err_msg) @@ -661,69 +913,126 @@ def check_interacting_energy_components(reference_system, alchemical_system, alc assert_almost_equal(unmod_na_exception_electro, 0.0 * energy_unit, err_msg) # Check sterics interactions match - assert_almost_equal(nn_particle_sterics, unmod_nn_particle_sterics, - 'Non-alchemical/non-alchemical atoms particle sterics' + region_label) - assert_almost_equal(nn_exception_sterics, unmod_nn_exception_sterics, - 'Non-alchemical/non-alchemical atoms exceptions sterics' + region_label) - assert_almost_equal(aa_particle_sterics, aa_custom_particle_sterics, - 'Alchemical/alchemical atoms particle sterics' + region_label) - assert_almost_equal(aa_exception_sterics, aa_custom_exception_sterics, - 'Alchemical/alchemical atoms exceptions sterics' + region_label) - assert_almost_equal(na_particle_sterics, na_custom_particle_sterics, - 'Non-alchemical/alchemical atoms particle sterics' + region_label) - assert_almost_equal(na_exception_sterics, na_custom_exception_sterics, - 'Non-alchemical/alchemical atoms exceptions sterics' + region_label) + assert_almost_equal( + nn_particle_sterics, + unmod_nn_particle_sterics, + "Non-alchemical/non-alchemical atoms particle sterics" + region_label, + ) + assert_almost_equal( + nn_exception_sterics, + unmod_nn_exception_sterics, + "Non-alchemical/non-alchemical atoms exceptions sterics" + region_label, + ) + assert_almost_equal( + aa_particle_sterics, + aa_custom_particle_sterics, + "Alchemical/alchemical atoms particle sterics" + region_label, + ) + assert_almost_equal( + aa_exception_sterics, + aa_custom_exception_sterics, + "Alchemical/alchemical atoms exceptions sterics" + region_label, + ) + assert_almost_equal( + na_particle_sterics, + na_custom_particle_sterics, + "Non-alchemical/alchemical atoms particle sterics" + region_label, + ) + assert_almost_equal( + na_exception_sterics, + na_custom_exception_sterics, + "Non-alchemical/alchemical atoms exceptions sterics" + region_label, + ) # Check electrostatics interactions - assert_almost_equal(nn_particle_electro, unmod_nn_particle_electro, - 'Non-alchemical/non-alchemical atoms particle electrostatics' + region_label) - assert_almost_equal(nn_exception_electro, unmod_nn_exception_electro, - 'Non-alchemical/non-alchemical atoms exceptions electrostatics' + region_label) + assert_almost_equal( + nn_particle_electro, + unmod_nn_particle_electro, + "Non-alchemical/non-alchemical atoms particle electrostatics" + region_label, + ) + assert_almost_equal( + nn_exception_electro, + unmod_nn_exception_electro, + "Non-alchemical/non-alchemical atoms exceptions electrostatics" + region_label, + ) # With exact treatment of PME, the electrostatics of alchemical-alchemical # atoms is modeled with NonbondedForce offsets. if is_exact_pme: # Reciprocal space. - assert_almost_equal(aa_reciprocal_energy, unmod_aa_reciprocal_energy, - 'Alchemical/alchemical atoms reciprocal space energy' + region_label) - assert_almost_equal(na_reciprocal_energy, unmod_na_reciprocal_energy, - 'Non-alchemical/alchemical atoms reciprocal space energy' + region_label) + assert_almost_equal( + aa_reciprocal_energy, + unmod_aa_reciprocal_energy, + "Alchemical/alchemical atoms reciprocal space energy" + region_label, + ) + assert_almost_equal( + na_reciprocal_energy, + unmod_na_reciprocal_energy, + "Non-alchemical/alchemical atoms reciprocal space energy" + region_label, + ) # Direct space. - assert_almost_equal(aa_particle_electro, unmod_aa_particle_electro, - 'Alchemical/alchemical atoms particle electrostatics' + region_label) - assert_almost_equal(na_particle_electro, unmod_na_particle_electro, - 'Non-alchemical/alchemical atoms particle electrostatics' + region_label) + assert_almost_equal( + aa_particle_electro, + unmod_aa_particle_electro, + "Alchemical/alchemical atoms particle electrostatics" + region_label, + ) + assert_almost_equal( + na_particle_electro, + unmod_na_particle_electro, + "Non-alchemical/alchemical atoms particle electrostatics" + region_label, + ) # Exceptions. - assert_almost_equal(aa_exception_electro, unmod_aa_exception_electro, - 'Alchemical/alchemical atoms exceptions electrostatics' + region_label) - assert_almost_equal(na_exception_electro, unmod_na_exception_electro, - 'Non-alchemical/alchemical atoms exceptions electrostatics' + region_label) + assert_almost_equal( + aa_exception_electro, + unmod_aa_exception_electro, + "Alchemical/alchemical atoms exceptions electrostatics" + region_label, + ) + assert_almost_equal( + na_exception_electro, + unmod_na_exception_electro, + "Non-alchemical/alchemical atoms exceptions electrostatics" + region_label, + ) # With direct space PME, the custom forces model only the # direct space of alchemical-alchemical interactions. else: # Get direct space correction due to reciprocal space exceptions - aa_correction, na_correction = compute_direct_space_correction(nonbonded_force, - alchemical_regions.alchemical_atoms, - positions) + aa_correction, na_correction = compute_direct_space_correction( + nonbonded_force, alchemical_regions.alchemical_atoms, positions + ) aa_particle_electro += aa_correction na_particle_electro += na_correction # Check direct space energy - assert_almost_equal(aa_particle_electro, aa_custom_particle_electro, - 'Alchemical/alchemical atoms particle electrostatics' + region_label) - assert_almost_equal(na_particle_electro, na_custom_particle_electro, - 'Non-alchemical/alchemical atoms particle electrostatics' + region_label) + assert_almost_equal( + aa_particle_electro, + aa_custom_particle_electro, + "Alchemical/alchemical atoms particle electrostatics" + region_label, + ) + assert_almost_equal( + na_particle_electro, + na_custom_particle_electro, + "Non-alchemical/alchemical atoms particle electrostatics" + region_label, + ) # Check exceptions. - assert_almost_equal(aa_exception_electro, aa_custom_exception_electro, - 'Alchemical/alchemical atoms exceptions electrostatics' + region_label) - assert_almost_equal(na_exception_electro, na_custom_exception_electro, - 'Non-alchemical/alchemical atoms exceptions electrostatics' + region_label) + assert_almost_equal( + aa_exception_electro, + aa_custom_exception_electro, + "Alchemical/alchemical atoms exceptions electrostatics" + region_label, + ) + assert_almost_equal( + na_exception_electro, + na_custom_exception_electro, + "Non-alchemical/alchemical atoms exceptions electrostatics" + region_label, + ) # With Ewald methods, the NonbondedForce should always hold the # reciprocal space energy of nonalchemical-nonalchemical atoms. if nonbonded_method in [openmm.NonbondedForce.PME, openmm.NonbondedForce.Ewald]: # Reciprocal space. - assert_almost_equal(nn_reciprocal_energy, unmod_nn_reciprocal_energy, - 'Non-alchemical/non-alchemical atoms reciprocal space energy') + assert_almost_equal( + nn_reciprocal_energy, + unmod_nn_reciprocal_energy, + "Non-alchemical/non-alchemical atoms reciprocal space energy", + ) else: # Reciprocal space energy should be null in this case assert nn_reciprocal_energy == unmod_nn_reciprocal_energy == 0.0 * energy_unit @@ -732,10 +1041,19 @@ def check_interacting_energy_components(reference_system, alchemical_system, alc # Check forces other than nonbonded # ---------------------------------- - for force_name in ['HarmonicBondForce', 'HarmonicAngleForce', 'PeriodicTorsionForce', - 'GBSAOBCForce', 'CustomGBForce']: - alchemical_forces_energies = [energy for label, energy in energy_components.items() if force_name in label] - reference_force_energy = compute_force_energy(reference_system, positions, force_name) + for force_name in [ + "HarmonicBondForce", + "HarmonicAngleForce", + "PeriodicTorsionForce", + "GBSAOBCForce", + "CustomGBForce", + ]: + alchemical_forces_energies = [ + energy for label, energy in energy_components.items() if force_name in label + ] + reference_force_energy = compute_force_energy( + reference_system, positions, force_name + ) # There should be no force in the alchemical system if force_name is missing from the reference if reference_force_energy is None: @@ -746,11 +1064,16 @@ def check_interacting_energy_components(reference_system, alchemical_system, alc tot_alchemical_forces_energies = 0.0 * energy_unit for energy in alchemical_forces_energies: tot_alchemical_forces_energies += energy - assert_almost_equal(reference_force_energy, tot_alchemical_forces_energies, - '{} energy '.format(force_name)) + assert_almost_equal( + reference_force_energy, + tot_alchemical_forces_energies, + f"{force_name} energy ", + ) -def check_multi_noninteracting_energy_components(reference_system, alchemical_system, alchemical_regions, positions): +def check_multi_noninteracting_energy_components( + reference_system, alchemical_system, alchemical_regions, positions +): """wrapper around check_noninteracting_energy_components for multiple regions Parameters ---------- @@ -764,10 +1087,18 @@ def check_multi_noninteracting_energy_components(reference_system, alchemical_sy The positions to test (units of length). """ for region in alchemical_regions: - check_noninteracting_energy_components(reference_system, alchemical_system, region, positions, True) - - -def check_noninteracting_energy_components(reference_system, alchemical_system, alchemical_regions, positions, multi_regions=False): + check_noninteracting_energy_components( + reference_system, alchemical_system, region, positions, True + ) + + +def check_noninteracting_energy_components( + reference_system, + alchemical_system, + alchemical_regions, + positions, + multi_regions=False, +): """Check non-interacting energy components are zero when appropriate. Parameters ---------- @@ -786,52 +1117,89 @@ def check_noninteracting_energy_components(reference_system, alchemical_system, is_exact_pme = is_alchemical_pme_treatment_exact(alchemical_system) # Set state to non-interacting. - alchemical_state = AlchemicalState.from_system(alchemical_system, parameters_name_suffix=alchemical_regions.name) + alchemical_state = AlchemicalState.from_system( + alchemical_system, parameters_name_suffix=alchemical_regions.name + ) alchemical_state.set_alchemical_parameters(0.0) - energy_components = AbsoluteAlchemicalFactory.get_energy_components(alchemical_system, alchemical_state, - positions, platform=GLOBAL_ALCHEMY_PLATFORM) + energy_components = AbsoluteAlchemicalFactory.get_energy_components( + alchemical_system, alchemical_state, positions, platform=GLOBAL_ALCHEMY_PLATFORM + ) def assert_zero_energy(label): # Handle multiple alchemical regions. if multi_regions: - label = label + ' for region ' + alchemical_regions.name + label = label + " for region " + alchemical_regions.name # Testing energy component of each region. - print('testing {}'.format(label)) + print(f"testing {label}") value = energy_components[label] - assert abs(value / GLOBAL_ENERGY_UNIT) == 0.0, ("'{}' should have zero energy in annihilated alchemical" - " state, but energy is {}").format(label, str(value)) + assert abs(value / GLOBAL_ENERGY_UNIT) == 0.0, ( + "'{}' should have zero energy in annihilated alchemical" + " state, but energy is {}" + ).format(label, str(value)) # Check that non-alchemical/alchemical particle interactions and 1,4 exceptions have been annihilated - assert_zero_energy('alchemically modified BondForce for non-alchemical/alchemical sterics exceptions') - assert_zero_energy('alchemically modified NonbondedForce for non-alchemical/alchemical sterics') + assert_zero_energy( + "alchemically modified BondForce for non-alchemical/alchemical sterics exceptions" + ) + assert_zero_energy( + "alchemically modified NonbondedForce for non-alchemical/alchemical sterics" + ) if is_exact_pme: - assert 'alchemically modified NonbondedForce for non-alchemical/alchemical electrostatics' not in energy_components - assert 'alchemically modified BondForce for non-alchemical/alchemical electrostatics exceptions' not in energy_components + assert ( + "alchemically modified NonbondedForce for non-alchemical/alchemical electrostatics" + not in energy_components + ) + assert ( + "alchemically modified BondForce for non-alchemical/alchemical electrostatics exceptions" + not in energy_components + ) else: - assert_zero_energy('alchemically modified NonbondedForce for non-alchemical/alchemical electrostatics') - assert_zero_energy('alchemically modified BondForce for non-alchemical/alchemical electrostatics exceptions') + assert_zero_energy( + "alchemically modified NonbondedForce for non-alchemical/alchemical electrostatics" + ) + assert_zero_energy( + "alchemically modified BondForce for non-alchemical/alchemical electrostatics exceptions" + ) # Check that alchemical/alchemical particle interactions and 1,4 exceptions have been annihilated if alchemical_regions.annihilate_sterics: - assert_zero_energy('alchemically modified NonbondedForce for alchemical/alchemical sterics') - assert_zero_energy('alchemically modified BondForce for alchemical/alchemical sterics exceptions') + assert_zero_energy( + "alchemically modified NonbondedForce for alchemical/alchemical sterics" + ) + assert_zero_energy( + "alchemically modified BondForce for alchemical/alchemical sterics exceptions" + ) if alchemical_regions.annihilate_electrostatics: if is_exact_pme: - assert 'alchemically modified NonbondedForce for alchemical/alchemical electrostatics' not in energy_components - assert 'alchemically modified BondForce for alchemical/alchemical electrostatics exceptions' not in energy_components + assert ( + "alchemically modified NonbondedForce for alchemical/alchemical electrostatics" + not in energy_components + ) + assert ( + "alchemically modified BondForce for alchemical/alchemical electrostatics exceptions" + not in energy_components + ) else: - assert_zero_energy('alchemically modified NonbondedForce for alchemical/alchemical electrostatics') - assert_zero_energy('alchemically modified BondForce for alchemical/alchemical electrostatics exceptions') + assert_zero_energy( + "alchemically modified NonbondedForce for alchemical/alchemical electrostatics" + ) + assert_zero_energy( + "alchemically modified BondForce for alchemical/alchemical electrostatics exceptions" + ) # Check valence terms - for force_name in ['HarmonicBondForce', 'HarmonicAngleForce', 'PeriodicTorsionForce']: - force_label = 'alchemically modified ' + force_name + for force_name in [ + "HarmonicBondForce", + "HarmonicAngleForce", + "PeriodicTorsionForce", + ]: + force_label = "alchemically modified " + force_name if force_label in energy_components: assert_zero_energy(force_label) # Check implicit solvent force. - for force_name in ['CustomGBForce', 'GBSAOBCForce']: - label = 'alchemically modified ' + force_name + for force_name in ["CustomGBForce", "GBSAOBCForce"]: + label = "alchemically modified " + force_name # Check if the system has an implicit solvent force. try: @@ -840,7 +1208,10 @@ def assert_zero_energy(label): continue # If all alchemical particles are modified, the alchemical energy should be zero. - if len(alchemical_regions.alchemical_atoms) == reference_system.getNumParticles(): + if ( + len(alchemical_regions.alchemical_atoms) + == reference_system.getNumParticles() + ): assert_zero_energy(label) continue @@ -896,13 +1267,21 @@ def assert_zero_energy(label): system.addForce(force) # Get positions for all non-alchemical particles. - non_alchemical_positions = [pos for i, pos in enumerate(positions) - if i not in alchemical_regions.alchemical_atoms] + non_alchemical_positions = [ + pos + for i, pos in enumerate(positions) + if i not in alchemical_regions.alchemical_atoms + ] # Compute reference force energy. - reference_force_energy = compute_force_energy(system, non_alchemical_positions, force_name) - assert_almost_equal(reference_force_energy, alchemical_energy, - 'reference {}, alchemical {}'.format(reference_force_energy, alchemical_energy)) + reference_force_energy = compute_force_energy( + system, non_alchemical_positions, force_name + ) + assert_almost_equal( + reference_force_energy, + alchemical_energy, + f"reference {reference_force_energy}, alchemical {alchemical_energy}", + ) def check_split_force_groups(system, region_names=None): @@ -915,7 +1294,8 @@ def check_split_force_groups(system, region_names=None): force_groups_by_lambda = {} lambdas_by_force_group = {} for force, lambda_name, _ in AlchemicalState._get_system_controlled_parameters( - system, parameters_name_suffix=region): + system, parameters_name_suffix=region + ): force_group = force.getForceGroup() try: force_groups_by_lambda[lambda_name].add(force_group) @@ -930,9 +1310,16 @@ def check_split_force_groups(system, region_names=None): assert 0 not in force_groups_by_lambda # There are as many alchemical force groups as not-None lambda variables. - alchemical_state = AlchemicalState.from_system(system, parameters_name_suffix=region) - valid_lambdas = {lambda_name for lambda_name in alchemical_state._get_controlled_parameters(parameters_name_suffix=region) - if getattr(alchemical_state, lambda_name) is not None} + alchemical_state = AlchemicalState.from_system( + system, parameters_name_suffix=region + ) + valid_lambdas = { + lambda_name + for lambda_name in alchemical_state._get_controlled_parameters( + parameters_name_suffix=region + ) + if getattr(alchemical_state, lambda_name) is not None + } assert valid_lambdas == set(force_groups_by_lambda.keys()) # Check that force groups and lambda variables are in 1-to-1 correspondence. @@ -944,16 +1331,26 @@ def check_split_force_groups(system, region_names=None): # With exact treatment of PME, the NonbondedForce must # be in the lambda_electrostatics force group. if is_alchemical_pme_treatment_exact(system): - force_idx, nonbonded_force = forces.find_forces(system, openmm.NonbondedForce, only_one=True) - assert force_groups_by_lambda['lambda_electrostatics_{}'.format(region)] == {nonbonded_force.getForceGroup()} + force_idx, nonbonded_force = forces.find_forces( + system, openmm.NonbondedForce, only_one=True + ) + assert force_groups_by_lambda[f"lambda_electrostatics_{region}"] == { + nonbonded_force.getForceGroup() + } # ============================================================================= # BENCHMARKING AND DEBUG FUNCTIONS # ============================================================================= -def benchmark(reference_system, alchemical_regions, positions, nsteps=500, - timestep=1.0*unit.femtoseconds): + +def benchmark( + reference_system, + alchemical_regions, + positions, + nsteps=500, + timestep=1.0 * unit.femtoseconds, +): """ Benchmark performance of alchemically modified system relative to original system. @@ -975,9 +1372,11 @@ def benchmark(reference_system, alchemical_regions, positions, nsteps=500, # Create the perturbed system. factory = AbsoluteAlchemicalFactory() - timer.start('Create alchemical system') - alchemical_system = factory.create_alchemical_system(reference_system, alchemical_regions) - timer.stop('Create alchemical system') + timer.start("Create alchemical system") + alchemical_system = factory.create_alchemical_system( + reference_system, alchemical_regions + ) + timer.stop("Create alchemical system") # Create an alchemically-perturbed state corresponding to nearly fully-interacting. # NOTE: We use a lambda slightly smaller than 1.0 because the AbsoluteAlchemicalFactory @@ -991,8 +1390,12 @@ def benchmark(reference_system, alchemical_regions, positions, nsteps=500, # Create contexts for sampling. if GLOBAL_ALCHEMY_PLATFORM: - reference_context = openmm.Context(reference_system, reference_integrator, GLOBAL_ALCHEMY_PLATFORM) - alchemical_context = openmm.Context(alchemical_system, alchemical_integrator, GLOBAL_ALCHEMY_PLATFORM) + reference_context = openmm.Context( + reference_system, reference_integrator, GLOBAL_ALCHEMY_PLATFORM + ) + alchemical_context = openmm.Context( + alchemical_system, alchemical_integrator, GLOBAL_ALCHEMY_PLATFORM + ) else: reference_context = openmm.Context(reference_system, reference_integrator) alchemical_context = openmm.Context(alchemical_system, alchemical_integrator) @@ -1004,65 +1407,103 @@ def benchmark(reference_system, alchemical_regions, positions, nsteps=500, alchemical_integrator.step(1) # Run simulations. - print('Running reference system...') - timer.start('Run reference system') + print("Running reference system...") + timer.start("Run reference system") reference_integrator.step(nsteps) - timer.stop('Run reference system') + timer.stop("Run reference system") - print('Running alchemical system...') - timer.start('Run alchemical system') + print("Running alchemical system...") + timer.start("Run alchemical system") alchemical_integrator.step(nsteps) - timer.stop('Run alchemical system') - print('Done.') + timer.stop("Run alchemical system") + print("Done.") timer.report_timing() def benchmark_alchemy_from_pdb(): - """CLI entry point for benchmarking alchemical performance from a PDB file. - """ + """CLI entry point for benchmarking alchemical performance from a PDB file.""" logging.basicConfig(level=logging.DEBUG) import mdtraj import argparse + try: from openmm import app except ImportError: # OpenMM < 7.6 from simtk.openmm import app - parser = argparse.ArgumentParser(description='Benchmark performance of alchemically-modified system.') - parser.add_argument('-p', '--pdb', metavar='PDBFILE', type=str, action='store', required=True, - help='PDB file to benchmark; only protein forcefields supported for now (no small molecules)') - parser.add_argument('-s', '--selection', metavar='SELECTION', type=str, action='store', default='not water', - help='MDTraj DSL describing alchemical region (default: "not water")') - parser.add_argument('-n', '--nsteps', metavar='STEPS', type=int, action='store', default=1000, - help='Number of benchmarking steps (default: 1000)') + parser = argparse.ArgumentParser( + description="Benchmark performance of alchemically-modified system." + ) + parser.add_argument( + "-p", + "--pdb", + metavar="PDBFILE", + type=str, + action="store", + required=True, + help="PDB file to benchmark; only protein forcefields supported for now (no small molecules)", + ) + parser.add_argument( + "-s", + "--selection", + metavar="SELECTION", + type=str, + action="store", + default="not water", + help='MDTraj DSL describing alchemical region (default: "not water")', + ) + parser.add_argument( + "-n", + "--nsteps", + metavar="STEPS", + type=int, + action="store", + default=1000, + help="Number of benchmarking steps (default: 1000)", + ) args = parser.parse_args() # Read the PDB file - print('Loading PDB file...') + print("Loading PDB file...") pdbfile = app.PDBFile(args.pdb) - print('Loading forcefield...') - forcefield = app.ForceField('amber99sbildn.xml', 'tip3p.xml') - print('Adding missing hydrogens...') + print("Loading forcefield...") + forcefield = app.ForceField("amber99sbildn.xml", "tip3p.xml") + print("Adding missing hydrogens...") modeller = app.Modeller(pdbfile.topology, pdbfile.positions) modeller.addHydrogens(forcefield) - print('Creating System...') - reference_system = forcefield.createSystem(modeller.topology, nonbondedMethod=app.PME) + print("Creating System...") + reference_system = forcefield.createSystem( + modeller.topology, nonbondedMethod=app.PME + ) # Minimize - print('Minimizing...') + print("Minimizing...") positions = minimize(reference_system, modeller.positions) # Select alchemical regions mdtraj_topology = mdtraj.Topology.from_openmm(modeller.topology) alchemical_atoms = mdtraj_topology.select(args.selection) alchemical_region = AlchemicalRegion(alchemical_atoms=alchemical_atoms) - print('There are %d atoms in the alchemical region.' % len(alchemical_atoms)) + print("There are %d atoms in the alchemical region." % len(alchemical_atoms)) # Benchmark - print('Benchmarking...') - benchmark(reference_system, alchemical_region, positions, nsteps=args.nsteps, timestep=1.0*unit.femtoseconds) - - -def overlap_check(reference_system, alchemical_system, positions, nsteps=50, nsamples=200, - cached_trajectory_filename=None, name=""): + print("Benchmarking...") + benchmark( + reference_system, + alchemical_region, + positions, + nsteps=args.nsteps, + timestep=1.0 * unit.femtoseconds, + ) + + +def overlap_check( + reference_system, + alchemical_system, + positions, + nsteps=50, + nsamples=200, + cached_trajectory_filename=None, + name="", +): """ Test overlap between reference system and alchemical system by running a short simulation. @@ -1099,7 +1540,9 @@ def overlap_check(reference_system, alchemical_system, positions, nsteps=50, nsa reference_system.addForce(openmm.MonteCarloBarostat(pressure, temperature)) # Create integrators. - reference_integrator = openmm.LangevinIntegrator(temperature, collision_rate, timestep) + reference_integrator = openmm.LangevinIntegrator( + temperature, collision_rate, timestep + ) alchemical_integrator = openmm.VerletIntegrator(timestep) # Create contexts. @@ -1110,7 +1553,7 @@ def overlap_check(reference_system, alchemical_system, positions, nsteps=50, nsa # du_n[n] is the potential energy difference of sample n. if cached_trajectory_filename is not None: try: - with open(cached_trajectory_filename, 'rb') as f: + with open(cached_trajectory_filename, "rb") as f: data = pickle.load(f) except FileNotFoundError: data = dict(du_n=[]) @@ -1119,17 +1562,17 @@ def overlap_check(reference_system, alchemical_system, positions, nsteps=50, nsa if not os.path.exists(directory): os.makedirs(directory) else: - positions = data['positions'] - reference_context.setPeriodicBoxVectors(*data['box_vectors']) + positions = data["positions"] + reference_context.setPeriodicBoxVectors(*data["box_vectors"]) else: data = dict(du_n=[]) # Collect simulation data. - iteration = len(data['du_n']) + iteration = len(data["du_n"]) reference_context.setPositions(positions) print() for sample in range(iteration, nsamples): - print('\rSample {}/{}'.format(sample+1, nsamples), end='') + print(f"\rSample {sample+1}/{nsamples}", end="") sys.stdout.flush() # Run dynamics. @@ -1138,28 +1581,30 @@ def overlap_check(reference_system, alchemical_system, positions, nsteps=50, nsa # Get reference energies. reference_state = reference_context.getState(getEnergy=True, getPositions=True) reference_potential = reference_state.getPotentialEnergy() - if np.isnan(reference_potential/kT): + if np.isnan(reference_potential / kT): raise Exception("Reference potential is NaN") # Get alchemical energies. - alchemical_context.setPeriodicBoxVectors(*reference_state.getPeriodicBoxVectors()) + alchemical_context.setPeriodicBoxVectors( + *reference_state.getPeriodicBoxVectors() + ) alchemical_context.setPositions(reference_state.getPositions(asNumpy=True)) alchemical_state = alchemical_context.getState(getEnergy=True) alchemical_potential = alchemical_state.getPotentialEnergy() - if np.isnan(alchemical_potential/kT): + if np.isnan(alchemical_potential / kT): raise Exception("Alchemical potential is NaN") # Update and cache data. - data['du_n'].append((alchemical_potential - reference_potential) / kT) + data["du_n"].append((alchemical_potential - reference_potential) / kT) if cached_trajectory_filename is not None: # Save only last iteration positions and vectors. - data['positions'] = reference_state.getPositions() - data['box_vectors'] = reference_state.getPeriodicBoxVectors() - with open(cached_trajectory_filename, 'wb') as f: + data["positions"] = reference_state.getPositions() + data["box_vectors"] = reference_state.getPeriodicBoxVectors() + with open(cached_trajectory_filename, "wb") as f: pickle.dump(data, f) # Discard data to equilibration and subsample. - du_n = np.array(data['du_n']) + du_n = np.array(data["du_n"]) t0, g, Neff = detect_equilibration(du_n) indices = subsample_correlated_data(du_n, g=g) du_n = du_n[indices] @@ -1169,8 +1614,10 @@ def overlap_check(reference_system, alchemical_system, positions, nsteps=50, nsa # Raise an exception if the error is larger than 3kT. MAX_DEVIATION = 3.0 # kT - report = ('\nDeltaF = {:12.3f} +- {:12.3f} kT ({:3.2f} samples, g = {:3.1f}); ' - 'du mean {:.3f} kT stddev {:.3f} kT').format(DeltaF, dDeltaF, Neff, g, du_n.mean(), du_n.std()) + report = ( + "\nDeltaF = {:12.3f} +- {:12.3f} kT ({:3.2f} samples, g = {:3.1f}); " + "du mean {:.3f} kT stddev {:.3f} kT" + ).format(DeltaF, dDeltaF, Neff, g, du_n.mean(), du_n.std()) print(report) if dDeltaF > MAX_DEVIATION: raise Exception(report) @@ -1190,37 +1637,40 @@ def rstyle(ax): import matplotlib import matplotlib.pyplot as plt - #Set the style of the major and minor grid lines, filled blocks - ax.grid(True, 'major', color='w', linestyle='-', linewidth=1.4) - ax.grid(True, 'minor', color='0.99', linestyle='-', linewidth=0.7) - ax.patch.set_facecolor('0.90') + # Set the style of the major and minor grid lines, filled blocks + ax.grid(True, "major", color="w", linestyle="-", linewidth=1.4) + ax.grid(True, "minor", color="0.99", linestyle="-", linewidth=0.7) + ax.patch.set_facecolor("0.90") ax.set_axisbelow(True) - #Set minor tick spacing to 1/2 of the major ticks - ax.xaxis.set_minor_locator((pylab.MultipleLocator((plt.xticks()[0][1] - plt.xticks()[0][0]) / 2.0))) - ax.yaxis.set_minor_locator((pylab.MultipleLocator((plt.yticks()[0][1] - plt.yticks()[0][0]) / 2.0))) + # Set minor tick spacing to 1/2 of the major ticks + ax.xaxis.set_minor_locator( + pylab.MultipleLocator((plt.xticks()[0][1] - plt.xticks()[0][0]) / 2.0) + ) + ax.yaxis.set_minor_locator( + pylab.MultipleLocator((plt.yticks()[0][1] - plt.yticks()[0][0]) / 2.0) + ) - #Remove axis border + # Remove axis border for child in ax.get_children(): if isinstance(child, matplotlib.spines.Spine): child.set_alpha(0) - #Restyle the tick lines + # Restyle the tick lines for line in ax.get_xticklines() + ax.get_yticklines(): line.set_markersize(5) line.set_color("gray") line.set_markeredgewidth(1.4) - #Remove the minor tick lines - for line in (ax.xaxis.get_ticklines(minor=True) + - ax.yaxis.get_ticklines(minor=True)): + # Remove the minor tick lines + for line in ax.xaxis.get_ticklines(minor=True) + ax.yaxis.get_ticklines(minor=True): line.set_markersize(0) - #Only show bottom left ticks, pointing out of axis - plt.rcParams['xtick.direction'] = 'out' - plt.rcParams['ytick.direction'] = 'out' - ax.xaxis.set_ticks_position('bottom') - ax.yaxis.set_ticks_position('left') + # Only show bottom left ticks, pointing out of axis + plt.rcParams["xtick.direction"] = "out" + plt.rcParams["ytick.direction"] = "out" + ax.xaxis.set_ticks_position("bottom") + ax.yaxis.set_ticks_position("left") def lambda_trace(reference_system, alchemical_regions, positions, nsteps=100): @@ -1231,7 +1681,9 @@ def lambda_trace(reference_system, alchemical_regions, positions, nsteps=100): # Create a factory to produce alchemical intermediates. factory = AbsoluteAlchemicalFactory() - alchemical_system = factory.create_alchemical_system(reference_system, alchemical_regions) + alchemical_system = factory.create_alchemical_system( + reference_system, alchemical_regions + ) alchemical_state = AlchemicalState.from_system(alchemical_system) # Take equally-sized steps. @@ -1241,28 +1693,29 @@ def lambda_trace(reference_system, alchemical_regions, positions, nsteps=100): u_original = compute_energy(reference_system, positions) # Scan through lambda values. - lambda_i = np.zeros([nsteps+1], np.float64) # lambda values for u_i + lambda_i = np.zeros([nsteps + 1], np.float64) # lambda values for u_i # u_i[i] is the potential energy for lambda_i[i] - u_i = unit.Quantity(np.zeros([nsteps+1], np.float64), unit.kilocalories_per_mole) - for i in range(nsteps+1): - lambda_i[i] = 1.0-i*delta + u_i = unit.Quantity(np.zeros([nsteps + 1], np.float64), unit.kilocalories_per_mole) + for i in range(nsteps + 1): + lambda_i[i] = 1.0 - i * delta alchemical_state.set_alchemical_parameters(lambda_i[i]) alchemical_state.apply_to_system(alchemical_system) u_i[i] = compute_energy(alchemical_system, positions) - logger.info("{:12.9f} {:24.8f} kcal/mol".format(lambda_i[i], u_i[i] / GLOBAL_ENERGY_UNIT)) + logger.info(f"{lambda_i[i]:12.9f} {u_i[i] / GLOBAL_ENERGY_UNIT:24.8f} kcal/mol") # Write figure as PDF. from matplotlib.backends.backend_pdf import PdfPages import matplotlib.pyplot as plt - with PdfPages('lambda-trace.pdf') as pdf: + + with PdfPages("lambda-trace.pdf") as pdf: fig = plt.figure(figsize=(10, 5)) ax = fig.add_subplot(111) - plt.plot(1, u_original / unit.kilocalories_per_mole, 'ro', label='unmodified') - plt.plot(lambda_i, u_i / unit.kilocalories_per_mole, 'k.', label='alchemical') - plt.title('T4 lysozyme L99A + p-xylene : AMBER96 + OBC GBSA') - plt.ylabel('potential (kcal/mol)') - plt.xlabel('lambda') + plt.plot(1, u_original / unit.kilocalories_per_mole, "ro", label="unmodified") + plt.plot(lambda_i, u_i / unit.kilocalories_per_mole, "k.", label="alchemical") + plt.title("T4 lysozyme L99A + p-xylene : AMBER96 + OBC GBSA") + plt.ylabel("potential (kcal/mol)") + plt.xlabel("lambda") ax.legend() rstyle(ax) pdf.savefig() # saves the current figure into a pdf page @@ -1270,19 +1723,25 @@ def lambda_trace(reference_system, alchemical_regions, positions, nsteps=100): def generate_trace(test_system): - lambda_trace(test_system['test'].system, test_system['test'].positions, test_system['receptor_atoms'], test_system['ligand_atoms']) + lambda_trace( + test_system["test"].system, + test_system["test"].positions, + test_system["receptor_atoms"], + test_system["ligand_atoms"], + ) # ============================================================================= # TEST ALCHEMICAL FACTORY SUITE # ============================================================================= + def test_resolve_alchemical_region(): """Test the method AbsoluteAlchemicalFactory._resolve_alchemical_region.""" test_cases = [ (testsystems.AlanineDipeptideVacuum(), range(22), 9, 36, 48), (testsystems.AlanineDipeptideVacuum(), range(11, 22), 4, 21, 31), - (testsystems.LennardJonesCluster(), range(27), 0, 0, 0) + (testsystems.LennardJonesCluster(), range(27), 0, 0, 0), ] for i, (test_case, atoms, n_bonds, n_angles, n_torsions) in enumerate(test_cases): @@ -1290,37 +1749,59 @@ def test_resolve_alchemical_region(): # Default arguments are converted to empty list. alchemical_region = AlchemicalRegion(alchemical_atoms=atoms) - resolved_region = AbsoluteAlchemicalFactory._resolve_alchemical_region(system, alchemical_region) - for region in ['bonds', 'angles', 'torsions']: - assert getattr(resolved_region, 'alchemical_' + region) == set() + resolved_region = AbsoluteAlchemicalFactory._resolve_alchemical_region( + system, alchemical_region + ) + for region in ["bonds", "angles", "torsions"]: + assert getattr(resolved_region, "alchemical_" + region) == set() # Numpy arrays are converted to sets. - alchemical_region = AlchemicalRegion(alchemical_atoms=np.array(atoms), - alchemical_bonds=np.array(range(n_bonds)), - alchemical_angles=np.array(range(n_angles)), - alchemical_torsions=np.array(range(n_torsions))) - resolved_region = AbsoluteAlchemicalFactory._resolve_alchemical_region(system, alchemical_region) - for region in ['atoms', 'bonds', 'angles', 'torsions']: - assert isinstance(getattr(resolved_region, 'alchemical_' + region), frozenset) + alchemical_region = AlchemicalRegion( + alchemical_atoms=np.array(atoms), + alchemical_bonds=np.array(range(n_bonds)), + alchemical_angles=np.array(range(n_angles)), + alchemical_torsions=np.array(range(n_torsions)), + ) + resolved_region = AbsoluteAlchemicalFactory._resolve_alchemical_region( + system, alchemical_region + ) + for region in ["atoms", "bonds", "angles", "torsions"]: + assert isinstance( + getattr(resolved_region, "alchemical_" + region), frozenset + ) # Bonds, angles and torsions are inferred correctly. - alchemical_region = AlchemicalRegion(alchemical_atoms=atoms, alchemical_bonds=True, - alchemical_angles=True, alchemical_torsions=True) - resolved_region = AbsoluteAlchemicalFactory._resolve_alchemical_region(system, alchemical_region) - for j, region in enumerate(['bonds', 'angles', 'torsions']): - assert len(getattr(resolved_region, 'alchemical_' + region)) == test_cases[i][j+2] + alchemical_region = AlchemicalRegion( + alchemical_atoms=atoms, + alchemical_bonds=True, + alchemical_angles=True, + alchemical_torsions=True, + ) + resolved_region = AbsoluteAlchemicalFactory._resolve_alchemical_region( + system, alchemical_region + ) + for j, region in enumerate(["bonds", "angles", "torsions"]): + assert ( + len(getattr(resolved_region, "alchemical_" + region)) + == test_cases[i][j + 2] + ) # An exception is if indices are not part of the system. alchemical_region = AlchemicalRegion(alchemical_atoms=[10000000]) with nose.tools.assert_raises(ValueError): - AbsoluteAlchemicalFactory._resolve_alchemical_region(system, alchemical_region) + AbsoluteAlchemicalFactory._resolve_alchemical_region( + system, alchemical_region + ) # An exception is raised if nothing is defined. alchemical_region = AlchemicalRegion() with nose.tools.assert_raises(ValueError): - AbsoluteAlchemicalFactory._resolve_alchemical_region(system, alchemical_region) + AbsoluteAlchemicalFactory._resolve_alchemical_region( + system, alchemical_region + ) + -class TestAbsoluteAlchemicalFactory(object): +class TestAbsoluteAlchemicalFactory: """Test AbsoluteAlchemicalFactory class.""" @classmethod @@ -1338,52 +1819,84 @@ def define_systems(cls): # Basic test systems: Lennard-Jones and water particles only. # Test also dispersion correction and switch off ("on" values # for these options are tested in HostGuestExplicit system). - cls.test_systems['LennardJonesCluster'] = testsystems.LennardJonesCluster() - cls.test_systems['LennardJonesFluid with dispersion correction'] = \ + cls.test_systems["LennardJonesCluster"] = testsystems.LennardJonesCluster() + cls.test_systems["LennardJonesFluid with dispersion correction"] = ( testsystems.LennardJonesFluid(nparticles=100, dispersion_correction=True) - cls.test_systems['TIP3P WaterBox with reaction field, no switch, no dispersion correction'] = \ - testsystems.WaterBox(dispersion_correction=False, switch=False, nonbondedMethod=openmm.app.CutoffPeriodic) - cls.test_systems['TIP4P-EW WaterBox and NaCl with PME'] = \ - testsystems.WaterBox(nonbondedMethod=openmm.app.PME, model='tip4pew', ionic_strength=200*unit.millimolar) + ) + cls.test_systems[ + "TIP3P WaterBox with reaction field, no switch, no dispersion correction" + ] = testsystems.WaterBox( + dispersion_correction=False, + switch=False, + nonbondedMethod=openmm.app.CutoffPeriodic, + ) + cls.test_systems["TIP4P-EW WaterBox and NaCl with PME"] = testsystems.WaterBox( + nonbondedMethod=openmm.app.PME, + model="tip4pew", + ionic_strength=200 * unit.millimolar, + ) # Vacuum and implicit. - cls.test_systems['AlanineDipeptideVacuum'] = testsystems.AlanineDipeptideVacuum() - cls.test_systems['AlanineDipeptideImplicit'] = testsystems.AlanineDipeptideImplicit() - cls.test_systems['TolueneImplicitOBC2'] = testsystems.TolueneImplicitOBC2() - cls.test_systems['TolueneImplicitGBn'] = testsystems.TolueneImplicitGBn() + cls.test_systems["AlanineDipeptideVacuum"] = ( + testsystems.AlanineDipeptideVacuum() + ) + cls.test_systems["AlanineDipeptideImplicit"] = ( + testsystems.AlanineDipeptideImplicit() + ) + cls.test_systems["TolueneImplicitOBC2"] = testsystems.TolueneImplicitOBC2() + cls.test_systems["TolueneImplicitGBn"] = testsystems.TolueneImplicitGBn() # Explicit test system: PME and CutoffPeriodic. - #cls.test_systems['AlanineDipeptideExplicit with CutoffPeriodic'] = \ + # cls.test_systems['AlanineDipeptideExplicit with CutoffPeriodic'] = \ # testsystems.AlanineDipeptideExplicit(nonbondedMethod=openmm.app.CutoffPeriodic) - cls.test_systems['HostGuestExplicit with PME'] = \ - testsystems.HostGuestExplicit(nonbondedMethod=openmm.app.PME) - cls.test_systems['HostGuestExplicit with CutoffPeriodic'] = \ + cls.test_systems["HostGuestExplicit with PME"] = testsystems.HostGuestExplicit( + nonbondedMethod=openmm.app.PME + ) + cls.test_systems["HostGuestExplicit with CutoffPeriodic"] = ( testsystems.HostGuestExplicit(nonbondedMethod=openmm.app.CutoffPeriodic) + ) @classmethod def define_regions(cls): """Create shared AlchemicalRegions for test systems in cls.test_regions.""" cls.test_regions = dict() - cls.test_regions['LennardJonesCluster'] = AlchemicalRegion(alchemical_atoms=range(2)) - cls.test_regions['LennardJonesFluid'] = AlchemicalRegion(alchemical_atoms=range(10)) - cls.test_regions['Toluene'] = AlchemicalRegion(alchemical_atoms=range(6)) # Only partially modified. - cls.test_regions['AlanineDipeptide'] = AlchemicalRegion(alchemical_atoms=range(22)) - cls.test_regions['HostGuestExplicit'] = AlchemicalRegion(alchemical_atoms=range(126, 156)) - cls.test_regions['TIP3P WaterBox'] = AlchemicalRegion(alchemical_atoms=range(0,3)) + cls.test_regions["LennardJonesCluster"] = AlchemicalRegion( + alchemical_atoms=range(2) + ) + cls.test_regions["LennardJonesFluid"] = AlchemicalRegion( + alchemical_atoms=range(10) + ) + cls.test_regions["Toluene"] = AlchemicalRegion( + alchemical_atoms=range(6) + ) # Only partially modified. + cls.test_regions["AlanineDipeptide"] = AlchemicalRegion( + alchemical_atoms=range(22) + ) + cls.test_regions["HostGuestExplicit"] = AlchemicalRegion( + alchemical_atoms=range(126, 156) + ) + cls.test_regions["TIP3P WaterBox"] = AlchemicalRegion( + alchemical_atoms=range(0, 3) + ) # Modify ions. - for atom in cls.test_systems['TIP4P-EW WaterBox and NaCl with PME'].topology.atoms(): - if atom.name in ['Na', 'Cl']: - cls.test_regions['TIP4P-EW WaterBox and NaCl'] = AlchemicalRegion(alchemical_atoms=range(atom.index, atom.index+1)) + for atom in cls.test_systems[ + "TIP4P-EW WaterBox and NaCl with PME" + ].topology.atoms(): + if atom.name in ["Na", "Cl"]: + cls.test_regions["TIP4P-EW WaterBox and NaCl"] = AlchemicalRegion( + alchemical_atoms=range(atom.index, atom.index + 1) + ) break @classmethod def generate_cases(cls): """Generate all test cases in cls.test_cases combinatorially.""" cls.test_cases = dict() - direct_space_factory = AbsoluteAlchemicalFactory(alchemical_pme_treatment='direct-space', - alchemical_rf_treatment='switched') - exact_pme_factory = AbsoluteAlchemicalFactory(alchemical_pme_treatment='exact') + direct_space_factory = AbsoluteAlchemicalFactory( + alchemical_pme_treatment="direct-space", alchemical_rf_treatment="switched" + ) + exact_pme_factory = AbsoluteAlchemicalFactory(alchemical_pme_treatment="exact") # We generate all possible combinations of annihilate_sterics/electrostatics # for each test system. We also annihilate bonds, angles and torsions every @@ -1391,7 +1904,6 @@ def generate_cases(cls): # each combination of annihilate_sterics/electrostatics. n_test_cases = 0 for test_system_name, test_system in cls.test_systems.items(): - # Find standard alchemical region. for region_name, region in cls.test_regions.items(): if region_name in test_system_name: @@ -1399,55 +1911,91 @@ def generate_cases(cls): assert region_name in test_system_name, test_system_name # Find nonbonded method. - force_idx, nonbonded_force = forces.find_forces(test_system.system, openmm.NonbondedForce, only_one=True) + force_idx, nonbonded_force = forces.find_forces( + test_system.system, openmm.NonbondedForce, only_one=True + ) nonbonded_method = nonbonded_force.getNonbondedMethod() # Create all combinations of annihilate_sterics/electrostatics. - for annihilate_sterics, annihilate_electrostatics in itertools.product((True, False), repeat=2): + for annihilate_sterics, annihilate_electrostatics in itertools.product( + (True, False), repeat=2 + ): # Create new region that we can modify. - test_region = region._replace(annihilate_sterics=annihilate_sterics, - annihilate_electrostatics=annihilate_electrostatics) + test_region = region._replace( + annihilate_sterics=annihilate_sterics, + annihilate_electrostatics=annihilate_electrostatics, + ) # Create test name. test_case_name = test_system_name[:] if annihilate_sterics: - test_case_name += ', annihilated sterics' + test_case_name += ", annihilated sterics" if annihilate_electrostatics: - test_case_name += ', annihilated electrostatics' + test_case_name += ", annihilated electrostatics" # Annihilate bonds and angles every three test_cases. if n_test_cases % 3 == 0: - test_region = test_region._replace(alchemical_bonds=True, alchemical_angles=True, - alchemical_torsions=True) - test_case_name += ', annihilated bonds, angles and torsions' + test_region = test_region._replace( + alchemical_bonds=True, + alchemical_angles=True, + alchemical_torsions=True, + ) + test_case_name += ", annihilated bonds, angles and torsions" # Add different softcore parameters every five test_cases. if n_test_cases % 5 == 0: - test_region = test_region._replace(softcore_alpha=1.0, softcore_beta=1.0, softcore_a=1.0, softcore_b=1.0, - softcore_c=1.0, softcore_d=1.0, softcore_e=1.0, softcore_f=1.0) - test_case_name += ', modified softcore parameters' + test_region = test_region._replace( + softcore_alpha=1.0, + softcore_beta=1.0, + softcore_a=1.0, + softcore_b=1.0, + softcore_c=1.0, + softcore_d=1.0, + softcore_e=1.0, + softcore_f=1.0, + ) + test_case_name += ", modified softcore parameters" # Pre-generate alchemical system. - alchemical_system = direct_space_factory.create_alchemical_system(test_system.system, test_region) + alchemical_system = direct_space_factory.create_alchemical_system( + test_system.system, test_region + ) # Add test case. - cls.test_cases[test_case_name] = (test_system, alchemical_system, test_region) + cls.test_cases[test_case_name] = ( + test_system, + alchemical_system, + test_region, + ) n_test_cases += 1 # If we don't use softcore electrostatics and we annihilate charges # we can test also exact PME treatment. We don't increase n_test_cases # purposely to keep track of which tests are added above. - if (test_region.softcore_beta == 0.0 and annihilate_electrostatics and - nonbonded_method in [openmm.NonbondedForce.PME, openmm.NonbondedForce.Ewald]): - alchemical_system = exact_pme_factory.create_alchemical_system(test_system.system, test_region) - test_case_name += ', exact PME' - cls.test_cases[test_case_name] = (test_system, alchemical_system, test_region) + if ( + test_region.softcore_beta == 0.0 + and annihilate_electrostatics + and nonbonded_method + in [openmm.NonbondedForce.PME, openmm.NonbondedForce.Ewald] + ): + alchemical_system = exact_pme_factory.create_alchemical_system( + test_system.system, test_region + ) + test_case_name += ", exact PME" + cls.test_cases[test_case_name] = ( + test_system, + alchemical_system, + test_region, + ) # If the test system uses reaction field replace reaction field # of the reference system to allow comparisons. if nonbonded_method == openmm.NonbondedForce.CutoffPeriodic: - forcefactories.replace_reaction_field(test_system.system, return_copy=False, - switch_width=direct_space_factory.switch_width) + forcefactories.replace_reaction_field( + test_system.system, + return_copy=False, + switch_width=direct_space_factory.switch_width, + ) def filter_cases(self, condition_func, max_number=None): """Return the list of test cases that satisfy condition_func(test_case_name).""" @@ -1465,43 +2013,82 @@ def filter_cases(self, condition_func, max_number=None): def test_split_force_groups(self): """Forces having different lambda variables should have a different force group.""" # Select 1 implicit, 1 explicit, and 1 exact PME explicit test case randomly. - test_cases = self.filter_cases(lambda x: 'Implicit' in x, max_number=1) - test_cases.update(self.filter_cases(lambda x: 'Explicit ' in x and 'exact PME' in x, max_number=1)) - test_cases.update(self.filter_cases(lambda x: 'Explicit ' in x and 'exact PME' not in x, max_number=1)) - for test_name, (test_system, alchemical_system, alchemical_region) in test_cases.items(): + test_cases = self.filter_cases(lambda x: "Implicit" in x, max_number=1) + test_cases.update( + self.filter_cases( + lambda x: "Explicit " in x and "exact PME" in x, max_number=1 + ) + ) + test_cases.update( + self.filter_cases( + lambda x: "Explicit " in x and "exact PME" not in x, max_number=1 + ) + ) + for test_name, ( + test_system, + alchemical_system, + alchemical_region, + ) in test_cases.items(): f = partial(check_split_force_groups, alchemical_system) - f.description = "Testing force splitting among groups of {}".format(test_name) + f.description = f"Testing force splitting among groups of {test_name}" yield f def test_fully_interacting_energy(self): """Compare the energies of reference and fully interacting alchemical system.""" - for test_name, (test_system, alchemical_system, alchemical_region) in self.test_cases.items(): - f = partial(compare_system_energies, test_system.system, - alchemical_system, alchemical_region, test_system.positions) - f.description = "Testing fully interacting energy of {}".format(test_name) + for test_name, ( + test_system, + alchemical_system, + alchemical_region, + ) in self.test_cases.items(): + f = partial( + compare_system_energies, + test_system.system, + alchemical_system, + alchemical_region, + test_system.positions, + ) + f.description = f"Testing fully interacting energy of {test_name}" yield f def test_noninteracting_energy_components(self): """Check all forces annihilated/decoupled when their lambda variables are zero.""" - for test_name, (test_system, alchemical_system, alchemical_region) in self.test_cases.items(): - f = partial(check_noninteracting_energy_components, test_system.system, alchemical_system, - alchemical_region, test_system.positions) - f.description = "Testing non-interacting energy of {}".format(test_name) + for test_name, ( + test_system, + alchemical_system, + alchemical_region, + ) in self.test_cases.items(): + f = partial( + check_noninteracting_energy_components, + test_system.system, + alchemical_system, + alchemical_region, + test_system.positions, + ) + f.description = f"Testing non-interacting energy of {test_name}" yield f - @attr('slow') + @attr("slow") def test_fully_interacting_energy_components(self): """Test interacting state energy by force component.""" # This is a very expensive but very informative test. We can # run this locally when test_fully_interacting_energies() fails. - test_cases = self.filter_cases(lambda x: 'Explicit' in x) - for test_name, (test_system, alchemical_system, alchemical_region) in test_cases.items(): - f = partial(check_interacting_energy_components, test_system.system, alchemical_system, - alchemical_region, test_system.positions) + test_cases = self.filter_cases(lambda x: "Explicit" in x) + for test_name, ( + test_system, + alchemical_system, + alchemical_region, + ) in test_cases.items(): + f = partial( + check_interacting_energy_components, + test_system.system, + alchemical_system, + alchemical_region, + test_system.positions, + ) f.description = "Testing energy components of %s..." % test_name yield f - @attr('slow') + @attr("slow") def test_platforms(self): """Test interacting and noninteracting energies on all platforms.""" global GLOBAL_ALCHEMY_PLATFORM @@ -1512,37 +2099,65 @@ def test_platforms(self): default_platform_name = utils.get_fastest_platform().getName() else: default_platform_name = old_global_platform.getName() - platforms = [platform for platform in utils.get_available_platforms() - if platform.getName() != default_platform_name] + platforms = [ + platform + for platform in utils.get_available_platforms() + if platform.getName() != default_platform_name + ] # Test interacting and noninteracting energies on all platforms. for platform in platforms: GLOBAL_ALCHEMY_PLATFORM = platform - for test_name, (test_system, alchemical_system, alchemical_region) in self.test_cases.items(): - f = partial(compare_system_energies, test_system.system, alchemical_system, - alchemical_region, test_system.positions) - f.description = "Test fully interacting energy of {} on {}".format(test_name, platform.getName()) + for test_name, ( + test_system, + alchemical_system, + alchemical_region, + ) in self.test_cases.items(): + f = partial( + compare_system_energies, + test_system.system, + alchemical_system, + alchemical_region, + test_system.positions, + ) + f.description = f"Test fully interacting energy of {test_name} on {platform.getName()}" yield f - f = partial(check_noninteracting_energy_components, test_system.system, alchemical_system, - alchemical_region, test_system.positions) - f.description = "Test non-interacting energy of {} on {}".format(test_name, platform.getName()) + f = partial( + check_noninteracting_energy_components, + test_system.system, + alchemical_system, + alchemical_region, + test_system.positions, + ) + f.description = f"Test non-interacting energy of {test_name} on {platform.getName()}" yield f # Restore global platform GLOBAL_ALCHEMY_PLATFORM = old_global_platform - @attr('slow') + @attr("slow") def test_overlap(self): """Tests overlap between reference and alchemical systems.""" - for test_name, (test_system, alchemical_system, alchemical_region) in self.test_cases.items(): - #cached_trajectory_filename = os.path.join(os.environ['HOME'], '.cache', 'alchemy', 'tests', + for test_name, ( + test_system, + alchemical_system, + alchemical_region, + ) in self.test_cases.items(): + # cached_trajectory_filename = os.path.join(os.environ['HOME'], '.cache', 'alchemy', 'tests', # test_name + '.pickle') cached_trajectory_filename = None - f = partial(overlap_check, test_system.system, alchemical_system, test_system.positions, - cached_trajectory_filename=cached_trajectory_filename, name=test_name) - f.description = "Testing reference/alchemical overlap for {}".format(test_name) + f = partial( + overlap_check, + test_system.system, + alchemical_system, + test_system.positions, + cached_trajectory_filename=cached_trajectory_filename, + name=test_name, + ) + f.description = f"Testing reference/alchemical overlap for {test_name}" yield f + class TestMultiRegionAbsoluteAlchemicalFactory(TestAbsoluteAlchemicalFactory): """Test AbsoluteAlchemicalFactory class using multiple regions.""" @@ -1554,15 +2169,23 @@ def define_systems(cls): # Basic test systems: Lennard-Jones and water particles only. # Test also dispersion correction and switch off ("on" values # for these options are tested in HostGuestExplicit system). - cls.test_systems['LennardJonesCluster'] = testsystems.LennardJonesCluster() - cls.test_systems['LennardJonesFluid with dispersion correction'] = \ + cls.test_systems["LennardJonesCluster"] = testsystems.LennardJonesCluster() + cls.test_systems["LennardJonesFluid with dispersion correction"] = ( testsystems.LennardJonesFluid(nparticles=100, dispersion_correction=True) - cls.test_systems['TIP3P WaterBox with reaction field, no switch, no dispersion correction'] = \ - testsystems.WaterBox(dispersion_correction=False, switch=False, nonbondedMethod=openmm.app.CutoffPeriodic) - cls.test_systems['HostGuestExplicit with PME'] = \ - testsystems.HostGuestExplicit(nonbondedMethod=openmm.app.PME) - cls.test_systems['HostGuestExplicit with CutoffPeriodic'] = \ + ) + cls.test_systems[ + "TIP3P WaterBox with reaction field, no switch, no dispersion correction" + ] = testsystems.WaterBox( + dispersion_correction=False, + switch=False, + nonbondedMethod=openmm.app.CutoffPeriodic, + ) + cls.test_systems["HostGuestExplicit with PME"] = testsystems.HostGuestExplicit( + nonbondedMethod=openmm.app.PME + ) + cls.test_systems["HostGuestExplicit with CutoffPeriodic"] = ( testsystems.HostGuestExplicit(nonbondedMethod=openmm.app.CutoffPeriodic) + ) @classmethod def define_regions(cls): @@ -1571,27 +2194,50 @@ def define_regions(cls): cls.test_region_one = dict() cls.test_region_two = dict() - cls.test_region_zero['LennardJonesCluster'] = AlchemicalRegion(alchemical_atoms=range(2), name='zero') - cls.test_region_one['LennardJonesCluster'] = AlchemicalRegion(alchemical_atoms=range(2,4), name='one') - cls.test_region_two['LennardJonesCluster'] = AlchemicalRegion(alchemical_atoms=range(4,6), name='two') - cls.test_region_zero['LennardJonesFluid'] = AlchemicalRegion(alchemical_atoms=range(10), name='zero') - cls.test_region_one['LennardJonesFluid'] = AlchemicalRegion(alchemical_atoms=range(10,20), name='one') - cls.test_region_two['LennardJonesFluid'] = AlchemicalRegion(alchemical_atoms=range(20,30), name='two') - cls.test_region_zero['TIP3P WaterBox'] = AlchemicalRegion(alchemical_atoms=range(3), name='zero') - cls.test_region_one['TIP3P WaterBox'] = AlchemicalRegion(alchemical_atoms=range(3,6), name='one') - cls.test_region_two['TIP3P WaterBox'] = AlchemicalRegion(alchemical_atoms=range(6,9), name='two') - #Three regions push HostGuest system beyond 32 force groups - cls.test_region_zero['HostGuestExplicit'] = AlchemicalRegion(alchemical_atoms=range(126, 156), name='zero') - cls.test_region_one['HostGuestExplicit'] = AlchemicalRegion(alchemical_atoms=range(156,160), name='one') - cls.test_region_two['HostGuestExplicit'] = None + cls.test_region_zero["LennardJonesCluster"] = AlchemicalRegion( + alchemical_atoms=range(2), name="zero" + ) + cls.test_region_one["LennardJonesCluster"] = AlchemicalRegion( + alchemical_atoms=range(2, 4), name="one" + ) + cls.test_region_two["LennardJonesCluster"] = AlchemicalRegion( + alchemical_atoms=range(4, 6), name="two" + ) + cls.test_region_zero["LennardJonesFluid"] = AlchemicalRegion( + alchemical_atoms=range(10), name="zero" + ) + cls.test_region_one["LennardJonesFluid"] = AlchemicalRegion( + alchemical_atoms=range(10, 20), name="one" + ) + cls.test_region_two["LennardJonesFluid"] = AlchemicalRegion( + alchemical_atoms=range(20, 30), name="two" + ) + cls.test_region_zero["TIP3P WaterBox"] = AlchemicalRegion( + alchemical_atoms=range(3), name="zero" + ) + cls.test_region_one["TIP3P WaterBox"] = AlchemicalRegion( + alchemical_atoms=range(3, 6), name="one" + ) + cls.test_region_two["TIP3P WaterBox"] = AlchemicalRegion( + alchemical_atoms=range(6, 9), name="two" + ) + # Three regions push HostGuest system beyond 32 force groups + cls.test_region_zero["HostGuestExplicit"] = AlchemicalRegion( + alchemical_atoms=range(126, 156), name="zero" + ) + cls.test_region_one["HostGuestExplicit"] = AlchemicalRegion( + alchemical_atoms=range(156, 160), name="one" + ) + cls.test_region_two["HostGuestExplicit"] = None @classmethod def generate_cases(cls): """Generate all test cases in cls.test_cases combinatorially.""" cls.test_cases = dict() - direct_space_factory = AbsoluteAlchemicalFactory(alchemical_pme_treatment='direct-space', - alchemical_rf_treatment='switched') - exact_pme_factory = AbsoluteAlchemicalFactory(alchemical_pme_treatment='exact') + direct_space_factory = AbsoluteAlchemicalFactory( + alchemical_pme_treatment="direct-space", alchemical_rf_treatment="switched" + ) + exact_pme_factory = AbsoluteAlchemicalFactory(alchemical_pme_treatment="exact") # We generate all possible combinations of annihilate_sterics/electrostatics # for each test system. We also annihilate bonds, angles and torsions every @@ -1599,7 +2245,6 @@ def generate_cases(cls): # each combination of annihilate_sterics/electrostatics. n_test_cases = 0 for test_system_name, test_system in cls.test_systems.items(): - # Find standard alchemical region zero. for region_name_zero, region_zero in cls.test_region_zero.items(): if region_name_zero in test_system_name: @@ -1618,92 +2263,152 @@ def generate_cases(cls): break assert region_name_two in test_system_name, test_system_name - assert region_name_zero == region_name_one and region_name_one == region_name_two - #We only want two regions for HostGuest or we get too many force groups - if 'HostGuestExplicit' in region_name_one: + assert ( + region_name_zero == region_name_one + and region_name_one == region_name_two + ) + # We only want two regions for HostGuest or we get too many force groups + if "HostGuestExplicit" in region_name_one: test_regions = [region_zero, region_one] else: test_regions = [region_zero, region_one, region_two] # Find nonbonded method. - force_idx, nonbonded_force = forces.find_forces(test_system.system, openmm.NonbondedForce, only_one=True) + force_idx, nonbonded_force = forces.find_forces( + test_system.system, openmm.NonbondedForce, only_one=True + ) nonbonded_method = nonbonded_force.getNonbondedMethod() # Create all combinations of annihilate_sterics/electrostatics. - for annihilate_sterics, annihilate_electrostatics in itertools.product((True, False), repeat=2): + for annihilate_sterics, annihilate_electrostatics in itertools.product( + (True, False), repeat=2 + ): # Create new region that we can modify. for i, test_region in enumerate(test_regions): - test_regions[i] = test_region._replace(annihilate_sterics=annihilate_sterics, - annihilate_electrostatics=annihilate_electrostatics) + test_regions[i] = test_region._replace( + annihilate_sterics=annihilate_sterics, + annihilate_electrostatics=annihilate_electrostatics, + ) # Create test name. test_case_name = test_system_name[:] if annihilate_sterics: - test_case_name += ', annihilated sterics' + test_case_name += ", annihilated sterics" if annihilate_electrostatics: - test_case_name += ', annihilated electrostatics' + test_case_name += ", annihilated electrostatics" # Annihilate bonds and angles every three test_cases. if n_test_cases % 3 == 0: for i, test_region in enumerate(test_regions): - test_regions[i] = test_region._replace(alchemical_bonds=True, alchemical_angles=True, - alchemical_torsions=True) - test_case_name += ', annihilated bonds, angles and torsions' + test_regions[i] = test_region._replace( + alchemical_bonds=True, + alchemical_angles=True, + alchemical_torsions=True, + ) + test_case_name += ", annihilated bonds, angles and torsions" # Add different softcore parameters every five test_cases. if n_test_cases % 5 == 0: for i, test_region in enumerate(test_regions): - test_regions[i] = test_region._replace(softcore_alpha=1.0, softcore_beta=1.0, softcore_a=1.0, softcore_b=1.0, - softcore_c=1.0, softcore_d=1.0, softcore_e=1.0, softcore_f=1.0) - test_case_name += ', modified softcore parameters' - - #region_interactions = frozenset(itertools.combinations(range(len(test_regions)), 2)) + test_regions[i] = test_region._replace( + softcore_alpha=1.0, + softcore_beta=1.0, + softcore_a=1.0, + softcore_b=1.0, + softcore_c=1.0, + softcore_d=1.0, + softcore_e=1.0, + softcore_f=1.0, + ) + test_case_name += ", modified softcore parameters" + + # region_interactions = frozenset(itertools.combinations(range(len(test_regions)), 2)) # Pre-generate alchemical system. - alchemical_system = direct_space_factory.create_alchemical_system(test_system.system, alchemical_regions = test_regions) + alchemical_system = direct_space_factory.create_alchemical_system( + test_system.system, alchemical_regions=test_regions + ) # Add test case. - cls.test_cases[test_case_name] = (test_system, alchemical_system, test_regions) + cls.test_cases[test_case_name] = ( + test_system, + alchemical_system, + test_regions, + ) n_test_cases += 1 # If we don't use softcore electrostatics and we annihilate charges # we can test also exact PME treatment. We don't increase n_test_cases # purposely to keep track of which tests are added above. - if (test_regions[1].softcore_beta == 0.0 and annihilate_electrostatics and - nonbonded_method in [openmm.NonbondedForce.PME, openmm.NonbondedForce.Ewald]): - alchemical_system = exact_pme_factory.create_alchemical_system(test_system.system, alchemical_regions = test_regions) - test_case_name += ', exact PME' - cls.test_cases[test_case_name] = (test_system, alchemical_system, test_regions) + if ( + test_regions[1].softcore_beta == 0.0 + and annihilate_electrostatics + and nonbonded_method + in [openmm.NonbondedForce.PME, openmm.NonbondedForce.Ewald] + ): + alchemical_system = exact_pme_factory.create_alchemical_system( + test_system.system, alchemical_regions=test_regions + ) + test_case_name += ", exact PME" + cls.test_cases[test_case_name] = ( + test_system, + alchemical_system, + test_regions, + ) # If the test system uses reaction field replace reaction field # of the reference system to allow comparisons. if nonbonded_method == openmm.NonbondedForce.CutoffPeriodic: - forcefactories.replace_reaction_field(test_system.system, return_copy=False, - switch_width=direct_space_factory.switch_width) + forcefactories.replace_reaction_field( + test_system.system, + return_copy=False, + switch_width=direct_space_factory.switch_width, + ) def test_split_force_groups(self): """Forces having different lambda variables should have a different force group.""" # Select 1 implicit, 1 explicit, and 1 exact PME explicit test case randomly. - test_cases = self.filter_cases(lambda x: 'Implicit' in x, max_number=1) - test_cases.update(self.filter_cases(lambda x: 'Explicit ' in x and 'exact PME' in x, max_number=1)) - test_cases.update(self.filter_cases(lambda x: 'Explicit ' in x and 'exact PME' not in x, max_number=1)) - for test_name, (test_system, alchemical_system, alchemical_region) in test_cases.items(): + test_cases = self.filter_cases(lambda x: "Implicit" in x, max_number=1) + test_cases.update( + self.filter_cases( + lambda x: "Explicit " in x and "exact PME" in x, max_number=1 + ) + ) + test_cases.update( + self.filter_cases( + lambda x: "Explicit " in x and "exact PME" not in x, max_number=1 + ) + ) + for test_name, ( + test_system, + alchemical_system, + alchemical_region, + ) in test_cases.items(): region_names = [] for region in alchemical_region: region_names.append(region.name) f = partial(check_split_force_groups, alchemical_system, region_names) - f.description = "Testing force splitting among groups of {}".format(test_name) + f.description = f"Testing force splitting among groups of {test_name}" yield f def test_noninteracting_energy_components(self): """Check all forces annihilated/decoupled when their lambda variables are zero.""" - for test_name, (test_system, alchemical_system, alchemical_region) in self.test_cases.items(): - f = partial(check_multi_noninteracting_energy_components, test_system.system, alchemical_system, - alchemical_region, test_system.positions) - f.description = "Testing non-interacting energy of {}".format(test_name) + for test_name, ( + test_system, + alchemical_system, + alchemical_region, + ) in self.test_cases.items(): + f = partial( + check_multi_noninteracting_energy_components, + test_system.system, + alchemical_system, + alchemical_region, + test_system.positions, + ) + f.description = f"Testing non-interacting energy of {test_name}" yield f - @attr('slow') + @attr("slow") def test_platforms(self): """Test interacting and noninteracting energies on all platforms.""" global GLOBAL_ALCHEMY_PLATFORM @@ -1714,43 +2419,70 @@ def test_platforms(self): default_platform_name = utils.get_fastest_platform().getName() else: default_platform_name = old_global_platform.getName() - platforms = [platform for platform in utils.get_available_platforms() - if platform.getName() != default_platform_name] + platforms = [ + platform + for platform in utils.get_available_platforms() + if platform.getName() != default_platform_name + ] # Test interacting and noninteracting energies on all platforms. for platform in platforms: GLOBAL_ALCHEMY_PLATFORM = platform - for test_name, (test_system, alchemical_system, alchemical_region) in self.test_cases.items(): - f = partial(compare_system_energies, test_system.system, alchemical_system, - alchemical_region, test_system.positions) - f.description = "Test fully interacting energy of {} on {}".format(test_name, platform.getName()) + for test_name, ( + test_system, + alchemical_system, + alchemical_region, + ) in self.test_cases.items(): + f = partial( + compare_system_energies, + test_system.system, + alchemical_system, + alchemical_region, + test_system.positions, + ) + f.description = f"Test fully interacting energy of {test_name} on {platform.getName()}" yield f - f = partial(check_multi_noninteracting_energy_components, test_system.system, alchemical_system, - alchemical_region, test_system.positions) - f.description = "Test non-interacting energy of {} on {}".format(test_name, platform.getName()) + f = partial( + check_multi_noninteracting_energy_components, + test_system.system, + alchemical_system, + alchemical_region, + test_system.positions, + ) + f.description = f"Test non-interacting energy of {test_name} on {platform.getName()}" yield f # Restore global platform GLOBAL_ALCHEMY_PLATFORM = old_global_platform - @attr('slow') + @attr("slow") def test_fully_interacting_energy_components(self): """Test interacting state energy by force component.""" # This is a very expensive but very informative test. We can # run this locally when test_fully_interacting_energies() fails. - test_cases = self.filter_cases(lambda x: 'Explicit' in x) - for test_name, (test_system, alchemical_system, alchemical_region) in test_cases.items(): - f = partial(check_multi_interacting_energy_components, test_system.system, alchemical_system, - alchemical_region, test_system.positions) + test_cases = self.filter_cases(lambda x: "Explicit" in x) + for test_name, ( + test_system, + alchemical_system, + alchemical_region, + ) in test_cases.items(): + f = partial( + check_multi_interacting_energy_components, + test_system.system, + alchemical_system, + alchemical_region, + test_system.positions, + ) f.description = "Testing energy components of %s..." % test_name yield f -class TestDispersionlessAlchemicalFactory(object): +class TestDispersionlessAlchemicalFactory: """ Only test overlap for dispersionless alchemical factory, since energy agreement will be poor. """ + @classmethod def setup_class(cls): """Create test systems and shared objects.""" @@ -1762,20 +2494,25 @@ def setup_class(cls): def define_systems(cls): """Create test systems and shared objects.""" cls.test_systems = dict() - cls.test_systems['LennardJonesFluid with dispersion correction'] = \ + cls.test_systems["LennardJonesFluid with dispersion correction"] = ( testsystems.LennardJonesFluid(nparticles=100, dispersion_correction=True) + ) @classmethod def define_regions(cls): """Create shared AlchemicalRegions for test systems in cls.test_regions.""" cls.test_regions = dict() - cls.test_regions['LennardJonesFluid'] = AlchemicalRegion(alchemical_atoms=range(10)) + cls.test_regions["LennardJonesFluid"] = AlchemicalRegion( + alchemical_atoms=range(10) + ) @classmethod def generate_cases(cls): """Generate all test cases in cls.test_cases combinatorially.""" cls.test_cases = dict() - factory = AbsoluteAlchemicalFactory(disable_alchemical_dispersion_correction=True) + factory = AbsoluteAlchemicalFactory( + disable_alchemical_dispersion_correction=True + ) # We generate all possible combinations of annihilate_sterics/electrostatics # for each test system. We also annihilate bonds, angles and torsions every @@ -1783,7 +2520,6 @@ def generate_cases(cls): # each combination of annihilate_sterics/electrostatics. n_test_cases = 0 for test_system_name, test_system in cls.test_systems.items(): - # Find standard alchemical region. for region_name, region in cls.test_regions.items(): if region_name in test_system_name: @@ -1792,33 +2528,51 @@ def generate_cases(cls): # Create all combinations of annihilate_sterics. for annihilate_sterics in itertools.product((True, False), repeat=1): - region = region._replace(annihilate_sterics=annihilate_sterics, - annihilate_electrostatics=True) + region = region._replace( + annihilate_sterics=annihilate_sterics, + annihilate_electrostatics=True, + ) # Create test name. test_case_name = test_system_name[:] if annihilate_sterics: - test_case_name += ', annihilated sterics' + test_case_name += ", annihilated sterics" # Pre-generate alchemical system - alchemical_system = factory.create_alchemical_system(test_system.system, region) - cls.test_cases[test_case_name] = (test_system, alchemical_system, region) + alchemical_system = factory.create_alchemical_system( + test_system.system, region + ) + cls.test_cases[test_case_name] = ( + test_system, + alchemical_system, + region, + ) n_test_cases += 1 def test_overlap(self): """Tests overlap between reference and alchemical systems.""" - for test_name, (test_system, alchemical_system, alchemical_region) in self.test_cases.items(): - #cached_trajectory_filename = os.path.join(os.environ['HOME'], '.cache', 'alchemy', 'tests', + for test_name, ( + test_system, + alchemical_system, + alchemical_region, + ) in self.test_cases.items(): + # cached_trajectory_filename = os.path.join(os.environ['HOME'], '.cache', 'alchemy', 'tests', # test_name + '.pickle') cached_trajectory_filename = None - f = partial(overlap_check, test_system.system, alchemical_system, test_system.positions, - cached_trajectory_filename=cached_trajectory_filename, name=test_name) - f.description = "Testing reference/alchemical overlap for no alchemical dispersion {}".format(test_name) + f = partial( + overlap_check, + test_system.system, + alchemical_system, + test_system.positions, + cached_trajectory_filename=cached_trajectory_filename, + name=test_name, + ) + f.description = f"Testing reference/alchemical overlap for no alchemical dispersion {test_name}" yield f -@attr('slow') +@attr("slow") class TestAbsoluteAlchemicalFactorySlow(TestAbsoluteAlchemicalFactory): """Test AbsoluteAlchemicalFactory class with a more comprehensive set of systems.""" @@ -1826,42 +2580,68 @@ class TestAbsoluteAlchemicalFactorySlow(TestAbsoluteAlchemicalFactory): def define_systems(cls): """Create test systems and shared objects.""" cls.test_systems = dict() - cls.test_systems['LennardJonesFluid without dispersion correction'] = \ + cls.test_systems["LennardJonesFluid without dispersion correction"] = ( testsystems.LennardJonesFluid(nparticles=100, dispersion_correction=False) - cls.test_systems['DischargedWaterBox with reaction field, no switch, no dispersion correction'] = \ - testsystems.DischargedWaterBox(dispersion_correction=False, switch=False, - nonbondedMethod=openmm.app.CutoffPeriodic) - cls.test_systems['WaterBox with reaction field, no switch, dispersion correction'] = \ - testsystems.WaterBox(dispersion_correction=False, switch=True, nonbondedMethod=openmm.app.CutoffPeriodic) - cls.test_systems['WaterBox with reaction field, switch, no dispersion correction'] = \ - testsystems.WaterBox(dispersion_correction=False, switch=True, nonbondedMethod=openmm.app.CutoffPeriodic) - cls.test_systems['WaterBox with PME, switch, dispersion correction'] = \ - testsystems.WaterBox(dispersion_correction=True, switch=True, nonbondedMethod=openmm.app.PME) + ) + cls.test_systems[ + "DischargedWaterBox with reaction field, no switch, no dispersion correction" + ] = testsystems.DischargedWaterBox( + dispersion_correction=False, + switch=False, + nonbondedMethod=openmm.app.CutoffPeriodic, + ) + cls.test_systems[ + "WaterBox with reaction field, no switch, dispersion correction" + ] = testsystems.WaterBox( + dispersion_correction=False, + switch=True, + nonbondedMethod=openmm.app.CutoffPeriodic, + ) + cls.test_systems[ + "WaterBox with reaction field, switch, no dispersion correction" + ] = testsystems.WaterBox( + dispersion_correction=False, + switch=True, + nonbondedMethod=openmm.app.CutoffPeriodic, + ) + cls.test_systems["WaterBox with PME, switch, dispersion correction"] = ( + testsystems.WaterBox( + dispersion_correction=True, switch=True, nonbondedMethod=openmm.app.PME + ) + ) # Big systems. - cls.test_systems['LysozymeImplicit'] = testsystems.LysozymeImplicit() - cls.test_systems['DHFRExplicit with reaction field'] = \ - testsystems.DHFRExplicit(nonbondedMethod=openmm.app.CutoffPeriodic) - cls.test_systems['SrcExplicit with PME'] = \ - testsystems.SrcExplicit(nonbondedMethod=openmm.app.PME) - cls.test_systems['SrcExplicit with reaction field'] = \ - testsystems.SrcExplicit(nonbondedMethod=openmm.app.CutoffPeriodic) - cls.test_systems['SrcImplicit'] = testsystems.SrcImplicit() + cls.test_systems["LysozymeImplicit"] = testsystems.LysozymeImplicit() + cls.test_systems["DHFRExplicit with reaction field"] = testsystems.DHFRExplicit( + nonbondedMethod=openmm.app.CutoffPeriodic + ) + cls.test_systems["SrcExplicit with PME"] = testsystems.SrcExplicit( + nonbondedMethod=openmm.app.PME + ) + cls.test_systems["SrcExplicit with reaction field"] = testsystems.SrcExplicit( + nonbondedMethod=openmm.app.CutoffPeriodic + ) + cls.test_systems["SrcImplicit"] = testsystems.SrcImplicit() @classmethod def define_regions(cls): - super(TestAbsoluteAlchemicalFactorySlow, cls).define_regions() - cls.test_regions['WaterBox'] = AlchemicalRegion(alchemical_atoms=range(3)) - cls.test_regions['LysozymeImplicit'] = AlchemicalRegion(alchemical_atoms=range(2603, 2621)) - cls.test_regions['DHFRExplicit'] = AlchemicalRegion(alchemical_atoms=range(0, 2849)) - cls.test_regions['Src'] = AlchemicalRegion(alchemical_atoms=range(0, 21)) + super().define_regions() + cls.test_regions["WaterBox"] = AlchemicalRegion(alchemical_atoms=range(3)) + cls.test_regions["LysozymeImplicit"] = AlchemicalRegion( + alchemical_atoms=range(2603, 2621) + ) + cls.test_regions["DHFRExplicit"] = AlchemicalRegion( + alchemical_atoms=range(0, 2849) + ) + cls.test_regions["Src"] = AlchemicalRegion(alchemical_atoms=range(0, 21)) # ============================================================================= # TEST ALCHEMICAL STATE # ============================================================================= -class TestAlchemicalState(object): + +class TestAlchemicalState: """Test AlchemicalState compatibility with CompoundThermodynamicState.""" @classmethod @@ -1870,38 +2650,62 @@ def setup_class(cls): alanine_vacuum = testsystems.AlanineDipeptideVacuum() alanine_explicit = testsystems.AlanineDipeptideExplicit() factory = AbsoluteAlchemicalFactory() - factory_exact_pme = AbsoluteAlchemicalFactory(alchemical_pme_treatment='exact') + factory_exact_pme = AbsoluteAlchemicalFactory(alchemical_pme_treatment="exact") cls.alanine_alchemical_atoms = list(range(22)) cls.alanine_test_system = alanine_explicit # System with only lambda_sterics and lambda_electrostatics. - alchemical_region = AlchemicalRegion(alchemical_atoms=cls.alanine_alchemical_atoms) - alchemical_alanine_system = factory.create_alchemical_system(alanine_vacuum.system, alchemical_region) - cls.alanine_state = states.ThermodynamicState(alchemical_alanine_system, - temperature=300*unit.kelvin) + alchemical_region = AlchemicalRegion( + alchemical_atoms=cls.alanine_alchemical_atoms + ) + alchemical_alanine_system = factory.create_alchemical_system( + alanine_vacuum.system, alchemical_region + ) + cls.alanine_state = states.ThermodynamicState( + alchemical_alanine_system, temperature=300 * unit.kelvin + ) # System with lambda_sterics and lambda_electrostatics and exact PME treatment. - alchemical_alanine_system_exact_pme = factory_exact_pme.create_alchemical_system(alanine_explicit.system, - alchemical_region) - cls.alanine_state_exact_pme = states.ThermodynamicState(alchemical_alanine_system_exact_pme, - temperature=300*unit.kelvin, - pressure=1.0*unit.atmosphere) + alchemical_alanine_system_exact_pme = ( + factory_exact_pme.create_alchemical_system( + alanine_explicit.system, alchemical_region + ) + ) + cls.alanine_state_exact_pme = states.ThermodynamicState( + alchemical_alanine_system_exact_pme, + temperature=300 * unit.kelvin, + pressure=1.0 * unit.atmosphere, + ) # System with all lambdas. - alchemical_region = AlchemicalRegion(alchemical_atoms=cls.alanine_alchemical_atoms, - alchemical_torsions=True, alchemical_angles=True, - alchemical_bonds=True) - fully_alchemical_alanine_system = factory.create_alchemical_system(alanine_vacuum.system, alchemical_region) - cls.full_alanine_state = states.ThermodynamicState(fully_alchemical_alanine_system, - temperature=300*unit.kelvin) + alchemical_region = AlchemicalRegion( + alchemical_atoms=cls.alanine_alchemical_atoms, + alchemical_torsions=True, + alchemical_angles=True, + alchemical_bonds=True, + ) + fully_alchemical_alanine_system = factory.create_alchemical_system( + alanine_vacuum.system, alchemical_region + ) + cls.full_alanine_state = states.ThermodynamicState( + fully_alchemical_alanine_system, temperature=300 * unit.kelvin + ) # Test case: (ThermodynamicState, defined_lambda_parameters) cls.test_cases = [ - (cls.alanine_state, {'lambda_sterics', 'lambda_electrostatics'}), - (cls.alanine_state_exact_pme, {'lambda_sterics', 'lambda_electrostatics'}), - (cls.full_alanine_state, {'lambda_sterics', 'lambda_electrostatics', 'lambda_bonds', - 'lambda_angles', 'lambda_torsions'}) + (cls.alanine_state, {"lambda_sterics", "lambda_electrostatics"}), + (cls.alanine_state_exact_pme, {"lambda_sterics", "lambda_electrostatics"}), + ( + cls.full_alanine_state, + { + "lambda_sterics", + "lambda_electrostatics", + "lambda_bonds", + "lambda_angles", + "lambda_torsions", + }, + ), ] @staticmethod @@ -1912,14 +2716,18 @@ def test_constructor(): AlchemicalState(lambda_electro=1.0) # Properties are initialized correctly. - test_cases = [{}, - {'lambda_sterics': 0.5, 'lambda_angles': 0.5}, - {'lambda_electrostatics': 1.0}] + test_cases = [ + {}, + {"lambda_sterics": 0.5, "lambda_angles": 0.5}, + {"lambda_electrostatics": 1.0}, + ] for test_kwargs in test_cases: alchemical_state = AlchemicalState(**test_kwargs) for parameter in AlchemicalState._get_controlled_parameters(): if parameter in test_kwargs: - assert getattr(alchemical_state, parameter) == test_kwargs[parameter] + assert ( + getattr(alchemical_state, parameter) == test_kwargs[parameter] + ) else: assert getattr(alchemical_state, parameter) is None @@ -1936,9 +2744,9 @@ def test_from_system_constructor(self): for parameter in AlchemicalState._get_controlled_parameters(): property_value = getattr(alchemical_state, parameter) if parameter in defined_lambdas: - assert property_value == 1.0, '{}: {}'.format(parameter, property_value) + assert property_value == 1.0, f"{parameter}: {property_value}" else: - assert property_value is None, '{}: {}'.format(parameter, property_value) + assert property_value is None, f"{parameter}: {property_value}" @staticmethod def test_equality_operator(): @@ -1982,10 +2790,10 @@ def test_apply_to_system(self): # Raise an error if an extra parameter is defined in the state. for state, defined_lambdas in test_cases: - if 'lambda_bonds' in defined_lambdas: + if "lambda_bonds" in defined_lambdas: continue defined_lambdas = set(defined_lambdas) # Copy - defined_lambdas.add('lambda_bonds') # Add extra parameter. + defined_lambdas.add("lambda_bonds") # Add extra parameter. kwargs = dict.fromkeys(defined_lambdas, 1.0) alchemical_state = AlchemicalState(**kwargs) with nose.tools.assert_raises(AlchemicalStateError): @@ -2013,7 +2821,7 @@ def test_check_system_consistency(self): def test_apply_to_context(self): """Test method AlchemicalState.apply_to_context.""" - integrator = openmm.VerletIntegrator(1.0*unit.femtosecond) + integrator = openmm.VerletIntegrator(1.0 * unit.femtosecond) # Raise error if Context has more parameters than AlchemicalState. alchemical_state = AlchemicalState.from_system(self.alanine_state.system) @@ -2060,7 +2868,9 @@ def test_standardize_system(self): assert alchemical_state != standard_alchemical_state for parameter_name, value in alchemical_state._parameters.items(): standard_value = getattr(standard_alchemical_state, parameter_name) - assert (value is None and standard_value is None) or (standard_value == 1.0) + assert (value is None and standard_value is None) or ( + standard_value == 1.0 + ) def test_find_force_groups_to_update(self): """Test method AlchemicalState._find_force_groups_to_update.""" @@ -2073,15 +2883,25 @@ def test_find_force_groups_to_update(self): # Each lambda should be separated in its own force group. expected_force_groups = {} - for force, lambda_name, _ in AlchemicalState._get_system_controlled_parameters( - system, parameters_name_suffix=None): + for ( + force, + lambda_name, + _, + ) in AlchemicalState._get_system_controlled_parameters( + system, parameters_name_suffix=None + ): expected_force_groups[lambda_name] = force.getForceGroup() - integrator = openmm.VerletIntegrator(2.0*unit.femtoseconds) + integrator = openmm.VerletIntegrator(2.0 * unit.femtoseconds) context = create_context(system, integrator) # No force group should be updated if we don't move. - assert alchemical_state._find_force_groups_to_update(context, alchemical_state2, memo={}) == set() + assert ( + alchemical_state._find_force_groups_to_update( + context, alchemical_state2, memo={} + ) + == set() + ) # Change the lambdas one by one and check that the method # recognize that the force group energy must be updated. @@ -2093,7 +2913,9 @@ def test_find_force_groups_to_update(self): # Change the current state. setattr(alchemical_state2, lambda_name, 0.0) force_group = expected_force_groups[lambda_name] - assert alchemical_state._find_force_groups_to_update(context, alchemical_state2, memo={}) == {force_group} + assert alchemical_state._find_force_groups_to_update( + context, alchemical_state2, memo={} + ) == {force_group} setattr(alchemical_state2, lambda_name, 1.0) # Reset current state. del context @@ -2103,23 +2925,25 @@ def test_alchemical_functions(self): alchemical_state = AlchemicalState.from_system(system) # Add two alchemical variables to the state. - alchemical_state.set_function_variable('lambda', 1.0) - alchemical_state.set_function_variable('lambda2', 0.5) - assert alchemical_state.get_function_variable('lambda') == 1.0 - assert alchemical_state.get_function_variable('lambda2') == 0.5 + alchemical_state.set_function_variable("lambda", 1.0) + alchemical_state.set_function_variable("lambda2", 0.5) + assert alchemical_state.get_function_variable("lambda") == 1.0 + assert alchemical_state.get_function_variable("lambda2") == 0.5 # Cannot call an alchemical variable as a supported parameter. with nose.tools.assert_raises(AlchemicalStateError): - alchemical_state.set_function_variable('lambda_sterics', 0.5) + alchemical_state.set_function_variable("lambda_sterics", 0.5) # Assign string alchemical functions to parameters. - alchemical_state.lambda_sterics = AlchemicalFunction('lambda') - alchemical_state.lambda_electrostatics = AlchemicalFunction('(lambda + lambda2) / 2.0') + alchemical_state.lambda_sterics = AlchemicalFunction("lambda") + alchemical_state.lambda_electrostatics = AlchemicalFunction( + "(lambda + lambda2) / 2.0" + ) assert alchemical_state.lambda_sterics == 1.0 assert alchemical_state.lambda_electrostatics == 0.75 # Setting alchemical variables updates alchemical parameter as well. - alchemical_state.set_function_variable('lambda2', 0) + alchemical_state.set_function_variable("lambda2", 0) assert alchemical_state.lambda_electrostatics == 0.5 # --------------------------------------------------- @@ -2140,7 +2964,9 @@ def test_constructor_compound_state(self): for state, defined_lambdas in test_cases: kwargs = dict.fromkeys(defined_lambdas, 0.5) alchemical_state = AlchemicalState(**kwargs) - compound_state = states.CompoundThermodynamicState(state, [alchemical_state]) + compound_state = states.CompoundThermodynamicState( + state, [alchemical_state] + ) system_state = AlchemicalState.from_system(compound_state.system) assert system_state == alchemical_state @@ -2150,7 +2976,9 @@ def test_lambda_properties_compound_state(self): for state, defined_lambdas in test_cases: alchemical_state = AlchemicalState.from_system(state.system) - compound_state = states.CompoundThermodynamicState(state, [alchemical_state]) + compound_state = states.CompoundThermodynamicState( + state, [alchemical_state] + ) # Defined properties can be assigned and read. for parameter_name in defined_lambdas: @@ -2171,9 +2999,9 @@ def test_lambda_properties_compound_state(self): assert getattr(system_alchemical_state, parameter_name) == 1.0 # Same for alchemical variables setters. - compound_state.set_function_variable('lambda', 0.25) + compound_state.set_function_variable("lambda", 0.25) for parameter_name in defined_lambdas: - setattr(compound_state, parameter_name, AlchemicalFunction('lambda')) + setattr(compound_state, parameter_name, AlchemicalFunction("lambda")) system_alchemical_state = AlchemicalState.from_system(compound_state.system) for parameter_name in defined_lambdas: assert getattr(compound_state, parameter_name) == 0.25 @@ -2183,7 +3011,9 @@ def test_set_system_compound_state(self): """Setting inconsistent system in compound state raise errors.""" alanine_state = copy.deepcopy(self.alanine_state) alchemical_state = AlchemicalState.from_system(alanine_state.system) - compound_state = states.CompoundThermodynamicState(alanine_state, [alchemical_state]) + compound_state = states.CompoundThermodynamicState( + alanine_state, [alchemical_state] + ) # We create an inconsistent state that has different parameters. incompatible_state = copy.deepcopy(alchemical_state) @@ -2209,33 +3039,43 @@ def test_method_compatibility_compound_state(self): # An incompatible state has a different set of defined lambdas. full_alanine_state = copy.deepcopy(self.full_alanine_state) - alchemical_state_incompatible = AlchemicalState.from_system(full_alanine_state.system) - compound_state_incompatible = states.CompoundThermodynamicState(full_alanine_state, - [alchemical_state_incompatible]) + alchemical_state_incompatible = AlchemicalState.from_system( + full_alanine_state.system + ) + compound_state_incompatible = states.CompoundThermodynamicState( + full_alanine_state, [alchemical_state_incompatible] + ) for state in test_cases: state = copy.deepcopy(state) alchemical_state = AlchemicalState.from_system(state.system) - compound_state = states.CompoundThermodynamicState(state, [alchemical_state]) + compound_state = states.CompoundThermodynamicState( + state, [alchemical_state] + ) # A compatible state has the same defined lambda parameters, # but their values can be different. alchemical_state_compatible = copy.deepcopy(alchemical_state) assert alchemical_state.lambda_electrostatics != 0.5 # Test pre-condition. alchemical_state_compatible.lambda_electrostatics = 0.5 - compound_state_compatible = states.CompoundThermodynamicState(copy.deepcopy(state), - [alchemical_state_compatible]) + compound_state_compatible = states.CompoundThermodynamicState( + copy.deepcopy(state), [alchemical_state_compatible] + ) # Test states compatibility. assert compound_state.is_state_compatible(compound_state_compatible) assert not compound_state.is_state_compatible(compound_state_incompatible) # Test context compatibility. - integrator = openmm.VerletIntegrator(1.0*unit.femtosecond) - context = compound_state_compatible.create_context(copy.deepcopy(integrator)) + integrator = openmm.VerletIntegrator(1.0 * unit.femtosecond) + context = compound_state_compatible.create_context( + copy.deepcopy(integrator) + ) assert compound_state.is_context_compatible(context) - context = compound_state_incompatible.create_context(copy.deepcopy(integrator)) + context = compound_state_incompatible.create_context( + copy.deepcopy(integrator) + ) assert not compound_state.is_context_compatible(context) @staticmethod @@ -2246,7 +3086,9 @@ def _check_compatibility(state1, state2, context_state1, is_compatible): assert state2.is_state_compatible(state1) is is_compatible # Test context incompatibility is commutative. - context_state2 = state2.create_context(openmm.VerletIntegrator(1.0*unit.femtosecond)) + context_state2 = state2.create_context( + openmm.VerletIntegrator(1.0 * unit.femtosecond) + ) assert state2.is_context_compatible(context_state1) is is_compatible assert state1.is_context_compatible(context_state2) is is_compatible del context_state2 @@ -2260,21 +3102,24 @@ def test_method_reduced_potential_compound_state(self): # Build a mixed collection of compatible and incompatible thermodynamic states. thermodynamic_states = [ copy.deepcopy(self.alanine_state), - copy.deepcopy(self.alanine_state_exact_pme) + copy.deepcopy(self.alanine_state_exact_pme), ] alchemical_states = [ AlchemicalState(lambda_electrostatics=1.0, lambda_sterics=1.0), AlchemicalState(lambda_electrostatics=0.5, lambda_sterics=1.0), AlchemicalState(lambda_electrostatics=0.5, lambda_sterics=0.0), - AlchemicalState(lambda_electrostatics=1.0, lambda_sterics=1.0) + AlchemicalState(lambda_electrostatics=1.0, lambda_sterics=1.0), ] compound_states = [] for thermo_state in thermodynamic_states: for alchemical_state in alchemical_states: - compound_states.append(states.CompoundThermodynamicState( - copy.deepcopy(thermo_state), [copy.deepcopy(alchemical_state)])) + compound_states.append( + states.CompoundThermodynamicState( + copy.deepcopy(thermo_state), [copy.deepcopy(alchemical_state)] + ) + ) # Group thermodynamic states by compatibility. compatible_groups, _ = states.group_by_compatibility(compound_states) @@ -2285,9 +3130,11 @@ def test_method_reduced_potential_compound_state(self): obtained_energies = [] for compatible_group in compatible_groups: # Create context. - integrator = openmm.VerletIntegrator(2.0*unit.femtoseconds) + integrator = openmm.VerletIntegrator(2.0 * unit.femtoseconds) context = compatible_group[0].create_context(integrator) - context.setPositions(self.alanine_test_system.positions[:compatible_group[0].n_particles]) + context.setPositions( + self.alanine_test_system.positions[: compatible_group[0].n_particles] + ) # Compute with single-state method. for state in compatible_group: @@ -2295,7 +3142,9 @@ def test_method_reduced_potential_compound_state(self): expected_energies.append(state.reduced_potential(context)) # Compute with multi-state method. - compatible_energies = states.ThermodynamicState.reduced_potential_at_states(context, compatible_group) + compatible_energies = states.ThermodynamicState.reduced_potential_at_states( + context, compatible_group + ) # The first and the last state must be equal. assert np.isclose(compatible_energies[0], compatible_energies[-1]) @@ -2305,9 +3154,11 @@ def test_method_reduced_potential_compound_state(self): def test_serialization(self): """Test AlchemicalState serialization alone and in a compound state.""" - alchemical_state = AlchemicalState(lambda_electrostatics=0.5, lambda_angles=None) - alchemical_state.set_function_variable('lambda', 0.0) - alchemical_state.lambda_sterics = AlchemicalFunction('lambda') + alchemical_state = AlchemicalState( + lambda_electrostatics=0.5, lambda_angles=None + ) + alchemical_state.set_function_variable("lambda", 0.0) + alchemical_state.lambda_sterics = AlchemicalFunction("lambda") # Test serialization/deserialization of AlchemicalState. serialization = utils.serialize(alchemical_state) @@ -2317,17 +3168,28 @@ def test_serialization(self): assert original_pickle == deserialized_pickle # Test serialization/deserialization of AlchemicalState in CompoundState. - test_cases = [copy.deepcopy(self.alanine_state), copy.deepcopy(self.alanine_state_exact_pme)] + test_cases = [ + copy.deepcopy(self.alanine_state), + copy.deepcopy(self.alanine_state_exact_pme), + ] for thermodynamic_state in test_cases: - compound_state = states.CompoundThermodynamicState(thermodynamic_state, [alchemical_state]) + compound_state = states.CompoundThermodynamicState( + thermodynamic_state, [alchemical_state] + ) # The serialized system is standard. serialization = utils.serialize(compound_state) - serialized_standard_system = serialization['thermodynamic_state']['standard_system'] + serialized_standard_system = serialization["thermodynamic_state"][ + "standard_system" + ] # Decompress the serialized_system - serialized_standard_system = zlib.decompress(serialized_standard_system).decode( - states.ThermodynamicState._ENCODING) - assert serialized_standard_system.__hash__() == compound_state._standard_system_hash + serialized_standard_system = zlib.decompress( + serialized_standard_system + ).decode(states.ThermodynamicState._ENCODING) + assert ( + serialized_standard_system.__hash__() + == compound_state._standard_system_hash + ) # The object is deserialized correctly. deserialized_state = utils.deserialize(serialization) diff --git a/openmmtools/tests/test_cache.py b/openmmtools/tests/test_cache.py index 6cbe82862..ead1325be 100644 --- a/openmmtools/tests/test_cache.py +++ b/openmmtools/tests/test_cache.py @@ -16,6 +16,7 @@ import itertools import nose + try: from openmm import unit except ImportError: # OpenMM < 7.6 @@ -30,26 +31,27 @@ # TEST LRU CACHE # ============================================================================= + def test_lru_cache_cache_entry_unpacking(): """Values in LRUCache are unpacked from CacheEntry.""" cache = LRUCache(capacity=5) - cache['first'] = 1 - assert cache['first'] == 1 + cache["first"] = 1 + assert cache["first"] == 1 # When we don't require a time-to-leave, there's # no expiration attribute in the cache entry. - assert not hasattr(cache._data['first'], 'expiration') + assert not hasattr(cache._data["first"], "expiration") def test_lru_cache_maximum_capacity(): """Maximum size of LRUCache is handled correctly.""" cache = LRUCache(capacity=2) - cache['first'] = 1 - cache['second'] = 2 + cache["first"] = 1 + cache["second"] = 2 assert len(cache) == 2 - cache['third'] = 3 + cache["third"] = 3 assert len(cache) == 2 - assert 'first' not in cache + assert "first" not in cache # Test infinite capacity cache = LRUCache() @@ -61,56 +63,57 @@ def test_lru_cache_maximum_capacity(): def test_lru_cache_eliminate_least_recently_used(): """LRUCache deletes LRU element when size exceeds capacity.""" cache = LRUCache(capacity=3) - cache['first'] = 1 - cache['second'] = 2 + cache["first"] = 1 + cache["second"] = 2 # We access 'first' through setting, so that it becomes the LRU. - cache['first'] = 1 - cache['third'] = 3 - cache['fourth'] = 4 # Here size exceed capacity. + cache["first"] = 1 + cache["third"] = 3 + cache["fourth"] = 4 # Here size exceed capacity. assert len(cache) == 3 - assert 'second' not in cache + assert "second" not in cache # We access 'first' through getting now. - cache['first'] - cache['fifth'] = 5 # Size exceed capacity. + cache["first"] + cache["fifth"] = 5 # Size exceed capacity. assert len(cache) == 3 - assert 'third' not in cache + assert "third" not in cache def test_lru_cache_access_to_live(): """LRUCache deletes element after specified number of accesses.""" + def almost_expire_first(): - cache['first'] = 1 # Update expiration. + cache["first"] = 1 # Update expiration. for _ in range(ttl - 1): - cache['second'] - assert 'first' in cache + cache["second"] + assert "first" in cache ttl = 3 cache = LRUCache(capacity=2, time_to_live=ttl) - cache['first'] = 1 - cache['second'] = 2 # First access. - assert cache._data['first'].expiration == ttl + 1 - cache['first'] # Expiration gets updated. - assert cache._data['first'].expiration == ttl + 3 + cache["first"] = 1 + cache["second"] = 2 # First access. + assert cache._data["first"].expiration == ttl + 1 + cache["first"] # Expiration gets updated. + assert cache._data["first"].expiration == ttl + 3 # At the ttl-th read access, 'first' gets deleted. almost_expire_first() - cache['second'] - assert 'second' in cache - assert 'first' not in cache + cache["second"] + assert "second" in cache + assert "first" not in cache # The same happen at the ttl-th write access. almost_expire_first() - cache['second'] = 2 - assert 'second' in cache - assert 'first' not in cache + cache["second"] = 2 + assert "second" in cache + assert "first" not in cache # If we touch at the last minute 'first', it remains in memory. almost_expire_first() - cache['first'] - assert 'second' in cache - assert 'first' in cache + cache["first"] + assert "second" in cache + assert "first" in cache def test_lru_cache_capacity_property(): @@ -122,11 +125,12 @@ def test_lru_cache_capacity_property(): cache.capacity = 1 assert len(cache) == 1 assert cache.capacity == 1 - assert str(capacity-1) in cache + assert str(capacity - 1) in cache def test_lru_cache_time_to_live_property(): """Decreasing the time to live updates the expiration of elements.""" + def add_4_elements(_cache): for i in range(4): _cache[str(i)] = i @@ -139,15 +143,15 @@ def add_4_elements(_cache): cache.time_to_live = 1 assert len(cache) == 4 assert cache.time_to_live == 1 - cache['4'] = 4 + cache["4"] = 4 assert len(cache) == 1 - assert '4' in cache + assert "4" in cache # Increase time_to_live. cache.time_to_live = 2 add_4_elements(cache) assert len(cache) == 2 - assert '2' in cache and '3' in cache + assert "2" in cache and "3" in cache # Setting it back to None makes it limitless. cache.time_to_live = None @@ -159,22 +163,25 @@ def add_4_elements(_cache): # TEST CONTEXT CACHE # ============================================================================= -class TestContextCache(object): + +class TestContextCache: """Test ContextCache class.""" @classmethod def setup_class(cls): """Create the thermodynamic states used in the test suite.""" - water_test = testsystems.WaterBox(box_edge=2.0*unit.nanometer) - cls.water_300k = states.ThermodynamicState(water_test.system, 300*unit.kelvin) - cls.water_310k = states.ThermodynamicState(water_test.system, 310*unit.kelvin) - cls.water_310k_1atm = states.ThermodynamicState(water_test.system, 310*unit.kelvin, - 1*unit.atmosphere) + water_test = testsystems.WaterBox(box_edge=2.0 * unit.nanometer) + cls.water_300k = states.ThermodynamicState(water_test.system, 300 * unit.kelvin) + cls.water_310k = states.ThermodynamicState(water_test.system, 310 * unit.kelvin) + cls.water_310k_1atm = states.ThermodynamicState( + water_test.system, 310 * unit.kelvin, 1 * unit.atmosphere + ) - cls.verlet_2fs = openmm.VerletIntegrator(2.0*unit.femtosecond) - cls.verlet_3fs = openmm.VerletIntegrator(3.0*unit.femtosecond) - cls.langevin_2fs_310k = openmm.LangevinIntegrator(310*unit.kelvin, 5.0/unit.picosecond, - 2.0*unit.femtosecond) + cls.verlet_2fs = openmm.VerletIntegrator(2.0 * unit.femtosecond) + cls.verlet_3fs = openmm.VerletIntegrator(3.0 * unit.femtosecond) + cls.langevin_2fs_310k = openmm.LangevinIntegrator( + 310 * unit.kelvin, 5.0 / unit.picosecond, 2.0 * unit.femtosecond + ) cls.compatible_states = [cls.water_300k, cls.water_310k] cls.compatible_integrators = [cls.verlet_2fs, cls.verlet_3fs] @@ -186,8 +193,9 @@ def setup_class(cls): def cache_incompatible_contexts(cls, cache): """Return the number of contexts created.""" context_ids = set() - for state, integrator in itertools.product(cls.incompatible_states, - cls.incompatible_integrators): + for state, integrator in itertools.product( + cls.incompatible_states, cls.incompatible_integrators + ): # Avoid binding same integrator to multiple contexts integrator = copy.deepcopy(integrator) context, context_integrator = cache.get_context(state, integrator) @@ -199,14 +207,21 @@ def test_copy_integrator_state(self): # Each test case has two integrators of the same class with a different state. test_cases = [ # The Langevin integrators require using setter/getters. - [copy.deepcopy(self.langevin_2fs_310k), - openmm.LangevinIntegrator(300*unit.kelvin, 8.0/unit.picosecond, - 3.0*unit.femtosecond)], + [ + copy.deepcopy(self.langevin_2fs_310k), + openmm.LangevinIntegrator( + 300 * unit.kelvin, 8.0 / unit.picosecond, 3.0 * unit.femtosecond + ), + ], # The Langevin splittin integrator requires setting global variables. - [integrators.LangevinIntegrator(temperature=270*unit.kelvin, - collision_rate=90/unit.picoseconds), - integrators.LangevinIntegrator(temperature=270*unit.kelvin, - collision_rate=180/unit.picoseconds)], + [ + integrators.LangevinIntegrator( + temperature=270 * unit.kelvin, collision_rate=90 / unit.picoseconds + ), + integrators.LangevinIntegrator( + temperature=270 * unit.kelvin, collision_rate=180 / unit.picoseconds + ), + ], ] for integrator1, integrator2 in copy.deepcopy(test_cases): @@ -220,17 +235,19 @@ def read_attribute(integrator, attribute_name): return getattr(integrator, attribute_name) except AttributeError: try: - return getattr(integrator, 'get' + attribute_name)() + return getattr(integrator, "get" + attribute_name)() except AttributeError: return integrator.getGlobalVariableByName(attribute_name) - test_cases[0].append('Temperature') # Getter/setter. - test_cases[1].append('a') # Global variable. + test_cases[0].append("Temperature") # Getter/setter. + test_cases[1].append("a") # Global variable. for integrator1, integrator2, incompatible_attribute in test_cases: ContextCache.INCOMPATIBLE_INTEGRATOR_ATTRIBUTES.add(incompatible_attribute) - if incompatible_attribute == 'Temperature': - old_standard_value = ContextCache.COMPATIBLE_INTEGRATOR_ATTRIBUTES.pop(incompatible_attribute) + if incompatible_attribute == "Temperature": + old_standard_value = ContextCache.COMPATIBLE_INTEGRATOR_ATTRIBUTES.pop( + incompatible_attribute + ) integ = [integrator1, integrator2] old_values = [read_attribute(i, incompatible_attribute) for i in integ] @@ -239,23 +256,29 @@ def read_attribute(integrator, attribute_name): new_values = [read_attribute(i, incompatible_attribute) for i in integ] assert old_values == new_values - ContextCache.INCOMPATIBLE_INTEGRATOR_ATTRIBUTES.remove(incompatible_attribute) - if incompatible_attribute == 'Temperature': - ContextCache.COMPATIBLE_INTEGRATOR_ATTRIBUTES[incompatible_attribute] = old_standard_value + ContextCache.INCOMPATIBLE_INTEGRATOR_ATTRIBUTES.remove( + incompatible_attribute + ) + if incompatible_attribute == "Temperature": + ContextCache.COMPATIBLE_INTEGRATOR_ATTRIBUTES[ + incompatible_attribute + ] = old_standard_value def test_generate_compatible_context_key(self): """ContextCache._generate_context_id creates same id for compatible contexts.""" all_ids = set() - for state, integrator in itertools.product(self.compatible_states, - self.compatible_integrators): + for state, integrator in itertools.product( + self.compatible_states, self.compatible_integrators + ): all_ids.add(ContextCache._generate_context_id(state, integrator)) assert len(all_ids) == 1 def test_generate_incompatible_context_key(self): """ContextCache._generate_context_id creates different ids for incompatible contexts.""" all_ids = set() - for state, integrator in itertools.product(self.incompatible_states, - self.incompatible_integrators): + for state, integrator in itertools.product( + self.incompatible_states, self.incompatible_integrators + ): all_ids.add(ContextCache._generate_context_id(state, integrator)) assert len(all_ids) == 4 @@ -267,8 +290,9 @@ def test_integrator_global_variable_standardization(self): """ cache = ContextCache() thermodynamic_state = copy.deepcopy(self.water_300k) - integrator = integrators.LangevinIntegrator(temperature=300*unit.kelvin, measure_heat=True, - measure_shadow_work=True) + integrator = integrators.LangevinIntegrator( + temperature=300 * unit.kelvin, measure_heat=True, measure_shadow_work=True + ) cache.get_context(thermodynamic_state, integrator) # If we modify a compatible global variable, we retrieve the @@ -279,14 +303,18 @@ def test_integrator_global_variable_standardization(self): context, context_integrator = cache.get_context(thermodynamic_state, integrator) assert len(cache) == 1 - assert context_integrator.getGlobalVariableByName(variable_name) == variable_new_value + assert ( + context_integrator.getGlobalVariableByName(variable_name) + == variable_new_value + ) def test_get_compatible_context(self): """ContextCache.get_context method do not recreate a compatible context.""" cache = ContextCache() context_ids = set() - for state, integrator in itertools.product(self.compatible_states, - self.compatible_integrators): + for state, integrator in itertools.product( + self.compatible_states, self.compatible_integrators + ): # Avoid binding same integrator to multiple contexts integrator = copy.deepcopy(integrator) context, context_integrator = cache.get_context(state, integrator) @@ -314,7 +342,9 @@ def test_get_context_any_integrator(self): assert len(cache) == 1 # Now we create another Context in state1 with a different integrator. - assert type(self.verlet_2fs) is not type(default_integrator) # test precondition + assert type(self.verlet_2fs) is not type( + default_integrator + ) # test precondition cache.get_context(state1, copy.deepcopy(self.verlet_2fs)) assert len(cache) == 2 @@ -328,7 +358,9 @@ def test_get_context_any_integrator(self): # When it has a choice, ContextCache picks the same context # in consecutive calls with same thermodynamic state. # First add another integrator so that ContextCache has 2 possible options. - assert type(self.langevin_2fs_310k) is not type(default_integrator) # test precondition + assert type(self.langevin_2fs_310k) is not type( + default_integrator + ) # test precondition cache.get_context(state1, copy.deepcopy(self.langevin_2fs_310k)) assert len(cache) == 3 context, integrator = cache.get_context(state1) @@ -353,7 +385,9 @@ def test_cache_capacity_ttl(self): def test_platform_property(self): """Platform change at runtime is only possible when cache is empty.""" - platforms = [openmm.Platform.getPlatformByName(name) for name in ['Reference', 'CPU']] + platforms = [ + openmm.Platform.getPlatformByName(name) for name in ["Reference", "CPU"] + ] cache = ContextCache(platform=platforms[0]) cache.platform = platforms[1] integrator = copy.deepcopy(self.compatible_integrators[0]) @@ -384,8 +418,7 @@ def test_platform_properties(self): # setter cache = ContextCache( - platform=cpu_platform, - platform_properties=platform_properties + platform=cpu_platform, platform_properties=platform_properties ) with nose.tools.assert_raises(ValueError) as cm: cache.platform = ref_platform @@ -399,20 +432,19 @@ def test_platform_properties(self): assert "Invalid platform property for this platform." in str(cm.exception) # assert that resetting the platform resets the properties cache = ContextCache( - platform=cpu_platform, - platform_properties=platform_properties + platform=cpu_platform, platform_properties=platform_properties ) cache.platform = None assert cache._platform_properties is None # Functionality test cache = ContextCache( - platform=cpu_platform, - platform_properties=platform_properties + platform=cpu_platform, platform_properties=platform_properties ) thermodynamic_state = copy.deepcopy(self.water_300k) - integrator = integrators.LangevinIntegrator(temperature=300 * unit.kelvin, measure_heat=True, - measure_shadow_work=True) + integrator = integrators.LangevinIntegrator( + temperature=300 * unit.kelvin, measure_heat=True, measure_shadow_work=True + ) context, _ = cache.get_context(thermodynamic_state, integrator) assert context.getPlatform().getPropertyValue(context, "CpuThreads") == "2" diff --git a/openmmtools/tests/test_forcefactories.py b/openmmtools/tests/test_forcefactories.py index 021ad7a2f..c819eaf0a 100644 --- a/openmmtools/tests/test_forcefactories.py +++ b/openmmtools/tests/test_forcefactories.py @@ -32,6 +32,7 @@ # TESTING UTILITIES # ============================================================================= + def create_context(system, integrator, platform=None): """Create a Context. @@ -51,6 +52,7 @@ def create_context(system, integrator, platform=None): # UTILITY FUNCTIONS # ============================================================================= + def compute_forces(system, positions, platform=None, force_group=-1): """Compute forces of the system in the given positions. @@ -72,7 +74,9 @@ def compute_forces(system, positions, platform=None, force_group=-1): return forces -def compare_system_forces(reference_system, alchemical_system, positions, name="", platform=None): +def compare_system_forces( + reference_system, alchemical_system, positions, name="", platform=None +): """Check that the forces of reference and modified systems are close. Parameters @@ -90,19 +94,36 @@ def compare_system_forces(reference_system, alchemical_system, positions, name=" """ # Compute forces - reference_force = compute_forces(reference_system, positions, platform=platform) / GLOBAL_FORCE_UNIT - alchemical_force = compute_forces(alchemical_system, positions, platform=platform) / GLOBAL_FORCE_UNIT + reference_force = ( + compute_forces(reference_system, positions, platform=platform) + / GLOBAL_FORCE_UNIT + ) + alchemical_force = ( + compute_forces(alchemical_system, positions, platform=platform) + / GLOBAL_FORCE_UNIT + ) # Check that error is small. def magnitude(vec): return np.sqrt(np.mean(np.sum(vec**2, axis=1))) - relative_error = magnitude(alchemical_force - reference_force) / magnitude(reference_force) + relative_error = magnitude(alchemical_force - reference_force) / magnitude( + reference_force + ) if np.any(np.abs(relative_error) > MAX_FORCE_RELATIVE_ERROR): - err_msg = ("Maximum allowable relative force error exceeded (was {:.8f}; allowed {:.8f}).\n" - "alchemical_force = {:.8f}, reference_force = {:.8f}, difference = {:.8f}") - raise Exception(err_msg.format(relative_error, MAX_FORCE_RELATIVE_ERROR, magnitude(alchemical_force), - magnitude(reference_force), magnitude(alchemical_force-reference_force))) + err_msg = ( + "Maximum allowable relative force error exceeded (was {:.8f}; allowed {:.8f}).\n" + "alchemical_force = {:.8f}, reference_force = {:.8f}, difference = {:.8f}" + ) + raise Exception( + err_msg.format( + relative_error, + MAX_FORCE_RELATIVE_ERROR, + magnitude(alchemical_force), + magnitude(reference_force), + magnitude(alchemical_force - reference_force), + ) + ) def generate_new_positions(system, positions, platform=None, nsteps=50): @@ -137,23 +158,29 @@ def generate_new_positions(system, positions, platform=None, nsteps=50): # TEST FORCE FACTORIES FUNCTIONS # ============================================================================= + def test_restrain_atoms(): """Check that the restrained molecule's centroid is in the origin.""" host_guest = testsystems.HostGuestExplicit() topology = mdtraj.Topology.from_openmm(host_guest.topology) sampler_state = states.SamplerState(positions=host_guest.positions) - thermodynamic_state = states.ThermodynamicState(host_guest.system, temperature=300*unit.kelvin, - pressure=1.0*unit.atmosphere) + thermodynamic_state = states.ThermodynamicState( + host_guest.system, temperature=300 * unit.kelvin, pressure=1.0 * unit.atmosphere + ) # Restrain all the host carbon atoms. - restrained_atoms = [atom.index for atom in topology.atoms - if atom.element.symbol is 'C' and atom.index <= 125] + restrained_atoms = [ + atom.index + for atom in topology.atoms + if atom.element.symbol == "C" and atom.index <= 125 + ] restrain_atoms(thermodynamic_state, sampler_state, restrained_atoms) # Compute host center_of_geometry. centroid = np.mean(sampler_state.positions[:126], axis=0) assert np.allclose(centroid, np.zeros(3)) + def test_replace_reaction_field(): """Check that replacing reaction-field electrostatics with Custom*Force yields minimal force differences with original system. @@ -164,35 +191,53 @@ def test_replace_reaction_field(): """ test_cases = [ testsystems.AlanineDipeptideExplicit(nonbondedMethod=openmm.app.CutoffPeriodic), - testsystems.HostGuestExplicit(nonbondedMethod=openmm.app.CutoffPeriodic) + testsystems.HostGuestExplicit(nonbondedMethod=openmm.app.CutoffPeriodic), ] - platform = openmm.Platform.getPlatformByName('Reference') + platform = openmm.Platform.getPlatformByName("Reference") for test_system in test_cases: test_name = test_system.__class__.__name__ # Replace reaction field. - modified_rf_system = replace_reaction_field(test_system.system, switch_width=None) + modified_rf_system = replace_reaction_field( + test_system.system, switch_width=None + ) # Make sure positions are not at minimum. positions = generate_new_positions(test_system.system, test_system.positions) # Test forces. - f = partial(compare_system_forces, test_system.system, modified_rf_system, positions, - name=test_name, platform=platform) - f.description = "Testing replace_reaction_field on system {}".format(test_name) + f = partial( + compare_system_forces, + test_system.system, + modified_rf_system, + positions, + name=test_name, + platform=platform, + ) + f.description = f"Testing replace_reaction_field on system {test_name}" yield f for test_system in test_cases: test_name = test_system.__class__.__name__ # Replace reaction field. - modified_rf_system = replace_reaction_field(test_system.system, switch_width=None, shifted=True) + modified_rf_system = replace_reaction_field( + test_system.system, switch_width=None, shifted=True + ) # Make sure positions are not at minimum. positions = generate_new_positions(test_system.system, test_system.positions) # Test forces. - f = partial(compare_system_forces, test_system.system, modified_rf_system, positions, - name=test_name, platform=platform) - f.description = "Testing replace_reaction_field on system {} with shifted=True".format(test_name) + f = partial( + compare_system_forces, + test_system.system, + modified_rf_system, + positions, + name=test_name, + platform=platform, + ) + f.description = ( + f"Testing replace_reaction_field on system {test_name} with shifted=True" + ) yield f diff --git a/openmmtools/tests/test_forces.py b/openmmtools/tests/test_forces.py index fde9269d7..024cee896 100644 --- a/openmmtools/tests/test_forces.py +++ b/openmmtools/tests/test_forces.py @@ -31,12 +31,13 @@ # TESTING UTILITIES # ============================================================================= + def assert_pickles_equal(object1, object2): assert pickle.dumps(object1) == pickle.dumps(object2) def assert_quantity_almost_equal(object1, object2): - assert utils.is_quantity_close(object1, object2), '{} != {}'.format(object1, object2) + assert utils.is_quantity_close(object1, object2), f"{object1} != {object2}" def assert_equal(*args, **kwargs): @@ -49,15 +50,19 @@ def assert_equal(*args, **kwargs): # UTILITY FUNCTIONS TESTS # ============================================================================= + def test_find_forces(): """Generator of tests for the find_forces() utility function.""" system = testsystems.TolueneVacuum().system # Add two CustomBondForces, one is restorable. - restraint_force = HarmonicRestraintBondForce(spring_constant=1.0*unit.kilojoule_per_mole/unit.angstroms**2, - restrained_atom_index1=2, restrained_atom_index2=5) + restraint_force = HarmonicRestraintBondForce( + spring_constant=1.0 * unit.kilojoule_per_mole / unit.angstroms**2, + restrained_atom_index1=2, + restrained_atom_index2=5, + ) system.addForce(restraint_force) - system.addForce(openmm.CustomBondForce('0.0')) + system.addForce(openmm.CustomBondForce("0.0")) def assert_forces_equal(found_forces, expected_force_classes): # Forces should be ordered by their index. @@ -71,125 +76,203 @@ def assert_forces_equal(found_forces, expected_force_classes): # Test find force and include subclasses. found_forces = find_forces(system, openmm.CustomBondForce, include_subclasses=True) - yield assert_forces_equal, found_forces, [(5, HarmonicRestraintBondForce), - (6, openmm.CustomBondForce)] - found_forces = find_forces(system, RadiallySymmetricRestraintForce, include_subclasses=True) + yield ( + assert_forces_equal, + found_forces, + [(5, HarmonicRestraintBondForce), (6, openmm.CustomBondForce)], + ) + found_forces = find_forces( + system, RadiallySymmetricRestraintForce, include_subclasses=True + ) yield assert_forces_equal, found_forces, [(5, HarmonicRestraintBondForce)] # Test exact name matching. - found_forces = find_forces(system, 'HarmonicBondForce') + found_forces = find_forces(system, "HarmonicBondForce") yield assert_forces_equal, found_forces, [(0, openmm.HarmonicBondForce)] # Find all forces containing the word "Harmonic". - found_forces = find_forces(system, '.*Harmonic.*') - yield assert_forces_equal, found_forces, [(0, openmm.HarmonicBondForce), - (1, openmm.HarmonicAngleForce), - (5, HarmonicRestraintBondForce)] + found_forces = find_forces(system, ".*Harmonic.*") + yield ( + assert_forces_equal, + found_forces, + [ + (0, openmm.HarmonicBondForce), + (1, openmm.HarmonicAngleForce), + (5, HarmonicRestraintBondForce), + ], + ) # Find all forces from the name including the subclasses. # Test find force and include subclasses. - found_forces = find_forces(system, 'CustomBond.*', include_subclasses=True) - yield assert_forces_equal, found_forces, [(5, HarmonicRestraintBondForce), - (6, openmm.CustomBondForce)] + found_forces = find_forces(system, "CustomBond.*", include_subclasses=True) + yield ( + assert_forces_equal, + found_forces, + [(5, HarmonicRestraintBondForce), (6, openmm.CustomBondForce)], + ) # With check_multiple=True only one force is returned. force_idx, force = find_forces(system, openmm.NonbondedForce, only_one=True) yield assert_forces_equal, {force_idx: force}, [(3, openmm.NonbondedForce)] # An exception is raised with "only_one" if multiple forces are found. - yield nose.tools.assert_raises, MultipleForcesError, find_forces, system, 'CustomBondForce', True, True + yield ( + nose.tools.assert_raises, + MultipleForcesError, + find_forces, + system, + "CustomBondForce", + True, + True, + ) # An exception is raised with "only_one" if the force wasn't found. - yield nose.tools.assert_raises, NoForceFoundError, find_forces, system, 'NonExistentForce', True + yield ( + nose.tools.assert_raises, + NoForceFoundError, + find_forces, + system, + "NonExistentForce", + True, + ) # ============================================================================= # RESTRAINTS TESTS # ============================================================================= -class TestRadiallySymmetricRestraints(object): + +class TestRadiallySymmetricRestraints: """Test radially symmetric receptor-ligand restraint classes.""" @classmethod def setup_class(cls): cls.well_radius = 12.0 * unit.angstroms - cls.spring_constant = 15000.0 * unit.joule/unit.mole/unit.nanometers**2 + cls.spring_constant = 15000.0 * unit.joule / unit.mole / unit.nanometers**2 cls.restrained_atom_indices1 = [2, 3, 4] cls.restrained_atom_indices2 = [10, 11] - cls.restrained_atom_index1=12 - cls.restrained_atom_index2=2 - cls.custom_parameter_name = 'restraints_parameter' + cls.restrained_atom_index1 = 12 + cls.restrained_atom_index2 = 2 + cls.custom_parameter_name = "restraints_parameter" cls.restraints = [ - HarmonicRestraintForce(spring_constant=cls.spring_constant, - restrained_atom_indices1=cls.restrained_atom_indices1, - restrained_atom_indices2=cls.restrained_atom_indices2), - HarmonicRestraintBondForce(spring_constant=cls.spring_constant, - restrained_atom_index1=cls.restrained_atom_index1, - restrained_atom_index2=cls.restrained_atom_index2), - FlatBottomRestraintForce(spring_constant=cls.spring_constant, well_radius=cls.well_radius, - restrained_atom_indices1=cls.restrained_atom_indices1, - restrained_atom_indices2=cls.restrained_atom_indices2), - FlatBottomRestraintBondForce(spring_constant=cls.spring_constant, well_radius=cls.well_radius, - restrained_atom_index1=cls.restrained_atom_index1, - restrained_atom_index2=cls.restrained_atom_index2), - HarmonicRestraintForce(spring_constant=cls.spring_constant, - restrained_atom_indices1=cls.restrained_atom_indices1, - restrained_atom_indices2=cls.restrained_atom_indices2, - controlling_parameter_name=cls.custom_parameter_name), - FlatBottomRestraintBondForce(spring_constant=cls.spring_constant, well_radius=cls.well_radius, - restrained_atom_index1=cls.restrained_atom_index1, - restrained_atom_index2=cls.restrained_atom_index2, - controlling_parameter_name=cls.custom_parameter_name), + HarmonicRestraintForce( + spring_constant=cls.spring_constant, + restrained_atom_indices1=cls.restrained_atom_indices1, + restrained_atom_indices2=cls.restrained_atom_indices2, + ), + HarmonicRestraintBondForce( + spring_constant=cls.spring_constant, + restrained_atom_index1=cls.restrained_atom_index1, + restrained_atom_index2=cls.restrained_atom_index2, + ), + FlatBottomRestraintForce( + spring_constant=cls.spring_constant, + well_radius=cls.well_radius, + restrained_atom_indices1=cls.restrained_atom_indices1, + restrained_atom_indices2=cls.restrained_atom_indices2, + ), + FlatBottomRestraintBondForce( + spring_constant=cls.spring_constant, + well_radius=cls.well_radius, + restrained_atom_index1=cls.restrained_atom_index1, + restrained_atom_index2=cls.restrained_atom_index2, + ), + HarmonicRestraintForce( + spring_constant=cls.spring_constant, + restrained_atom_indices1=cls.restrained_atom_indices1, + restrained_atom_indices2=cls.restrained_atom_indices2, + controlling_parameter_name=cls.custom_parameter_name, + ), + FlatBottomRestraintBondForce( + spring_constant=cls.spring_constant, + well_radius=cls.well_radius, + restrained_atom_index1=cls.restrained_atom_index1, + restrained_atom_index2=cls.restrained_atom_index2, + controlling_parameter_name=cls.custom_parameter_name, + ), ] def test_restorable_forces(self): """Test that the restraint interface can be restored after serialization.""" for restorable_force in self.restraints: force_serialization = openmm.XmlSerializer.serialize(restorable_force) - deserialized_force = utils.RestorableOpenMMObject.deserialize_xml(force_serialization) + deserialized_force = utils.RestorableOpenMMObject.deserialize_xml( + force_serialization + ) yield assert_pickles_equal, restorable_force, deserialized_force def test_restraint_properties(self): """Test that properties work as expected.""" for restraint in self.restraints: - yield assert_quantity_almost_equal, restraint.spring_constant, self.spring_constant + yield ( + assert_quantity_almost_equal, + restraint.spring_constant, + self.spring_constant, + ) if isinstance(restraint, FlatBottomRestraintForceMixIn): - yield assert_quantity_almost_equal, restraint.well_radius, self.well_radius + yield ( + assert_quantity_almost_equal, + restraint.well_radius, + self.well_radius, + ) if isinstance(restraint, RadiallySymmetricCentroidRestraintForce): - yield assert_equal, restraint.restrained_atom_indices1, self.restrained_atom_indices1 - yield assert_equal, restraint.restrained_atom_indices2, self.restrained_atom_indices2 + yield ( + assert_equal, + restraint.restrained_atom_indices1, + self.restrained_atom_indices1, + ) + yield ( + assert_equal, + restraint.restrained_atom_indices2, + self.restrained_atom_indices2, + ) else: assert isinstance(restraint, RadiallySymmetricBondRestraintForce) - yield assert_equal, restraint.restrained_atom_indices1, [self.restrained_atom_index1] - yield assert_equal, restraint.restrained_atom_indices2, [self.restrained_atom_index2] + yield ( + assert_equal, + restraint.restrained_atom_indices1, + [self.restrained_atom_index1], + ) + yield ( + assert_equal, + restraint.restrained_atom_indices2, + [self.restrained_atom_index2], + ) def test_controlling_parameter_name(self): """Test that the controlling parameter name enters the energy function correctly.""" default_name_restraint = self.restraints[0] custom_name_restraints = self.restraints[-2:] - assert default_name_restraint.controlling_parameter_name == 'lambda_restraints' + assert default_name_restraint.controlling_parameter_name == "lambda_restraints" energy_function = default_name_restraint.getEnergyFunction() - assert 'lambda_restraints' in energy_function + assert "lambda_restraints" in energy_function assert self.custom_parameter_name not in energy_function for custom_name_restraint in custom_name_restraints: - assert custom_name_restraint.controlling_parameter_name == self.custom_parameter_name + assert ( + custom_name_restraint.controlling_parameter_name + == self.custom_parameter_name + ) energy_function = custom_name_restraint.getEnergyFunction() - assert 'lambda_restraints' not in energy_function + assert "lambda_restraints" not in energy_function assert self.custom_parameter_name in energy_function def test_compute_restraint_volume(self): """Test the calculation of the restraint volume.""" testsystem = testsystems.TolueneVacuum() - thermodynamic_state = states.ThermodynamicState(testsystem.system, 300*unit.kelvin) + thermodynamic_state = states.ThermodynamicState( + testsystem.system, 300 * unit.kelvin + ) energy_cutoffs = np.linspace(0.0, 10.0, num=3) radius_cutoffs = np.linspace(0.0, 5.0, num=3) * unit.nanometers - def assert_integrated_analytical_equal(restraint, square_well, radius_cutoff, energy_cutoff): + def assert_integrated_analytical_equal( + restraint, square_well, radius_cutoff, energy_cutoff + ): args = [thermodynamic_state, square_well, radius_cutoff, energy_cutoff] # For flat-bottom, the calculation is only partially analytical. @@ -198,16 +281,23 @@ def assert_integrated_analytical_equal(restraint, square_well, radius_cutoff, en # Make sure there's no analytical component (from _determine_integral_limits) # in the numerical integration calculation. copied_restraint = copy.deepcopy(restraint) - for parent_cls in [RadiallySymmetricCentroidRestraintForce, RadiallySymmetricBondRestraintForce]: + for parent_cls in [ + RadiallySymmetricCentroidRestraintForce, + RadiallySymmetricBondRestraintForce, + ]: if isinstance(copied_restraint, parent_cls): copied_restraint.__class__ = parent_cls integrated_volume = copied_restraint._integrate_restraint_volume(*args) - err_msg = '{}: square_well={}, radius_cutoff={}, energy_cutoff={}\n'.format( - restraint.__class__.__name__, square_well, radius_cutoff, energy_cutoff) - err_msg += 'integrated_volume={}, analytical_volume={}'.format(integrated_volume, - analytical_volume) - assert utils.is_quantity_close(integrated_volume, analytical_volume, rtol=1e-2), err_msg + err_msg = "{}: square_well={}, radius_cutoff={}, energy_cutoff={}\n".format( + restraint.__class__.__name__, square_well, radius_cutoff, energy_cutoff + ) + err_msg += "integrated_volume={}, analytical_volume={}".format( + integrated_volume, analytical_volume + ) + assert utils.is_quantity_close( + integrated_volume, analytical_volume, rtol=1e-2 + ), err_msg for restraint in self.restraints: # Test integrated and analytical agree with no cutoffs. @@ -216,15 +306,41 @@ def assert_integrated_analytical_equal(restraint, square_well, radius_cutoff, en for square_well in [True, False]: # Try energies and distances singly and together. for energy_cutoff in energy_cutoffs: - yield assert_integrated_analytical_equal, restraint, square_well, None, energy_cutoff + yield ( + assert_integrated_analytical_equal, + restraint, + square_well, + None, + energy_cutoff, + ) for radius_cutoff in radius_cutoffs: - yield assert_integrated_analytical_equal, restraint, square_well, radius_cutoff, None + yield ( + assert_integrated_analytical_equal, + restraint, + square_well, + radius_cutoff, + None, + ) for energy_cutoff, radius_cutoff in zip(energy_cutoffs, radius_cutoffs): - yield assert_integrated_analytical_equal, restraint, square_well, radius_cutoff, energy_cutoff - for energy_cutoff, radius_cutoff in zip(energy_cutoffs, reversed(radius_cutoffs)): - yield assert_integrated_analytical_equal, restraint, square_well, radius_cutoff, energy_cutoff + yield ( + assert_integrated_analytical_equal, + restraint, + square_well, + radius_cutoff, + energy_cutoff, + ) + for energy_cutoff, radius_cutoff in zip( + energy_cutoffs, reversed(radius_cutoffs) + ): + yield ( + assert_integrated_analytical_equal, + restraint, + square_well, + radius_cutoff, + energy_cutoff, + ) def test_compute_standard_state_correction(self): """Test standard state correction works correctly in all ensembles.""" @@ -236,58 +352,108 @@ def test_compute_standard_state_correction(self): # Limit the maximum volume to 1nm^3. distance_unit = unit.nanometers state_volume = 1.0 * distance_unit**3 - box_vectors = np.identity(3) * np.cbrt(state_volume / distance_unit**3) * distance_unit + box_vectors = ( + np.identity(3) * np.cbrt(state_volume / distance_unit**3) * distance_unit + ) alanine.system.setDefaultPeriodicBoxVectors(*box_vectors) toluene.system.setDefaultPeriodicBoxVectors(*box_vectors) # Create systems in various ensembles (NVT, NPT and non-periodic). nvt_state = states.ThermodynamicState(alanine.system, temperature) - npt_state = states.ThermodynamicState(alanine.system, temperature, 1.0*unit.atmosphere) + npt_state = states.ThermodynamicState( + alanine.system, temperature, 1.0 * unit.atmosphere + ) nonperiodic_state = states.ThermodynamicState(toluene.system, temperature) - def assert_equal_ssc(expected_restraint_volume, restraint, thermodynamic_state, square_well=False, - radius_cutoff=None, energy_cutoff=None, max_volume=None): - expected_ssc = -math.log(STANDARD_STATE_VOLUME/expected_restraint_volume) - ssc = restraint.compute_standard_state_correction(thermodynamic_state, square_well, - radius_cutoff, energy_cutoff, max_volume) - err_msg = '{} computed SSC != expected SSC'.format(restraint.__class__.__name__) + def assert_equal_ssc( + expected_restraint_volume, + restraint, + thermodynamic_state, + square_well=False, + radius_cutoff=None, + energy_cutoff=None, + max_volume=None, + ): + expected_ssc = -math.log(STANDARD_STATE_VOLUME / expected_restraint_volume) + ssc = restraint.compute_standard_state_correction( + thermodynamic_state, + square_well, + radius_cutoff, + energy_cutoff, + max_volume, + ) + err_msg = f"{restraint.__class__.__name__} computed SSC != expected SSC" nose.tools.assert_equal(ssc, expected_ssc, msg=err_msg) for restraint in self.restraints: # In NPT ensemble, an exception is thrown if max_volume is not provided. - with nose.tools.assert_raises_regexp(TypeError, "max_volume must be provided"): + with nose.tools.assert_raises_regexp( + TypeError, "max_volume must be provided" + ): restraint.compute_standard_state_correction(npt_state) # With non-periodic systems and reweighting to square-well # potential, a cutoff must be given. - with nose.tools.assert_raises_regexp(TypeError, "One between radius_cutoff"): - restraint.compute_standard_state_correction(nonperiodic_state, square_well=True) + with nose.tools.assert_raises_regexp( + TypeError, "One between radius_cutoff" + ): + restraint.compute_standard_state_correction( + nonperiodic_state, square_well=True + ) # While there are no problems if we don't reweight to a square-well potential. - restraint.compute_standard_state_correction(nonperiodic_state, square_well=False) + restraint.compute_standard_state_correction( + nonperiodic_state, square_well=False + ) # SSC is limited by max_volume (in NVT and NPT). - assert_equal_ssc(state_volume, restraint, nvt_state, radius_cutoff=big_radius) - assert_equal_ssc(state_volume, restraint, npt_state, radius_cutoff=big_radius, - max_volume='system') + assert_equal_ssc( + state_volume, restraint, nvt_state, radius_cutoff=big_radius + ) + assert_equal_ssc( + state_volume, + restraint, + npt_state, + radius_cutoff=big_radius, + max_volume="system", + ) # SSC is not limited by max_volume with non periodic systems. - expected_ssc = -math.log(STANDARD_STATE_VOLUME/state_volume) - ssc = restraint.compute_standard_state_correction(nonperiodic_state, radius_cutoff=big_radius) + expected_ssc = -math.log(STANDARD_STATE_VOLUME / state_volume) + ssc = restraint.compute_standard_state_correction( + nonperiodic_state, radius_cutoff=big_radius + ) assert expected_ssc < ssc, (restraint, expected_ssc, ssc) # Check reweighting to square-well potential. expected_volume = _compute_sphere_volume(big_radius) - assert_equal_ssc(expected_volume, restraint, nonperiodic_state, - square_well=True, radius_cutoff=big_radius) + assert_equal_ssc( + expected_volume, + restraint, + nonperiodic_state, + square_well=True, + radius_cutoff=big_radius, + ) energy_cutoff = 10 * nonperiodic_state.kT - radius_cutoff = _compute_harmonic_radius(self.spring_constant, energy_cutoff) + radius_cutoff = _compute_harmonic_radius( + self.spring_constant, energy_cutoff + ) if isinstance(restraint, FlatBottomRestraintForceMixIn): radius_cutoff += self.well_radius expected_volume = _compute_sphere_volume(radius_cutoff) - assert_equal_ssc(expected_volume, restraint, nonperiodic_state, - square_well=True, radius_cutoff=radius_cutoff) + assert_equal_ssc( + expected_volume, + restraint, + nonperiodic_state, + square_well=True, + radius_cutoff=radius_cutoff, + ) max_volume = 3.0 * unit.nanometers**3 - assert_equal_ssc(max_volume, restraint, nonperiodic_state, - square_well=True, max_volume=max_volume) + assert_equal_ssc( + max_volume, + restraint, + nonperiodic_state, + square_well=True, + max_volume=max_volume, + ) diff --git a/openmmtools/tests/test_integrators.py b/openmmtools/tests/test_integrators.py index 0ecc3fb42..4e5cb5738 100755 --- a/openmmtools/tests/test_integrators.py +++ b/openmmtools/tests/test_integrators.py @@ -1,17 +1,17 @@ #!/usr/local/bin/env python -#============================================================================================= +# ============================================================================================= # MODULE DOCSTRING -#============================================================================================= +# ============================================================================================= """ Test custom integrators. """ -#============================================================================================= +# ============================================================================================= # GLOBAL IMPORTS -#============================================================================================= +# ============================================================================================= import copy import inspect @@ -20,6 +20,7 @@ from unittest import TestCase import numpy as np + try: import openmm from openmm import unit @@ -28,20 +29,25 @@ from simtk import openmm from openmmtools import integrators, testsystems -from openmmtools.integrators import (ThermostatedIntegrator, AlchemicalNonequilibriumLangevinIntegrator, - GHMCIntegrator, NoseHooverChainVelocityVerletIntegrator) +from openmmtools.integrators import ( + ThermostatedIntegrator, + AlchemicalNonequilibriumLangevinIntegrator, + GHMCIntegrator, + NoseHooverChainVelocityVerletIntegrator, +) -#============================================================================================= +# ============================================================================================= # CONSTANTS -#============================================================================================= +# ============================================================================================= kB = unit.BOLTZMANN_CONSTANT_kB * unit.AVOGADRO_CONSTANT_NA -#============================================================================================= +# ============================================================================================= # UTILITY SUBROUTINES -#============================================================================================= +# ============================================================================================= + def get_all_custom_integrators(only_thermostated=False): """Return all CustomIntegrators in integrators. @@ -58,17 +64,23 @@ def get_all_custom_integrators(only_thermostated=False): A list of tuples ('IntegratorName', IntegratorClass) """ - predicate = lambda x: (inspect.isclass(x) and - issubclass(x, openmm.CustomIntegrator) and - x != integrators.ThermostatedIntegrator) + predicate = lambda x: ( + inspect.isclass(x) + and issubclass(x, openmm.CustomIntegrator) + and x != integrators.ThermostatedIntegrator + ) if only_thermostated: old_predicate = predicate # Avoid infinite recursion. - predicate = lambda x: old_predicate(x) and issubclass(x, integrators.ThermostatedIntegrator) + predicate = lambda x: old_predicate(x) and issubclass( + x, integrators.ThermostatedIntegrator + ) custom_integrators = inspect.getmembers(integrators, predicate=predicate) return custom_integrators -def check_stability(integrator, test, platform=None, nsteps=100, temperature=300.0*unit.kelvin): +def check_stability( + integrator, test, platform=None, nsteps=100, temperature=300.0 * unit.kelvin +): """ Check that the simulation does not explode over a number integration steps. @@ -88,10 +100,10 @@ def check_stability(integrator, test, platform=None, nsteps=100, temperature=300 else: context = openmm.Context(test.system, integrator) context.setPositions(test.positions) - context.setVelocitiesToTemperature(temperature) # TODO: Make deterministic. + context.setVelocitiesToTemperature(temperature) # TODO: Make deterministic. # Set integrator temperature - if hasattr(integrator, 'setTemperature'): + if hasattr(integrator, "setTemperature"): integrator.setTemperature(temperature) # Take a number of steps. @@ -101,21 +113,29 @@ def check_stability(integrator, test, platform=None, nsteps=100, temperature=300 state = context.getState(getEnergy=True) potential = state.getPotentialEnergy() / kT if np.isnan(potential): - raise Exception("Potential energy for integrator %s became NaN." % integrator.__doc__) + raise Exception( + "Potential energy for integrator %s became NaN." % integrator.__doc__ + ) del context def check_integrator_temperature(integrator, temperature, has_changed): """Check integrator temperature has has_kT_changed variables.""" - kT = (temperature * integrators.kB) + kT = temperature * integrators.kB temperature = temperature / unit.kelvin assert np.isclose(integrator.getTemperature() / unit.kelvin, temperature) - assert np.isclose(integrator.getGlobalVariableByName('kT'), kT.value_in_unit_system(unit.md_unit_system)) - assert np.isclose(integrator.kT.value_in_unit_system(unit.md_unit_system), kT.value_in_unit_system(unit.md_unit_system)) + assert np.isclose( + integrator.getGlobalVariableByName("kT"), + kT.value_in_unit_system(unit.md_unit_system), + ) + assert np.isclose( + integrator.kT.value_in_unit_system(unit.md_unit_system), + kT.value_in_unit_system(unit.md_unit_system), + ) has_kT_changed = False - if 'has_kT_changed' in integrator.global_variable_names: - has_kT_changed = integrator.getGlobalVariableByName('has_kT_changed') + if "has_kT_changed" in integrator.global_variable_names: + has_kT_changed = integrator.getGlobalVariableByName("has_kT_changed") if has_kT_changed is not False: assert has_kT_changed == has_changed @@ -138,7 +158,7 @@ def check_integrator_temperature_getter_setter(integrator): check_integrator_temperature(integrator, temperature, 0) # Setting temperature update kT and has_kT_changed. - temperature += 100*unit.kelvin + temperature += 100 * unit.kelvin integrator.setTemperature(temperature) check_integrator_temperature(integrator, temperature, 1) @@ -147,9 +167,10 @@ def check_integrator_temperature_getter_setter(integrator): check_integrator_temperature(integrator, temperature, 0) -#============================================================================================= +# ============================================================================================= # TESTS -#============================================================================================= +# ============================================================================================= + def test_stabilities(): """ @@ -157,8 +178,10 @@ def test_stabilities(): """ ts = testsystems # shortcut - test_cases = {'harmonic oscillator': ts.HarmonicOscillator(), - 'alanine dipeptide in implicit solvent': ts.AlanineDipeptideImplicit()} + test_cases = { + "harmonic oscillator": ts.HarmonicOscillator(), + "alanine dipeptide in implicit solvent": ts.AlanineDipeptideImplicit(), + } custom_integrators = get_all_custom_integrators() for test_name, test in test_cases.items(): @@ -167,8 +190,10 @@ def test_stabilities(): # NoseHooverChainVelocityVerletIntegrator will print a severe warning here, # because it is being initialized without a system. That's OK. integrator.__doc__ = integrator_name - check_stability.description = ("Testing {} for stability over a short number of " - "integration steps of a {}.").format(integrator_name, test_name) + check_stability.description = ( + "Testing {} for stability over a short number of " + "integration steps of a {}." + ).format(integrator_name, test_name) yield check_stability, integrator, test @@ -194,15 +219,17 @@ def test_nose_hoover_integrator(): conserves the system and bath energy to a reasonable tolerance. Also test that the target temperature is rougly matched (+- 10 K). """ - temperature = 298*unit.kelvin + temperature = 298 * unit.kelvin testsystem = testsystems.WaterBox() - num_dof = 3*testsystem.system.getNumParticles() - testsystem.system.getNumConstraints() + num_dof = ( + 3 * testsystem.system.getNumParticles() - testsystem.system.getNumConstraints() + ) integrator = NoseHooverChainVelocityVerletIntegrator(testsystem.system, temperature) # Create Context and initialize positions. context = openmm.Context(testsystem.system, integrator) context.setPositions(testsystem.positions) context.setVelocitiesToTemperature(temperature) - integrator.step(200) # Short equilibration + integrator.step(200) # Short equilibration energies = [] temperatures = [] for n in range(100): @@ -210,24 +237,28 @@ def test_nose_hoover_integrator(): state = context.getState(getEnergy=True) # temperature kinE = state.getKineticEnergy() - temp = (2.0 * kinE / (num_dof * unit.MOLAR_GAS_CONSTANT_R)).value_in_unit(unit.kelvin) + temp = (2.0 * kinE / (num_dof * unit.MOLAR_GAS_CONSTANT_R)).value_in_unit( + unit.kelvin + ) temperatures.append(temp) # total energy KE = kinE.value_in_unit(unit.kilojoules_per_mole) PE = state.getPotentialEnergy().value_in_unit(unit.kilojoules_per_mole) - bathKE = integrator.getGlobalVariableByName('bathKE') - bathPE = integrator.getGlobalVariableByName('bathPE') + bathKE = integrator.getGlobalVariableByName("bathKE") + bathPE = integrator.getGlobalVariableByName("bathPE") conserved = KE + PE + bathKE + bathPE energies.append(conserved) # Compute maximum deviation from the mean for conserved energies meanenergies = np.mean(energies) - maxdeviation = np.amax(np.abs(energies - meanenergies)/meanenergies) + maxdeviation = np.amax(np.abs(energies - meanenergies) / meanenergies) assert maxdeviation < 1e-3 # Coarse check for target temperature mean_temperature = np.mean(temperatures) - assert abs(mean_temperature - temperature.value_in_unit(unit.kelvin)) < 10.0, mean_temperature + assert ( + abs(mean_temperature - temperature.value_in_unit(unit.kelvin)) < 10.0 + ), mean_temperature def test_pretty_formatting(): @@ -236,28 +267,30 @@ def test_pretty_formatting(): """ custom_integrators = get_all_custom_integrators() for integrator_name, integrator_class in custom_integrators: - integrator = integrator_class() # NoseHooverChainVelocityVerletIntegrator will print a severe warning here, # because it is being initialized without a system. That's OK. - if hasattr(integrator, 'pretty_format'): + if hasattr(integrator, "pretty_format"): # Check formatting as text text = integrator.pretty_format() # Check formatting as text with highlighted steps text = integrator.pretty_format(step_types_to_highlight=[5]) # Check list format lines = integrator.pretty_format(as_list=True) - msg = "integrator.pretty_format(as_list=True) has %d lines while integrator has %d steps" % (len(lines), integrator.getNumComputations()) + msg = ( + "integrator.pretty_format(as_list=True) has %d lines while integrator has %d steps" + % (len(lines), integrator.getNumComputations()) + ) assert len(lines) == integrator.getNumComputations(), msg + def test_update_context_state_calls(): """ Ensure that all integrators only call addUpdateContextState() once. """ custom_integrators = get_all_custom_integrators() for integrator_name, integrator_class in custom_integrators: - integrator = integrator_class() # NoseHooverChainVelocityVerletIntegrator will print a severe warning here, # because it is being initialized without a system. That's OK. @@ -269,11 +302,15 @@ def test_update_context_state_calls(): if step_type == 5: num_force_update += 1 - msg = "Integrator '%s' has %d calls to addUpdateContextState(), while there should be only one." % (integrator_name, num_force_update) - if hasattr(integrator, 'pretty_format'): - msg += '\n' + integrator.pretty_format(step_types_to_highlight=[5]) + msg = ( + "Integrator '%s' has %d calls to addUpdateContextState(), while there should be only one." + % (integrator_name, num_force_update) + ) + if hasattr(integrator, "pretty_format"): + msg += "\n" + integrator.pretty_format(step_types_to_highlight=[5]) assert num_force_update == 1, msg + def test_vvvr_shadow_work_accumulation(): """When `measure_shadow_work==True`, assert that global `shadow_work` is initialized to zero and reaches a nonzero value after integrating a few dozen steps.""" @@ -286,11 +323,19 @@ def test_vvvr_shadow_work_accumulation(): context = openmm.Context(system, integrator) context.setPositions(testsystem.positions) context.setVelocitiesToTemperature(temperature) - assert(integrator.get_shadow_work(dimensionless=True) == 0), "Shadow work should initially be zero." - assert(integrator.get_shadow_work() / unit.kilojoules_per_mole == 0), "integrator.get_shadow_work() should have units of energy." - assert(integrator.shadow_work / unit.kilojoules_per_mole == 0), "integrator.shadow_work should have units of energy." + assert ( + integrator.get_shadow_work(dimensionless=True) == 0 + ), "Shadow work should initially be zero." + assert ( + integrator.get_shadow_work() / unit.kilojoules_per_mole == 0 + ), "integrator.get_shadow_work() should have units of energy." + assert ( + integrator.shadow_work / unit.kilojoules_per_mole == 0 + ), "integrator.shadow_work should have units of energy." integrator.step(25) - assert(integrator.get_shadow_work(dimensionless=True) != 0), "integrator.get_shadow_work() should be nonzero after dynamics" + assert ( + integrator.get_shadow_work(dimensionless=True) != 0 + ), "integrator.get_shadow_work() should be nonzero after dynamics" integrator = integrators.VVVRIntegrator(temperature) context = openmm.Context(system, integrator) @@ -300,6 +345,7 @@ def test_vvvr_shadow_work_accumulation(): del context, integrator + def test_baoab_heat_accumulation(): """When `measure_heat==True`, assert that global `heat` is initialized to zero and reaches a nonzero value after integrating a few dozen steps.""" @@ -312,11 +358,19 @@ def test_baoab_heat_accumulation(): context = openmm.Context(system, integrator) context.setPositions(testsystem.positions) context.setVelocitiesToTemperature(temperature) - assert(integrator.get_heat(dimensionless=True) == 0), "Heat should initially be zero." - assert(integrator.get_heat() / unit.kilojoules_per_mole == 0), "integrator.get_heat() should have units of energy." - assert(integrator.heat / unit.kilojoules_per_mole == 0), "integrator.heat should have units of energy." + assert ( + integrator.get_heat(dimensionless=True) == 0 + ), "Heat should initially be zero." + assert ( + integrator.get_heat() / unit.kilojoules_per_mole == 0 + ), "integrator.get_heat() should have units of energy." + assert ( + integrator.heat / unit.kilojoules_per_mole == 0 + ), "integrator.heat should have units of energy." integrator.step(25) - assert(integrator.get_heat(dimensionless=True) != 0), "integrator.get_heat() should be nonzero after dynamics" + assert ( + integrator.get_heat(dimensionless=True) != 0 + ), "integrator.get_heat() should be nonzero after dynamics" integrator = integrators.VVVRIntegrator(temperature) context = openmm.Context(system, integrator) @@ -326,6 +380,7 @@ def test_baoab_heat_accumulation(): del context, integrator + def test_external_protocol_work_accumulation(): """When `measure_protocol_work==True`, assert that global `protocol_work` is initialized to zero and reaches a zero value after integrating a few dozen steps without perturbation. @@ -335,23 +390,35 @@ def test_external_protocol_work_accumulation(): testsystem = testsystems.HarmonicOscillator() system, topology = testsystem.system, testsystem.topology temperature = 298.0 * unit.kelvin - integrator = integrators.ExternalPerturbationLangevinIntegrator(splitting="O V R V O", temperature=temperature) + integrator = integrators.ExternalPerturbationLangevinIntegrator( + splitting="O V R V O", temperature=temperature + ) context = openmm.Context(system, integrator) context.setPositions(testsystem.positions) context.setVelocitiesToTemperature(temperature) # Check that initial step accumulates no protocol work - assert(integrator.get_protocol_work(dimensionless=True) == 0), "Protocol work should be 0 initially" - assert(integrator.get_protocol_work() / unit.kilojoules_per_mole == 0), "Protocol work should have units of energy" + assert ( + integrator.get_protocol_work(dimensionless=True) == 0 + ), "Protocol work should be 0 initially" + assert ( + integrator.get_protocol_work() / unit.kilojoules_per_mole == 0 + ), "Protocol work should have units of energy" integrator.step(1) - assert(integrator.get_protocol_work(dimensionless=True) == 0), "There should be no protocol work." + assert ( + integrator.get_protocol_work(dimensionless=True) == 0 + ), "There should be no protocol work." # Check that a single step accumulates protocol work pe_1 = context.getState(getEnergy=True).getPotentialEnergy() - perturbed_K=99.0 * unit.kilocalories_per_mole / unit.angstroms**2 - context.setParameter('testsystems_HarmonicOscillator_K', perturbed_K) + perturbed_K = 99.0 * unit.kilocalories_per_mole / unit.angstroms**2 + context.setParameter("testsystems_HarmonicOscillator_K", perturbed_K) pe_2 = context.getState(getEnergy=True).getPotentialEnergy() integrator.step(1) - assert (integrator.get_protocol_work(dimensionless=True) != 0), "There should be protocol work after perturbing." - assert (integrator.protocol_work == (pe_2 - pe_1)), "The potential energy difference should be equal to protocol work." + assert ( + integrator.get_protocol_work(dimensionless=True) != 0 + ), "There should be protocol work after perturbing." + assert integrator.protocol_work == ( + pe_2 - pe_1 + ), "The potential energy difference should be equal to protocol work." del context, integrator integrator = integrators.VVVRIntegrator(temperature) @@ -363,23 +430,36 @@ def test_external_protocol_work_accumulation(): class TestExternalPerturbationLangevinIntegrator(TestCase): - - def create_system(self, testsystem, parameter_name, parameter_initial, temperature = 298.0 * unit.kelvin, platform_name='Reference'): + def create_system( + self, + testsystem, + parameter_name, + parameter_initial, + temperature=298.0 * unit.kelvin, + platform_name="Reference", + ): """ Create an example system to be used by other tests """ system, topology = testsystem.system, testsystem.topology - integrator = integrators.ExternalPerturbationLangevinIntegrator(splitting="O V R V O", temperature=temperature) + integrator = integrators.ExternalPerturbationLangevinIntegrator( + splitting="O V R V O", temperature=temperature + ) # Create the context platform = openmm.Platform.getPlatformByName(platform_name) - if platform_name in ['CPU', 'CUDA']: + if platform_name in ["CPU", "CUDA"]: try: - platform.setPropertyDefaultValue('DeterministicForces', 'true') + platform.setPropertyDefaultValue("DeterministicForces", "true") except Exception as e: - mm_min_version = '7.2.0' - if platform_name == 'CPU' and openmm.version.short_version < mm_min_version: - print("Deterministic CPU forces not present in versions of OpenMM prior to {}".format(mm_min_version)) + mm_min_version = "7.2.0" + if ( + platform_name == "CPU" + and openmm.version.short_version < mm_min_version + ): + print( + f"Deterministic CPU forces not present in versions of OpenMM prior to {mm_min_version}" + ) else: raise e context = openmm.Context(system, integrator, platform) @@ -389,7 +469,16 @@ def create_system(self, testsystem, parameter_name, parameter_initial, temperatu return context, integrator - def run_ncmc(self, context, integrator, temperature, nsteps, parameter_name, parameter_initial, parameter_final): + def run_ncmc( + self, + context, + integrator, + temperature, + nsteps, + parameter_name, + parameter_initial, + parameter_final, + ): """ A simple example of NCMC to be used with unit tests. The protocol work should be reset each time this command is called. @@ -408,7 +497,9 @@ def run_ncmc(self, context, integrator, temperature, nsteps, parameter_name, par integrator.step(1) for step in range(nsteps): lambda_value = float(step + 1) / float(nsteps) - parameter_value = parameter_initial * (1 - lambda_value) + parameter_final * lambda_value + parameter_value = ( + parameter_initial * (1 - lambda_value) + parameter_final * lambda_value + ) initial_energy = context.getState(getEnergy=True).getPotentialEnergy() context.setParameter(parameter_name, parameter_value) final_energy = context.getState(getEnergy=True).getPotentialEnergy() @@ -427,18 +518,24 @@ def test_initial_protocol_work(self): from openmm import app except ImportError: # OpenMM < 7.6 from simtk.openmm import app - parameter_name = 'lambda_electrostatics' + parameter_name = "lambda_electrostatics" temperature = 298.0 * unit.kelvin parameter_initial = 1.0 - platform_name = 'CPU' - nonbonded_method = 'CutoffPeriodic' + platform_name = "CPU" + nonbonded_method = "CutoffPeriodic" # Create the system - testsystem = testsystems.AlchemicalWaterBox(nonbondedMethod=getattr(app, nonbonded_method)) - testsystem.system.addForce(openmm.MonteCarloBarostat(1 * unit.atmospheres, temperature, 2)) - context, integrator = self.create_system(testsystem, parameter_name, parameter_initial, temperature, platform_name) + testsystem = testsystems.AlchemicalWaterBox( + nonbondedMethod=getattr(app, nonbonded_method) + ) + testsystem.system.addForce( + openmm.MonteCarloBarostat(1 * unit.atmospheres, temperature, 2) + ) + context, integrator = self.create_system( + testsystem, parameter_name, parameter_initial, temperature, platform_name + ) - assert (integrator.get_protocol_work(dimensionless=True) == 0) + assert integrator.get_protocol_work(dimensionless=True) == 0 integrator.step(5) assert np.allclose(integrator.get_protocol_work(dimensionless=True), 0) @@ -452,16 +549,20 @@ def test_reset_protocol_work(self): except ImportError: # OpenMM < 7.6 from simtk.openmm import app - parameter_name = 'lambda_electrostatics' + parameter_name = "lambda_electrostatics" temperature = 298.0 * unit.kelvin parameter_initial = 1.0 parameter_final = 0.0 - platform_name = 'CPU' - nonbonded_method = 'CutoffPeriodic' + platform_name = "CPU" + nonbonded_method = "CutoffPeriodic" # Creating the test system with a high frequency barostat. - testsystem = testsystems.AlchemicalAlanineDipeptide(nonbondedMethod=getattr(app, nonbonded_method)) - context, integrator = self.create_system(testsystem, parameter_name, parameter_initial, temperature, platform_name) + testsystem = testsystems.AlchemicalAlanineDipeptide( + nonbondedMethod=getattr(app, nonbonded_method) + ) + context, integrator = self.create_system( + testsystem, parameter_name, parameter_initial, temperature, platform_name + ) # Number of NCMC steps nsteps = 20 @@ -473,8 +574,16 @@ def test_reset_protocol_work(self): # Reseting the protocol work inside the integrator integrator.reset_protocol_work() integrator.reset() - external_protocol_work, integrator_protocol_work = self.run_ncmc(context, integrator, temperature, nsteps, parameter_name, parameter_initial, parameter_final) - assert abs(external_protocol_work - integrator_protocol_work) < 1.E-5 + external_protocol_work, integrator_protocol_work = self.run_ncmc( + context, + integrator, + temperature, + nsteps, + parameter_name, + parameter_initial, + parameter_final, + ) + assert abs(external_protocol_work - integrator_protocol_work) < 1.0e-5 def test_ncmc_update_parameters_in_context(self): """ @@ -492,14 +601,23 @@ def test_ncmc_update_parameters_in_context(self): size = 20.0 temperature = 298.0 * unit.kelvin kT = kB * temperature - nonbonded_method = 'CutoffPeriodic' - platform_name = 'CPU' - timestep = 1. * unit.femtoseconds - collision_rate = 90. / unit.picoseconds - - wbox = testsystems.WaterBox(box_edge=size*unit.angstrom, cutoff=9.*unit.angstrom, nonbondedMethod=getattr(app, nonbonded_method)) + nonbonded_method = "CutoffPeriodic" + platform_name = "CPU" + timestep = 1.0 * unit.femtoseconds + collision_rate = 90.0 / unit.picoseconds + + wbox = testsystems.WaterBox( + box_edge=size * unit.angstrom, + cutoff=9.0 * unit.angstrom, + nonbondedMethod=getattr(app, nonbonded_method), + ) - integrator = integrators.ExternalPerturbationLangevinIntegrator(splitting="V R O R V", temperature=temperature, timestep=timestep, collision_rate=collision_rate) + integrator = integrators.ExternalPerturbationLangevinIntegrator( + splitting="V R O R V", + temperature=temperature, + timestep=timestep, + collision_rate=collision_rate, + ) # Create context platform = openmm.Platform.getPlatformByName(platform_name) @@ -509,13 +627,24 @@ def test_ncmc_update_parameters_in_context(self): context.setVelocitiesToTemperature(temperature) def switchoff(omm_force, omm_context, frac=0.9): - omm_force.setParticleParameters(0, charge=-0.834 * frac, sigma=0.3150752406575124 * frac, epsilon=0.635968 * frac) - omm_force.setParticleParameters(1, charge=0.417 * frac, sigma=0, epsilon=1 * frac) - omm_force.setParticleParameters(2, charge=0.417 * frac, sigma=0, epsilon=1 * frac) + omm_force.setParticleParameters( + 0, + charge=-0.834 * frac, + sigma=0.3150752406575124 * frac, + epsilon=0.635968 * frac, + ) + omm_force.setParticleParameters( + 1, charge=0.417 * frac, sigma=0, epsilon=1 * frac + ) + omm_force.setParticleParameters( + 2, charge=0.417 * frac, sigma=0, epsilon=1 * frac + ) omm_force.updateParametersInContext(omm_context) def switchon(omm_force, omm_context): - omm_force.setParticleParameters(0, charge=-0.834, sigma=0.3150752406575124, epsilon=0.635968) + omm_force.setParticleParameters( + 0, charge=-0.834, sigma=0.3150752406575124, epsilon=0.635968 + ) omm_force.setParticleParameters(1, charge=0.417, sigma=0, epsilon=1) omm_force.setParticleParameters(2, charge=0.417, sigma=0, epsilon=1) omm_force.updateParametersInContext(omm_context) @@ -544,39 +673,60 @@ def switchon(omm_force, omm_context): external_protocol_work += (final_energy - initial_energy) / kT integrator.step(1) integrator_protocol_work = integrator.get_protocol_work(dimensionless=True) - assert abs(external_protocol_work - integrator_protocol_work) < 1.E-5 + assert abs(external_protocol_work - integrator_protocol_work) < 1.0e-5 # Return to unperturbed state switchon(nonbonded_force, context) - def test_protocol_work_accumulation_harmonic_oscillator(self): - """Testing protocol work accumulation for ExternalPerturbationLangevinIntegrator with HarmonicOscillator - """ + """Testing protocol work accumulation for ExternalPerturbationLangevinIntegrator with HarmonicOscillator""" testsystem = testsystems.HarmonicOscillator() - parameter_name = 'testsystems_HarmonicOscillator_x0' + parameter_name = "testsystems_HarmonicOscillator_x0" parameter_initial = 0.0 * unit.angstroms parameter_final = 10.0 * unit.angstroms - for platform_name in ['Reference', 'CPU']: - self.compare_external_protocol_work_accumulation(testsystem, parameter_name, parameter_initial, parameter_final, platform_name=platform_name) + for platform_name in ["Reference", "CPU"]: + self.compare_external_protocol_work_accumulation( + testsystem, + parameter_name, + parameter_initial, + parameter_final, + platform_name=platform_name, + ) def test_protocol_work_accumulation_waterbox(self): - """Testing protocol work accumulation for ExternalPerturbationLangevinIntegrator with AlchemicalWaterBox - """ + """Testing protocol work accumulation for ExternalPerturbationLangevinIntegrator with AlchemicalWaterBox""" try: from openmm import app except ImportError: # OpenMM < 7.6 from simtk.openmm import app - parameter_name = 'lambda_electrostatics' + parameter_name = "lambda_electrostatics" parameter_initial = 1.0 parameter_final = 0.0 - platform_names = [ openmm.Platform.getPlatform(index).getName() for index in range(openmm.Platform.getNumPlatforms()) ] - for nonbonded_method in ['CutoffPeriodic']: - testsystem = testsystems.AlchemicalWaterBox(nonbondedMethod=getattr(app, nonbonded_method), box_edge=12.0*unit.angstroms, cutoff=5.0*unit.angstroms) + platform_names = [ + openmm.Platform.getPlatform(index).getName() + for index in range(openmm.Platform.getNumPlatforms()) + ] + for nonbonded_method in ["CutoffPeriodic"]: + testsystem = testsystems.AlchemicalWaterBox( + nonbondedMethod=getattr(app, nonbonded_method), + box_edge=12.0 * unit.angstroms, + cutoff=5.0 * unit.angstroms, + ) for platform_name in platform_names: - name = '%s %s %s' % (testsystem.name, nonbonded_method, platform_name) - self.compare_external_protocol_work_accumulation(testsystem, parameter_name, parameter_initial, parameter_final, platform_name=platform_name, name=name) - - def test_protocol_work_accumulation_waterbox_barostat(self, temperature=300*unit.kelvin): + name = "{} {} {}".format( + testsystem.name, nonbonded_method, platform_name + ) + self.compare_external_protocol_work_accumulation( + testsystem, + parameter_name, + parameter_initial, + parameter_final, + platform_name=platform_name, + name=name, + ) + + def test_protocol_work_accumulation_waterbox_barostat( + self, temperature=300 * unit.kelvin + ): """ Testing protocol work accumulation for ExternalPerturbationLangevinIntegrator with AlchemicalWaterBox with barostat. For brevity, only using CutoffPeriodic as the non-bonded method. @@ -585,23 +735,46 @@ def test_protocol_work_accumulation_waterbox_barostat(self, temperature=300*unit from openmm import app except ImportError: # OpenMM < 7.6 from simtk.openmm import app - parameter_name = 'lambda_electrostatics' + parameter_name = "lambda_electrostatics" parameter_initial = 1.0 parameter_final = 0.0 - platform_names = [ openmm.Platform.getPlatform(index).getName() for index in range(openmm.Platform.getNumPlatforms()) ] - nonbonded_method = 'CutoffPeriodic' - testsystem = testsystems.AlchemicalWaterBox(nonbondedMethod=getattr(app, nonbonded_method), box_edge=12.0*unit.angstroms, cutoff=5.0*unit.angstroms) + platform_names = [ + openmm.Platform.getPlatform(index).getName() + for index in range(openmm.Platform.getNumPlatforms()) + ] + nonbonded_method = "CutoffPeriodic" + testsystem = testsystems.AlchemicalWaterBox( + nonbondedMethod=getattr(app, nonbonded_method), + box_edge=12.0 * unit.angstroms, + cutoff=5.0 * unit.angstroms, + ) # Adding the barostat with a high frequency - testsystem.system.addForce(openmm.MonteCarloBarostat(1*unit.atmospheres, temperature, 2)) + testsystem.system.addForce( + openmm.MonteCarloBarostat(1 * unit.atmospheres, temperature, 2) + ) for platform_name in platform_names: - name = '%s %s %s' % (testsystem.name, nonbonded_method, platform_name) - self.compare_external_protocol_work_accumulation(testsystem, parameter_name, parameter_initial, parameter_final, platform_name=platform_name, name=name) - - def compare_external_protocol_work_accumulation(self, testsystem, parameter_name, parameter_initial, parameter_final, platform_name='Reference', name=None): - """Compare external work accumulation between Reference and CPU platforms. - """ + name = "{} {} {}".format(testsystem.name, nonbonded_method, platform_name) + self.compare_external_protocol_work_accumulation( + testsystem, + parameter_name, + parameter_initial, + parameter_final, + platform_name=platform_name, + name=name, + ) + + def compare_external_protocol_work_accumulation( + self, + testsystem, + parameter_name, + parameter_initial, + parameter_final, + platform_name="Reference", + name=None, + ): + """Compare external work accumulation between Reference and CPU platforms.""" if name is None: name = testsystem.name @@ -611,15 +784,22 @@ def compare_external_protocol_work_accumulation(self, testsystem, parameter_name temperature = 298.0 * unit.kelvin kT = kB * temperature - context, integrator = self.create_system(testsystem, parameter_name, parameter_initial, - temperature=temperature, platform_name='Reference') + context, integrator = self.create_system( + testsystem, + parameter_name, + parameter_initial, + temperature=temperature, + platform_name="Reference", + ) external_protocol_work = 0.0 nsteps = 20 integrator.step(1) for step in range(nsteps): - lambda_value = float(step+1) / float(nsteps) - parameter_value = parameter_initial * (1-lambda_value) + parameter_final * lambda_value + lambda_value = float(step + 1) / float(nsteps) + parameter_value = ( + parameter_initial * (1 - lambda_value) + parameter_final * lambda_value + ) initial_energy = context.getState(getEnergy=True).getPotentialEnergy() context.setParameter(parameter_name, parameter_value) final_energy = context.getState(getEnergy=True).getPotentialEnergy() @@ -628,23 +808,34 @@ def compare_external_protocol_work_accumulation(self, testsystem, parameter_name integrator.step(1) integrator_protocol_work = integrator.get_protocol_work(dimensionless=True) - message = '\n' - message += 'protocol work discrepancy noted for %s on platform %s\n' % (name, platform_name) - message += 'step %5d : external %16e kT | integrator %16e kT | difference %16e kT' % (step, external_protocol_work, integrator_protocol_work, external_protocol_work - integrator_protocol_work) - self.assertAlmostEqual(external_protocol_work, integrator_protocol_work, msg=message) + message = "\n" + message += "protocol work discrepancy noted for {} on platform {}\n".format( + name, platform_name + ) + message += ( + "step %5d : external %16e kT | integrator %16e kT | difference %16e kT" + % ( + step, + external_protocol_work, + integrator_protocol_work, + external_protocol_work - integrator_protocol_work, + ) + ) + self.assertAlmostEqual( + external_protocol_work, integrator_protocol_work, msg=message + ) del context, integrator def test_temperature_getter_setter(): """Test that temperature setter and getter modify integrator variables.""" - temperature = 350*unit.kelvin + temperature = 350 * unit.kelvin test = testsystems.HarmonicOscillator() custom_integrators = get_all_custom_integrators() thermostated_integrators = dict(get_all_custom_integrators(only_thermostated=True)) for integrator_name, integrator_class in custom_integrators: - # If this is not a ThermostatedIntegrator, the interface should not be added. if integrator_name not in thermostated_integrators: integrator = integrator_class() @@ -652,12 +843,13 @@ def test_temperature_getter_setter(): # because it is being initialized without a system. That's OK. assert ThermostatedIntegrator.is_thermostated(integrator) is False assert ThermostatedIntegrator.restore_interface(integrator) is False - assert not hasattr(integrator, 'getTemperature') + assert not hasattr(integrator, "getTemperature") continue # Test original integrator. - check_integrator_temperature_getter_setter.description = ('Test temperature setter and ' - 'getter of {}').format(integrator_name) + check_integrator_temperature_getter_setter.description = ( + "Test temperature setter and " "getter of {}" + ).format(integrator_name) integrator = integrator_class(temperature=temperature) # NoseHooverChainVelocityVerletIntegrator will print a severe warning here, # because it is being initialized without a system. That's OK. @@ -670,8 +862,9 @@ def test_temperature_getter_setter(): del context # Test Context integrator wrapper. - check_integrator_temperature_getter_setter.description = ('Test temperature wrapper ' - 'of {}').format(integrator_name) + check_integrator_temperature_getter_setter.description = ( + "Test temperature wrapper " "of {}" + ).format(integrator_name) integrator = integrator_class() # NoseHooverChainVelocityVerletIntegrator will print a severe warning here, # because it is being initialized without a system. That's OK. @@ -698,34 +891,47 @@ def test_restorable_integrator_copy(): assert isinstance(integrator_copied, integrator_class) assert set(integrator_copied.__dict__.keys()) == set(integrator.__dict__.keys()) + def run_alchemical_langevin_integrator(nsteps=0, splitting="O { V R H R V } O"): """Check that the AlchemicalLangevinSplittingIntegrator reproduces the analytical free energy difference for a harmonic oscillator deformation, using BAR. Up to 6*sigma is tolerated for error. The total work (protocol work + shadow work) is used. """ - #max deviation from the calculated free energy + # max deviation from the calculated free energy NSIGMA_MAX = 6 n_iterations = 200 # number of forward and reverse protocols # These are the alchemical functions that will be used to control the system temperature = 298.0 * unit.kelvin - sigma = 1.0 * unit.angstrom # stddev of harmonic oscillator - kT = kB * temperature # thermal energy - beta = 1.0 / kT # inverse thermal energy - K = kT / sigma**2 # spring constant corresponding to sigma + sigma = 1.0 * unit.angstrom # stddev of harmonic oscillator + kT = kB * temperature # thermal energy + beta = 1.0 / kT # inverse thermal energy + K = kT / sigma**2 # spring constant corresponding to sigma mass = 39.948 * unit.amu - period = unit.sqrt(mass/K) # period of harmonic oscillator + period = unit.sqrt(mass / K) # period of harmonic oscillator timestep = period / 20.0 collision_rate = 1.0 / period dF_analytical = 1.0 parameters = dict() - parameters['testsystems_HarmonicOscillator_x0'] = (0 * sigma, 2 * sigma) - parameters['testsystems_HarmonicOscillator_U0'] = (0 * kT, 1 * kT) + parameters["testsystems_HarmonicOscillator_x0"] = (0 * sigma, 2 * sigma) + parameters["testsystems_HarmonicOscillator_U0"] = (0 * kT, 1 * kT) alchemical_functions = { - 'forward' : { name : '(1-lambda)*%f + lambda*%f' % (value[0].value_in_unit_system(unit.md_unit_system), value[1].value_in_unit_system(unit.md_unit_system)) for (name, value) in parameters.items() }, - 'reverse' : { name : '(1-lambda)*%f + lambda*%f' % (value[1].value_in_unit_system(unit.md_unit_system), value[0].value_in_unit_system(unit.md_unit_system)) for (name, value) in parameters.items() }, - } + "forward": { + name: "(1-lambda)*{:f} + lambda*{:f}".format( + value[0].value_in_unit_system(unit.md_unit_system), + value[1].value_in_unit_system(unit.md_unit_system), + ) + for (name, value) in parameters.items() + }, + "reverse": { + name: "(1-lambda)*{:f} + lambda*{:f}".format( + value[1].value_in_unit_system(unit.md_unit_system), + value[0].value_in_unit_system(unit.md_unit_system), + ) + for (name, value) in parameters.items() + }, + } # Create harmonic oscillator testsystem testsystem = testsystems.HarmonicOscillator(K=K, mass=mass) @@ -733,21 +939,29 @@ def run_alchemical_langevin_integrator(nsteps=0, splitting="O { V R H R V } O"): positions = testsystem.positions # Get equilibrium samples from initial and final states - burn_in = 5 * 20 # 5 periods - thinning = 5 * 20 # 5 periods + burn_in = 5 * 20 # 5 periods + thinning = 5 * 20 # 5 periods # Collect forward and reverse work values - directions = ['forward', 'reverse'] - work = { direction : np.zeros([n_iterations], np.float64) for direction in directions } + directions = ["forward", "reverse"] + work = {direction: np.zeros([n_iterations], np.float64) for direction in directions} platform = openmm.Platform.getPlatformByName("Reference") for direction in directions: positions = testsystem.positions # Create equilibrium and nonequilibrium integrators - equilibrium_integrator = GHMCIntegrator(temperature=temperature, collision_rate=collision_rate, timestep=timestep) - nonequilibrium_integrator = AlchemicalNonequilibriumLangevinIntegrator(temperature=temperature, collision_rate=collision_rate, timestep=timestep, - alchemical_functions=alchemical_functions[direction], splitting=splitting, nsteps_neq=nsteps, - measure_shadow_work=True) + equilibrium_integrator = GHMCIntegrator( + temperature=temperature, collision_rate=collision_rate, timestep=timestep + ) + nonequilibrium_integrator = AlchemicalNonequilibriumLangevinIntegrator( + temperature=temperature, + collision_rate=collision_rate, + timestep=timestep, + alchemical_functions=alchemical_functions[direction], + splitting=splitting, + nsteps_neq=nsteps, + measure_shadow_work=True, + ) # Create compound integrator compound_integrator = openmm.CompoundIntegrator() @@ -776,38 +990,66 @@ def run_alchemical_langevin_integrator(nsteps=0, splitting="O { V R H R V } O"): nonequilibrium_integrator.reset() # Check initial conditions after reset - current_lambda = nonequilibrium_integrator.getGlobalVariableByName('lambda') - assert current_lambda == 0.0, 'initial lambda should be 0.0 (was %f)' % current_lambda - current_step = nonequilibrium_integrator.getGlobalVariableByName('step') - assert current_step == 0.0, 'initial step should be 0 (was %f)' % current_step - - compound_integrator.step(max(1, nsteps)) # need to execute at least one step - work[direction][iteration] = nonequilibrium_integrator.get_total_work(dimensionless=True) + current_lambda = nonequilibrium_integrator.getGlobalVariableByName("lambda") + assert current_lambda == 0.0, ( + "initial lambda should be 0.0 (was %f)" % current_lambda + ) + current_step = nonequilibrium_integrator.getGlobalVariableByName("step") + assert current_step == 0.0, ( + "initial step should be 0 (was %f)" % current_step + ) + + compound_integrator.step( + max(1, nsteps) + ) # need to execute at least one step + work[direction][iteration] = nonequilibrium_integrator.get_total_work( + dimensionless=True + ) # Check final conditions before reset - current_lambda = nonequilibrium_integrator.getGlobalVariableByName('lambda') - assert current_lambda == 1.0, 'final lambda should be 1.0 (was %f) for splitting %s' % (current_lambda, splitting) - current_step = nonequilibrium_integrator.getGlobalVariableByName('step') - assert int(current_step) == max(1,nsteps), 'final step should be %d (was %f) for splitting %s' % (max(1,nsteps), current_step, splitting) + current_lambda = nonequilibrium_integrator.getGlobalVariableByName("lambda") + assert ( + current_lambda == 1.0 + ), "final lambda should be 1.0 (was {:f}) for splitting {}".format( + current_lambda, splitting + ) + current_step = nonequilibrium_integrator.getGlobalVariableByName("step") + assert int(current_step) == max(1, nsteps), ( + "final step should be %d (was %f) for splitting %s" + % (max(1, nsteps), current_step, splitting) + ) nonequilibrium_integrator.reset() # Clean up del context del compound_integrator - results = _pymbar_bar(work['forward'], work['reverse']) - nsigma = np.abs(results['Delta_f'] - dF_analytical) / results['dDelta_f'] + results = _pymbar_bar(work["forward"], work["reverse"]) + nsigma = np.abs(results["Delta_f"] - dF_analytical) / results["dDelta_f"] print( "analytical DeltaF: {:12.4f}, DeltaF: {:12.4f}, dDeltaF: {:12.4f}, nsigma: {:12.1f}".format( - dF_analytical, results['Delta_f'], results['dDelta_f'], nsigma, + dF_analytical, + results["Delta_f"], + results["dDelta_f"], + nsigma, ) ) if nsigma > NSIGMA_MAX: - raise Exception("The free energy difference for the nonequilibrium switching for splitting '%s' and %d steps is not zero within statistical error." % (splitting, nsteps)) + raise Exception( + "The free energy difference for the nonequilibrium switching for splitting '%s' and %d steps is not zero within statistical error." + % (splitting, nsteps) + ) -def test_periodic_langevin_integrator(splitting="H V R O R V H", ncycles=40, nsteps_neq=1000, nsteps_eq=1000, write_trajectory=False): + +def test_periodic_langevin_integrator( + splitting="H V R O R V H", + ncycles=40, + nsteps_neq=1000, + nsteps_eq=1000, + write_trajectory=False, +): """ Test PeriodicNonequilibriumIntegrator @@ -824,30 +1066,38 @@ def test_periodic_langevin_integrator(splitting="H V R O R V H", ncycles=40, nst write_trajectory : bool, optional, default=True If True, will generate a PDB file that contains the harmonic oscillator trajectory """ - #max deviation from the calculated free energy + # max deviation from the calculated free energy NSIGMA_MAX = 6 # These are the alchemical functions that will be used to control the system temperature = 298.0 * unit.kelvin - sigma = 1.0 * unit.angstrom # stddev of harmonic oscillator - kT = kB * temperature # thermal energy - beta = 1.0 / kT # inverse thermal energy - K = kT / sigma**2 # spring constant corresponding to sigma + sigma = 1.0 * unit.angstrom # stddev of harmonic oscillator + kT = kB * temperature # thermal energy + beta = 1.0 / kT # inverse thermal energy + K = kT / sigma**2 # spring constant corresponding to sigma mass = 39.948 * unit.amu - period = unit.sqrt(mass/K) # period of harmonic oscillator + period = unit.sqrt(mass / K) # period of harmonic oscillator timestep = period / 20.0 collision_rate = 1.0 / period dF_analytical = 5.0 parameters = dict() displacement = 10 * sigma - parameters['testsystems_HarmonicOscillator_x0'] = (0 * sigma, displacement) - parameters['testsystems_HarmonicOscillator_U0'] = (0 * kT, 5 * kT) - integrator_kwargs = {'temperature':temperature, - 'collision_rate': collision_rate, - 'timestep': timestep, - 'measure_shadow_work': False, - 'measure_heat': False} - alchemical_functions = { name : '(1-lambda)*%f + lambda*%f' % (value[0].value_in_unit_system(unit.md_unit_system), value[1].value_in_unit_system(unit.md_unit_system)) for (name, value) in parameters.items() } + parameters["testsystems_HarmonicOscillator_x0"] = (0 * sigma, displacement) + parameters["testsystems_HarmonicOscillator_U0"] = (0 * kT, 5 * kT) + integrator_kwargs = { + "temperature": temperature, + "collision_rate": collision_rate, + "timestep": timestep, + "measure_shadow_work": False, + "measure_heat": False, + } + alchemical_functions = { + name: "(1-lambda)*{:f} + lambda*{:f}".format( + value[0].value_in_unit_system(unit.md_unit_system), + value[1].value_in_unit_system(unit.md_unit_system), + ) + for (name, value) in parameters.items() + } # Create harmonic oscillator testsystem testsystem = testsystems.HarmonicOscillator(K=K, mass=mass) system = testsystem.system @@ -856,11 +1106,14 @@ def test_periodic_langevin_integrator(splitting="H V R O R V H", ncycles=40, nst # Create integrator from openmmtools.integrators import PeriodicNonequilibriumIntegrator - integrator = PeriodicNonequilibriumIntegrator(alchemical_functions=alchemical_functions, - splitting=splitting, - nsteps_eq=nsteps_eq, - nsteps_neq=nsteps_neq, - **integrator_kwargs) + + integrator = PeriodicNonequilibriumIntegrator( + alchemical_functions=alchemical_functions, + splitting=splitting, + nsteps_eq=nsteps_eq, + nsteps_neq=nsteps_neq, + **integrator_kwargs, + ) platform = openmm.Platform.getPlatformByName("Reference") context = openmm.Context(system, integrator, platform) context.setPositions(positions) @@ -873,22 +1126,25 @@ def test_periodic_langevin_integrator(splitting="H V R O R V H", ncycles=40, nst from openmm.app import PDBFile except ImportError: # OpenMM < 7.6 from simtk.openmm.app import PDBFile - filename = 'neq-trajectory.pdb' - print(f'Writing trajectory to {filename}') - with open(filename, 'wt') as outfile: + filename = "neq-trajectory.pdb" + print(f"Writing trajectory to {filename}") + with open(filename, "w") as outfile: # Write reference import copy + pos1 = copy.deepcopy(positions) pos2 = copy.deepcopy(positions) - pos2[0,0] += displacement + pos2[0, 0] += displacement PDBFile.writeModel(topology, pos1, outfile) PDBFile.writeModel(topology, pos2, outfile) interval = 10 PDBFile.writeModel(topology, positions, outfile, modelIndex=0) - for step in range(0,2*nsteps_per_cycle,interval): + for step in range(0, 2 * nsteps_per_cycle, interval): integrator.step(interval) - positions = context.getState(getPositions=True).getPositions(asNumpy=True) + positions = context.getState(getPositions=True).getPositions( + asNumpy=True + ) PDBFile.writeModel(topology, positions, outfile, modelIndex=step) PDBFile.writeModel(topology, pos1, outfile) @@ -903,26 +1159,38 @@ def test_periodic_langevin_integrator(splitting="H V R O R V H", ncycles=40, nst for i in range(nsteps_eq): integrator.step(1) step += 1 - assert integrator.getGlobalVariableByName("step") == (step % nsteps_per_cycle) + assert integrator.getGlobalVariableByName("step") == ( + step % nsteps_per_cycle + ) assert np.isclose(integrator.getGlobalVariableByName("lambda"), 0.0) # neq (0 -> 1) for i in range(nsteps_neq): integrator.step(1) step += 1 - assert integrator.getGlobalVariableByName("step") == (step % nsteps_per_cycle) - assert np.isclose(integrator.getGlobalVariableByName("lambda"), (i+1)/nsteps_neq), f'{step} {integrator.getGlobalVariableByName("lambda")}' + assert integrator.getGlobalVariableByName("step") == ( + step % nsteps_per_cycle + ) + assert np.isclose( + integrator.getGlobalVariableByName("lambda"), (i + 1) / nsteps_neq + ), f'{step} {integrator.getGlobalVariableByName("lambda")}' # eq (1) for i in range(nsteps_eq): integrator.step(1) step += 1 - assert integrator.getGlobalVariableByName("step") == (step % nsteps_per_cycle) + assert integrator.getGlobalVariableByName("step") == ( + step % nsteps_per_cycle + ) assert np.isclose(integrator.getGlobalVariableByName("lambda"), 1.0) # neq (1 -> 0) for i in range(nsteps_neq): integrator.step(1) step += 1 - assert integrator.getGlobalVariableByName("step") == (step % nsteps_per_cycle) - assert np.isclose(integrator.getGlobalVariableByName("lambda"), 1 - (i+1)/nsteps_neq) + assert integrator.getGlobalVariableByName("step") == ( + step % nsteps_per_cycle + ) + assert np.isclose( + integrator.getGlobalVariableByName("lambda"), 1 - (i + 1) / nsteps_neq + ) assert np.isclose(integrator.getGlobalVariableByName("lambda"), 0.0) @@ -952,25 +1220,32 @@ def test_periodic_langevin_integrator(splitting="H V R O R V H", ncycles=40, nst print(np.array(reverse_works).std()) results = _pymbar_bar(np.array(forward_works), np.array(reverse_works)) - nsigma = np.abs(results['Delta_f'] - dF_analytical) / results['dDelta_f'] + nsigma = np.abs(results["Delta_f"] - dF_analytical) / results["dDelta_f"] assert np.isclose(integrator.getGlobalVariableByName("lambda"), 0.0) print( "analytical DeltaF: {:12.4f}, DeltaF: {:12.4f}, dDeltaF: {:12.4f}, nsigma: {:12.1f}".format( - dF_analytical, results['Delta_f'], results['dDelta_f'], nsigma, + dF_analytical, + results["Delta_f"], + results["dDelta_f"], + nsigma, ) ) if nsigma > NSIGMA_MAX: - raise Exception(f"The free energy difference for the nonequilibrium switching for splitting {splitting} is not zero within statistical error.") + raise Exception( + f"The free energy difference for the nonequilibrium switching for splitting {splitting} is not zero within statistical error." + ) # Clean up del context del integrator + def test_alchemical_langevin_integrator(): for splitting in ["O V R H R V O", "H R V O V R H", "O { V R H R V } O"]: for nsteps in [0, 1, 10]: run_alchemical_langevin_integrator(splitting=splitting, nsteps=nsteps) -if __name__=="__main__": + +if __name__ == "__main__": test_alchemical_langevin_integrator() diff --git a/openmmtools/tests/test_integrators_and_testsystems.py b/openmmtools/tests/test_integrators_and_testsystems.py index a9e9ec5b9..090fe9d66 100644 --- a/openmmtools/tests/test_integrators_and_testsystems.py +++ b/openmmtools/tests/test_integrators_and_testsystems.py @@ -1,17 +1,17 @@ #!/usr/local/bin/env python -#============================================================================================= +# ============================================================================================= # MODULE DOCSTRING -#============================================================================================= +# ============================================================================================= """ Test combinations of custom integrators and testsystems to make sure there are no namespace collisions. """ -#============================================================================================= +# ============================================================================================= # GLOBAL IMPORTS -#============================================================================================= +# ============================================================================================= import re import inspect @@ -24,15 +24,16 @@ from simtk import unit from simtk import openmm -#============================================================================================= +# ============================================================================================= # CONSTANTS -#============================================================================================= +# ============================================================================================= kB = unit.BOLTZMANN_CONSTANT_kB * unit.AVOGADRO_CONSTANT_NA -#============================================================================================= +# ============================================================================================= # UTILITY SUBROUTINES -#============================================================================================= +# ============================================================================================= + def check_combination(integrator, test, platform=None): """ @@ -59,9 +60,10 @@ def check_combination(integrator, test, platform=None): return -#============================================================================================= +# ============================================================================================= # TESTS -#============================================================================================= +# ============================================================================================= + def test_integrators_and_testsystems(): """ @@ -71,19 +73,22 @@ def test_integrators_and_testsystems(): from openmmtools import integrators, testsystems # Get all the CustomIntegrators in the integrators module. - is_integrator = lambda x: (inspect.isclass(x) and - issubclass(x, openmm.CustomIntegrator) and - x != integrators.ThermostatedIntegrator) + is_integrator = lambda x: ( + inspect.isclass(x) + and issubclass(x, openmm.CustomIntegrator) + and x != integrators.ThermostatedIntegrator + ) custom_integrators = inspect.getmembers(integrators, predicate=is_integrator) def all_subclasses(cls): """Return list of all subclasses and subsubclasses for a given class.""" return cls.__subclasses__() + [s for s in cls.__subclasses__()] + testsystem_classes = all_subclasses(testsystems.TestSystem) - testsystem_names = [ cls.__name__ for cls in testsystem_classes ] + testsystem_names = [cls.__name__ for cls in testsystem_classes] # Use Reference platform. - platform = openmm.Platform.getPlatformByName('Reference') + platform = openmm.Platform.getPlatformByName("Reference") for testsystem_name in testsystem_names: # Create testsystem. @@ -106,5 +111,7 @@ def all_subclasses(cls): # Create test. f = partial(check_combination, integrator, testsystem, platform) - f.description = "Checking combination of %s and %s" % (integrator_name, testsystem_name) + f.description = "Checking combination of {} and {}".format( + integrator_name, testsystem_name + ) yield f diff --git a/openmmtools/tests/test_mcmc.py b/openmmtools/tests/test_mcmc.py index 5d6c2da31..24f3a03c7 100644 --- a/openmmtools/tests/test_mcmc.py +++ b/openmmtools/tests/test_mcmc.py @@ -32,17 +32,37 @@ # Test various combinations of systems and MCMC schemes analytical_testsystems = [ - ("HarmonicOscillator", testsystems.HarmonicOscillator(), - GHMCMove(timestep=10.0*unit.femtoseconds, n_steps=100)), - ("HarmonicOscillator", testsystems.HarmonicOscillator(), - WeightedMove([(GHMCMove(timestep=10.0 * unit.femtoseconds, n_steps=100), 0.5), - (HMCMove(timestep=10 * unit.femtosecond, n_steps=10), 0.5)])), - ("HarmonicOscillatorArray", testsystems.HarmonicOscillatorArray(N=4), - LangevinDynamicsMove(timestep=10.0*unit.femtoseconds, n_steps=100)), - ("IdealGas", testsystems.IdealGas(nparticles=216), - SequenceMove([HMCMove(timestep=10*unit.femtosecond, n_steps=10), - MonteCarloBarostatMove()])) - ] + ( + "HarmonicOscillator", + testsystems.HarmonicOscillator(), + GHMCMove(timestep=10.0 * unit.femtoseconds, n_steps=100), + ), + ( + "HarmonicOscillator", + testsystems.HarmonicOscillator(), + WeightedMove( + [ + (GHMCMove(timestep=10.0 * unit.femtoseconds, n_steps=100), 0.5), + (HMCMove(timestep=10 * unit.femtosecond, n_steps=10), 0.5), + ] + ), + ), + ( + "HarmonicOscillatorArray", + testsystems.HarmonicOscillatorArray(N=4), + LangevinDynamicsMove(timestep=10.0 * unit.femtoseconds, n_steps=100), + ), + ( + "IdealGas", + testsystems.IdealGas(nparticles=216), + SequenceMove( + [ + HMCMove(timestep=10 * unit.femtosecond, n_steps=10), + MonteCarloBarostatMove(), + ] + ), + ), +] NSIGMA_CUTOFF = 6.0 # cutoff for significance testing @@ -53,6 +73,7 @@ # TEST FUNCTIONS # ============================================================================= + def test_minimizer_all_testsystems(): # testsystem_classes = testsystems.TestSystem.__subclasses__() testsystem_classes = [testsystems.AlanineDipeptideVacuum] @@ -63,14 +84,14 @@ def test_minimizer_all_testsystems(): testsystem = testsystem_class() sampler_state = SamplerState(testsystem.positions) - thermodynamic_state = ThermodynamicState(testsystem.system, 300*unit.kelvin) + thermodynamic_state = ThermodynamicState(testsystem.system, 300 * unit.kelvin) # Create sampler for minimization. sampler = MCMCSampler(thermodynamic_state, sampler_state, move=None) sampler.minimize(max_iterations=0) # Check if NaN. - err_msg = 'Minimization of system {} yielded NaN'.format(class_name) + err_msg = f"Minimization of system {class_name} yielded NaN" assert not sampler_state.has_nan(), err_msg @@ -95,7 +116,7 @@ def subtest_mcmc_expectation(testsystem, move): temperature = 298.0 * unit.kelvin niterations = 500 # number of production iterations if system.usesPeriodicBoundaryConditions(): - pressure = 1.0*unit.atmosphere + pressure = 1.0 * unit.atmosphere else: pressure = None @@ -106,19 +127,29 @@ def subtest_mcmc_expectation(testsystem, move): # Create sampler and thermodynamic state. sampler_state = SamplerState(positions=positions) - thermodynamic_state = ThermodynamicState(system=system, - temperature=temperature, - pressure=pressure) + thermodynamic_state = ThermodynamicState( + system=system, temperature=temperature, pressure=pressure + ) # Create MCMC sampler sampler = MCMCSampler(thermodynamic_state, sampler_state, move=move) # Accumulate statistics. - x_n = np.zeros([niterations], np.float64) # x_n[i] is the x position of atom 1 after iteration i, in angstroms - potential_n = np.zeros([niterations], np.float64) # potential_n[i] is the potential energy after iteration i, in kT - kinetic_n = np.zeros([niterations], np.float64) # kinetic_n[i] is the kinetic energy after iteration i, in kT - temperature_n = np.zeros([niterations], np.float64) # temperature_n[i] is the instantaneous kinetic temperature from iteration i, in K - volume_n = np.zeros([niterations], np.float64) # volume_n[i] is the volume from iteration i, in K + x_n = np.zeros( + [niterations], np.float64 + ) # x_n[i] is the x position of atom 1 after iteration i, in angstroms + potential_n = np.zeros( + [niterations], np.float64 + ) # potential_n[i] is the potential energy after iteration i, in kT + kinetic_n = np.zeros( + [niterations], np.float64 + ) # kinetic_n[i] is the kinetic energy after iteration i, in kT + temperature_n = np.zeros( + [niterations], np.float64 + ) # temperature_n[i] is the instantaneous kinetic temperature from iteration i, in K + volume_n = np.zeros( + [niterations], np.float64 + ) # volume_n[i] is the volume from iteration i, in K for iteration in range(niterations): # Update sampler state. sampler.run(1) @@ -137,51 +168,88 @@ def subtest_mcmc_expectation(testsystem, move): volume_n[iteration] = volume / (unit.nanometers**3) # Compute expected statistics. - if (hasattr(testsystem, 'get_potential_expectation') and - testsystem.get_potential_standard_deviation(thermodynamic_state) / kT.unit != 0.0): - assert potential_n.std() != 0.0, 'Test {} shows no potential fluctuations'.format( - testsystem.__class__.__name__) - - potential_expectation = testsystem.get_potential_expectation(thermodynamic_state) / kT + if ( + hasattr(testsystem, "get_potential_expectation") + and testsystem.get_potential_standard_deviation(thermodynamic_state) / kT.unit + != 0.0 + ): + assert ( + potential_n.std() != 0.0 + ), "Test {} shows no potential fluctuations".format( + testsystem.__class__.__name__ + ) + + potential_expectation = ( + testsystem.get_potential_expectation(thermodynamic_state) / kT + ) [t0, g, Neff_max] = detect_equilibration(potential_n) potential_mean = potential_n[t0:].mean() dpotential_mean = potential_n[t0:].std() / np.sqrt(Neff_max) potential_error = potential_mean - potential_expectation nsigma = abs(potential_error) / dpotential_mean - err_msg = ('Potential energy expectation\n' - 'observed {:10.5f} +- {:10.5f}kT | expected {:10.5f} | ' - 'error {:10.5f} +- {:10.5f} ({:.1f} sigma) | t0 {:5d} | g {:5.1f} | Neff {:8.1f}\n' - '----------------------------------------------------------------------------').format( - potential_mean, dpotential_mean, potential_expectation, potential_error, dpotential_mean, nsigma, t0, g, Neff_max) + err_msg = ( + "Potential energy expectation\n" + "observed {:10.5f} +- {:10.5f}kT | expected {:10.5f} | " + "error {:10.5f} +- {:10.5f} ({:.1f} sigma) | t0 {:5d} | g {:5.1f} | Neff {:8.1f}\n" + "----------------------------------------------------------------------------" + ).format( + potential_mean, + dpotential_mean, + potential_expectation, + potential_error, + dpotential_mean, + nsigma, + t0, + g, + Neff_max, + ) assert nsigma <= NSIGMA_CUTOFF, err_msg.format() if debug: print(err_msg) elif debug: - print('Skipping potential expectation test.') - - if (hasattr(testsystem, 'get_volume_expectation') and - testsystem.get_volume_standard_deviation(thermodynamic_state) / (unit.nanometers**3) != 0.0): - assert volume_n.std() != 0.0, 'Test {} shows no volume fluctuations'.format( - testsystem.__class__.__name__) - - volume_expectation = testsystem.get_volume_expectation(thermodynamic_state) / (unit.nanometers**3) + print("Skipping potential expectation test.") + + if ( + hasattr(testsystem, "get_volume_expectation") + and testsystem.get_volume_standard_deviation(thermodynamic_state) + / (unit.nanometers**3) + != 0.0 + ): + assert volume_n.std() != 0.0, "Test {} shows no volume fluctuations".format( + testsystem.__class__.__name__ + ) + + volume_expectation = testsystem.get_volume_expectation(thermodynamic_state) / ( + unit.nanometers**3 + ) [t0, g, Neff_max] = detect_equilibration(volume_n) volume_mean = volume_n[t0:].mean() dvolume_mean = volume_n[t0:].std() / np.sqrt(Neff_max) volume_error = volume_mean - volume_expectation nsigma = abs(volume_error) / dvolume_mean - err_msg = ('Volume expectation\n' - 'observed {:10.5f} +- {:10.5f}kT | expected {:10.5f} | ' - 'error {:10.5f} +- {:10.5f} ({:.1f} sigma) | t0 {:5d} | g {:5.1f} | Neff {:8.1f}\n' - '----------------------------------------------------------------------------').format( - volume_mean, dvolume_mean, volume_expectation, volume_error, dvolume_mean, nsigma, t0, g, Neff_max) + err_msg = ( + "Volume expectation\n" + "observed {:10.5f} +- {:10.5f}kT | expected {:10.5f} | " + "error {:10.5f} +- {:10.5f} ({:.1f} sigma) | t0 {:5d} | g {:5.1f} | Neff {:8.1f}\n" + "----------------------------------------------------------------------------" + ).format( + volume_mean, + dvolume_mean, + volume_expectation, + volume_error, + dvolume_mean, + nsigma, + t0, + g, + Neff_max, + ) assert nsigma <= NSIGMA_CUTOFF, err_msg.format() if debug: print(err_msg) elif debug: - print('Skipping volume expectation test.') + print("Skipping volume expectation test.") def test_barostat_move_frequency(): @@ -191,11 +259,14 @@ def test_barostat_move_frequency(): testsystem = test_case[1] if testsystem.system.usesPeriodicBoundaryConditions(): break - assert testsystem.system.usesPeriodicBoundaryConditions(), "Can't find periodic test case!" + assert ( + testsystem.system.usesPeriodicBoundaryConditions() + ), "Can't find periodic test case!" sampler_state = SamplerState(testsystem.positions) - thermodynamic_state = ThermodynamicState(testsystem.system, 298*unit.kelvin, - 1*unit.atmosphere) + thermodynamic_state = ThermodynamicState( + testsystem.system, 298 * unit.kelvin, 1 * unit.atmosphere + ) move = MonteCarloBarostatMove(n_attempts=5) # Test-precondition: the frequency must be different than 1 or it @@ -214,7 +285,9 @@ def test_default_context_cache(): """ # By default an independent local context cache is used move = SequenceMove([LangevinDynamicsMove(n_steps=5), GHMCMove(n_steps=5)]) - context_cache = move._get_context_cache(context_cache_input=None) # get default context cache + context_cache = move._get_context_cache( + context_cache_input=None + ) # get default context cache # Assert the default context_cache is the global one assert context_cache is cache.global_context_cache @@ -260,22 +333,32 @@ def test_context_cache_sequence_apply(): local_cache = cache.ContextCache() move = SequenceMove([LangevinDynamicsMove(n_steps=5), GHMCMove(n_steps=5)]) # Context cache before apply without access - assert local_cache._lru._n_access == 0, f"Expected no access in local context cache." + assert ( + local_cache._lru._n_access == 0 + ), f"Expected no access in local context cache." move.apply(thermodynamic_state, sampler_state, context_cache=local_cache) # Context cache now must have 2 accesses - assert local_cache._lru._n_access == 2, "Expected two accesses in local context cache." + assert ( + local_cache._lru._n_access == 2 + ), "Expected two accesses in local context cache." def test_context_cache_compatibility(): """Tests only one context cache is created and used for compatible moves.""" testsystem = testsystems.AlanineDipeptideImplicit() sampler_state = SamplerState(testsystem.positions) - thermodynamic_state = ThermodynamicState(testsystem.system, 300*unit.kelvin) + thermodynamic_state = ThermodynamicState(testsystem.system, 300 * unit.kelvin) # The ContextCache creates only one context with compatible moves. context_cache = cache.ContextCache(capacity=10, time_to_live=None) - move = SequenceMove([LangevinDynamicsMove(n_steps=1), LangevinDynamicsMove(n_steps=1), - LangevinDynamicsMove(n_steps=1), LangevinDynamicsMove(n_steps=1)]) + move = SequenceMove( + [ + LangevinDynamicsMove(n_steps=1), + LangevinDynamicsMove(n_steps=1), + LangevinDynamicsMove(n_steps=1), + LangevinDynamicsMove(n_steps=1), + ] + ) move.apply(thermodynamic_state, sampler_state, context_cache=context_cache) assert len(context_cache) == 1 @@ -311,7 +394,7 @@ def test_dummy_context_cache(): """Test DummyContextCache works for all platforms.""" testsystem = testsystems.AlanineDipeptideImplicit() sampler_state = SamplerState(testsystem.positions) - thermodynamic_state = ThermodynamicState(testsystem.system, 300*unit.kelvin) + thermodynamic_state = ThermodynamicState(testsystem.system, 300 * unit.kelvin) # DummyContextCache works for all platforms. platforms = utils.get_available_platforms() dummy_cache = cache.DummyContextCache() @@ -334,7 +417,9 @@ def test_mcmc_move_context_cache_shallow_copy(): from openmmtools import multistate platform = get_fastest_platform() - context_cache = cache.ContextCache(capacity=None, time_to_live=None, platform=platform) + context_cache = cache.ContextCache( + capacity=None, time_to_live=None, platform=platform + ) testsystem = testsystems.AlanineDipeptideExplicit() n_replicas = 5 # Number of temperature replicas. T_min = 300.0 * unit.kelvin # Minimum temperature. @@ -347,7 +432,8 @@ def test_mcmc_move_context_cache_shallow_copy(): for i in range(n_replicas) ] thermodynamic_states = [ - ThermodynamicState(system=testsystem.system, temperature=T) for T in temperatures + ThermodynamicState(system=testsystem.system, temperature=T) + for T in temperatures ] move = LangevinSplittingDynamicsMove( timestep=4.0 * unit.femtoseconds, @@ -381,18 +467,18 @@ def test_mcmc_move_context_cache_shallow_copy(): def test_moves_serialization(): """Test serialization of various MCMCMoves.""" # Test cases. - platform = openmm.Platform.getPlatformByName('Reference') + platform = openmm.Platform.getPlatformByName("Reference") context_cache = cache.ContextCache(capacity=1, time_to_live=1) dummy_cache = cache.DummyContextCache(platform=platform) test_cases = [ - IntegratorMove(openmm.VerletIntegrator(1.0*unit.femtosecond), n_steps=10), + IntegratorMove(openmm.VerletIntegrator(1.0 * unit.femtosecond), n_steps=10), LangevinDynamicsMove(), LangevinSplittingDynamicsMove(), GHMCMove(), HMCMove(context_cache=context_cache), MonteCarloBarostatMove(context_cache=dummy_cache), SequenceMove(move_list=[LangevinDynamicsMove(), GHMCMove()]), - WeightedMove(move_set=[(HMCMove(), 0.5), (MonteCarloBarostatMove(), 0.5)]) + WeightedMove(move_set=[(HMCMove(), 0.5), (MonteCarloBarostatMove(), 0.5)]), ] for move in test_cases: original_pickle = pickle.dumps(move) @@ -409,11 +495,11 @@ def test_move_restart(): # We define a Move that counts the times it is attempted. class MyMove(BaseIntegratorMove): def __init__(self, **kwargs): - super(MyMove, self).__init__(n_steps=1, n_restart_attempts=n_restart_attempts, **kwargs) + super().__init__(n_steps=1, n_restart_attempts=n_restart_attempts, **kwargs) self.attempted_count = 0 def _get_integrator(self, thermodynamic_state): - return integrators.GHMCIntegrator(temperature=300*unit.kelvin) + return integrators.GHMCIntegrator(temperature=300 * unit.kelvin) def _before_integration(self, context, thermodynamic_state): self.attempted_count += 1 @@ -429,16 +515,18 @@ def _before_integration(self, context, thermodynamic_state): system.addParticle(39.9 * unit.amu) force.addParticle(0.0, 1.0, 0.0) particle_position = np.array([np.nan, 0.2, 0.2]) - positions = unit.Quantity(np.vstack((testsystem.positions, particle_position)), - unit=testsystem.positions.unit) + positions = unit.Quantity( + np.vstack((testsystem.positions, particle_position)), + unit=testsystem.positions.unit, + ) # Create and run move. An IntegratoMoveError is raised. sampler_state = SamplerState(positions) - thermodynamic_state = ThermodynamicState(system, 300*unit.kelvin) + thermodynamic_state = ThermodynamicState(system, 300 * unit.kelvin) # We use a local context cache with Reference platform since on the # CPU platform CustomIntegrators raises an error with NaN particles. - reference_platform = openmm.Platform.getPlatformByName('Reference') + reference_platform = openmm.Platform.getPlatformByName("Reference") context_cache = cache.ContextCache(platform=reference_platform) move = MyMove(context_cache=context_cache) with nose.tools.assert_raises(IntegratorMoveError) as cm: @@ -449,19 +537,19 @@ def _before_integration(self, context, thermodynamic_state): # Test serialization of the error. with utils.temporary_directory() as tmp_dir: - prefix = os.path.join(tmp_dir, 'prefix') + prefix = os.path.join(tmp_dir, "prefix") cm.exception.serialize_error(prefix) - assert os.path.exists(prefix + '-move.json') - assert os.path.exists(prefix + '-system.xml') - assert os.path.exists(prefix + '-integrator.xml') - assert os.path.exists(prefix + '-state.xml') + assert os.path.exists(prefix + "-move.json") + assert os.path.exists(prefix + "-system.xml") + assert os.path.exists(prefix + "-integrator.xml") + assert os.path.exists(prefix + "-state.xml") def test_metropolized_moves(): """Test Displacement and Rotation moves.""" testsystem = testsystems.AlanineDipeptideVacuum() original_sampler_state = SamplerState(testsystem.positions) - thermodynamic_state = ThermodynamicState(testsystem.system, 300*unit.kelvin) + thermodynamic_state = ThermodynamicState(testsystem.system, 300 * unit.kelvin) all_metropolized_moves = MetropolizedMove.__subclasses__() for move_class in all_metropolized_moves: @@ -489,8 +577,11 @@ def test_metropolized_moves(): old_n_accepted, old_n_proposed = move.n_accepted, move.n_proposed # Check that we were able to generate both an accepted and a rejected move. - assert len(move.atom_subset) != 0, ('Could not generate an accepted and rejected ' - 'move for class {}'.format(move_class.__name__)) + assert len(move.atom_subset) != 0, ( + "Could not generate an accepted and rejected " "move for class {}".format( + move_class.__name__ + ) + ) def test_langevin_splitting_move(): @@ -498,13 +589,14 @@ def test_langevin_splitting_move(): splittings = ["V R O R V", "V R R R O R R R V", "O { V R V } O"] testsystem = testsystems.AlanineDipeptideVacuum() sampler_state = SamplerState(testsystem.positions) - thermodynamic_state = ThermodynamicState(testsystem.system, 300*unit.kelvin) + thermodynamic_state = ThermodynamicState(testsystem.system, 300 * unit.kelvin) for splitting in splittings: move = LangevinSplittingDynamicsMove(splitting=splitting) # Create MCMC sampler sampler = MCMCSampler(thermodynamic_state, sampler_state, move=move) sampler.run(1) + def test_langevin_dynamics_move_constraint_tolerance(): """Test constraint tolerance is properly set in LangevinDynamicsMove integrator.""" testsystem = testsystems.AlanineDipeptideVacuum() @@ -513,32 +605,40 @@ def test_langevin_dynamics_move_constraint_tolerance(): default_move = LangevinDynamicsMove() default_constraint_tolerance = 1e-8 move_tolerance = default_move.constraint_tolerance - assert move_tolerance == default_constraint_tolerance, f"LangevinDynamicsMove tolerance, {move_tolerance}, is" \ - f" not the same as the expected default tolerance," \ - f" {default_constraint_tolerance}." + assert move_tolerance == default_constraint_tolerance, ( + f"LangevinDynamicsMove tolerance, {move_tolerance}, is" + f" not the same as the expected default tolerance," + f" {default_constraint_tolerance}." + ) default_integrator = default_move._get_integrator(thermodynamic_state) default_integrator_tolerance = default_integrator.getConstraintTolerance() - assert default_integrator_tolerance == default_constraint_tolerance, f"LangevinDynamicsMove integrator tolerance," \ - f" {default_integrator_tolerance}, is not " \ - f"the same as the expected default " \ - f"tolerance, {default_constraint_tolerance}." + assert default_integrator_tolerance == default_constraint_tolerance, ( + f"LangevinDynamicsMove integrator tolerance," + f" {default_integrator_tolerance}, is not " + f"the same as the expected default " + f"tolerance, {default_constraint_tolerance}." + ) # Now we change the tolerance in initializer and check new_constraint_tolerance = 1e-5 new_move = LangevinDynamicsMove(constraint_tolerance=new_constraint_tolerance) new_integrator = new_move._get_integrator(thermodynamic_state) new_integrator_tolerance = new_integrator.getConstraintTolerance() - assert new_integrator_tolerance == new_constraint_tolerance, f"LangevinDynamicsMove integrator tolerance," \ - f" {new_integrator_tolerance}, is not the same as" \ - f" the specified value of {new_constraint_tolerance}." + assert new_integrator_tolerance == new_constraint_tolerance, ( + f"LangevinDynamicsMove integrator tolerance," + f" {new_integrator_tolerance}, is not the same as" + f" the specified value of {new_constraint_tolerance}." + ) # Test by changing public attribute constraint_tolerance = 1e-7 move = LangevinDynamicsMove() # create default move move.constraint_tolerance = constraint_tolerance # change the public attribute integrator = move._get_integrator(thermodynamic_state) integrator_tolerance = integrator.getConstraintTolerance() - assert integrator_tolerance == constraint_tolerance, f"LangevinDynamicsMove integrator tolerance," \ - f" {integrator_tolerance}, is not the same as" \ - f" the specified value of {constraint_tolerance}." + assert integrator_tolerance == constraint_tolerance, ( + f"LangevinDynamicsMove integrator tolerance," + f" {integrator_tolerance}, is not the same as" + f" the specified value of {constraint_tolerance}." + ) # ============================================================================= diff --git a/openmmtools/tests/test_mixing.py b/openmmtools/tests/test_mixing.py index 0c65301dd..3cdb081a6 100644 --- a/openmmtools/tests/test_mixing.py +++ b/openmmtools/tests/test_mixing.py @@ -1,5 +1,3 @@ - - """ Test Cython and weave mixing code. """ @@ -9,6 +7,7 @@ import numpy as np import scipy.stats as stats + def mix_replicas(n_swaps=100, n_states=16, u_kl=None, nswap_attempts=None): """ Utility function to generate replicas and call the mixing function a certain number of times @@ -33,12 +32,15 @@ def mix_replicas(n_swaps=100, n_states=16, u_kl=None, nswap_attempts=None): replica_states = np.array(range(n_states), np.int64) if nswap_attempts is None: nswap_attempts = n_states**4 - Nij_proposed = np.zeros([n_states,n_states], dtype=np.int64) - Nij_accepted = np.zeros([n_states,n_states], dtype=np.int64) + Nij_proposed = np.zeros([n_states, n_states], dtype=np.int64) + Nij_accepted = np.zeros([n_states, n_states], dtype=np.int64) permutation_list = [] from openmmtools.multistate import ReplicaExchangeSampler + for i in range(n_swaps): - ReplicaExchangeSampler._mix_all_replicas_numba(nswap_attempts, n_states, replica_states, u_kl, Nij_proposed, Nij_accepted) + ReplicaExchangeSampler._mix_all_replicas_numba( + nswap_attempts, n_states, replica_states, u_kl, Nij_proposed, Nij_accepted + ) permutation_list.append(copy.deepcopy(replica_states)) permutation_list_np = np.array(permutation_list, dtype=np.int64) return permutation_list_np @@ -75,7 +77,8 @@ def test_even_mixing(verbose=True): """ Testing Cython mixing code with 1000 swap attempts and uniform 0 energies """ - if verbose: print("Testing Cython mixing code with uniform zero energies") + if verbose: + print("Testing Cython mixing code with uniform zero energies") n_swaps = 1000 n_states = 16 corrected_threshold = 0.001 / n_states diff --git a/openmmtools/tests/test_platforms.py b/openmmtools/tests/test_platforms.py index 33f017864..6d9807a43 100644 --- a/openmmtools/tests/test_platforms.py +++ b/openmmtools/tests/test_platforms.py @@ -3,8 +3,8 @@ import os, os.path import logging + def test_openmm_platforms(): - """Testing comparison of platforms. - """ + """Testing comparison of platforms.""" from openmmtools.scripts import test_openmm_platforms - #test_openmm_platforms.main() + # test_openmm_platforms.main() diff --git a/openmmtools/tests/test_sampling.py b/openmmtools/tests/test_sampling.py index 8c2495670..2dd16c723 100644 --- a/openmmtools/tests/test_sampling.py +++ b/openmmtools/tests/test_sampling.py @@ -28,6 +28,7 @@ import yaml from nose.plugins.attrib import attr from nose.tools import assert_raises + try: import openmm from openmm import unit @@ -59,25 +60,34 @@ # SUBROUTINES # ============================================================================== + def check_thermodynamic_states_equality(original_states, restored_states): """Check that the thermodynamic states are equivalent.""" - assert len(original_states) == len(restored_states), '{}, {}'.format( - len(original_states), len(restored_states)) + assert len(original_states) == len(restored_states), "{}, {}".format( + len(original_states), len(restored_states) + ) for original_state, restored_state in zip(original_states, restored_states): - assert original_state._standard_system_hash == restored_state._standard_system_hash + assert ( + original_state._standard_system_hash == restored_state._standard_system_hash + ) assert original_state.temperature == restored_state.temperature assert original_state.pressure == restored_state.pressure if isinstance(original_state, mmtools.states.CompoundThermodynamicState): assert original_state.lambda_sterics == restored_state.lambda_sterics - assert original_state.lambda_electrostatics == restored_state.lambda_electrostatics + assert ( + original_state.lambda_electrostatics + == restored_state.lambda_electrostatics + ) + # ============================================================================== # Harmonic oscillator free energy test # ============================================================================== -class TestHarmonicOscillatorsMultiStateSampler(object): + +class TestHarmonicOscillatorsMultiStateSampler: """Test multistate sampler can detect equilibration and compute free energies of harmonic oscillator""" # ------------------------------------ @@ -85,8 +95,8 @@ class TestHarmonicOscillatorsMultiStateSampler(object): # ------------------------------------ N_SAMPLERS = 3 - N_STATES = 5 # number of thermodynamic states to sample; two additional unsampled states will be added - N_ITERATIONS = 1000 # number of iterations + N_STATES = 5 # number of thermodynamic states to sample; two additional unsampled states will be added + N_ITERATIONS = 1000 # number of iterations SAMPLER = MultiStateSampler ANALYZER = MultiStateSamplerAnalyzer @@ -94,7 +104,8 @@ class TestHarmonicOscillatorsMultiStateSampler(object): def setup_class(cls): # Configure the global context cache to use the Reference platform from openmmtools import cache - platform = openmm.Platform.getPlatformByName('Reference') + + platform = openmm.Platform.getPlatformByName("Reference") cls.old_global_context_cache = cache.global_context_cache cache.global_context_cache = cache.ContextCache(platform=platform) @@ -103,82 +114,129 @@ def setup_class(cls): # Translate the sampler states to be different one from each other. n_particles = 1 - positions = unit.Quantity(np.zeros([n_particles,3]), unit.angstroms) + positions = unit.Quantity(np.zeros([n_particles, 3]), unit.angstroms) cls.sampler_states = [ mmtools.states.SamplerState(positions=positions) - for sampler_index in range(cls.N_SAMPLERS)] + for sampler_index in range(cls.N_SAMPLERS) + ] # Generate list of thermodynamic states and analytical free energies # This list includes both sampled and two unsampled states thermodynamic_states = list() temperature = 300 * unit.kelvin - f_i = np.zeros([cls.N_STATES+2]) # f_i[state_index] is the dimensionless free energy of state `state_index` + f_i = np.zeros( + [cls.N_STATES + 2] + ) # f_i[state_index] is the dimensionless free energy of state `state_index` for state_index in range(cls.N_STATES + 2): - sigma = (1.0 + 0.2 * state_index) * unit.angstroms # compute reasonable standard deviation with good overlap - kT = kB * temperature # compute thermal energy - K = kT / sigma**2 # compute spring constant + sigma = ( + (1.0 + 0.2 * state_index) * unit.angstroms + ) # compute reasonable standard deviation with good overlap + kT = kB * temperature # compute thermal energy + K = kT / sigma**2 # compute spring constant testsystem = testsystems.HarmonicOscillator(K=K, mass=mass) - thermodynamic_state = mmtools.states.ThermodynamicState(testsystem.system, temperature) + thermodynamic_state = mmtools.states.ThermodynamicState( + testsystem.system, temperature + ) thermodynamic_states.append(thermodynamic_state) # Store analytical reduced free energy - f_i[state_index] = - np.log(2 * np.pi * (sigma / unit.angstroms)**2) * (3.0/2.0) + f_i[state_index] = -np.log(2 * np.pi * (sigma / unit.angstroms) ** 2) * ( + 3.0 / 2.0 + ) # delta_f_ij_analytical[i,j] = f_i_analytical[j] - f_i_analytical[i] cls.f_i_analytical = f_i - cls.delta_f_ij_analytical = f_i - f_i[:,np.newaxis] + cls.delta_f_ij_analytical = f_i - f_i[:, np.newaxis] # Define sampled and unsampled states. cls.nstates = cls.N_STATES - cls.unsampled_states = [thermodynamic_states[0], thermodynamic_states[-1]] # first and last - cls.thermodynamic_states = thermodynamic_states[1:-1] # intermediate states + cls.unsampled_states = [ + thermodynamic_states[0], + thermodynamic_states[-1], + ] # first and last + cls.thermodynamic_states = thermodynamic_states[1:-1] # intermediate states @classmethod def teardown_class(cls): # Restore global context cache from openmmtools import cache + cache.global_context_cache = cls.old_global_context_cache def run(self, include_unsampled_states=False): # Create and configure simulation object - move = mmtools.mcmc.MCDisplacementMove(displacement_sigma=1.0*unit.angstroms) - simulation = self.SAMPLER(mcmc_moves=move, number_of_iterations=self.N_ITERATIONS, - online_analysis_interval=self.N_ITERATIONS) + move = mmtools.mcmc.MCDisplacementMove(displacement_sigma=1.0 * unit.angstroms) + simulation = self.SAMPLER( + mcmc_moves=move, + number_of_iterations=self.N_ITERATIONS, + online_analysis_interval=self.N_ITERATIONS, + ) # Define file for temporary storage. with temporary_directory() as tmp_dir: - storage = os.path.join(tmp_dir, 'test_storage.nc') - reporter = MultiStateReporter(storage, checkpoint_interval=self.N_ITERATIONS) + storage = os.path.join(tmp_dir, "test_storage.nc") + reporter = MultiStateReporter( + storage, checkpoint_interval=self.N_ITERATIONS + ) if include_unsampled_states: - simulation.create(self.thermodynamic_states, self.sampler_states, reporter, - unsampled_thermodynamic_states=self.unsampled_states) + simulation.create( + self.thermodynamic_states, + self.sampler_states, + reporter, + unsampled_thermodynamic_states=self.unsampled_states, + ) else: - simulation.create(self.thermodynamic_states, self.sampler_states, reporter) + simulation.create( + self.thermodynamic_states, self.sampler_states, reporter + ) # Run simulation without debug logging import logging + logger = logging.getLogger() logger.setLevel(logging.CRITICAL) simulation.run() # Create Analyzer specfiying statistical_inefficiency without n_equilibration_iterations and # check that it throws an exception - assert_raises(Exception, self.ANALYZER, reporter, statistical_inefficiency=10) + assert_raises( + Exception, self.ANALYZER, reporter, statistical_inefficiency=10 + ) # Create Analyzer specifying n_equilibration_iterations=10 without statistical_inefficiency and # check that equilibration detection returns n_equilibration_iterations > 10 analyzer = self.ANALYZER(reporter, n_equilibration_iterations=10) - sampled_energy_matrix, unsampled_energy_matrix, neighborhoods, replicas_state_indices = list(analyzer._read_energies(truncate_max_n_iterations=True)) - n_equilibration_iterations, statistical_inefficiency, n_effective_max = analyzer._get_equilibration_data(sampled_energy_matrix, neighborhoods, replicas_state_indices) + ( + sampled_energy_matrix, + unsampled_energy_matrix, + neighborhoods, + replicas_state_indices, + ) = list(analyzer._read_energies(truncate_max_n_iterations=True)) + n_equilibration_iterations, statistical_inefficiency, n_effective_max = ( + analyzer._get_equilibration_data( + sampled_energy_matrix, neighborhoods, replicas_state_indices + ) + ) assert n_equilibration_iterations > 10 del analyzer # Create Analyzer specifying both n_equilibration_iterations and statistical_inefficiency # check that it returns the user specified values without running the equilibration detection - analyzer = self.ANALYZER(reporter, n_equilibration_iterations=10, statistical_inefficiency=3) - sampled_energy_matrix, unsampled_energy_matrix, neighborhoods, replicas_state_indices = list(analyzer._read_energies(truncate_max_n_iterations=True)) - n_equilibration_iterations, statistical_inefficiency, n_effective_max = analyzer._get_equilibration_data(sampled_energy_matrix, neighborhoods, replicas_state_indices) + analyzer = self.ANALYZER( + reporter, n_equilibration_iterations=10, statistical_inefficiency=3 + ) + ( + sampled_energy_matrix, + unsampled_energy_matrix, + neighborhoods, + replicas_state_indices, + ) = list(analyzer._read_energies(truncate_max_n_iterations=True)) + n_equilibration_iterations, statistical_inefficiency, n_effective_max = ( + analyzer._get_equilibration_data( + sampled_energy_matrix, neighborhoods, replicas_state_indices + ) + ) assert n_equilibration_iterations == 10 assert statistical_inefficiency == 3 del analyzer @@ -187,8 +245,17 @@ def run(self, include_unsampled_states=False): analyzer = self.ANALYZER(reporter) # Check that default analyzer yields n_equilibration_iterations > 1 - sampled_energy_matrix, unsampled_energy_matrix, neighborhoods, replicas_state_indices = list(analyzer._read_energies(truncate_max_n_iterations=True)) - n_equilibration_iterations, statistical_inefficiency, n_effective_max = analyzer._get_equilibration_data(sampled_energy_matrix, neighborhoods, replicas_state_indices) + ( + sampled_energy_matrix, + unsampled_energy_matrix, + neighborhoods, + replicas_state_indices, + ) = list(analyzer._read_energies(truncate_max_n_iterations=True)) + n_equilibration_iterations, statistical_inefficiency, n_effective_max = ( + analyzer._get_equilibration_data( + sampled_energy_matrix, neighborhoods, replicas_state_indices + ) + ) assert n_equilibration_iterations > 1 # Check if free energies have the right shape and deviations exceed tolerance @@ -196,20 +263,27 @@ def run(self, include_unsampled_states=False): nstates, _ = delta_f_ij.shape if include_unsampled_states: - nstates_expected = self.N_STATES+2 # We expect N_STATES plus two additional states - delta_f_ij_analytical = self.delta_f_ij_analytical # Use the whole matrix + nstates_expected = ( + self.N_STATES + 2 + ) # We expect N_STATES plus two additional states + delta_f_ij_analytical = ( + self.delta_f_ij_analytical + ) # Use the whole matrix else: - nstates_expected = self.N_STATES # We expect only N_STATES - delta_f_ij_analytical = self.delta_f_ij_analytical[1:-1,1:-1] # Use only the intermediate, sampled states + nstates_expected = self.N_STATES # We expect only N_STATES + delta_f_ij_analytical = self.delta_f_ij_analytical[ + 1:-1, 1:-1 + ] # Use only the intermediate, sampled states - assert nstates == nstates_expected, \ - f'analyzer.get_free_energy() returned {delta_f_ij.shape} but expected {nstates_expected,nstates_expected}' + assert ( + nstates == nstates_expected + ), f"analyzer.get_free_energy() returned {delta_f_ij.shape} but expected {nstates_expected,nstates_expected}" error = np.abs(delta_f_ij - delta_f_ij_analytical) indices = np.where(delta_f_ij_stderr > 0.0) - nsigma = np.zeros([nstates,nstates], np.float32) + nsigma = np.zeros([nstates, nstates], np.float32) nsigma[indices] = error[indices] / delta_f_ij_stderr[indices] - MAX_SIGMA = 6.0 # maximum allowed number of standard errors + MAX_SIGMA = 6.0 # maximum allowed number of standard errors if np.any(nsigma > MAX_SIGMA): np.set_printoptions(precision=3) print("delta_f_ij") @@ -222,7 +296,10 @@ def run(self, include_unsampled_states=False): print(delta_f_ij_stderr) print("nsigma") print(nsigma) - raise Exception("Dimensionless free energy difference exceeds MAX_SIGMA of %.1f" % MAX_SIGMA) + raise Exception( + "Dimensionless free energy difference exceeds MAX_SIGMA of %.1f" + % MAX_SIGMA + ) # Clean up. del simulation @@ -235,7 +312,10 @@ def test_without_unsampled_states(self): """Test multistate sampler on a harmonic oscillator without unsampled endstates""" self.run(include_unsampled_states=False) -class TestHarmonicOscillatorsReplicaExchangeSampler(TestHarmonicOscillatorsMultiStateSampler): + +class TestHarmonicOscillatorsReplicaExchangeSampler( + TestHarmonicOscillatorsMultiStateSampler +): """Test replica-exchange sampler can compute free energies of harmonic oscillator""" # ------------------------------------ @@ -247,6 +327,7 @@ class TestHarmonicOscillatorsReplicaExchangeSampler(TestHarmonicOscillatorsMulti SAMPLER = ReplicaExchangeSampler ANALYZER = ReplicaExchangeAnalyzer + class TestHarmonicOscillatorsSAMSSampler(TestHarmonicOscillatorsMultiStateSampler): """Test SAMS sampler can compute free energies of harmonic oscillator""" @@ -256,68 +337,94 @@ class TestHarmonicOscillatorsSAMSSampler(TestHarmonicOscillatorsMultiStateSample N_SAMPLERS = 1 N_STATES = 5 - N_ITERATIONS = 1000 * N_STATES # number of iterations + N_ITERATIONS = 1000 * N_STATES # number of iterations SAMPLER = SAMSSampler ANALYZER = SAMSAnalyzer + # ============================================================================== # TEST REPORTER # ============================================================================== -class TestReporter(object): + +class TestReporter: """Test suite for Reporter class.""" @staticmethod @contextlib.contextmanager - def temporary_reporter(checkpoint_interval=1, checkpoint_storage=None, analysis_particle_indices=()): + def temporary_reporter( + checkpoint_interval=1, checkpoint_storage=None, analysis_particle_indices=() + ): """Create and initialize a reporter in a temporary directory.""" with temporary_directory() as tmp_dir_path: - storage_file = os.path.join(tmp_dir_path, 'temp_dir/test_storage.nc') + storage_file = os.path.join(tmp_dir_path, "temp_dir/test_storage.nc") assert not os.path.isfile(storage_file) - reporter = MultiStateReporter(storage=storage_file, open_mode='w', - checkpoint_interval=checkpoint_interval, - checkpoint_storage=checkpoint_storage, - analysis_particle_indices=analysis_particle_indices) + reporter = MultiStateReporter( + storage=storage_file, + open_mode="w", + checkpoint_interval=checkpoint_interval, + checkpoint_storage=checkpoint_storage, + analysis_particle_indices=analysis_particle_indices, + ) assert reporter.storage_exists(skip_size=True) yield reporter def test_store_thermodynamic_states(self): """Check correct storage of thermodynamic states.""" # Thermodynamic states. - temperature = 300*unit.kelvin + temperature = 300 * unit.kelvin alanine_system = testsystems.AlanineDipeptideImplicit().system alanine_explicit_system = testsystems.AlanineDipeptideExplicit().system - thermodynamic_state_nvt = mmtools.states.ThermodynamicState(alanine_system, temperature) - thermodynamic_state_nvt_compatible = mmtools.states.ThermodynamicState(alanine_system, - temperature + 20*unit.kelvin) - thermodynamic_state_npt = mmtools.states.ThermodynamicState(alanine_explicit_system, - temperature, 1.0*unit.atmosphere) + thermodynamic_state_nvt = mmtools.states.ThermodynamicState( + alanine_system, temperature + ) + thermodynamic_state_nvt_compatible = mmtools.states.ThermodynamicState( + alanine_system, temperature + 20 * unit.kelvin + ) + thermodynamic_state_npt = mmtools.states.ThermodynamicState( + alanine_explicit_system, temperature, 1.0 * unit.atmosphere + ) # Compound states. factory = mmtools.alchemy.AbsoluteAlchemicalFactory() alchemical_region = mmtools.alchemy.AlchemicalRegion(alchemical_atoms=range(22)) - alanine_alchemical = factory.create_alchemical_system(alanine_system, alchemical_region) - alchemical_state_interacting = mmtools.alchemy.AlchemicalState.from_system(alanine_alchemical) + alanine_alchemical = factory.create_alchemical_system( + alanine_system, alchemical_region + ) + alchemical_state_interacting = mmtools.alchemy.AlchemicalState.from_system( + alanine_alchemical + ) alchemical_state_noninteracting = copy.deepcopy(alchemical_state_interacting) alchemical_state_noninteracting.set_alchemical_parameters(0.0) compound_state_interacting = mmtools.states.CompoundThermodynamicState( - thermodynamic_state=mmtools.states.ThermodynamicState(alanine_alchemical, temperature), - composable_states=[alchemical_state_interacting] + thermodynamic_state=mmtools.states.ThermodynamicState( + alanine_alchemical, temperature + ), + composable_states=[alchemical_state_interacting], ) compound_state_noninteracting = mmtools.states.CompoundThermodynamicState( - thermodynamic_state=mmtools.states.ThermodynamicState(alanine_alchemical, temperature), - composable_states=[alchemical_state_noninteracting] + thermodynamic_state=mmtools.states.ThermodynamicState( + alanine_alchemical, temperature + ), + composable_states=[alchemical_state_noninteracting], ) - thermodynamic_states = [thermodynamic_state_nvt, thermodynamic_state_nvt_compatible, - thermodynamic_state_npt, compound_state_interacting, - compound_state_noninteracting] + thermodynamic_states = [ + thermodynamic_state_nvt, + thermodynamic_state_nvt_compatible, + thermodynamic_state_npt, + compound_state_interacting, + compound_state_noninteracting, + ] # Unsampled thermodynamic states. toluene_system = testsystems.TolueneVacuum().system toluene_state = mmtools.states.ThermodynamicState(toluene_system, temperature) - unsampled_states = [copy.deepcopy(toluene_state), copy.deepcopy(toluene_state), - copy.deepcopy(compound_state_interacting)] + unsampled_states = [ + copy.deepcopy(toluene_state), + copy.deepcopy(toluene_state), + copy.deepcopy(compound_state_interacting), + ] with self.temporary_reporter() as reporter: # Check that after writing and reading, states are identical. @@ -327,60 +434,76 @@ def test_store_thermodynamic_states(self): check_thermodynamic_states_equality(unsampled_states, restored_unsampled) # The latest writer only stores one full serialization per compatible state. - ncgrp_states = reporter._storage_analysis.groups['thermodynamic_states'] - ncgrp_unsampled = reporter._storage_analysis.groups['unsampled_states'] + ncgrp_states = reporter._storage_analysis.groups["thermodynamic_states"] + ncgrp_unsampled = reporter._storage_analysis.groups["unsampled_states"] # Load representation of the states on the disk. There # should be only one full serialization per compatible state. def decompact_state_variable(variable): - if variable.dtype == 'S1': + if variable.dtype == "S1": # Handle variables stored in fixed_dimensions data_chars = variable[:] data_str = data_chars.tostring().decode() else: data_str = str(variable[0]) return data_str + states_serialized = [] for state_id in range(len(thermodynamic_states)): - state_str = decompact_state_variable(ncgrp_states.variables['state' + str(state_id)]) + state_str = decompact_state_variable( + ncgrp_states.variables["state" + str(state_id)] + ) state_dict = yaml.load(state_str, Loader=_DictYamlLoader) states_serialized.append(state_dict) unsampled_serialized = [] for state_id in range(len(unsampled_states)): - unsampled_str = decompact_state_variable(ncgrp_unsampled.variables['state' + str(state_id)]) + unsampled_str = decompact_state_variable( + ncgrp_unsampled.variables["state" + str(state_id)] + ) unsampled_dict = yaml.load(unsampled_str, Loader=_DictYamlLoader) unsampled_serialized.append(unsampled_dict) # Two of the three ThermodynamicStates are compatible. - assert 'standard_system' in states_serialized[0] - assert 'standard_system' not in states_serialized[1] - state_compatible_to_1 = states_serialized[1]['_Reporter__compatible_state'] - assert state_compatible_to_1 == 'thermodynamic_states/0' - assert 'standard_system' in states_serialized[2] + assert "standard_system" in states_serialized[0] + assert "standard_system" not in states_serialized[1] + state_compatible_to_1 = states_serialized[1]["_Reporter__compatible_state"] + assert state_compatible_to_1 == "thermodynamic_states/0" + assert "standard_system" in states_serialized[2] # The two CompoundThermodynamicStates are compatible. - assert 'standard_system' in states_serialized[3]['thermodynamic_state'] - thermodynamic_state_4 = states_serialized[4]['thermodynamic_state'] - assert thermodynamic_state_4['_Reporter__compatible_state'] == 'thermodynamic_states/3' + assert "standard_system" in states_serialized[3]["thermodynamic_state"] + thermodynamic_state_4 = states_serialized[4]["thermodynamic_state"] + assert ( + thermodynamic_state_4["_Reporter__compatible_state"] + == "thermodynamic_states/3" + ) # The first two unsampled states are incompatible with everything else # but compatible to each other, while the third unsampled state is # compatible with the alchemical states. - assert 'standard_system' in unsampled_serialized[0] - state_compatible_to_1 = unsampled_serialized[1]['_Reporter__compatible_state'] - assert state_compatible_to_1 == 'unsampled_states/0' - thermodynamic_state_2 = unsampled_serialized[2]['thermodynamic_state'] - assert thermodynamic_state_2['_Reporter__compatible_state'] == 'thermodynamic_states/3' + assert "standard_system" in unsampled_serialized[0] + state_compatible_to_1 = unsampled_serialized[1][ + "_Reporter__compatible_state" + ] + assert state_compatible_to_1 == "unsampled_states/0" + thermodynamic_state_2 = unsampled_serialized[2]["thermodynamic_state"] + assert ( + thermodynamic_state_2["_Reporter__compatible_state"] + == "thermodynamic_states/3" + ) def test_write_sampler_states(self): """Check correct storage of sampler states.""" analysis_particles = (1, 2) - with self.temporary_reporter(analysis_particle_indices=analysis_particles, checkpoint_interval=2) as reporter: + with self.temporary_reporter( + analysis_particle_indices=analysis_particles, checkpoint_interval=2 + ) as reporter: # Create sampler states. alanine_test = testsystems.AlanineDipeptideVacuum() positions = alanine_test.positions - sampler_states = [mmtools.states.SamplerState(positions=positions) - for _ in range(2)] + sampler_states = [ + mmtools.states.SamplerState(positions=positions) for _ in range(2) + ] # Check that after writing and reading, states are identical. for iteration in range(3): @@ -390,10 +513,17 @@ def test_write_sampler_states(self): for state, restored_state in zip(sampler_states, restored_sampler_states): assert np.allclose(state.positions, restored_state.positions) # By default stored velocities are zeros if not present in origin sampler_state - assert np.allclose(np.zeros(state.positions.shape), restored_state.velocities) - assert np.allclose(state.box_vectors / unit.nanometer, restored_state.box_vectors / unit.nanometer) + assert np.allclose( + np.zeros(state.positions.shape), restored_state.velocities + ) + assert np.allclose( + state.box_vectors / unit.nanometer, + restored_state.box_vectors / unit.nanometer, + ) # Check that the analysis particles are written off checkpoint whereas full trajectory is not - restored_analysis_states = reporter.read_sampler_states(iteration=1, analysis_particles_only=True) + restored_analysis_states = reporter.read_sampler_states( + iteration=1, analysis_particles_only=True + ) restored_checkpoint_states = reporter.read_sampler_states(iteration=1) assert type(restored_analysis_states) is list for state in restored_analysis_states: @@ -401,18 +531,30 @@ def test_write_sampler_states(self): assert state.velocities.shape == (len(analysis_particles), 3) assert restored_checkpoint_states is None # Check that the analysis particles are written separate from the checkpoint particles - restored_analysis_states = reporter.read_sampler_states(iteration=2, analysis_particles_only=True) + restored_analysis_states = reporter.read_sampler_states( + iteration=2, analysis_particles_only=True + ) restored_checkpoint_states = reporter.read_sampler_states(iteration=2) assert len(restored_analysis_states) == len(restored_checkpoint_states) - for analysis_state, checkpoint_state in zip(restored_analysis_states, restored_checkpoint_states): + for analysis_state, checkpoint_state in zip( + restored_analysis_states, restored_checkpoint_states + ): # This assert is multiple purpose: Positions are identical; Velocities are indetical and zeros # (since unspecified); Analysis shape is correct # Will raise a ValueError for np.allclose(x,y) if x.shape != y.shape # Will raise AssertionError if the values are not allclose - assert np.allclose(analysis_state.positions, checkpoint_state.positions[analysis_particles, :]) - assert np.allclose(analysis_state.velocities, checkpoint_state.velocities[analysis_particles, :]) - assert np.allclose(analysis_state.box_vectors / unit.nanometer, - checkpoint_state.box_vectors / unit.nanometer) + assert np.allclose( + analysis_state.positions, + checkpoint_state.positions[analysis_particles, :], + ) + assert np.allclose( + analysis_state.velocities, + checkpoint_state.velocities[analysis_particles, :], + ) + assert np.allclose( + analysis_state.box_vectors / unit.nanometer, + checkpoint_state.box_vectors / unit.nanometer, + ) def test_analysis_particle_mismatch(self): """Test that previously stored analysis particles is higher priority.""" @@ -422,28 +564,49 @@ def test_analysis_particle_mismatch(self): # Does not use the temp reporter since we close and reopen reporter a few times with temporary_directory() as tmp_dir_path: # Test that starting with a blank analysis cannot be overwritten - blank_file = os.path.join(tmp_dir_path, 'temp_dir/blank_analysis.nc') - reporter = MultiStateReporter(storage=blank_file, open_mode='w', - analysis_particle_indices=blank_analysis_particles) + blank_file = os.path.join(tmp_dir_path, "temp_dir/blank_analysis.nc") + reporter = MultiStateReporter( + storage=blank_file, + open_mode="w", + analysis_particle_indices=blank_analysis_particles, + ) reporter.close() del reporter - new_blank_reporter = MultiStateReporter(storage=blank_file, open_mode='r', - analysis_particle_indices=set1_analysis_particles) - assert new_blank_reporter.analysis_particle_indices == blank_analysis_particles + new_blank_reporter = MultiStateReporter( + storage=blank_file, + open_mode="r", + analysis_particle_indices=set1_analysis_particles, + ) + assert ( + new_blank_reporter.analysis_particle_indices == blank_analysis_particles + ) del new_blank_reporter # Test that starting from an initial set of particles and passing in a blank does not overwrite - set1_file = os.path.join(tmp_dir_path, 'temp_dir/set1_analysis.nc') - set1_reporter = MultiStateReporter(storage=set1_file, open_mode='w', - analysis_particle_indices=set1_analysis_particles) + set1_file = os.path.join(tmp_dir_path, "temp_dir/set1_analysis.nc") + set1_reporter = MultiStateReporter( + storage=set1_file, + open_mode="w", + analysis_particle_indices=set1_analysis_particles, + ) set1_reporter.close() # Don't delete, we'll need it for another test - new_set1_reporter = MultiStateReporter(storage=set1_file, open_mode='r', - analysis_particle_indices=blank_analysis_particles) - assert new_set1_reporter.analysis_particle_indices == set1_analysis_particles + new_set1_reporter = MultiStateReporter( + storage=set1_file, + open_mode="r", + analysis_particle_indices=blank_analysis_particles, + ) + assert ( + new_set1_reporter.analysis_particle_indices == set1_analysis_particles + ) del new_set1_reporter # Test that passing in a different set than the initial returns the initial set - new2_set1_reporter = MultiStateReporter(storage=set1_file, open_mode='r', - analysis_particle_indices=set2_analysis_particles) - assert new2_set1_reporter.analysis_particle_indices == set1_analysis_particles + new2_set1_reporter = MultiStateReporter( + storage=set1_file, + open_mode="r", + analysis_particle_indices=set2_analysis_particles, + ) + assert ( + new2_set1_reporter.analysis_particle_indices == set1_analysis_particles + ) def test_store_replica_thermodynamic_states(self): """Check storage of replica thermodynamic states indices.""" @@ -451,16 +614,20 @@ def test_store_replica_thermodynamic_states(self): for i, replica_states in enumerate([[2, 1, 0, 3], np.array([3, 1, 0, 2])]): reporter.write_replica_thermodynamic_states(replica_states, iteration=i) reporter.write_last_iteration(i) - restored_replica_states = reporter.read_replica_thermodynamic_states(iteration=i) + restored_replica_states = reporter.read_replica_thermodynamic_states( + iteration=i + ) assert np.all(replica_states == restored_replica_states) def test_store_mcmc_moves(self): """Check storage of MCMC moves.""" - sequence_move = mmtools.mcmc.SequenceMove(move_list=[mmtools.mcmc.LangevinDynamicsMove(), - mmtools.mcmc.GHMCMove()], - context_cache=mmtools.cache.ContextCache(capacity=1)) - integrator_move = mmtools.mcmc.IntegratorMove(openmm.VerletIntegrator(1.0*unit.femtosecond), - n_steps=100) + sequence_move = mmtools.mcmc.SequenceMove( + move_list=[mmtools.mcmc.LangevinDynamicsMove(), mmtools.mcmc.GHMCMove()], + context_cache=mmtools.cache.ContextCache(capacity=1), + ) + integrator_move = mmtools.mcmc.IntegratorMove( + openmm.VerletIntegrator(1.0 * unit.femtosecond), n_steps=100 + ) mcmc_moves = [sequence_move, integrator_move] with self.temporary_reporter() as reporter: reporter.write_mcmc_moves(mcmc_moves) @@ -473,24 +640,25 @@ def test_store_mcmc_moves(self): def test_store_energies(self): """Check storage of energies.""" - energy_thermodynamic_states = np.array( - [[0, 2, 3], - [1, 2, 0], - [1, 2, 3]]) - energy_neighborhoods = np.array( - [[0, 1, 1], - [1, 1, 0], - [1, 1, 3]] - ) - energy_unsampled_states = np.array( - [[1, 2], - [2, 3.0], - [3, 9.0]]) + energy_thermodynamic_states = np.array([[0, 2, 3], [1, 2, 0], [1, 2, 3]]) + energy_neighborhoods = np.array([[0, 1, 1], [1, 1, 0], [1, 1, 3]]) + energy_unsampled_states = np.array([[1, 2], [2, 3.0], [3, 9.0]]) with self.temporary_reporter() as reporter: - reporter.write_energies(energy_thermodynamic_states, energy_neighborhoods, energy_unsampled_states, iteration=0) - restored_energy_thermodynamic_states, restored_energy_neighborhoods, restored_energy_unsampled_states = reporter.read_energies(iteration=0) - assert np.all(energy_thermodynamic_states == restored_energy_thermodynamic_states) + reporter.write_energies( + energy_thermodynamic_states, + energy_neighborhoods, + energy_unsampled_states, + iteration=0, + ) + ( + restored_energy_thermodynamic_states, + restored_energy_neighborhoods, + restored_energy_unsampled_states, + ) = reporter.read_energies(iteration=0) + assert np.all( + energy_thermodynamic_states == restored_energy_thermodynamic_states + ) assert np.all(energy_neighborhoods == restored_energy_neighborhoods) assert np.all(energy_unsampled_states == restored_energy_unsampled_states) @@ -498,11 +666,11 @@ def test_ensure_dimension_exists(self): """Test ensuring that a dimension exists works as expected.""" with self.temporary_reporter() as reporter: # These should work fine - reporter._ensure_dimension_exists('dim1', 0) - reporter._ensure_dimension_exists('dim2', 1) + reporter._ensure_dimension_exists("dim1", 0) + reporter._ensure_dimension_exists("dim2", 1) # These should raise an exception - assert_raises(ValueError, reporter._ensure_dimension_exists, 'dim1', 1) - assert_raises(ValueError, reporter._ensure_dimension_exists, 'dim2', 2) + assert_raises(ValueError, reporter._ensure_dimension_exists, "dim1", 1) + assert_raises(ValueError, reporter._ensure_dimension_exists, "dim2", 2) def test_store_dict(self): """Check correct storage and restore of dictionaries.""" @@ -522,48 +690,55 @@ def compare_dicts(reference, restored): assert pickle.dumps(sorted_reference) == pickle.dumps(sorted_restored) data = { - 'mybool': False, - 'mystring': 'test', - 'myinteger': 3, 'myfloat': 4.0, - 'mylist': [0, 1, 2, 3], 'mynumpyarray': np.array([2.0, 3, 4]), - 'mynestednumpyarray': np.array([[1, 2, 3], [4.0, 5, 6]]), - 'myquantity': 5.0 / unit.femtosecond, - 'myquantityarray': unit.Quantity(np.array([[1, 2, 3], [4.0, 5, 6]]), unit=unit.angstrom), - 'mynesteddict': {'field1': 'string', 'field2': {'field21': 3.0, 'field22': True}} + "mybool": False, + "mystring": "test", + "myinteger": 3, + "myfloat": 4.0, + "mylist": [0, 1, 2, 3], + "mynumpyarray": np.array([2.0, 3, 4]), + "mynestednumpyarray": np.array([[1, 2, 3], [4.0, 5, 6]]), + "myquantity": 5.0 / unit.femtosecond, + "myquantityarray": unit.Quantity( + np.array([[1, 2, 3], [4.0, 5, 6]]), unit=unit.angstrom + ), + "mynesteddict": { + "field1": "string", + "field2": {"field21": 3.0, "field22": True}, + }, } with self.temporary_reporter() as reporter: # Test both nested and single-string representations. - for name, nested in [('testdict', False), ('nested', True)]: + for name, nested in [("testdict", False), ("nested", True)]: reporter._write_dict(name, data, nested=nested) restored_data = reporter.read_dict(name) compare_dicts(data, restored_data) # Test reading a keyword inside a dict. - restored_data = reporter.read_dict(name + '/mynesteddict/field2') - compare_dicts(data['mynesteddict']['field2'], restored_data) + restored_data = reporter.read_dict(name + "/mynesteddict/field2") + compare_dicts(data["mynesteddict"]["field2"], restored_data) # write_dict supports updates, even with the nested representation # if the structure of the dictionary doesn't change. - data['mybool'] = True - data['mystring'] = 'substituted' + data["mybool"] = True + data["mystring"] = "substituted" reporter._write_dict(name, data, nested=nested) restored_data = reporter.read_dict(name) - assert restored_data['mybool'] is True - assert restored_data['mystring'] == 'substituted' + assert restored_data["mybool"] is True + assert restored_data["mystring"] == "substituted" # In nested representation, dictionaries are stored as groups and # values as variables. Otherwise, there's only a single variable. if nested: dict_group = reporter._storage_analysis.groups[name] - assert 'mynesteddict' in dict_group.groups - assert 'mylist' in dict_group.variables + assert "mynesteddict" in dict_group.groups + assert "mylist" in dict_group.variables else: assert name in reporter._storage_analysis.variables # Write dict fixed_dimension creates static dimensions and read/writes correctly - reporter._write_dict('fixed', data, fixed_dimension=True) - restored_fixed_data = reporter.read_dict('fixed') + reporter._write_dict("fixed", data, fixed_dimension=True) + restored_fixed_data = reporter.read_dict("fixed") compare_dicts(data, restored_fixed_data) def test_store_mixing_statistics(self): @@ -571,8 +746,12 @@ def test_store_mixing_statistics(self): n_accepted_matrix = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) n_proposed_matrix = np.array([[3, 3, 3], [6, 6, 6], [9, 9, 9]]) with self.temporary_reporter() as reporter: - reporter.write_mixing_statistics(n_accepted_matrix, n_proposed_matrix, iteration=0) - restored_n_accepted, restored_n_proposed = reporter.read_mixing_statistics(iteration=0) + reporter.write_mixing_statistics( + n_accepted_matrix, n_proposed_matrix, iteration=0 + ) + restored_n_accepted, restored_n_proposed = reporter.read_mixing_statistics( + iteration=0 + ) assert np.all(n_accepted_matrix == restored_n_accepted) assert np.all(n_proposed_matrix == restored_n_proposed) @@ -581,8 +760,10 @@ def test_store_mixing_statistics(self): # TEST MULTISTATE SAMPLERS # ============================================================================== -class TestBaseMultistateSampler(object): + +class TestBaseMultistateSampler: """Minimal Base class to test sampler objects""" + # ------------------------------------ # VARIABLES TO SET FOR EACH TEST CLASS # ------------------------------------ @@ -603,7 +784,7 @@ def temporary_storage_path(): """ mpicomm = mpiplus.get_mpicomm() with temporary_directory() as tmp_dir_path: - storage_file_path = os.path.join(tmp_dir_path, 'test_storage.nc') + storage_file_path = os.path.join(tmp_dir_path, "test_storage.nc") if mpicomm is not None: storage_file_path = mpicomm.bcast(storage_file_path, root=0) yield storage_file_path @@ -631,29 +812,34 @@ def property_creator(name, on_disk_name, value, on_disk_value): 'on_disk_value': on_disk_value } """ - return {name: { - - 'value': value, - 'on_disk_value': on_disk_value, - 'on_disk_name': on_disk_name - }} + return { + name: { + "value": value, + "on_disk_value": on_disk_value, + "on_disk_name": on_disk_name, + } + } class TestMultiStateSampler(TestBaseMultistateSampler): """Base test suite for the multi-state classes""" + # -------------------------------------- # Optional helper function to overwrite. # -------------------------------------- @classmethod - def call_sampler_create(cls, sampler, reporter, - thermodynamic_states, - sampler_states, - unsampled_states): + def call_sampler_create( + cls, sampler, reporter, thermodynamic_states, sampler_states, unsampled_states + ): """Helper function to call the create method for the sampler""" # Allows initial thermodynamic states to be handled by the built in methods - sampler.create(thermodynamic_states, sampler_states, reporter, - unsampled_thermodynamic_states=unsampled_states) + sampler.create( + thermodynamic_states, + sampler_states, + reporter, + unsampled_thermodynamic_states=unsampled_states, + ) # -------------------------------- # Tests overwritten by sub-classes @@ -678,11 +864,15 @@ def _compute_energies_independently(cls, sampler): # Compute the energies independently. energy_thermodynamic_states = np.zeros((n_replicas, n_states)) energy_unsampled_states = np.zeros((n_replicas, len(unsampled_states))) - for energies, states in [(energy_thermodynamic_states, thermodynamic_states), - (energy_unsampled_states, unsampled_states)]: + for energies, states in [ + (energy_thermodynamic_states, thermodynamic_states), + (energy_unsampled_states, unsampled_states), + ]: for i, sampler_state in enumerate(sampler_states): for j, state in enumerate(states): - context, integrator = mmtools.cache.global_context_cache.get_context(state) + context, integrator = ( + mmtools.cache.global_context_cache.get_context(state) + ) sampler_state.apply_to_context(context) energies[i][j] = state.reduced_potential(context) return energy_thermodynamic_states, energy_unsampled_states @@ -700,13 +890,18 @@ def setup_class(cls): # Translate the sampler states to be different one from each other. alanine_sampler_states = [ - mmtools.states.SamplerState(positions=alanine_test.positions + 10 * i * unit.nanometers) - for i in range(cls.N_SAMPLERS)] + mmtools.states.SamplerState( + positions=alanine_test.positions + 10 * i * unit.nanometers + ) + for i in range(cls.N_SAMPLERS) + ] # Set increasing temperature. temperatures = [(300 + 10 * i) * unit.kelvin for i in range(cls.N_STATES)] - alanine_thermodynamic_states = [mmtools.states.ThermodynamicState(alanine_test.system, temperatures[i]) - for i in range(cls.N_STATES)] + alanine_thermodynamic_states = [ + mmtools.states.ThermodynamicState(alanine_test.system, temperatures[i]) + for i in range(cls.N_STATES) + ] # No unsampled states for this test. cls.alanine_test = (alanine_thermodynamic_states, alanine_sampler_states, []) @@ -715,22 +910,33 @@ def setup_class(cls): # ----------------------------------------------------------------------------------------- hostguest_test = testsystems.HostGuestVacuum() factory = mmtools.alchemy.AbsoluteAlchemicalFactory() - alchemical_region = mmtools.alchemy.AlchemicalRegion(alchemical_atoms=range(126, 156)) - hostguest_alchemical = factory.create_alchemical_system(hostguest_test.system, alchemical_region) + alchemical_region = mmtools.alchemy.AlchemicalRegion( + alchemical_atoms=range(126, 156) + ) + hostguest_alchemical = factory.create_alchemical_system( + hostguest_test.system, alchemical_region + ) # Translate the sampler states to be different one from each other. hostguest_sampler_states = [ - mmtools.states.SamplerState(positions=hostguest_test.positions + 10 * i * unit.nanometers) - for i in range(cls.N_SAMPLERS)] + mmtools.states.SamplerState( + positions=hostguest_test.positions + 10 * i * unit.nanometers + ) + for i in range(cls.N_SAMPLERS) + ] # Create the three basic thermodynamic states. temperatures = [(300 + 10 * i) * unit.kelvin for i in range(cls.N_STATES)] - hostguest_thermodynamic_states = [mmtools.states.ThermodynamicState(hostguest_alchemical, temperatures[i]) - for i in range(cls.N_STATES)] + hostguest_thermodynamic_states = [ + mmtools.states.ThermodynamicState(hostguest_alchemical, temperatures[i]) + for i in range(cls.N_STATES) + ] # Create alchemical states at different parameter values. - alchemical_states = [mmtools.alchemy.AlchemicalState.from_system(hostguest_alchemical) - for _ in range(cls.N_STATES)] + alchemical_states = [ + mmtools.alchemy.AlchemicalState.from_system(hostguest_alchemical) + for _ in range(cls.N_STATES) + ] for i, alchemical_state in enumerate(alchemical_states): alchemical_state.set_alchemical_parameters(float(i) / (cls.N_STATES - 1)) @@ -738,19 +944,28 @@ def setup_class(cls): hostguest_compound_states = list() for i in range(cls.N_STATES): hostguest_compound_states.append( - mmtools.states.CompoundThermodynamicState(thermodynamic_state=hostguest_thermodynamic_states[i], - composable_states=[alchemical_states[i]]) + mmtools.states.CompoundThermodynamicState( + thermodynamic_state=hostguest_thermodynamic_states[i], + composable_states=[alchemical_states[i]], + ) ) # Unsampled states. - nonalchemical_state = mmtools.states.ThermodynamicState(hostguest_test.system, temperatures[0]) + nonalchemical_state = mmtools.states.ThermodynamicState( + hostguest_test.system, temperatures[0] + ) hostguest_unsampled_states = [copy.deepcopy(nonalchemical_state)] - cls.hostguest_test = (hostguest_compound_states, hostguest_sampler_states, hostguest_unsampled_states) + cls.hostguest_test = ( + hostguest_compound_states, + hostguest_sampler_states, + hostguest_unsampled_states, + ) # Debugging Messages to sent to Nose with --nocapture enabled output_descr = "Testing Sampler: {} -- States: {} -- Samplers: {}".format( - cls.SAMPLER.__name__, cls.N_STATES, cls.N_SAMPLERS) + cls.SAMPLER.__name__, cls.N_STATES, cls.N_SAMPLERS + ) len_output = len(output_descr) print("#" * len_output) print(output_descr) @@ -765,7 +980,6 @@ def get_node_replica_ids(tot_n_replicas): else: return set(range(mpicomm.rank, tot_n_replicas, mpicomm.size)) - def test_create(self): """Test creation of a new MultiState simulation. @@ -774,19 +988,25 @@ def test_create(self): open Reporter for writing. """ - thermodynamic_states, sampler_states, unsampled_states = copy.deepcopy(self.alanine_test) + thermodynamic_states, sampler_states, unsampled_states = copy.deepcopy( + self.alanine_test + ) n_states = len(thermodynamic_states) n_samplers = len(sampler_states) with self.temporary_storage_path() as storage_path: reporter = self.REPORTER(storage_path, checkpoint_interval=1) sampler = self.SAMPLER() - if hasattr(sampler, 'replica_mixing_scheme'): - sampler.replica_mixing_scheme = 'swap-neighbors' + if hasattr(sampler, "replica_mixing_scheme"): + sampler.replica_mixing_scheme = "swap-neighbors" sampler.locality = 2 - self.call_sampler_create(sampler, reporter, - thermodynamic_states, - sampler_states, unsampled_states) + self.call_sampler_create( + sampler, + reporter, + thermodynamic_states, + sampler_states, + unsampled_states, + ) # Check that reporter has reporter only if rank 0. mpicomm = mpiplus.get_mpicomm() @@ -799,52 +1019,85 @@ def test_create(self): reporter.close() # Open reporter to read stored data. - reporter = self.REPORTER(storage_path, open_mode='r', checkpoint_interval=1) + reporter = self.REPORTER(storage_path, open_mode="r", checkpoint_interval=1) # The n_states sampler states have been distributed restored_sampler_states = reporter.read_sampler_states(iteration=0) restored_thermo_states, _ = reporter.read_thermodynamic_states() - assert sampler.n_states == n_states, ("Mismatch: sampler.n_states = {} " - "but n_states = {}".format(sampler.n_states, n_states)) - assert sampler.n_replicas == n_samplers, ("Mismatch: sampler.n_replicas = {} " - "but n_samplers = {}".format(sampler.n_replicas, n_samplers)) + assert sampler.n_states == n_states, ( + "Mismatch: sampler.n_states = {} " "but n_states = {}".format( + sampler.n_states, n_states + ) + ) + assert sampler.n_replicas == n_samplers, ( + "Mismatch: sampler.n_replicas = {} " "but n_samplers = {}".format( + sampler.n_replicas, n_samplers + ) + ) assert len(restored_sampler_states) == n_samplers assert len(restored_thermo_states) == n_states - assert np.allclose(restored_sampler_states[0].positions, sampler._sampler_states[0].positions) + assert np.allclose( + restored_sampler_states[0].positions, + sampler._sampler_states[0].positions, + ) # MCMCMove was stored correctly. restored_mcmc_moves = reporter.read_mcmc_moves() assert len(sampler._mcmc_moves) == n_states assert len(restored_mcmc_moves) == n_states - for sampler_move, restored_move in zip(sampler._mcmc_moves, restored_mcmc_moves): + for sampler_move, restored_move in zip( + sampler._mcmc_moves, restored_mcmc_moves + ): assert isinstance(sampler_move, mmtools.mcmc.LangevinDynamicsMove) assert isinstance(restored_move, mmtools.mcmc.LangevinDynamicsMove) # Options have been stored. - stored_options = reporter.read_dict('options') + stored_options = reporter.read_dict("options") options_to_store = dict() for cls in inspect.getmro(type(sampler)): - parameter_names, _, _, defaults, _, _, _ = inspect.getfullargspec(cls.__init__) + parameter_names, _, _, defaults, _, _, _ = inspect.getfullargspec( + cls.__init__ + ) if defaults: - for parameter_name in parameter_names[-len(defaults):]: - options_to_store[parameter_name] = getattr(sampler, '_' + parameter_name) - options_to_store.pop('mcmc_moves') # mcmc_moves are stored separately + for parameter_name in parameter_names[-len(defaults) :]: + options_to_store[parameter_name] = getattr( + sampler, "_" + parameter_name + ) + options_to_store.pop("mcmc_moves") # mcmc_moves are stored separately for key, value in options_to_store.items(): if np.isscalar(value): - assert stored_options[key] == value, "stored_options['%s'] = %s, but value = %s" % (key, stored_options[key], value) - assert getattr(sampler, '_' + key) == value, "getattr(sampler, '%s') = %s, but value = %s" % ('_' + key, getattr(sampler, '_' + key), value) + assert ( + stored_options[key] == value + ), "stored_options['{}'] = {}, but value = {}".format( + key, stored_options[key], value + ) + assert ( + getattr(sampler, "_" + key) == value + ), "getattr(sampler, '{}') = {}, but value = {}".format( + "_" + key, getattr(sampler, "_" + key), value + ) else: - assert np.all(stored_options[key] == value), "stored_options['%s'] = %s, but value = %s" % (key, stored_options[key], value) - assert np.all(getattr(sampler, '_' + key) == value), "getattr(sampler, '%s') = %s, but value = %s" % ('_' + key, getattr(sampler, '_' + key), value) + assert np.all( + stored_options[key] == value + ), "stored_options['{}'] = {}, but value = {}".format( + key, stored_options[key], value + ) + assert np.all( + getattr(sampler, "_" + key) == value + ), "getattr(sampler, '{}') = {}, but value = {}".format( + "_" + key, getattr(sampler, "_" + key), value + ) # A default title has been added to the stored metadata. - metadata = reporter.read_dict('metadata') + metadata = reporter.read_dict("metadata") assert len(metadata) == 1 - assert sampler.metadata['title'] == metadata['title'] + assert sampler.metadata["title"] == metadata["title"] def test_citations(self): """Test that citations are displayed and suppressed as needed.""" - thermodynamic_states, sampler_states, unsampled_states = copy.deepcopy(self.alanine_test) + thermodynamic_states, sampler_states, unsampled_states = copy.deepcopy( + self.alanine_test + ) with self.temporary_storage_path() as storage_path: sampler = self.SAMPLER() @@ -857,13 +1110,17 @@ def test_citations(self): with self.captured_output() as (out, _): sampler._display_citations(overwrite_global=True) cite_string = out.getvalue().strip() - self.call_sampler_create(sampler, reporter, - thermodynamic_states, sampler_states, - unsampled_states) + self.call_sampler_create( + sampler, + reporter, + thermodynamic_states, + sampler_states, + unsampled_states, + ) # Reset internal flag sampler._have_displayed_citations_before = False # Test that the overwrite flag worked - assert cite_string != '' + assert cite_string != "" # Test that the output is not generate when the global is set with self.captured_output() as (out, _): sampler._global_citation_silence = True @@ -895,45 +1152,65 @@ def test_from_storage(self): """ # We don't want to restore reporter and timing data attributes __NON_RESTORABLE_ATTRIBUTES__ = ("_reporter", "_timing_data") - thermodynamic_states, sampler_states, unsampled_states = copy.deepcopy(self.hostguest_test) + thermodynamic_states, sampler_states, unsampled_states = copy.deepcopy( + self.hostguest_test + ) n_replicas = len(sampler_states) with self.temporary_storage_path() as storage_path: number_of_iterations = 3 move = mmtools.mcmc.LangevinDynamicsMove(n_steps=1) - sampler = self.SAMPLER(mcmc_moves=move, number_of_iterations=number_of_iterations) - if hasattr(sampler, 'replica_mixing_scheme'): + sampler = self.SAMPLER( + mcmc_moves=move, number_of_iterations=number_of_iterations + ) + if hasattr(sampler, "replica_mixing_scheme"): # TODO: Test both 'swap-all' with locality=None and 'swap-neighbors' with locality=1 - sampler.replica_mixing_scheme = 'swap-neighbors' # required for non-global locality + sampler.replica_mixing_scheme = ( + "swap-neighbors" # required for non-global locality + ) sampler.locality = 1 reporter = self.REPORTER(storage_path, checkpoint_interval=1) - self.call_sampler_create(sampler, reporter, - thermodynamic_states, sampler_states, - unsampled_states) + self.call_sampler_create( + sampler, + reporter, + thermodynamic_states, + sampler_states, + unsampled_states, + ) # Test at the beginning and after few iterations. for iteration in range(2): # Store the state of the initial repex object (its __dict__). We leave the # reporter out because when the NetCDF file is copied, it runs into issues. - original_dict = copy.deepcopy({k: v for k, v in sampler.__dict__.items() - if k not in __NON_RESTORABLE_ATTRIBUTES__}) + original_dict = copy.deepcopy( + { + k: v + for k, v in sampler.__dict__.items() + if k not in __NON_RESTORABLE_ATTRIBUTES__ + } + ) # Delete repex to close reporter before creating a new one # to avoid weird issues with multiple NetCDF files open. del sampler reporter.close() sampler = self.SAMPLER.from_storage(reporter) - restored_dict = copy.deepcopy({k: v for k, v in sampler.__dict__.items() - if k not in __NON_RESTORABLE_ATTRIBUTES__}) + restored_dict = copy.deepcopy( + { + k: v + for k, v in sampler.__dict__.items() + if k not in __NON_RESTORABLE_ATTRIBUTES__ + } + ) # Check thermodynamic states. - original_ts = original_dict.pop('_thermodynamic_states') - restored_ts = restored_dict.pop('_thermodynamic_states') + original_ts = original_dict.pop("_thermodynamic_states") + restored_ts = restored_dict.pop("_thermodynamic_states") check_thermodynamic_states_equality(original_ts, restored_ts) # Check unsampled thermodynamic states. - original_us = original_dict.pop('_unsampled_states') - restored_us = restored_dict.pop('_unsampled_states') + original_us = original_dict.pop("_unsampled_states") + restored_us = restored_dict.pop("_unsampled_states") check_thermodynamic_states_equality(original_us, restored_us) # The reporter of the restored simulation must be open only in node 0. @@ -948,47 +1225,59 @@ def test_from_storage(self): node_replica_ids = self.get_node_replica_ids(n_replicas) # Check sampler states. Non 0 nodes only hold their positions. - original_ss = original_dict.pop('_sampler_states') - restored_ss = restored_dict.pop('_sampler_states') - for replica_id, (original, restored) in enumerate(zip(original_ss, restored_ss)): + original_ss = original_dict.pop("_sampler_states") + restored_ss = restored_dict.pop("_sampler_states") + for replica_id, (original, restored) in enumerate( + zip(original_ss, restored_ss) + ): if replica_id in node_replica_ids: assert np.allclose(original.positions, restored.positions) assert np.all(original.box_vectors == restored.box_vectors) # Check energies. Non 0 nodes only hold their energies. - original_neighborhoods = original_dict.pop('_neighborhoods') - restored_neighborhoods = restored_dict.pop('_neighborhoods') - original_ets = original_dict.pop('_energy_thermodynamic_states') - restored_ets = restored_dict.pop('_energy_thermodynamic_states') - original_eus = original_dict.pop('_energy_unsampled_states') - restored_eus = restored_dict.pop('_energy_unsampled_states') + original_neighborhoods = original_dict.pop("_neighborhoods") + restored_neighborhoods = restored_dict.pop("_neighborhoods") + original_ets = original_dict.pop("_energy_thermodynamic_states") + restored_ets = restored_dict.pop("_energy_thermodynamic_states") + original_eus = original_dict.pop("_energy_unsampled_states") + restored_eus = restored_dict.pop("_energy_unsampled_states") for replica_id in node_replica_ids: - assert np.allclose(original_neighborhoods[replica_id], restored_neighborhoods[replica_id]) - assert np.allclose(original_ets[replica_id], restored_ets[replica_id]) - assert np.allclose(original_eus[replica_id], restored_eus[replica_id]) + assert np.allclose( + original_neighborhoods[replica_id], + restored_neighborhoods[replica_id], + ) + assert np.allclose( + original_ets[replica_id], restored_ets[replica_id] + ) + assert np.allclose( + original_eus[replica_id], restored_eus[replica_id] + ) # Only node 0 has updated accepted and proposed exchanges. - original_accepted = original_dict.pop('_n_accepted_matrix') - restored_accepted = restored_dict.pop('_n_accepted_matrix') - original_proposed = original_dict.pop('_n_proposed_matrix') - restored_proposed = restored_dict.pop('_n_proposed_matrix') + original_accepted = original_dict.pop("_n_accepted_matrix") + restored_accepted = restored_dict.pop("_n_accepted_matrix") + original_proposed = original_dict.pop("_n_proposed_matrix") + restored_proposed = restored_dict.pop("_n_proposed_matrix") if len(node_replica_ids) == n_replicas: assert np.all(original_accepted == restored_accepted) assert np.all(original_proposed == restored_proposed) # Test mcmc moves with pickle. - original_mcmc_moves = original_dict.pop('_mcmc_moves') - restored_mcmc_moves = restored_dict.pop('_mcmc_moves') + original_mcmc_moves = original_dict.pop("_mcmc_moves") + restored_mcmc_moves = restored_dict.pop("_mcmc_moves") if len(node_replica_ids) == n_replicas: - assert pickle.dumps(original_mcmc_moves) == pickle.dumps(restored_mcmc_moves) + assert pickle.dumps(original_mcmc_moves) == pickle.dumps( + restored_mcmc_moves + ) # Check all other arrays. Instantiate list so that we can pop from original_dict. for attr, original_value in list(original_dict.items()): if isinstance(original_value, np.ndarray): original_value = original_dict.pop(attr) restored_value = restored_dict.pop(attr) - assert np.all(original_value == restored_value), '{}: {}\t{}'.format( - attr, original_value, restored_value) + assert np.all( + original_value == restored_value + ), "{}: {}\t{}".format(attr, original_value, restored_value) # Everything else should be a dict of builtins. assert original_dict == restored_dict @@ -999,21 +1288,27 @@ def test_from_storage(self): def actual_stored_properties_check(self, additional_properties=None): """Stored properties check which expects a keyword""" - thermodynamic_states, sampler_states, unsampled_states = copy.deepcopy(self.alanine_test) + thermodynamic_states, sampler_states, unsampled_states = copy.deepcopy( + self.alanine_test + ) with self.temporary_storage_path() as storage_path: sampler = self.SAMPLER(number_of_iterations=5, online_analysis_interval=1) reporter = self.REPORTER(storage_path, checkpoint_interval=1) - self.call_sampler_create(sampler, reporter, - thermodynamic_states, sampler_states, - unsampled_states) + self.call_sampler_create( + sampler, + reporter, + thermodynamic_states, + sampler_states, + unsampled_states, + ) # Update options and check the storage is synchronized. - sampler.number_of_iterations = float('inf') + sampler.number_of_iterations = float("inf") # Process Additional properties if additional_properties is not None: for add_property, property_data in additional_properties.items(): - setattr(sampler, add_property, property_data['value']) + setattr(sampler, add_property, property_data["value"]) # Displace positions of the first sampler state. sampler_states = sampler.sampler_states @@ -1025,24 +1320,30 @@ def actual_stored_properties_check(self, additional_properties=None): mpicomm = mpiplus.get_mpicomm() if mpicomm is None or mpicomm.rank == 0: reporter.close() - reporter = self.REPORTER(storage_path, open_mode='r') - restored_options = reporter.read_dict('options') - assert restored_options['number_of_iterations'] == float('inf') + reporter = self.REPORTER(storage_path, open_mode="r") + restored_options = reporter.read_dict("options") + assert restored_options["number_of_iterations"] == float("inf") if additional_properties is not None: for _, property_data in additional_properties.items(): - on_disk_name = property_data['on_disk_name'] - on_disk_value = property_data['on_disk_value'] + on_disk_name = property_data["on_disk_name"] + on_disk_value = property_data["on_disk_value"] restored_value = restored_options[on_disk_name] if on_disk_value is None: - assert restored_value is on_disk_value, "Restored {} is not {}".format(restored_value, - on_disk_value) + assert ( + restored_value is on_disk_value + ), "Restored {} is not {}".format( + restored_value, on_disk_value + ) else: - assert restored_value == on_disk_value, "Restored {} != {}".format(restored_value, - on_disk_value) + assert ( + restored_value == on_disk_value + ), "Restored {} != {}".format(restored_value, on_disk_value) restored_sampler_states = reporter.read_sampler_states(iteration=0) - assert np.allclose(restored_sampler_states[0].positions, - original_positions + displacement_vector) + assert np.allclose( + restored_sampler_states[0].positions, + original_positions + displacement_vector, + ) def test_propagate_replicas(self): """Test method _propagate_replicas from MultiStateSampler. @@ -1052,7 +1353,9 @@ def test_propagate_replicas(self): the new positions and box vectors. """ - thermodynamic_states, sampler_states, unsampled_states = copy.deepcopy(self.alanine_test) + thermodynamic_states, sampler_states, unsampled_states = copy.deepcopy( + self.alanine_test + ) n_replicas = len(sampler_states) if n_replicas == 1: # This test is intended for use with more than one replica @@ -1062,18 +1365,30 @@ def test_propagate_replicas(self): # For this test to work, positions should be the same but # translated, so that minimized positions should satisfy # the same condition. - original_diffs = [np.average(sampler_states[i].positions - sampler_states[i+1].positions) - for i in range(n_replicas - 1)] - assert not np.allclose(original_diffs, [0 for _ in range(n_replicas - 1)]), "sampler %s failed" % self.SAMPLER + original_diffs = [ + np.average( + sampler_states[i].positions - sampler_states[i + 1].positions + ) + for i in range(n_replicas - 1) + ] + assert not np.allclose( + original_diffs, [0 for _ in range(n_replicas - 1)] + ), "sampler %s failed" % self.SAMPLER # Create a replica exchange that propagates only 1 femtosecond # per iteration so that positions won't change much. - move = mmtools.mcmc.IntegratorMove(openmm.VerletIntegrator(1.0*unit.femtosecond), n_steps=1) + move = mmtools.mcmc.IntegratorMove( + openmm.VerletIntegrator(1.0 * unit.femtosecond), n_steps=1 + ) sampler = self.SAMPLER(mcmc_moves=move) reporter = self.REPORTER(storage_path) - self.call_sampler_create(sampler, reporter, - thermodynamic_states, sampler_states, - unsampled_states) + self.call_sampler_create( + sampler, + reporter, + thermodynamic_states, + sampler_states, + unsampled_states, + ) # Propagate. sampler._propagate_replicas() @@ -1082,8 +1397,13 @@ def test_propagate_replicas(self): # be still translated the same way (i.e. we are not assigning # the minimized positions to the incorrect sampler states). new_sampler_states = sampler._sampler_states - new_diffs = [np.average(new_sampler_states[i].positions - new_sampler_states[i+1].positions) - for i in range(n_replicas - 1)] + new_diffs = [ + np.average( + new_sampler_states[i].positions + - new_sampler_states[i + 1].positions + ) + for i in range(n_replicas - 1) + ] assert np.allclose(original_diffs, new_diffs, rtol=1e-4) def test_compute_energies(self): @@ -1094,33 +1414,48 @@ def test_compute_energies(self): when it communicates them to the other nodes. """ - thermodynamic_states, sampler_states, unsampled_states = copy.deepcopy(self.hostguest_test) + thermodynamic_states, sampler_states, unsampled_states = copy.deepcopy( + self.hostguest_test + ) n_states = len(thermodynamic_states) n_replicas = len(sampler_states) with self.temporary_storage_path() as storage_path: sampler = self.SAMPLER() reporter = self.REPORTER(storage_path, checkpoint_interval=1) - self.call_sampler_create(sampler, reporter, - thermodynamic_states, sampler_states, - unsampled_states) + self.call_sampler_create( + sampler, + reporter, + thermodynamic_states, + sampler_states, + unsampled_states, + ) # Let MultiStateSampler distribute the computation of energies among nodes. sampler._compute_energies() # Compute energies at all states - energy_thermodynamic_states, energy_unsampled_states = self._compute_energies_independently(sampler) + energy_thermodynamic_states, energy_unsampled_states = ( + self._compute_energies_independently(sampler) + ) # Only node 0 has all the energies. mpicomm = mpiplus.get_mpicomm() if mpicomm is None or mpicomm.rank == 0: for replica_index in range(n_replicas): - neighborhood = sampler._neighborhoods[replica_index,:] + neighborhood = sampler._neighborhoods[replica_index, :] msg = f"{sampler} failed test_compute_energies:\n" msg += f"{sampler._energy_thermodynamic_states}\n" msg += f"{energy_thermodynamic_states}" - assert np.allclose(sampler._energy_thermodynamic_states[replica_index,neighborhood], energy_thermodynamic_states[replica_index,neighborhood]), msg - assert np.allclose(sampler._energy_unsampled_states, energy_unsampled_states) + assert np.allclose( + sampler._energy_thermodynamic_states[ + replica_index, neighborhood + ], + energy_thermodynamic_states[replica_index, neighborhood], + ), msg + assert np.allclose( + sampler._energy_unsampled_states, energy_unsampled_states + ) def test_minimize(self): """Test MultiStateSampler minimize method. @@ -1131,7 +1466,9 @@ def test_minimize(self): decreased. """ - thermodynamic_states, sampler_states, unsampled_states = copy.deepcopy(self.alanine_test) + thermodynamic_states, sampler_states, unsampled_states = copy.deepcopy( + self.alanine_test + ) n_states = len(thermodynamic_states) n_replicas = len(sampler_states) if n_replicas == 1: @@ -1141,21 +1478,32 @@ def test_minimize(self): with self.temporary_storage_path() as storage_path: sampler = self.SAMPLER() reporter = self.REPORTER(storage_path, checkpoint_interval=1) - self.call_sampler_create(sampler, reporter, - thermodynamic_states, sampler_states, - unsampled_states) + self.call_sampler_create( + sampler, + reporter, + thermodynamic_states, + sampler_states, + unsampled_states, + ) # For this test to work, positions should be the same but # translated, so that minimized positions should satisfy # the same condition. - original_diffs = [np.average(sampler_states[i].positions - sampler_states[i + 1].positions) - for i in range(n_replicas - 1)] + original_diffs = [ + np.average( + sampler_states[i].positions - sampler_states[i + 1].positions + ) + for i in range(n_replicas - 1) + ] assert not np.allclose(original_diffs, [0 for _ in range(n_replicas - 1)]) # Compute initial energies. sampler._compute_energies() state_indices = sampler._replica_thermodynamic_states - original_energies = [sampler._energy_thermodynamic_states[i, j] for i, j in enumerate(state_indices)] + original_energies = [ + sampler._energy_thermodynamic_states[i, j] + for i, j in enumerate(state_indices) + ] # Minimize. sampler.minimize() @@ -1164,8 +1512,13 @@ def test_minimize(self): # be still translated the same way (i.e. we are not assigning # the minimized positions to the incorrect sampler states). new_sampler_states = sampler._sampler_states - new_diffs = [np.average(new_sampler_states[i].positions - new_sampler_states[i + 1].positions) - for i in range(n_replicas - 1)] + new_diffs = [ + np.average( + new_sampler_states[i].positions + - new_sampler_states[i + 1].positions + ) + for i in range(n_replicas - 1) + ] assert np.allclose(original_diffs, new_diffs, atol=0.1) # Each replica keeps only the info for the replicas it is @@ -1177,15 +1530,21 @@ def test_minimize(self): for replica_index in node_replica_ids: state_index = sampler._replica_thermodynamic_states[replica_index] old_energy = original_energies[replica_index] - new_energy = sampler._energy_thermodynamic_states[replica_index, state_index] - assert new_energy <= old_energy, "Energies did not decrease: Replica {} was originally {}, now {}".format(replica_index, old_energy, new_energy) + new_energy = sampler._energy_thermodynamic_states[ + replica_index, state_index + ] + assert ( + new_energy <= old_energy + ), f"Energies did not decrease: Replica {replica_index} was originally {old_energy}, now {new_energy}" # The storage has been updated. reporter.close() if len(node_replica_ids) == n_states: - reporter = self.REPORTER(storage_path, open_mode='r') + reporter = self.REPORTER(storage_path, open_mode="r") stored_sampler_states = reporter.read_sampler_states(iteration=0) - for new_state, stored_state in zip(new_sampler_states, stored_sampler_states): + for new_state, stored_state in zip( + new_sampler_states, stored_sampler_states + ): assert np.allclose(new_state.positions, stored_state.positions) def test_equilibrate(self): @@ -1196,16 +1555,22 @@ def test_equilibrate(self): updated positions. """ - thermodynamic_states, sampler_states, unsampled_states = copy.deepcopy(self.alanine_test) + thermodynamic_states, sampler_states, unsampled_states = copy.deepcopy( + self.alanine_test + ) n_replicas = len(sampler_states) with self.temporary_storage_path() as storage_path: # We create a ReplicaExchange with a GHMC move but use Langevin for equilibration. sampler = self.SAMPLER(mcmc_moves=mmtools.mcmc.GHMCMove()) reporter = self.REPORTER(storage_path, checkpoint_interval=1) - self.call_sampler_create(sampler, reporter, - thermodynamic_states, sampler_states, - unsampled_states) + self.call_sampler_create( + sampler, + reporter, + thermodynamic_states, + sampler_states, + unsampled_states, + ) # Equilibrate equilibration_move = mmtools.mcmc.LangevinDynamicsMove(n_steps=1) @@ -1219,10 +1584,17 @@ def test_equilibrate(self): # The storage has been updated. reporter.close() if len(node_replica_ids) == n_replicas: - reporter = self.REPORTER(storage_path, open_mode='r', checkpoint_interval=1) + reporter = self.REPORTER( + storage_path, open_mode="r", checkpoint_interval=1 + ) stored_sampler_states = reporter.read_sampler_states(iteration=0) for stored_state in stored_sampler_states: - assert any([np.allclose(new_state.positions, stored_state.positions) for new_state in sampler._sampler_states]) + assert any( + [ + np.allclose(new_state.positions, stored_state.positions) + for new_state in sampler._sampler_states + ] + ) # We are still at iteration 0. assert sampler._iteration == 0 @@ -1232,20 +1604,28 @@ def test_run_extend(self): test_cases = [self.alanine_test, self.hostguest_test] for test_case in test_cases: - thermodynamic_states, sampler_states, unsampled_states = copy.deepcopy(test_case) + thermodynamic_states, sampler_states, unsampled_states = copy.deepcopy( + test_case + ) with self.temporary_storage_path() as storage_path: - moves = mmtools.mcmc.SequenceMove([ - mmtools.mcmc.LangevinDynamicsMove(n_steps=1), - mmtools.mcmc.MCRotationMove(), - mmtools.mcmc.GHMCMove(n_steps=1) - ]) + moves = mmtools.mcmc.SequenceMove( + [ + mmtools.mcmc.LangevinDynamicsMove(n_steps=1), + mmtools.mcmc.MCRotationMove(), + mmtools.mcmc.GHMCMove(n_steps=1), + ] + ) sampler = self.SAMPLER(mcmc_moves=moves, number_of_iterations=2) reporter = self.REPORTER(storage_path, checkpoint_interval=1) - self.call_sampler_create(sampler, reporter, - thermodynamic_states, sampler_states, - unsampled_states) + self.call_sampler_create( + sampler, + reporter, + thermodynamic_states, + sampler_states, + unsampled_states, + ) # MultiStateSampler.run doesn't go past number_of_iterations. assert not sampler.is_completed @@ -1259,44 +1639,59 @@ def test_run_extend(self): # Extract the sampled thermodynamic states # Only use propagated states since the last iteration is not subject to MCMC moves - sampled_states = list(reporter.read_replica_thermodynamic_states()[1:].flat) + sampled_states = list( + reporter.read_replica_thermodynamic_states()[1:].flat + ) # All replicas must have moves with updated statistics. for state_index, sequence_move in enumerate(sampler._mcmc_moves): # LangevinDynamicsMove (index 0) doesn't have statistics. for move_id in [1, 2]: - assert sequence_move.move_list[move_id].n_proposed == sampled_states.count(state_index) + assert sequence_move.move_list[ + move_id + ].n_proposed == sampled_states.count(state_index) # The MCMCMoves statistics in the storage are updated. mpicomm = mpiplus.get_mpicomm() if mpicomm is None or mpicomm.rank == 0: reporter.close() - reporter = self.REPORTER(storage_path, open_mode='r', checkpoint_interval=1) + reporter = self.REPORTER( + storage_path, open_mode="r", checkpoint_interval=1 + ) restored_mcmc_moves = reporter.read_mcmc_moves() for state_index, sequence_move in enumerate(restored_mcmc_moves): # LangevinDynamicsMove (index 0) doesn't have statistic for move_id in [1, 2]: - assert sequence_move.move_list[move_id].n_proposed == sampled_states.count(state_index) + assert sequence_move.move_list[ + move_id + ].n_proposed == sampled_states.count(state_index) def test_checkpointing(self): """Test that checkpointing writes infrequently""" - thermodynamic_states, sampler_states, unsampled_states = copy.deepcopy(self.alanine_test) + thermodynamic_states, sampler_states, unsampled_states = copy.deepcopy( + self.alanine_test + ) with self.temporary_storage_path() as storage_path: # For this test, we simply check that the checkpointing writes on the interval # We don't care about the numbers, per se, but we do care about when things are written n_iterations = 3 - move = mmtools.mcmc.IntegratorMove(openmm.VerletIntegrator(1.0 * unit.femtosecond), - n_steps=1) + move = mmtools.mcmc.IntegratorMove( + openmm.VerletIntegrator(1.0 * unit.femtosecond), n_steps=1 + ) reporter = self.REPORTER(storage_path, checkpoint_interval=2) sampler = self.SAMPLER(mcmc_moves=move, number_of_iterations=n_iterations) - self.call_sampler_create(sampler, reporter, - thermodynamic_states, sampler_states, - unsampled_states) + self.call_sampler_create( + sampler, + reporter, + thermodynamic_states, + sampler_states, + unsampled_states, + ) # Propagate. sampler.run() reporter.close() - reporter = self.REPORTER(storage_path, open_mode='r', checkpoint_interval=2) + reporter = self.REPORTER(storage_path, open_mode="r", checkpoint_interval=2) for i in range(n_iterations): energies, _, _ = reporter.read_energies(i) states = reporter.read_sampler_states(i) @@ -1312,20 +1707,28 @@ def test_resume_positions_velocities_from_storage(self): test_cases = [self.alanine_test, self.hostguest_test] for test_case in test_cases: - thermodynamic_states, sampler_states, unsampled_states = copy.deepcopy(test_case) + thermodynamic_states, sampler_states, unsampled_states = copy.deepcopy( + test_case + ) with self.temporary_storage_path() as storage_path: - moves = mmtools.mcmc.SequenceMove([ - mmtools.mcmc.LangevinDynamicsMove(n_steps=1), - mmtools.mcmc.MCRotationMove(), - mmtools.mcmc.GHMCMove(n_steps=1) - ]) + moves = mmtools.mcmc.SequenceMove( + [ + mmtools.mcmc.LangevinDynamicsMove(n_steps=1), + mmtools.mcmc.MCRotationMove(), + mmtools.mcmc.GHMCMove(n_steps=1), + ] + ) sampler = self.SAMPLER(mcmc_moves=moves, number_of_iterations=3) reporter = self.REPORTER(storage_path, checkpoint_interval=1) - self.call_sampler_create(sampler, reporter, - thermodynamic_states, sampler_states, - unsampled_states) + self.call_sampler_create( + sampler, + reporter, + thermodynamic_states, + sampler_states, + unsampled_states, + ) # Run 3 iterations sampler.run(n_iterations=3) # store a copy of the original states @@ -1336,35 +1739,48 @@ def test_resume_positions_velocities_from_storage(self): # recreate sampler from storage sampler = self.SAMPLER.from_storage(reporter) restored_states = sampler.sampler_states - for original_state, restored_state in zip(original_states, restored_states): - assert np.allclose(original_state.positions, restored_state.positions) - assert np.allclose(original_state.velocities, restored_state.velocities) - + for original_state, restored_state in zip( + original_states, restored_states + ): + assert np.allclose( + original_state.positions, restored_state.positions + ) + assert np.allclose( + original_state.velocities, restored_state.velocities + ) def test_last_iteration_functions(self): """Test that the last_iteration functions work right""" - thermodynamic_states, sampler_states, unsampled_states = copy.deepcopy(self.alanine_test) + thermodynamic_states, sampler_states, unsampled_states = copy.deepcopy( + self.alanine_test + ) with self.temporary_storage_path() as storage_path: # For this test, we simply check that the checkpointing writes on the interval # We don't care about the numbers, per se, but we do care about when things are written n_iterations = 10 - move = mmtools.mcmc.IntegratorMove(openmm.VerletIntegrator(1.0 * unit.femtosecond), n_steps=1) + move = mmtools.mcmc.IntegratorMove( + openmm.VerletIntegrator(1.0 * unit.femtosecond), n_steps=1 + ) sampler = self.SAMPLER(mcmc_moves=move, number_of_iterations=n_iterations) reporter = self.REPORTER(storage_path, checkpoint_interval=2) - self.call_sampler_create(sampler, reporter, - thermodynamic_states, sampler_states, - unsampled_states) + self.call_sampler_create( + sampler, + reporter, + thermodynamic_states, + sampler_states, + unsampled_states, + ) # Propagate. sampler.run() reporter.close() - reporter = self.REPORTER(storage_path, open_mode='a', checkpoint_interval=2) + reporter = self.REPORTER(storage_path, open_mode="a", checkpoint_interval=2) all_energies, _, _ = reporter.read_energies() # Break the checkpoint last_index = 4 reporter.write_last_iteration(last_index) # 5th iteration reporter.close() del reporter - reporter = self.REPORTER(storage_path, open_mode='r', checkpoint_interval=2) + reporter = self.REPORTER(storage_path, open_mode="r", checkpoint_interval=2) # Check single positive index within range energies, _, _ = reporter.read_energies(1) assert np.all(energies == all_energies[1]) @@ -1374,7 +1790,8 @@ def test_last_iteration_functions(self): # Check slice energies, _, _ = reporter.read_energies() assert np.all( - energies == all_energies[:last_index + 1]) # +1 to make sure we get the last index + energies == all_energies[: last_index + 1] + ) # +1 to make sure we get the last index # Check negative slicing energies, _, _ = reporter.read_energies(slice(-1, None, -1)) assert np.all(energies == all_energies[last_index::-1]) @@ -1385,10 +1802,12 @@ def test_last_iteration_functions(self): def test_separate_checkpoint_file(self): """Test that a separate checkpoint file can be created""" with self.temporary_storage_path() as storage_path: - cp_file = 'checkpoint_file.nc' + cp_file = "checkpoint_file.nc" base, head = os.path.split(storage_path) cp_path = os.path.join(base, cp_file) - reporter = self.REPORTER(storage_path, checkpoint_storage=cp_file, open_mode='w') + reporter = self.REPORTER( + storage_path, checkpoint_storage=cp_file, open_mode="w" + ) reporter.close() assert os.path.isfile(storage_path) assert os.path.isfile(cp_path) @@ -1396,45 +1815,63 @@ def test_separate_checkpoint_file(self): def test_checkpoint_uuid_matching(self): """Test that checkpoint and storage files have the same UUID""" with self.temporary_storage_path() as storage_path: - cp_file = 'checkpoint_file.nc' - reporter = self.REPORTER(storage_path, checkpoint_storage=cp_file, open_mode='w') + cp_file = "checkpoint_file.nc" + reporter = self.REPORTER( + storage_path, checkpoint_storage=cp_file, open_mode="w" + ) assert reporter._storage_checkpoint.UUID == reporter._storage_analysis.UUID def test_uuid_mismatch_errors(self): """Test that trying to use separate checkpoint file fails the UUID check""" with self.temporary_storage_path() as storage_path: file_base, ext = os.path.splitext(storage_path) - storage_mod = file_base + '_mod' + ext - cp_file_main = 'checkpoint_file.nc' - cp_file_mod = 'checkpoint_mod.nc' - reporter_main = self.REPORTER(storage_path, checkpoint_storage=cp_file_main, open_mode='w') + storage_mod = file_base + "_mod" + ext + cp_file_main = "checkpoint_file.nc" + cp_file_mod = "checkpoint_mod.nc" + reporter_main = self.REPORTER( + storage_path, checkpoint_storage=cp_file_main, open_mode="w" + ) reporter_main.close() - reporter_mod = self.REPORTER(storage_mod, checkpoint_storage=cp_file_mod, open_mode='w') + reporter_mod = self.REPORTER( + storage_mod, checkpoint_storage=cp_file_mod, open_mode="w" + ) reporter_mod.close() del reporter_main, reporter_mod with assert_raises(IOError): - self.REPORTER(storage_path, checkpoint_storage=cp_file_mod, open_mode='r') + self.REPORTER( + storage_path, checkpoint_storage=cp_file_mod, open_mode="r" + ) def test_analysis_opens_without_checkpoint(self): """Test that the analysis file can open without the checkpoint file""" with self.temporary_storage_path() as storage_path: - cp_file = 'checkpoint_file.nc' - cp_file_mod = 'checkpoint_mod.nc' - reporter = self.REPORTER(storage_path, checkpoint_storage=cp_file, open_mode='w') + cp_file = "checkpoint_file.nc" + cp_file_mod = "checkpoint_mod.nc" + reporter = self.REPORTER( + storage_path, checkpoint_storage=cp_file, open_mode="w" + ) reporter.close() del reporter - self.REPORTER(storage_path, checkpoint_storage=cp_file_mod, open_mode='r') + self.REPORTER(storage_path, checkpoint_storage=cp_file_mod, open_mode="r") def test_storage_reporter_and_string(self): """Test that creating a MultiState by storage string and reporter is the same""" - thermodynamic_states, sampler_states, unsampled_states = copy.deepcopy(self.alanine_test) + thermodynamic_states, sampler_states, unsampled_states = copy.deepcopy( + self.alanine_test + ) with self.temporary_storage_path() as storage_path: n_iterations = 5 - move = mmtools.mcmc.IntegratorMove(openmm.VerletIntegrator(1.0 * unit.femtosecond), n_steps=1) + move = mmtools.mcmc.IntegratorMove( + openmm.VerletIntegrator(1.0 * unit.femtosecond), n_steps=1 + ) sampler = self.SAMPLER(mcmc_moves=move, number_of_iterations=n_iterations) - self.call_sampler_create(sampler, storage_path, - thermodynamic_states, sampler_states, - unsampled_states) + self.call_sampler_create( + sampler, + storage_path, + thermodynamic_states, + sampler_states, + unsampled_states, + ) # Propagate. sampler.run() energies_str, _, _ = sampler._reporter.read_energies() @@ -1446,25 +1883,40 @@ def test_storage_reporter_and_string(self): def test_online_analysis_works(self): """Test online analysis runs""" - thermodynamic_states, sampler_states, unsampled_states = copy.deepcopy(self.alanine_test) + thermodynamic_states, sampler_states, unsampled_states = copy.deepcopy( + self.alanine_test + ) with self.temporary_storage_path() as storage_path: n_iterations = 10 online_interval = 2 - move = mmtools.mcmc.IntegratorMove(openmm.VerletIntegrator(1.0 * unit.femtosecond), n_steps=1) - sampler = self.SAMPLER(mcmc_moves=move, number_of_iterations=n_iterations, - online_analysis_interval=online_interval, - online_analysis_minimum_iterations=3) + move = mmtools.mcmc.IntegratorMove( + openmm.VerletIntegrator(1.0 * unit.femtosecond), n_steps=1 + ) + sampler = self.SAMPLER( + mcmc_moves=move, + number_of_iterations=n_iterations, + online_analysis_interval=online_interval, + online_analysis_minimum_iterations=3, + ) reporter = self.REPORTER(storage_path, checkpoint_interval=online_interval) - self.call_sampler_create(sampler, reporter, - thermodynamic_states, sampler_states, - unsampled_states) + self.call_sampler_create( + sampler, + reporter, + thermodynamic_states, + sampler_states, + unsampled_states, + ) # Run sampler.run() def validate_this_test(): # The stored values of online analysis should be up to date. - last_written_free_energy = self.SAMPLER._read_last_free_energy(sampler._reporter, sampler.iteration) - last_mbar_f_k, (last_free_energy, last_err_free_energy) = last_written_free_energy + last_written_free_energy = self.SAMPLER._read_last_free_energy( + sampler._reporter, sampler.iteration + ) + last_mbar_f_k, (last_free_energy, last_err_free_energy) = ( + last_written_free_energy + ) assert len(sampler._last_mbar_f_k) == len(thermodynamic_states) assert not np.all(sampler._last_mbar_f_k == 0) @@ -1475,12 +1927,16 @@ def validate_this_test(): # Error should not be 0 yet assert sampler._last_err_free_energy != 0 - assert sampler._last_err_free_energy == last_err_free_energy, \ - ("SAMPLER %s : sampler._last_err_free_energy = %s, " - "last_err_free_energy = %s" % (self.SAMPLER.__name__, - sampler._last_err_free_energy, - last_err_free_energy) - ) + assert sampler._last_err_free_energy == last_err_free_energy, ( + "SAMPLER %s : sampler._last_err_free_energy = %s, " + "last_err_free_energy = %s" + % ( + self.SAMPLER.__name__, + sampler._last_err_free_energy, + last_err_free_energy, + ) + ) + try: validate_this_test() except AssertionError as e: @@ -1488,8 +1944,13 @@ def validate_this_test(): # Only run up until we have sampled every state, or we hit some cycle limit cycle_limit = 100 # Put some upper limit of cycles cycles = 0 - while (not np.unique(sampler._reporter.read_replica_thermodynamic_states()).size == self.N_STATES - and cycles < cycle_limit): + while ( + not np.unique( + sampler._reporter.read_replica_thermodynamic_states() + ).size + == self.N_STATES + and cycles < cycle_limit + ): sampler.extend(20) cycles += 1 try: @@ -1506,19 +1967,30 @@ def validate_this_test(): def test_online_analysis_stops(self): """Test online analysis will stop the simulation""" - thermodynamic_states, sampler_states, unsampled_states = copy.deepcopy(self.alanine_test) + thermodynamic_states, sampler_states, unsampled_states = copy.deepcopy( + self.alanine_test + ) with self.temporary_storage_path() as storage_path: n_iterations = 5 online_interval = 1 - move = mmtools.mcmc.IntegratorMove(openmm.VerletIntegrator(1.0 * unit.femtosecond), n_steps=1) - sampler = self.SAMPLER(mcmc_moves=move, number_of_iterations=n_iterations, - online_analysis_interval=online_interval, - online_analysis_minimum_iterations=0, - online_analysis_target_error=np.inf) # use infinite error to stop right away + move = mmtools.mcmc.IntegratorMove( + openmm.VerletIntegrator(1.0 * unit.femtosecond), n_steps=1 + ) + sampler = self.SAMPLER( + mcmc_moves=move, + number_of_iterations=n_iterations, + online_analysis_interval=online_interval, + online_analysis_minimum_iterations=0, + online_analysis_target_error=np.inf, + ) # use infinite error to stop right away reporter = self.REPORTER(storage_path, checkpoint_interval=online_interval) - self.call_sampler_create(sampler, reporter, - thermodynamic_states, sampler_states, - unsampled_states) + self.call_sampler_create( + sampler, + reporter, + thermodynamic_states, + sampler_states, + unsampled_states, + ) # Run sampler.run() assert sampler._iteration < n_iterations @@ -1534,7 +2006,9 @@ def test_context_cache_default(self): def test_context_cache_energy_propagation(self): """Test specifying different context caches for energy and propagation in a short simulation.""" - thermodynamic_states, sampler_states, unsampled_states = copy.deepcopy(self.alanine_test) + thermodynamic_states, sampler_states, unsampled_states = copy.deepcopy( + self.alanine_test + ) n_replicas = len(sampler_states) if n_replicas == 1: # This test is intended for use with more than one replica @@ -1543,55 +2017,85 @@ def test_context_cache_energy_propagation(self): with self.temporary_storage_path() as storage_path: # Create a replica exchange that propagates only 1 femtosecond # per iteration so that positions won't change much. - move = mmtools.mcmc.IntegratorMove(openmm.VerletIntegrator(1.0 * unit.femtosecond), n_steps=1) + move = mmtools.mcmc.IntegratorMove( + openmm.VerletIntegrator(1.0 * unit.femtosecond), n_steps=1 + ) sampler = self.SAMPLER(mcmc_moves=move) reporter = self.REPORTER(storage_path) - self.call_sampler_create(sampler, reporter, - thermodynamic_states, sampler_states, - unsampled_states) + self.call_sampler_create( + sampler, + reporter, + thermodynamic_states, + sampler_states, + unsampled_states, + ) # Set context cache attributes - sampler.energy_context_cache = cache.ContextCache(capacity=None, time_to_live=None) - sampler.sampler_context_cache = cache.ContextCache(capacity=None, time_to_live=None) + sampler.energy_context_cache = cache.ContextCache( + capacity=None, time_to_live=None + ) + sampler.sampler_context_cache = cache.ContextCache( + capacity=None, time_to_live=None + ) # Compute energies sampler._compute_energies() # Check only energy context cache has been accessed - assert sampler.energy_context_cache._lru._n_access > 0, \ - f"Expected more than 0 accesses, received {sampler.energy_context_cache._lru._n_access }." - assert sampler.sampler_context_cache._lru._n_access == 0, \ - f"{sampler.sampler_context_cache._lru._n_access} accesses, expected 0." + assert ( + sampler.energy_context_cache._lru._n_access > 0 + ), f"Expected more than 0 accesses, received {sampler.energy_context_cache._lru._n_access }." + assert ( + sampler.sampler_context_cache._lru._n_access == 0 + ), f"{sampler.sampler_context_cache._lru._n_access} accesses, expected 0." # Propagate replicas sampler._propagate_replicas() # Check propagation context cache has been accessed after propagation - assert sampler.sampler_context_cache._lru._n_access > 0, \ - f"Expected more than 0 accesses, received {sampler.energy_context_cache._lru._n_access }." + assert ( + sampler.sampler_context_cache._lru._n_access > 0 + ), f"Expected more than 0 accesses, received {sampler.energy_context_cache._lru._n_access }." def test_real_time_analysis_yaml(self): """Test expected number of entries in real time analysis output yaml file.""" - thermodynamic_states, sampler_states, unsampled_states = copy.deepcopy(self.alanine_test) + thermodynamic_states, sampler_states, unsampled_states = copy.deepcopy( + self.alanine_test + ) with self.temporary_storage_path() as storage_path: n_iterations = 13 online_interval = 3 - expected_yaml_entries = int(n_iterations/online_interval) - move = mmtools.mcmc.IntegratorMove(openmm.VerletIntegrator(1.0 * unit.femtosecond), n_steps=1) - sampler = self.SAMPLER(mcmc_moves=move, number_of_iterations=n_iterations, - online_analysis_interval=online_interval) + expected_yaml_entries = int(n_iterations / online_interval) + move = mmtools.mcmc.IntegratorMove( + openmm.VerletIntegrator(1.0 * unit.femtosecond), n_steps=1 + ) + sampler = self.SAMPLER( + mcmc_moves=move, + number_of_iterations=n_iterations, + online_analysis_interval=online_interval, + ) reporter = self.REPORTER(storage_path, checkpoint_interval=online_interval) - self.call_sampler_create(sampler, reporter, - thermodynamic_states, sampler_states, - unsampled_states) + self.call_sampler_create( + sampler, + reporter, + thermodynamic_states, + sampler_states, + unsampled_states, + ) # Run sampler.run() # load file and check number of iterations - storage_dir, reporter_filename = os.path.split(sampler._reporter._storage_analysis_file_path) + storage_dir, reporter_filename = os.path.split( + sampler._reporter._storage_analysis_file_path + ) # remove extension from filename yaml_prefix = os.path.splitext(reporter_filename)[0] - output_filepath = os.path.join(storage_dir, f"{yaml_prefix}_real_time_analysis.yaml") + output_filepath = os.path.join( + storage_dir, f"{yaml_prefix}_real_time_analysis.yaml" + ) with open(output_filepath) as yaml_file: yaml_contents = yaml.safe_load(yaml_file) # Make sure we get the correct number of entries - assert len(yaml_contents) == expected_yaml_entries, \ - "Expected yaml entries do not match the actual number entries in the file." + assert ( + len(yaml_contents) == expected_yaml_entries + ), "Expected yaml entries do not match the actual number entries in the file." + def test_real_time_analysis_can_be_none(): """Test if real time analysis can be done""" @@ -1599,20 +2103,40 @@ def test_real_time_analysis_can_be_none(): n_replicas = 3 T_min = 298.0 * unit.kelvin # Minimum temperature. T_max = 600.0 * unit.kelvin # Maximum temperature. - temperatures = [T_min + (T_max - T_min) * (math.exp(float(i) / float(n_replicas-1)) - 1.0) / (math.e - 1.0) - for i in range(n_replicas)] - temperatures = [T_min + (T_max - T_min) * (math.exp(float(i) / float(n_replicas-1)) - 1.0) / (math.e - 1.0) - for i in range(n_replicas)] - thermodynamic_states = [states.ThermodynamicState(system=testsystem.system, temperature=T) - for T in temperatures] - move = mcmc.GHMCMove(timestep=2.0*unit.femtoseconds, n_steps=50) - simulation = MultiStateSampler(mcmc_moves=move, number_of_iterations=2, online_analysis_interval=None) - storage_path = tempfile.NamedTemporaryFile(delete=False).name + '.nc' + temperatures = [ + T_min + + (T_max - T_min) + * (math.exp(float(i) / float(n_replicas - 1)) - 1.0) + / (math.e - 1.0) + for i in range(n_replicas) + ] + temperatures = [ + T_min + + (T_max - T_min) + * (math.exp(float(i) / float(n_replicas - 1)) - 1.0) + / (math.e - 1.0) + for i in range(n_replicas) + ] + thermodynamic_states = [ + states.ThermodynamicState(system=testsystem.system, temperature=T) + for T in temperatures + ] + move = mcmc.GHMCMove(timestep=2.0 * unit.femtoseconds, n_steps=50) + simulation = MultiStateSampler( + mcmc_moves=move, number_of_iterations=2, online_analysis_interval=None + ) + storage_path = tempfile.NamedTemporaryFile(delete=False).name + ".nc" reporter = MultiStateReporter(storage_path, checkpoint_interval=1) - simulation.create(thermodynamic_states=thermodynamic_states, - sampler_states=states.SamplerState(testsystem.positions), storage=reporter) + simulation.create( + thermodynamic_states=thermodynamic_states, + sampler_states=states.SamplerState(testsystem.positions), + storage=reporter, + ) + + ############# + class TestExtraSamplersMultiStateSampler(TestMultiStateSampler): """Test MultiStateSampler with more samplers than states""" @@ -1645,10 +2169,14 @@ class TestReplicaExchange(TestMultiStateSampler): def test_stored_properties(self): """Test that storage is kept in sync with options. Unique to ReplicaExchange""" additional_values = {} - additional_values.update(self.property_creator('replica_mixing_scheme', 'replica_mixing_scheme', None, None)) + additional_values.update( + self.property_creator( + "replica_mixing_scheme", "replica_mixing_scheme", None, None + ) + ) self.actual_stored_properties_check(additional_properties=additional_values) - @attr('slow') # Skip on Travis-CI + @attr("slow") # Skip on Travis-CI def test_uniform_mixing(self): """Test that mixing is uniform for a sequence of harmonic oscillators. @@ -1659,14 +2187,14 @@ def test_uniform_mixing(self): """ temperature = 300.0 * unit.kelvin sigma = 1.0 * unit.angstrom # Oscillator width - #n_states = 50 # Number of harmonic oscillators. + # n_states = 50 # Number of harmonic oscillators. n_states = 6 # DEBUG n_states = 20 # DEBUG collision_rate = 10.0 / unit.picoseconds number_of_iterations = 2000 - number_of_iterations = 200 # DEBUG + number_of_iterations = 200 # DEBUG # Build an equidistant sequence of harmonic oscillators. sampler_states = [] @@ -1674,8 +2202,8 @@ def test_uniform_mixing(self): # The minima of the harmonic oscillators are 1 kT from each other. K = mmtools.constants.kB * temperature / sigma**2 # spring constant - mass = 39.948*unit.amu # mass - period = 2*np.pi*np.sqrt(mass/K) + mass = 39.948 * unit.amu # mass + period = 2 * np.pi * np.sqrt(mass / K) n_steps = 20 # Number of steps per iteration. timestep = period / n_steps spacing_sigma = 0.05 @@ -1687,47 +2215,58 @@ def test_uniform_mixing(self): # Determine the position of the harmonic oscillator minimum. minimum_position = oscillator_idx * sigma * spacing_sigma - minimum_position_unitless = minimum_position.value_in_unit_system(unit.md_unit_system) + minimum_position_unitless = minimum_position.value_in_unit_system( + unit.md_unit_system + ) positions[0][0] = minimum_position # Create an oscillator starting from its minimum. force = system.getForce(0) - assert force.getGlobalParameterName(1) == 'testsystems_HarmonicOscillator_x0' + assert ( + force.getGlobalParameterName(1) == "testsystems_HarmonicOscillator_x0" + ) force.setGlobalParameterDefaultValue(1, minimum_position_unitless) - thermodynamic_states.append(mmtools.states.ThermodynamicState( - system=system, temperature=temperature)) + thermodynamic_states.append( + mmtools.states.ThermodynamicState( + system=system, temperature=temperature + ) + ) sampler_states.append(mmtools.states.SamplerState(positions)) # Run a short repex simulation and gather data. with self.temporary_storage_path() as storage_path: # Create and run object. sampler = self.SAMPLER( - mcmc_moves=mmtools.mcmc.LangevinDynamicsMove(timestep=timestep, collision_rate=collision_rate, n_steps=n_steps), + mcmc_moves=mmtools.mcmc.LangevinDynamicsMove( + timestep=timestep, collision_rate=collision_rate, n_steps=n_steps + ), number_of_iterations=number_of_iterations, ) - reporter = self.REPORTER(storage_path, checkpoint_interval=number_of_iterations) + reporter = self.REPORTER( + storage_path, checkpoint_interval=number_of_iterations + ) sampler.create(thermodynamic_states, sampler_states, reporter) - #sampler.replica_mixing_scheme = 'swap-neighbors' - sampler.replica_mixing_scheme = 'swap-all' + # sampler.replica_mixing_scheme = 'swap-neighbors' + sampler.replica_mixing_scheme = "swap-all" sampler.run() # Retrieve from the reporter the mixing information before deleting. # Only the reporter from MPI node 0 should be open. n_accepted_matrix, n_proposed_matrix = mpiplus.run_single_node( - task=reporter.read_mixing_statistics, - rank=0, broadcast_result=True + task=reporter.read_mixing_statistics, rank=0, broadcast_result=True ) replica_thermo_states = mpiplus.run_single_node( task=reporter.read_replica_thermodynamic_states, - rank=0, broadcast_result=True + rank=0, + broadcast_result=True, ) del sampler, reporter # No need to analyze the same data in multiple MPI processes. mpicomm = mpiplus.get_mpicomm() if mpicomm is not None and mpicomm.rank == 0: - print('Acceptance matrix') + print("Acceptance matrix") print(n_accepted_matrix) print() @@ -1735,10 +2274,10 @@ def test_uniform_mixing(self): replica_thermo_state_counts = np.empty(n_states) for replica_idx in range(n_states): state_trajectory = replica_thermo_states[:, replica_idx] - #print(f"replica {replica_idx} : {''.join([ str(state) for state in state_trajectory ])}") + # print(f"replica {replica_idx} : {''.join([ str(state) for state in state_trajectory ])}") n_visited_states = len(set(state_trajectory)) replica_thermo_state_counts[replica_idx] = n_visited_states - print(replica_idx, ':', n_visited_states) + print(replica_idx, ":", n_visited_states) print() # Count the number of visited states by each MPI process. @@ -1747,23 +2286,38 @@ def test_uniform_mixing(self): mpi_sem_thermo_state_counts = np.empty(n_mpi_processes) for mpi_idx in range(n_mpi_processes): # Find replicas assigned to this MPI process. - replica_indices = list(i for i in range(n_states) if i % n_mpi_processes == mpi_idx) + replica_indices = list( + i for i in range(n_states) if i % n_mpi_processes == mpi_idx + ) # Find the average number of states visited by # the replicas assigned to this MPI process. - mpi_avg_thermo_state_counts[mpi_idx] = np.mean(replica_thermo_state_counts[replica_indices]) - mpi_sem_thermo_state_counts[mpi_idx] = np.std(replica_thermo_state_counts[replica_indices], ddof=1) / np.sqrt(len(replica_indices)) + mpi_avg_thermo_state_counts[mpi_idx] = np.mean( + replica_thermo_state_counts[replica_indices] + ) + mpi_sem_thermo_state_counts[mpi_idx] = np.std( + replica_thermo_state_counts[replica_indices], ddof=1 + ) / np.sqrt(len(replica_indices)) # These should be roughly equal. - print('MPI process mean number of thermo states visited:') - for mpi_idx, (mean, sem) in enumerate(zip(mpi_avg_thermo_state_counts, - mpi_sem_thermo_state_counts)): - print('{}: {} +- {}'.format(mpi_idx, mean, 2*sem)) + print("MPI process mean number of thermo states visited:") + for mpi_idx, (mean, sem) in enumerate( + zip(mpi_avg_thermo_state_counts, mpi_sem_thermo_state_counts) + ): + print(f"{mpi_idx}: {mean} +- {2*sem}") # Check if the confidence intervals overlap. def are_overlapping(interval1, interval2): - return min(interval1[1], interval2[1]) - max(interval1[0], interval2[0]) > 0 - - cis = [(mean-2*sem, mean+2*sem) for mean, sem in zip(mpi_avg_thermo_state_counts, mpi_sem_thermo_state_counts)] + return ( + min(interval1[1], interval2[1]) - max(interval1[0], interval2[0]) + > 0 + ) + + cis = [ + (mean - 2 * sem, mean + 2 * sem) + for mean, sem in zip( + mpi_avg_thermo_state_counts, mpi_sem_thermo_state_counts + ) + ] for i in range(1, len(cis)): assert are_overlapping(cis[0], cis[i]) @@ -1788,35 +2342,46 @@ def test_stored_properties(self): """Test that storage is kept in sync with options. Unique to SAMSSampler""" additional_values = {} options = { - 'state_update_scheme': 'global-jump', - 'locality': None, - 'update_stages': 'one-stage', - 'weight_update_method': 'optimal', - 'adapt_target_probabilities': False, - } - for (name, value) in options.items(): + "state_update_scheme": "global-jump", + "locality": None, + "update_stages": "one-stage", + "weight_update_method": "optimal", + "adapt_target_probabilities": False, + } + for name, value in options.items(): additional_values.update(self.property_creator(name, name, value, value)) self.actual_stored_properties_check(additional_properties=additional_values) def test_state_histogram(self): """Ensure SAMS on-the-fly state histograms match actually visited states""" - thermodynamic_states, sampler_states, unsampled_states = copy.deepcopy(self.alanine_test) + thermodynamic_states, sampler_states, unsampled_states = copy.deepcopy( + self.alanine_test + ) with self.temporary_storage_path() as storage_path: # For this test, we simply check that the checkpointing writes on the interval # We don't care about the numbers, per se, but we do care about when things are written n_iterations = 10 - move = mmtools.mcmc.IntegratorMove(openmm.VerletIntegrator(1.0 * unit.femtosecond), n_steps=1) + move = mmtools.mcmc.IntegratorMove( + openmm.VerletIntegrator(1.0 * unit.femtosecond), n_steps=1 + ) sampler = self.SAMPLER(mcmc_moves=move, number_of_iterations=n_iterations) reporter = self.REPORTER(storage_path, checkpoint_interval=2) - self.call_sampler_create(sampler, reporter, - thermodynamic_states, sampler_states, - unsampled_states) + self.call_sampler_create( + sampler, + reporter, + thermodynamic_states, + sampler_states, + unsampled_states, + ) # Propagate. sampler.run() reporter.close() - reporter = self.REPORTER(storage_path, open_mode='a', checkpoint_interval=2) + reporter = self.REPORTER(storage_path, open_mode="a", checkpoint_interval=2) replica_thermodynamic_states = reporter.read_replica_thermodynamic_states() - N_k, _ = np.histogram(replica_thermodynamic_states, bins=np.arange(-0.5, sampler.n_states + 0.5)) + N_k, _ = np.histogram( + replica_thermodynamic_states, + bins=np.arange(-0.5, sampler.n_states + 0.5), + ) assert np.all(sampler._state_histogram == N_k) # TODO: Test all update methods @@ -1839,13 +2404,13 @@ def test_stored_properties(self): """Test that storage is kept in sync with options. Unique to SAMSSampler""" additional_values = {} options = { - 'state_update_scheme': 'global-jump', - 'locality': None, - 'update_stages': 'two-stage', - 'weight_update_method' : 'rao-blackwellized', - 'adapt_target_probabilities': False, - } - for (name, value) in options.items(): + "state_update_scheme": "global-jump", + "locality": None, + "update_stages": "two-stage", + "weight_update_method": "rao-blackwellized", + "adapt_target_probabilities": False, + } + for name, value in options.items(): additional_values.update(self.property_creator(name, name, value, value)) self.actual_stored_properties_check(additional_properties=additional_values) @@ -1853,7 +2418,6 @@ def test_stored_properties(self): class TestParallelTempering(TestMultiStateSampler): - # ------------------------------------ # VARIABLES TO SET FOR EACH TEST CLASS # ------------------------------------ @@ -1865,49 +2429,69 @@ class TestParallelTempering(TestMultiStateSampler): N_STATES = 3 SAMPLER = ParallelTemperingSampler REPORTER = MultiStateReporter - MIN_TEMP = 300*unit.kelvin - MAX_TEMP = 350*unit.kelvin + MIN_TEMP = 300 * unit.kelvin + MAX_TEMP = 350 * unit.kelvin # -------------------------------------- # Optional helper function to overwrite. # -------------------------------------- @classmethod - def call_sampler_create(cls, sampler, reporter, - thermodynamic_states, - sampler_states, - unsampled_states): + def call_sampler_create( + cls, sampler, reporter, thermodynamic_states, sampler_states, unsampled_states + ): """ Helper function to call the create method for the sampler ParallelTempering has a unique call """ single_state = thermodynamic_states[0] # Allows initial thermodynamic states to be handled by the built in methods - sampler.create(single_state, sampler_states, reporter, - min_temperature=cls.MIN_TEMP, max_temperature=cls.MAX_TEMP, n_temperatures=cls.N_STATES, - unsampled_thermodynamic_states=unsampled_states) + sampler.create( + single_state, + sampler_states, + reporter, + min_temperature=cls.MIN_TEMP, + max_temperature=cls.MAX_TEMP, + n_temperatures=cls.N_STATES, + unsampled_thermodynamic_states=unsampled_states, + ) def test_temperatures(self): """ Test temperatures are created with desired range """ - thermodynamic_states, sampler_states, unsampled_states = copy.deepcopy(self.alanine_test) + thermodynamic_states, sampler_states, unsampled_states = copy.deepcopy( + self.alanine_test + ) with self.temporary_storage_path() as storage_path: sampler = self.SAMPLER() reporter = self.REPORTER(storage_path, checkpoint_interval=1) - self.call_sampler_create(sampler, reporter, - thermodynamic_states, sampler_states, - unsampled_states) + self.call_sampler_create( + sampler, + reporter, + thermodynamic_states, + sampler_states, + unsampled_states, + ) try: from openmm import unit except ImportError: # OpenMM < 7.6 from simtk import unit - temperatures = [state.temperature/unit.kelvin for state in sampler._thermodynamic_states] # in kelvin - assert len(temperatures) == self.N_STATES, f"There are {len(temperatures)} thermodynamic states; expected {self.N_STATES}" - assert np.isclose(min(temperatures), (self.MIN_TEMP/unit.kelvin)), f"Min temperature is {min(temperatures)} K; expected {(self.MIN_TEMP/unit.kelvin)} K" - assert np.isclose(max(temperatures), (self.MAX_TEMP/unit.kelvin)), f"Max temperature is {max(temperatures)} K; expected {(self.MAX_TEMP/unit.kelvin)} K" + temperatures = [ + state.temperature / unit.kelvin + for state in sampler._thermodynamic_states + ] # in kelvin + assert ( + len(temperatures) == self.N_STATES + ), f"There are {len(temperatures)} thermodynamic states; expected {self.N_STATES}" + assert np.isclose( + min(temperatures), (self.MIN_TEMP / unit.kelvin) + ), f"Min temperature is {min(temperatures)} K; expected {(self.MIN_TEMP/unit.kelvin)} K" + assert np.isclose( + max(temperatures), (self.MAX_TEMP / unit.kelvin) + ), f"Max temperature is {max(temperatures)} K; expected {(self.MAX_TEMP/unit.kelvin)} K" # ---------------------------------- # Methods overwritten from the Super @@ -1933,11 +2517,15 @@ def _compute_energies_independently(cls, sampler): # parallel tempering specific subclass implementation works as desired energy_thermodynamic_states = np.zeros((n_replicas, n_states)) energy_unsampled_states = np.zeros((n_replicas, len(unsampled_states))) - for energies, states in [(energy_thermodynamic_states, thermodynamic_states), - (energy_unsampled_states, unsampled_states)]: + for energies, states in [ + (energy_thermodynamic_states, thermodynamic_states), + (energy_unsampled_states, unsampled_states), + ]: for i, sampler_state in enumerate(sampler_states): for j, state in enumerate(states): - context, integrator = mmtools.cache.global_context_cache.get_context(state) + context, integrator = ( + mmtools.cache.global_context_cache.get_context(state) + ) sampler_state.apply_to_context(context) energies[i][j] = state.reduced_potential(context) return energy_thermodynamic_states, energy_unsampled_states @@ -1957,29 +2545,42 @@ def test_resume_velocities_from_legacy_storage(self): This emulates the behavior of reading older versions (previous to 0.21.3 release) of serialized simulations. """ import netCDF4 + origin_reporter_path = testsystems.get_data_filename( os.path.join("data", "reporter-examples", "alanine_dipeptide_legacy.nc") ) origin_checkpoint_path = testsystems.get_data_filename( - os.path.join("data", "reporter-examples", "alanine_dipeptide_legacy_checkpoint.nc") + os.path.join( + "data", "reporter-examples", "alanine_dipeptide_legacy_checkpoint.nc" + ) ) # Assert no velocities in legacy dataset variables - netcdf_data = netCDF4.Dataset(origin_checkpoint_path) # open checkpoint for reading - assert 'velocities' not in netcdf_data.variables, "velocities variable should not exist in legacy reporter " \ - "netcdf file." + netcdf_data = netCDF4.Dataset( + origin_checkpoint_path + ) # open checkpoint for reading + assert "velocities" not in netcdf_data.variables, ( + "velocities variable should not exist in legacy reporter " "netcdf file." + ) with self.temporary_storage_path() as storage_path: # copy files to temporary directory - temporary_checkpoint_path = f"{os.path.splitext(storage_path)[0]}_checkpoint.nc" - reporter_path = shutil.copy(origin_reporter_path, storage_path) # copy reporter file - checkpoint_path = shutil.copy(origin_checkpoint_path, temporary_checkpoint_path) # copy checkpoint file + temporary_checkpoint_path = ( + f"{os.path.splitext(storage_path)[0]}_checkpoint.nc" + ) + reporter_path = shutil.copy( + origin_reporter_path, storage_path + ) # copy reporter file + checkpoint_path = shutil.copy( + origin_checkpoint_path, temporary_checkpoint_path + ) # copy checkpoint file # Load repex simulation reporter = self.REPORTER(reporter_path, checkpoint_interval=1) sampler = self.SAMPLER.from_storage(reporter) # Assert velocities are initialized as zeros for state in sampler.sampler_states: - assert np.all(state.velocities.value_in_unit_system(unit.md_unit_system) == 0), \ - "Velocities in sampler state from legacy checkpoint are expected to be all zeros." + assert np.all( + state.velocities.value_in_unit_system(unit.md_unit_system) == 0 + ), "Velocities in sampler state from legacy checkpoint are expected to be all zeros." # Resume simulation sampler.extend(n_iterations=1) @@ -1988,28 +2589,32 @@ def test_resume_velocities_from_legacy_storage(self): del sampler reporter.close() # assert velocities variable exist - netcdf_data = netCDF4.Dataset(checkpoint_path) # open checkpoint for reading - assert 'velocities' in netcdf_data.variables, "velocities variable should exist in new reporter " \ - "netcdf file." + netcdf_data = netCDF4.Dataset( + checkpoint_path + ) # open checkpoint for reading + assert "velocities" in netcdf_data.variables, ( + "velocities variable should exist in new reporter " "netcdf file." + ) netcdf_data.close() # close or it errors in next line # Load repex simulation from new reporter file new_sampler = self.SAMPLER.from_storage(reporter) # assert velocities in sampler states are non-zero for state in new_sampler.sampler_states: - assert np.any(state.velocities.value_in_unit_system(unit.md_unit_system) != 0), \ - "At least some velocity in sampler state from new checkpoint is expected to different from zero." + assert np.any( + state.velocities.value_in_unit_system(unit.md_unit_system) != 0 + ), "At least some velocity in sampler state from new checkpoint is expected to different from zero." + # ============================================================================== # MAIN AND TESTS # ============================================================================== if __name__ == "__main__": - # Test simple system of harmonic oscillators. # Disabled until we fix the test # test_replica_exchange() - print('Creating class') + print("Creating class") repex = TestReplicaExchange() - print('testing...') + print("testing...") repex.test_uniform_mixing() diff --git a/openmmtools/tests/test_states.py b/openmmtools/tests/test_states.py index 8a92d4da3..35cca94b0 100644 --- a/openmmtools/tests/test_states.py +++ b/openmmtools/tests/test_states.py @@ -26,14 +26,15 @@ # ============================================================================= # We use CPU as OpenCL sometimes causes segfaults on Travis. -DEFAULT_PLATFORM = openmm.Platform.getPlatformByName('CPU') -DEFAULT_PLATFORM.setPropertyDefaultValue('DeterministicForces', 'true') +DEFAULT_PLATFORM = openmm.Platform.getPlatformByName("CPU") +DEFAULT_PLATFORM.setPropertyDefaultValue("DeterministicForces", "true") # ============================================================================= # UTILITY FUNCTIONS # ============================================================================= + def create_default_context(thermodynamic_state, integrator): """Shortcut to create a context from the thermodynamic state using the DEFAULT_PLATFORM.""" return thermodynamic_state.create_context(integrator, DEFAULT_PLATFORM) @@ -51,7 +52,8 @@ def get_barostat_temperature(barostat): # TEST THERMODYNAMIC STATE # ============================================================================= -class TestThermodynamicState(object): + +class TestThermodynamicState: """Test suite for states.ThermodynamicState class.""" @classmethod @@ -70,12 +72,14 @@ def setup_class(cls): cls.toluene_positions = toluene_implicit.positions cls.toluene_implicit = toluene_implicit.system cls.toluene_vacuum = testsystems.TolueneVacuum().system - thermostat = openmm.AndersenThermostat(cls.std_temperature, - 1.0/unit.picosecond) + thermostat = openmm.AndersenThermostat( + cls.std_temperature, 1.0 / unit.picosecond + ) cls.toluene_vacuum.addForce(thermostat) cls.alanine_explicit = copy.deepcopy(cls.alanine_no_thermostat) - thermostat = openmm.AndersenThermostat(cls.std_temperature, - 1.0/unit.picosecond) + thermostat = openmm.AndersenThermostat( + cls.std_temperature, 1.0 / unit.picosecond + ) cls.alanine_explicit.addForce(thermostat) # A system correctly barostated @@ -96,62 +100,88 @@ def setup_class(cls): cls.multiple_barostat_alanine.addForce(barostat) # A system with an unsupported MonteCarloAnisotropicBarostat - cls.unsupported_anisotropic_barostat_alanine = copy.deepcopy(cls.alanine_explicit) + cls.unsupported_anisotropic_barostat_alanine = copy.deepcopy( + cls.alanine_explicit + ) pressure_in_bars = cls.std_pressure / unit.bar - anisotropic_pressure = openmm.Vec3(pressure_in_bars, pressure_in_bars, - 1.0+pressure_in_bars) - cls.unsupported_anisotropic_barostat = openmm.MonteCarloAnisotropicBarostat(anisotropic_pressure, cls.std_temperature) - cls.unsupported_anisotropic_barostat_alanine.addForce(cls.unsupported_anisotropic_barostat) + anisotropic_pressure = openmm.Vec3( + pressure_in_bars, pressure_in_bars, 1.0 + pressure_in_bars + ) + cls.unsupported_anisotropic_barostat = openmm.MonteCarloAnisotropicBarostat( + anisotropic_pressure, cls.std_temperature + ) + cls.unsupported_anisotropic_barostat_alanine.addForce( + cls.unsupported_anisotropic_barostat + ) # A system with an unsupported MonteCarloMembraneBarostat - cls.membrane_barostat_alanine_gamma_nonzero = copy.deepcopy(cls.alanine_explicit) + cls.membrane_barostat_alanine_gamma_nonzero = copy.deepcopy( + cls.alanine_explicit + ) # working around a bug in the unit conversion https://github.com/openmm/openmm/issues/2406 cls.membrane_barostat_gamma_nonzero = openmm.MonteCarloMembraneBarostat( - cls.std_pressure, cls.modified_surface_tension.value_in_unit(unit.bar*unit.nanometer), cls.std_temperature, - openmm.MonteCarloMembraneBarostat.XYIsotropic, openmm.MonteCarloMembraneBarostat.ZFree + cls.std_pressure, + cls.modified_surface_tension.value_in_unit(unit.bar * unit.nanometer), + cls.std_temperature, + openmm.MonteCarloMembraneBarostat.XYIsotropic, + openmm.MonteCarloMembraneBarostat.ZFree, + ) + cls.membrane_barostat_alanine_gamma_nonzero.addForce( + cls.membrane_barostat_gamma_nonzero ) - cls.membrane_barostat_alanine_gamma_nonzero.addForce(cls.membrane_barostat_gamma_nonzero) # A system with a supported MonteCarloAnisotropicBarostat cls.supported_anisotropic_barostat_alanine = copy.deepcopy(cls.alanine_explicit) pressure_in_bars = cls.std_pressure / unit.bar - anisotropic_pressure = openmm.Vec3(pressure_in_bars, pressure_in_bars, - pressure_in_bars) - cls.supported_anisotropic_barostat = openmm.MonteCarloAnisotropicBarostat(anisotropic_pressure, cls.std_temperature) - cls.supported_anisotropic_barostat_alanine.addForce(cls.supported_anisotropic_barostat) + anisotropic_pressure = openmm.Vec3( + pressure_in_bars, pressure_in_bars, pressure_in_bars + ) + cls.supported_anisotropic_barostat = openmm.MonteCarloAnisotropicBarostat( + anisotropic_pressure, cls.std_temperature + ) + cls.supported_anisotropic_barostat_alanine.addForce( + cls.supported_anisotropic_barostat + ) # A system with a supported MonteCarloMembraneBarostat cls.membrane_barostat_alanine_gamma_zero = copy.deepcopy(cls.alanine_explicit) cls.membrane_barostat_gamma_zero = openmm.MonteCarloMembraneBarostat( - cls.std_pressure, 0.0, cls.std_temperature, - openmm.MonteCarloMembraneBarostat.XYIsotropic, openmm.MonteCarloMembraneBarostat.ZFree + cls.std_pressure, + 0.0, + cls.std_temperature, + openmm.MonteCarloMembraneBarostat.XYIsotropic, + openmm.MonteCarloMembraneBarostat.ZFree, + ) + cls.membrane_barostat_alanine_gamma_zero.addForce( + cls.membrane_barostat_gamma_zero ) - cls.membrane_barostat_alanine_gamma_zero.addForce(cls.membrane_barostat_gamma_zero) # A system with an inconsistent pressure in the barostat. cls.inconsistent_pressure_alanine = copy.deepcopy(cls.alanine_explicit) - barostat = openmm.MonteCarloBarostat(cls.std_pressure + 0.2*unit.bar, - cls.std_temperature) + barostat = openmm.MonteCarloBarostat( + cls.std_pressure + 0.2 * unit.bar, cls.std_temperature + ) cls.inconsistent_pressure_alanine.addForce(barostat) # A system with an inconsistent temperature in the barostat. cls.inconsistent_temperature_alanine = copy.deepcopy(cls.alanine_no_thermostat) - barostat = openmm.MonteCarloBarostat(cls.std_pressure, - cls.std_temperature + 1.0*unit.kelvin) - thermostat = openmm.AndersenThermostat(cls.std_temperature + 1.0*unit.kelvin, - 1.0/unit.picosecond) + barostat = openmm.MonteCarloBarostat( + cls.std_pressure, cls.std_temperature + 1.0 * unit.kelvin + ) + thermostat = openmm.AndersenThermostat( + cls.std_temperature + 1.0 * unit.kelvin, 1.0 / unit.picosecond + ) cls.inconsistent_temperature_alanine.addForce(barostat) cls.inconsistent_temperature_alanine.addForce(thermostat) @staticmethod def get_integrators(temperature): - friction = 5.0/unit.picosecond - time_step = 2.0*unit.femtosecond + friction = 5.0 / unit.picosecond + time_step = 2.0 * unit.femtosecond # Test cases verlet = openmm.VerletIntegrator(time_step) - langevin = openmm.LangevinIntegrator(temperature, - friction, time_step) + langevin = openmm.LangevinIntegrator(temperature, friction, time_step) velocity_verlet = integrators.VelocityVerletIntegrator() ghmc = integrators.GHMCIntegrator(temperature) # Copying a CustomIntegrator will make it lose any extra function @@ -166,15 +196,30 @@ def get_integrators(temperature): compound_verlet.addIntegrator(openmm.VerletIntegrator(time_step)) compound_verlet.addIntegrator(openmm.VerletIntegrator(time_step)) - return [(False, verlet), (False, velocity_verlet), (False, compound_verlet), - (True, langevin), (True, ghmc), (True, custom_ghmc), (True, compound_ghmc)] + return [ + (False, verlet), + (False, velocity_verlet), + (False, compound_verlet), + (True, langevin), + (True, ghmc), + (True, custom_ghmc), + (True, compound_ghmc), + ] def test_single_instance_standard_system(self): """ThermodynamicState should store only 1 System per compatible state.""" - state_nvt_300 = ThermodynamicState(system=self.alanine_explicit, temperature=300*unit.kelvin) - state_nvt_350 = ThermodynamicState(system=self.alanine_explicit, temperature=350*unit.kelvin) - state_npt_1 = ThermodynamicState(system=self.alanine_explicit, pressure=1.0*unit.atmosphere) - state_npt_2 = ThermodynamicState(system=self.alanine_explicit, pressure=2.0*unit.atmosphere) + state_nvt_300 = ThermodynamicState( + system=self.alanine_explicit, temperature=300 * unit.kelvin + ) + state_nvt_350 = ThermodynamicState( + system=self.alanine_explicit, temperature=350 * unit.kelvin + ) + state_npt_1 = ThermodynamicState( + system=self.alanine_explicit, pressure=1.0 * unit.atmosphere + ) + state_npt_2 = ThermodynamicState( + system=self.alanine_explicit, pressure=2.0 * unit.atmosphere + ) assert state_nvt_300._standard_system == state_nvt_350._standard_system assert state_nvt_300._standard_system != state_npt_1._standard_system assert state_npt_1._standard_system == state_npt_2._standard_system @@ -192,16 +237,25 @@ def test_method_find_barostat(self): barostat = ThermodynamicState._find_barostat(self.barostated_alanine) assert isinstance(barostat, openmm.MonteCarloBarostat) - barostat = ThermodynamicState._find_barostat(self.supported_anisotropic_barostat_alanine) + barostat = ThermodynamicState._find_barostat( + self.supported_anisotropic_barostat_alanine + ) assert isinstance(barostat, openmm.MonteCarloAnisotropicBarostat) - barostat = ThermodynamicState._find_barostat(self.membrane_barostat_alanine_gamma_zero) + barostat = ThermodynamicState._find_barostat( + self.membrane_barostat_alanine_gamma_zero + ) assert isinstance(barostat, openmm.MonteCarloMembraneBarostat) # Raise exception if multiple or unsupported barostats found TE = ThermodynamicsError # shortcut - test_cases = [(self.multiple_barostat_alanine, TE.MULTIPLE_BAROSTATS), - (self.unsupported_anisotropic_barostat_alanine, TE.UNSUPPORTED_ANISOTROPIC_BAROSTAT)] + test_cases = [ + (self.multiple_barostat_alanine, TE.MULTIPLE_BAROSTATS), + ( + self.unsupported_anisotropic_barostat_alanine, + TE.UNSUPPORTED_ANISOTROPIC_BAROSTAT, + ), + ] for system, err_code in test_cases: with nose.tools.assert_raises(ThermodynamicsError) as cm: ThermodynamicState._find_barostat(system) @@ -211,14 +265,16 @@ def test_method_find_thermostat(self): """ThermodynamicState._find_thermostat() method.""" system = copy.deepcopy(self.alanine_no_thermostat) assert ThermodynamicState._find_thermostat(system) is None - thermostat = openmm.AndersenThermostat(self.std_temperature, - 1.0/unit.picosecond) + thermostat = openmm.AndersenThermostat( + self.std_temperature, 1.0 / unit.picosecond + ) system.addForce(thermostat) assert ThermodynamicState._find_thermostat(system) is not None # An error is raised with two thermostats. - thermostat2 = openmm.AndersenThermostat(self.std_temperature, - 1.0/unit.picosecond) + thermostat2 = openmm.AndersenThermostat( + self.std_temperature, 1.0 / unit.picosecond + ) system.addForce(thermostat2) with nose.tools.assert_raises(ThermodynamicsError) as cm: ThermodynamicState._find_thermostat(system) @@ -232,9 +288,9 @@ def test_method_is_barostat_consistent(self): barostat = openmm.MonteCarloBarostat(pressure, temperature) assert state._is_barostat_consistent(barostat) - barostat = openmm.MonteCarloBarostat(pressure + 0.2*unit.bar, temperature) + barostat = openmm.MonteCarloBarostat(pressure + 0.2 * unit.bar, temperature) assert not state._is_barostat_consistent(barostat) - barostat = openmm.MonteCarloBarostat(pressure, temperature + 10*unit.kelvin) + barostat = openmm.MonteCarloBarostat(pressure, temperature + 10 * unit.kelvin) assert not state._is_barostat_consistent(barostat) assert not state._is_barostat_consistent(self.supported_anisotropic_barostat) @@ -251,20 +307,21 @@ def test_method_set_system_temperature(self): assert thermostat.getDefaultTemperature() == self.std_temperature # Change temperature of thermostat and barostat. - new_temperature = self.std_temperature + 1.0*unit.kelvin + new_temperature = self.std_temperature + 1.0 * unit.kelvin ThermodynamicState._set_system_temperature(system, new_temperature) assert thermostat.getDefaultTemperature() == new_temperature def test_property_temperature(self): """ThermodynamicState.temperature property.""" - for system in [self.barostated_alanine, - self.supported_anisotropic_barostat_alanine, - self.membrane_barostat_alanine_gamma_zero]: - state = ThermodynamicState(system, - self.std_temperature) + for system in [ + self.barostated_alanine, + self.supported_anisotropic_barostat_alanine, + self.membrane_barostat_alanine_gamma_zero, + ]: + state = ThermodynamicState(system, self.std_temperature) assert state.temperature == self.std_temperature - temperature = self.std_temperature + 10.0*unit.kelvin + temperature = self.std_temperature + 10.0 * unit.kelvin state.temperature = temperature assert state.temperature == temperature assert get_barostat_temperature(state.barostat) == temperature @@ -276,14 +333,18 @@ def test_property_temperature(self): def test_method_set_system_pressure(self): """ThermodynamicState._set_system_pressure() method.""" - for system in [self.barostated_alanine, - self.supported_anisotropic_barostat_alanine, - self.membrane_barostat_alanine_gamma_zero]: + for system in [ + self.barostated_alanine, + self.supported_anisotropic_barostat_alanine, + self.membrane_barostat_alanine_gamma_zero, + ]: state = ThermodynamicState(self.alanine_explicit, self.std_temperature) system = state.system assert state._find_barostat(system) is None state._set_system_pressure(system, self.std_pressure) - assert state._find_barostat(system).getDefaultPressure() == self.std_pressure + assert ( + state._find_barostat(system).getDefaultPressure() == self.std_pressure + ) state._set_system_pressure(system, None) assert state._find_barostat(system) is None @@ -291,7 +352,7 @@ def test_property_pressure_barostat(self): """ThermodynamicState.pressure and barostat properties.""" # Vacuum and implicit system are read with no pressure nonperiodic_testcases = [self.toluene_vacuum, self.toluene_implicit] - new_barostat = openmm.MonteCarloBarostat(1.0*unit.bar, self.std_temperature) + new_barostat = openmm.MonteCarloBarostat(1.0 * unit.bar, self.std_temperature) for system in nonperiodic_testcases: state = ThermodynamicState(system, self.std_temperature) assert state.pressure is None @@ -299,7 +360,7 @@ def test_property_pressure_barostat(self): # We can't set the pressure on non-periodic systems with nose.tools.assert_raises(ThermodynamicsError) as cm: - state.pressure = 1.0*unit.bar + state.pressure = 1.0 * unit.bar assert cm.exception.code == ThermodynamicsError.BAROSTATED_NONPERIODIC with nose.tools.assert_raises(ThermodynamicsError) as cm: state.barostat = new_barostat @@ -309,9 +370,11 @@ def test_property_pressure_barostat(self): assert state.barostat is None # Correctly reads and set system pressures - periodic_testcases = [self.alanine_explicit, - self.supported_anisotropic_barostat_alanine, - self.membrane_barostat_alanine_gamma_zero] + periodic_testcases = [ + self.alanine_explicit, + self.supported_anisotropic_barostat_alanine, + self.membrane_barostat_alanine_gamma_zero, + ] for system in periodic_testcases: if system is self.alanine_explicit: state = ThermodynamicState(system, self.std_temperature) @@ -325,7 +388,7 @@ def test_property_pressure_barostat(self): assert get_barostat_temperature(state.barostat) == self.std_temperature # Changing the exposed barostat doesn't affect the state. - new_pressure = self.std_pressure + 1.0*unit.bar + new_pressure = self.std_pressure + 1.0 * unit.bar barostat = state.barostat barostat.setDefaultPressure(new_pressure) assert state.barostat.getDefaultPressure() == self.std_pressure @@ -355,7 +418,7 @@ def test_property_pressure_barostat(self): assert state.pressure is None # It is impossible to assign an unsupported barostat with incorrect temperature - new_temperature = self.std_temperature + 10.0*unit.kelvin + new_temperature = self.std_temperature + 10.0 * unit.kelvin ThermodynamicState._set_barostat_temperature(barostat, new_temperature) with nose.tools.assert_raises(ThermodynamicsError) as cm: state.barostat = barostat @@ -364,14 +427,21 @@ def test_property_pressure_barostat(self): # Assign incompatible barostat raise error with nose.tools.assert_raises(ThermodynamicsError) as cm: state.barostat = self.unsupported_anisotropic_barostat - assert cm.exception.code == ThermodynamicsError.UNSUPPORTED_ANISOTROPIC_BAROSTAT + assert ( + cm.exception.code + == ThermodynamicsError.UNSUPPORTED_ANISOTROPIC_BAROSTAT + ) # Assign barostat with different type raise error - if state.barostat is not None and type(state.barostat) != type(self.supported_anisotropic_barostat): + if state.barostat is not None and type(state.barostat) != type( + self.supported_anisotropic_barostat + ): with nose.tools.assert_raises(ThermodynamicsError) as cm: state.barostat = self.supported_anisotropic_barostat assert cm.exception.code == ThermodynamicsError.INCONSISTENT_BAROSTA - if state.barostat is not None and type(state.barostat) != type(self.membrane_barostat_gamma_zero): + if state.barostat is not None and type(state.barostat) != type( + self.membrane_barostat_gamma_zero + ): with nose.tools.assert_raises(ThermodynamicsError) as cm: state.barostat = self.membrane_barostat_gamma_zero assert cm.exception.code == ThermodynamicsError.INCONSISTENT_BAROSTAT @@ -396,16 +466,28 @@ def test_surface_tension(self): assert cm.exception.code == ThermodynamicsError.SURFACE_TENSION_NOT_SUPPORTED # test setting and getting surface tension - state = ThermodynamicState(self.membrane_barostat_alanine_gamma_zero, self.std_temperature) - assert utils.is_quantity_close(state.surface_tension, 0.0 * unit.bar * unit.nanometer, rtol=0.0, atol=1e-10) + state = ThermodynamicState( + self.membrane_barostat_alanine_gamma_zero, self.std_temperature + ) + assert utils.is_quantity_close( + state.surface_tension, 0.0 * unit.bar * unit.nanometer, rtol=0.0, atol=1e-10 + ) state.surface_tension = self.modified_surface_tension - assert utils.is_quantity_close(state.surface_tension, self.modified_surface_tension) + assert utils.is_quantity_close( + state.surface_tension, self.modified_surface_tension + ) state.surface_tension = 0.0 * unit.bar * unit.nanometer - assert utils.is_quantity_close(state.surface_tension, 0.0 * unit.bar * unit.nanometer, rtol=0.0, atol=1e-10) + assert utils.is_quantity_close( + state.surface_tension, 0.0 * unit.bar * unit.nanometer, rtol=0.0, atol=1e-10 + ) # test initial surface tension of nonzero-gamma barostat - state = ThermodynamicState(self.membrane_barostat_alanine_gamma_nonzero, self.std_temperature) - assert utils.is_quantity_close(state.surface_tension, self.modified_surface_tension) + state = ThermodynamicState( + self.membrane_barostat_alanine_gamma_nonzero, self.std_temperature + ) + assert utils.is_quantity_close( + state.surface_tension, self.modified_surface_tension + ) def test_property_volume(self): """Check that volume is computed correctly.""" @@ -429,20 +511,27 @@ def test_property_system(self): state = ThermodynamicState(self.barostated_alanine, self.std_temperature) assert state.pressure == self.std_pressure # pre-condition - inconsistent_barostat_temperature = copy.deepcopy(self.inconsistent_temperature_alanine) + inconsistent_barostat_temperature = copy.deepcopy( + self.inconsistent_temperature_alanine + ) thermostat = state._find_thermostat(inconsistent_barostat_temperature) thermostat.setDefaultTemperature(self.std_temperature) TE = ThermodynamicsError # shortcut - test_cases = [(self.toluene_vacuum, TE.NO_BAROSTAT), - (self.barostated_toluene, TE.BAROSTATED_NONPERIODIC), - (self.multiple_barostat_alanine, TE.MULTIPLE_BAROSTATS), - (self.unsupported_anisotropic_barostat_alanine, TE.UNSUPPORTED_ANISOTROPIC_BAROSTAT), - (self.supported_anisotropic_barostat_alanine, TE.INCONSISTENT_BAROSTAT), - (self.membrane_barostat_alanine_gamma_zero, TE.INCONSISTENT_BAROSTAT), - (self.inconsistent_pressure_alanine, TE.INCONSISTENT_BAROSTAT), - (self.inconsistent_temperature_alanine, TE.INCONSISTENT_THERMOSTAT), - (inconsistent_barostat_temperature, TE.INCONSISTENT_BAROSTAT)] + test_cases = [ + (self.toluene_vacuum, TE.NO_BAROSTAT), + (self.barostated_toluene, TE.BAROSTATED_NONPERIODIC), + (self.multiple_barostat_alanine, TE.MULTIPLE_BAROSTATS), + ( + self.unsupported_anisotropic_barostat_alanine, + TE.UNSUPPORTED_ANISOTROPIC_BAROSTAT, + ), + (self.supported_anisotropic_barostat_alanine, TE.INCONSISTENT_BAROSTAT), + (self.membrane_barostat_alanine_gamma_zero, TE.INCONSISTENT_BAROSTAT), + (self.inconsistent_pressure_alanine, TE.INCONSISTENT_BAROSTAT), + (self.inconsistent_temperature_alanine, TE.INCONSISTENT_THERMOSTAT), + (inconsistent_barostat_temperature, TE.INCONSISTENT_BAROSTAT), + ] for i, (system, error_code) in enumerate(test_cases): with nose.tools.assert_raises(ThermodynamicsError) as cm: state.system = system @@ -451,7 +540,7 @@ def test_property_system(self): # It is possible to set an inconsistent system # if thermodynamic state is changed first. inconsistent_system = self.inconsistent_pressure_alanine - state.pressure = self.std_pressure + 0.2*unit.bar + state.pressure = self.std_pressure + 0.2 * unit.bar state.system = self.inconsistent_pressure_alanine state_system_str = openmm.XmlSerializer.serialize(state.system) inconsistent_system_str = openmm.XmlSerializer.serialize(inconsistent_system) @@ -470,7 +559,9 @@ def test_method_set_system(self): state.set_system(system, fix_state=True) system = state.system thermostat = state._find_thermostat(system) - assert utils.is_quantity_close(thermostat.getDefaultTemperature(), self.std_temperature) + assert utils.is_quantity_close( + thermostat.getDefaultTemperature(), self.std_temperature + ) assert state.barostat is None # In NPT, we can't set the system without adding a barostat. @@ -505,10 +596,14 @@ def test_method_get_system(self): def test_constructor_unsupported_barostat(self): """Exception is raised on construction with unsupported barostats.""" TE = ThermodynamicsError # shortcut - test_cases = [(self.barostated_toluene, TE.BAROSTATED_NONPERIODIC), - (self.multiple_barostat_alanine, TE.MULTIPLE_BAROSTATS), - (self.unsupported_anisotropic_barostat_alanine, TE.UNSUPPORTED_ANISOTROPIC_BAROSTAT) - ] + test_cases = [ + (self.barostated_toluene, TE.BAROSTATED_NONPERIODIC), + (self.multiple_barostat_alanine, TE.MULTIPLE_BAROSTATS), + ( + self.unsupported_anisotropic_barostat_alanine, + TE.UNSUPPORTED_ANISOTROPIC_BAROSTAT, + ), + ] for i, (system, err_code) in enumerate(test_cases): with nose.tools.assert_raises(TE) as cm: ThermodynamicState(system=system, temperature=self.std_temperature) @@ -525,19 +620,23 @@ def test_constructor_barostat(self): assert state.barostat is None # If we specify pressure, barostat is added - state = ThermodynamicState(system=system, temperature=self.std_temperature, - pressure=self.std_pressure) + state = ThermodynamicState( + system=system, temperature=self.std_temperature, pressure=self.std_pressure + ) assert state.barostat is not None # If we feed a barostat with an inconsistent temperature, it's fixed. - state = ThermodynamicState(self.inconsistent_temperature_alanine, - temperature=self.std_temperature) + state = ThermodynamicState( + self.inconsistent_temperature_alanine, temperature=self.std_temperature + ) assert state._is_barostat_consistent(state.barostat) # If we feed a barostat with an inconsistent pressure, it's fixed. - state = ThermodynamicState(self.inconsistent_pressure_alanine, - temperature=self.std_temperature, - pressure=self.std_pressure) + state = ThermodynamicState( + self.inconsistent_pressure_alanine, + temperature=self.std_temperature, + pressure=self.std_pressure, + ) assert state.pressure == self.std_pressure # The original system is unaltered. @@ -555,14 +654,16 @@ def test_constructor_thermostat(self): # With thermostat, temperature is inferred correctly. system = copy.deepcopy(self.alanine_explicit) - new_temperature = self.std_temperature + 1.0*unit.kelvin + new_temperature = self.std_temperature + 1.0 * unit.kelvin thermostat = ThermodynamicState._find_thermostat(system) thermostat.setDefaultTemperature(new_temperature) state = ThermodynamicState(system=system) assert state.temperature == new_temperature # If barostat is inconsistent, an error is raised. - system.addForce(openmm.MonteCarloBarostat(self.std_pressure, self.std_temperature)) + system.addForce( + openmm.MonteCarloBarostat(self.std_pressure, self.std_temperature) + ) with nose.tools.assert_raises(ThermodynamicsError) as cm: ThermodynamicState(system=system) assert cm.exception.code == ThermodynamicsError.INCONSISTENT_BAROSTAT @@ -576,7 +677,7 @@ def test_method_is_integrator_thermostated(self): """ThermodynamicState._is_integrator_thermostated method.""" state = ThermodynamicState(self.toluene_vacuum, self.std_temperature) test_cases = self.get_integrators(self.std_temperature) - inconsistent_temperature = self.std_temperature + 1.0*unit.kelvin + inconsistent_temperature = self.std_temperature + 1.0 * unit.kelvin for thermostated, integrator in test_cases: # If integrator expose a getTemperature method, return True. @@ -584,7 +685,9 @@ def test_method_is_integrator_thermostated(self): # If temperature is different, it raises an exception. if thermostated: - for _integrator in ThermodynamicState._loop_over_integrators(integrator): + for _integrator in ThermodynamicState._loop_over_integrators( + integrator + ): try: _integrator.setTemperature(inconsistent_temperature) except AttributeError: # handle CompoundIntegrator case @@ -596,13 +699,15 @@ def test_method_is_integrator_thermostated(self): def test_method_set_integrator_temperature(self): """ThermodynamicState._set_integrator_temperature() method.""" test_cases = self.get_integrators(self.std_temperature) - new_temperature = self.std_temperature + 1.0*unit.kelvin + new_temperature = self.std_temperature + 1.0 * unit.kelvin state = ThermodynamicState(self.toluene_vacuum, new_temperature) for thermostated, integrator in test_cases: if thermostated: assert state._set_integrator_temperature(integrator) - for _integrator in ThermodynamicState._loop_over_integrators(integrator): + for _integrator in ThermodynamicState._loop_over_integrators( + integrator + ): try: assert _integrator.getTemperature() == new_temperature except AttributeError: # handle CompoundIntegrator case @@ -622,8 +727,9 @@ def check_barostat_thermostat(_system, cmp_op): assert cmp_op(thermostat.getDefaultTemperature(), self.std_temperature) # Create NPT system in non-standard state. - npt_state = ThermodynamicState(self.inconsistent_pressure_alanine, - self.std_temperature + 1.0*unit.kelvin) + npt_state = ThermodynamicState( + self.inconsistent_pressure_alanine, self.std_temperature + 1.0 * unit.kelvin + ) npt_system = npt_state.system check_barostat_thermostat(npt_system, operator.ne) @@ -636,17 +742,24 @@ def test_method_create_context(self): state = ThermodynamicState(self.toluene_vacuum, self.std_temperature) toluene_str = openmm.XmlSerializer.serialize(self.toluene_vacuum) test_integrators = self.get_integrators(self.std_temperature) - inconsistent_temperature = self.std_temperature + 1.0*unit.kelvin + inconsistent_temperature = self.std_temperature + 1.0 * unit.kelvin # Divide test platforms among the integrators since we # can't bind the same integrator to multiple contexts. test_platforms = utils.get_available_platforms() - test_platforms = [test_platforms[i % len(test_platforms)] - for i in range(len(test_integrators))] + test_platforms = [ + test_platforms[i % len(test_platforms)] + for i in range(len(test_integrators)) + ] - for (is_thermostated, integrator), platform in zip(test_integrators, test_platforms): + for (is_thermostated, integrator), platform in zip( + test_integrators, test_platforms + ): context = state.create_context(integrator, platform) - assert platform is None or platform.getName() == context.getPlatform().getName() + assert ( + platform is None + or platform.getName() == context.getPlatform().getName() + ) assert isinstance(integrator, context.getIntegrator().__class__) if is_thermostated: @@ -654,7 +767,9 @@ def test_method_create_context(self): # create_context complains if integrator is inconsistent inconsistent_integrator = copy.deepcopy(integrator) - for _integrator in ThermodynamicState._loop_over_integrators(inconsistent_integrator): + for _integrator in ThermodynamicState._loop_over_integrators( + inconsistent_integrator + ): try: _integrator.setTemperature(inconsistent_temperature) except AttributeError: # handle CompoundIntegrator case @@ -677,20 +792,22 @@ def test_method_create_context(self): state.create_context( openmm.VerletIntegrator(0.001), platform=None, - platform_properties=platform_properties + platform_properties=platform_properties, ) - assert str(cm.exception) == "To set platform_properties, you need to also specify the platform." + assert ( + str(cm.exception) + == "To set platform_properties, you need to also specify the platform." + ) platform = openmm.Platform.getPlatformByName("CPU") context = state.create_context( - openmm.VerletIntegrator(0.001), - platform=platform, - platform_properties=platform_properties + openmm.VerletIntegrator(0.001), + platform=platform, + platform_properties=platform_properties, ) assert context.getPlatform().getPropertyValue(context, "CpuThreads") == "2" del context - def test_method_is_compatible(self): """ThermodynamicState context and state compatibility methods.""" @@ -698,19 +815,27 @@ def check_compatibility(state1, state2, is_compatible): """Check compatibility of contexts thermostated by force or integrator.""" assert state1.is_state_compatible(state2) is is_compatible assert state2.is_state_compatible(state1) is is_compatible - time_step = 1.0*unit.femtosecond - friction = 5.0/unit.picosecond + time_step = 1.0 * unit.femtosecond + friction = 5.0 / unit.picosecond integrator1 = openmm.VerletIntegrator(time_step) - integrator2 = openmm.LangevinIntegrator(state2.temperature, friction, time_step) + integrator2 = openmm.LangevinIntegrator( + state2.temperature, friction, time_step + ) context1 = create_default_context(state1, integrator1) context2 = create_default_context(state2, integrator2) assert state1.is_context_compatible(context2) is is_compatible assert state2.is_context_compatible(context1) is is_compatible toluene_vacuum = ThermodynamicState(self.toluene_vacuum, self.std_temperature) - toluene_implicit = ThermodynamicState(self.toluene_implicit, self.std_temperature) - alanine_explicit = ThermodynamicState(self.alanine_explicit, self.std_temperature) - barostated_alanine = ThermodynamicState(self.barostated_alanine, self.std_temperature) + toluene_implicit = ThermodynamicState( + self.toluene_implicit, self.std_temperature + ) + alanine_explicit = ThermodynamicState( + self.alanine_explicit, self.std_temperature + ) + barostated_alanine = ThermodynamicState( + self.barostated_alanine, self.std_temperature + ) # Different systems/ensembles are incompatible. check_compatibility(toluene_vacuum, toluene_vacuum, True) @@ -720,43 +845,46 @@ def check_compatibility(state1, state2, is_compatible): # System in same ensemble with different parameters are compatible. alanine_explicit2 = copy.deepcopy(alanine_explicit) - alanine_explicit2.temperature = alanine_explicit.temperature + 1.0*unit.kelvin + alanine_explicit2.temperature = alanine_explicit.temperature + 1.0 * unit.kelvin check_compatibility(alanine_explicit, alanine_explicit2, True) barostated_alanine2 = copy.deepcopy(barostated_alanine) - barostated_alanine2.pressure = barostated_alanine.pressure + 0.2*unit.bars + barostated_alanine2.pressure = barostated_alanine.pressure + 0.2 * unit.bars check_compatibility(barostated_alanine, barostated_alanine2, True) check_compatibility( ThermodynamicState(self.membrane_barostat_alanine_gamma_zero), ThermodynamicState(self.membrane_barostat_alanine_gamma_nonzero), - True + True, ) check_compatibility( ThermodynamicState(self.barostated_alanine), ThermodynamicState(self.membrane_barostat_alanine_gamma_nonzero), - False + False, ) def test_method_apply_to_context(self): """ThermodynamicState.apply_to_context() method.""" - friction = 5.0/unit.picosecond - time_step = 2.0*unit.femtosecond + friction = 5.0 / unit.picosecond + time_step = 2.0 * unit.femtosecond state0 = ThermodynamicState(self.barostated_alanine, self.std_temperature) - langevin_integrator = openmm.LangevinIntegrator(self.std_temperature, friction, time_step) + langevin_integrator = openmm.LangevinIntegrator( + self.std_temperature, friction, time_step + ) context = create_default_context(state0, langevin_integrator) - verlet_integrator = openmm.VerletIntegrator(1.0*unit.femtosecond) + verlet_integrator = openmm.VerletIntegrator(1.0 * unit.femtosecond) thermostated_context = create_default_context(state0, verlet_integrator) # Change context pressure. barostat = state0._find_barostat(context.getSystem()) assert barostat.getDefaultPressure() == self.std_pressure assert context.getParameter(barostat.Pressure()) == self.std_pressure / unit.bar - new_pressure = self.std_pressure + 1.0*unit.bars - state1 = ThermodynamicState(self.barostated_alanine, self.std_temperature, - new_pressure) + new_pressure = self.std_pressure + 1.0 * unit.bars + state1 = ThermodynamicState( + self.barostated_alanine, self.std_temperature, new_pressure + ) state1.apply_to_context(context) assert barostat.getDefaultPressure() == new_pressure assert context.getParameter(barostat.Pressure()) == new_pressure / unit.bar @@ -770,26 +898,38 @@ def test_method_apply_to_context(self): assert get_barostat_temperature(barostat) == self.std_temperature # TODO remove try except when OpenMM 7.1 works on travis try: - assert c.getParameter(barostat.Temperature()) == self.std_temperature / unit.kelvin + assert ( + c.getParameter(barostat.Temperature()) + == self.std_temperature / unit.kelvin + ) except AttributeError: pass if thermostat is not None: - assert c.getParameter(thermostat.Temperature()) == self.std_temperature / unit.kelvin + assert ( + c.getParameter(thermostat.Temperature()) + == self.std_temperature / unit.kelvin + ) else: assert context.getIntegrator().getTemperature() == self.std_temperature - new_temperature = self.std_temperature + 10.0*unit.kelvin + new_temperature = self.std_temperature + 10.0 * unit.kelvin state2 = ThermodynamicState(self.barostated_alanine, new_temperature) state2.apply_to_context(c) assert get_barostat_temperature(barostat) == new_temperature # TODO remove try except when OpenMM 7.1 works on travis try: - assert c.getParameter(barostat.Temperature()) == new_temperature / unit.kelvin + assert ( + c.getParameter(barostat.Temperature()) + == new_temperature / unit.kelvin + ) except AttributeError: pass if thermostat is not None: - assert c.getParameter(thermostat.Temperature()) == new_temperature / unit.kelvin + assert ( + c.getParameter(thermostat.Temperature()) + == new_temperature / unit.kelvin + ) else: assert context.getIntegrator().getTemperature() == new_temperature @@ -799,22 +939,31 @@ def test_method_apply_to_context(self): state2.apply_to_context(context) assert cm.exception.code == ThermodynamicsError.INCOMPATIBLE_ENSEMBLE - state3 = ThermodynamicState(self.membrane_barostat_alanine_gamma_zero, self.std_temperature) + state3 = ThermodynamicState( + self.membrane_barostat_alanine_gamma_zero, self.std_temperature + ) with nose.tools.assert_raises(ThermodynamicsError) as cm: state3.apply_to_context(context) assert cm.exception.code == ThermodynamicsError.INCOMPATIBLE_ENSEMBLE # apply surface tension - gamma_context = openmm.Context(self.membrane_barostat_alanine_gamma_zero, openmm.VerletIntegrator(0.001)) + gamma_context = openmm.Context( + self.membrane_barostat_alanine_gamma_zero, openmm.VerletIntegrator(0.001) + ) state3.apply_to_context(gamma_context) - assert gamma_context.getParameter(self.membrane_barostat_gamma_nonzero.SurfaceTension()) == 0.0 + assert ( + gamma_context.getParameter( + self.membrane_barostat_gamma_nonzero.SurfaceTension() + ) + == 0.0 + ) state3.surface_tension = self.modified_surface_tension state3.apply_to_context(gamma_context) - assert (gamma_context.getParameter(self.membrane_barostat_gamma_nonzero.SurfaceTension()) - == self.modified_surface_tension.value_in_unit(unit.nanometer*unit.bar)) + assert gamma_context.getParameter( + self.membrane_barostat_gamma_nonzero.SurfaceTension() + ) == self.modified_surface_tension.value_in_unit(unit.nanometer * unit.bar) state3.surface_tension = 0.0 * unit.nanometer * unit.bar - # Clean up contexts. del context, langevin_integrator del thermostated_context, verlet_integrator @@ -827,15 +976,12 @@ def test_method_apply_to_context(self): assert cm.exception.code == ThermodynamicsError.INCOMPATIBLE_ENSEMBLE del nvt_context, verlet_integrator - - - def test_method_reduced_potential(self): """ThermodynamicState.reduced_potential() method.""" kj_mol = unit.kilojoule_per_mole beta = 1.0 / (unit.MOLAR_GAS_CONSTANT_R * self.std_temperature) state = ThermodynamicState(self.alanine_explicit, self.std_temperature) - integrator = openmm.VerletIntegrator(1.0*unit.femtosecond) + integrator = openmm.VerletIntegrator(1.0 * unit.femtosecond) context = create_default_context(state, integrator) context.setPositions(self.alanine_positions) sampler_state = SamplerState.from_context(context) @@ -849,8 +995,9 @@ def test_method_reduced_potential(self): # Compute constant pressure reduced potential. state.pressure = self.std_pressure reduced_potential = state.reduced_potential(sampler_state) - pressure_volume_work = (self.std_pressure * sampler_state.volume * - unit.AVOGADRO_CONSTANT_NA) + pressure_volume_work = ( + self.std_pressure * sampler_state.volume * unit.AVOGADRO_CONSTANT_NA + ) potential_energy = (reduced_potential / beta - pressure_volume_work) / kj_mol assert np.isclose(sampler_state.potential_energy / kj_mol, potential_energy) assert np.isclose(reduced_potential, state.reduced_potential(context)) @@ -862,18 +1009,26 @@ def test_method_reduced_potential(self): assert cm.exception.code == ThermodynamicsError.INCOMPATIBLE_SAMPLER_STATE # Compute constant surface tension reduced potential. - state = ThermodynamicState(self.membrane_barostat_alanine_gamma_nonzero, self.std_temperature) + state = ThermodynamicState( + self.membrane_barostat_alanine_gamma_nonzero, self.std_temperature + ) integrator = openmm.VerletIntegrator(1.0 * unit.femtosecond) context = create_default_context(state, integrator) context.setPositions(self.alanine_positions) sampler_state = SamplerState.from_context(context) state.pressure = self.std_pressure reduced_potential = state.reduced_potential(sampler_state) - pressure_volume_work = (self.std_pressure * sampler_state.volume * - unit.AVOGADRO_CONSTANT_NA) - surface_work = (self.modified_surface_tension * sampler_state.area_xy * - unit.AVOGADRO_CONSTANT_NA) - potential_energy = (reduced_potential / beta - pressure_volume_work + surface_work) / kj_mol + pressure_volume_work = ( + self.std_pressure * sampler_state.volume * unit.AVOGADRO_CONSTANT_NA + ) + surface_work = ( + self.modified_surface_tension + * sampler_state.area_xy + * unit.AVOGADRO_CONSTANT_NA + ) + potential_energy = ( + reduced_potential / beta - pressure_volume_work + surface_work + ) / kj_mol assert np.isclose(sampler_state.potential_energy / kj_mol, potential_energy) assert np.isclose(reduced_potential, state.reduced_potential(context)) @@ -886,15 +1041,23 @@ def test_method_reduced_potential_at_states(self): """ # Build a mixed collection of compatible and incompatible thermodynamic states. thermodynamic_states = [ - ThermodynamicState(self.alanine_explicit, temperature=300*unit.kelvin, - pressure=1.0*unit.atmosphere), - ThermodynamicState(self.toluene_implicit, temperature=200*unit.kelvin), - ThermodynamicState(self.alanine_explicit, temperature=250*unit.kelvin, - pressure=1.2*unit.atmosphere) + ThermodynamicState( + self.alanine_explicit, + temperature=300 * unit.kelvin, + pressure=1.0 * unit.atmosphere, + ), + ThermodynamicState(self.toluene_implicit, temperature=200 * unit.kelvin), + ThermodynamicState( + self.alanine_explicit, + temperature=250 * unit.kelvin, + pressure=1.2 * unit.atmosphere, + ), ] # Group thermodynamic states by compatibility. - compatible_groups, original_indices = group_by_compatibility(thermodynamic_states) + compatible_groups, original_indices = group_by_compatibility( + thermodynamic_states + ) assert len(compatible_groups) == 2 assert original_indices == [[0, 2], [1]] @@ -903,7 +1066,7 @@ def test_method_reduced_potential_at_states(self): obtained_energies = [] for compatible_group in compatible_groups: # Create context. - integrator = openmm.VerletIntegrator(2.0*unit.femtoseconds) + integrator = openmm.VerletIntegrator(2.0 * unit.femtoseconds) context = create_default_context(compatible_group[0], integrator) if len(compatible_group) == 2: context.setPositions(self.alanine_positions) @@ -916,7 +1079,11 @@ def test_method_reduced_potential_at_states(self): expected_energies.append(state.reduced_potential(context)) # Compute with multi-state method. - obtained_energies.extend(ThermodynamicState.reduced_potential_at_states(context, compatible_group)) + obtained_energies.extend( + ThermodynamicState.reduced_potential_at_states( + context, compatible_group + ) + ) expected_energies = np.array(expected_energies) assert np.allclose(np.array(expected_energies), np.array(obtained_energies)) @@ -924,8 +1091,10 @@ def test_method_reduced_potential_at_states(self): sampler_state = SamplerState(positions=self.alanine_positions) thermodynamic_states = [thermodynamic_states[i] for i in [0, 2]] from openmmtools.cache import ContextCache + obtained_energies = reduced_potential_at_states( - sampler_state, thermodynamic_states, ContextCache()) + sampler_state, thermodynamic_states, ContextCache() + ) assert np.allclose(expected_energies[:2], obtained_energies) @@ -933,47 +1102,61 @@ def test_method_reduced_potential_at_states(self): # TEST SAMPLER STATE # ============================================================================= -class TestSamplerState(object): + +class TestSamplerState: """Test suite for states.SamplerState class.""" @classmethod def setup_class(cls): """Create various variables shared by tests in suite.""" - temperature = 300*unit.kelvin + temperature = 300 * unit.kelvin alanine_vacuum = testsystems.AlanineDipeptideVacuum() cls.alanine_vacuum_positions = alanine_vacuum.positions - cls.alanine_vacuum_state = ThermodynamicState(alanine_vacuum.system, - temperature) + cls.alanine_vacuum_state = ThermodynamicState( + alanine_vacuum.system, temperature + ) alanine_explicit = testsystems.AlanineDipeptideExplicit() cls.alanine_explicit_positions = alanine_explicit.positions - cls.alanine_explicit_state = ThermodynamicState(alanine_explicit.system, - temperature) + cls.alanine_explicit_state = ThermodynamicState( + alanine_explicit.system, temperature + ) @staticmethod def is_sampler_state_equal_context(sampler_state, context): """Check sampler and openmm states in context are equal.""" equal = True ss = sampler_state # Shortcut. - os = context.getState(getPositions=True, getEnergy=True, - getVelocities=True) - equal = equal and np.allclose(ss.positions.value_in_unit(ss.positions.unit), - os.getPositions().value_in_unit(ss.positions.unit)) - equal = equal and np.allclose(ss.velocities.value_in_unit(ss.velocities.unit), - os.getVelocities().value_in_unit(ss.velocities.unit)) - equal = equal and np.allclose(ss.box_vectors.value_in_unit(ss.box_vectors.unit), - os.getPeriodicBoxVectors().value_in_unit(ss.box_vectors.unit)) - equal = equal and np.isclose(ss.potential_energy.value_in_unit(ss.potential_energy.unit), - os.getPotentialEnergy().value_in_unit(ss.potential_energy.unit)) - equal = equal and np.isclose(ss.kinetic_energy.value_in_unit(ss.kinetic_energy.unit), - os.getKineticEnergy().value_in_unit(ss.kinetic_energy.unit)) - equal = equal and np.isclose(ss.volume.value_in_unit(ss.volume.unit), - os.getPeriodicBoxVolume().value_in_unit(ss.volume.unit)) + os = context.getState(getPositions=True, getEnergy=True, getVelocities=True) + equal = equal and np.allclose( + ss.positions.value_in_unit(ss.positions.unit), + os.getPositions().value_in_unit(ss.positions.unit), + ) + equal = equal and np.allclose( + ss.velocities.value_in_unit(ss.velocities.unit), + os.getVelocities().value_in_unit(ss.velocities.unit), + ) + equal = equal and np.allclose( + ss.box_vectors.value_in_unit(ss.box_vectors.unit), + os.getPeriodicBoxVectors().value_in_unit(ss.box_vectors.unit), + ) + equal = equal and np.isclose( + ss.potential_energy.value_in_unit(ss.potential_energy.unit), + os.getPotentialEnergy().value_in_unit(ss.potential_energy.unit), + ) + equal = equal and np.isclose( + ss.kinetic_energy.value_in_unit(ss.kinetic_energy.unit), + os.getKineticEnergy().value_in_unit(ss.kinetic_energy.unit), + ) + equal = equal and np.isclose( + ss.volume.value_in_unit(ss.volume.unit), + os.getPeriodicBoxVolume().value_in_unit(ss.volume.unit), + ) return equal @staticmethod def create_context(thermodynamic_state): - integrator = openmm.VerletIntegrator(1.0*unit.femtoseconds) + integrator = openmm.VerletIntegrator(1.0 * unit.femtoseconds) return thermodynamic_state.create_context(integrator, DEFAULT_PLATFORM) def test_inconsistent_n_particles(self): @@ -1008,7 +1191,9 @@ def test_constructor_from_context(self): alanine_vacuum_context.setPositions(self.alanine_vacuum_positions) sampler_state = SamplerState.from_context(alanine_vacuum_context) - assert self.is_sampler_state_equal_context(sampler_state, alanine_vacuum_context) + assert self.is_sampler_state_equal_context( + sampler_state, alanine_vacuum_context + ) def test_unitless_cache(self): """Test that the unitless cache for positions and velocities is invalidated.""" @@ -1019,7 +1204,7 @@ def test_unitless_cache(self): test_cases = [ SamplerState(positions), - SamplerState.from_context(alanine_vacuum_context) + SamplerState.from_context(alanine_vacuum_context), ] pos_unit = unit.micrometer @@ -1030,29 +1215,53 @@ def test_unitless_cache(self): old_unitless_positions = copy.deepcopy(sampler_state._unitless_positions) sampler_state.positions[5] = [1.0, 1.0, 1.0] * pos_unit assert sampler_state.positions.has_changed - assert np.all(old_unitless_positions[5] != sampler_state._unitless_positions[5]) + assert np.all( + old_unitless_positions[5] != sampler_state._unitless_positions[5] + ) sampler_state.positions = copy.deepcopy(positions) assert sampler_state._unitless_positions_cache is None if isinstance(sampler_state._positions._value, np.ndarray): - old_unitless_positions = copy.deepcopy(sampler_state._unitless_positions) - sampler_state.positions[5:8] = [[2.0, 2.0, 2.0], [2.0, 2.0, 2.0], [2.0, 2.0, 2.0]] * pos_unit + old_unitless_positions = copy.deepcopy( + sampler_state._unitless_positions + ) + sampler_state.positions[5:8] = [ + [2.0, 2.0, 2.0], + [2.0, 2.0, 2.0], + [2.0, 2.0, 2.0], + ] * pos_unit assert sampler_state.positions.has_changed - assert np.all(old_unitless_positions[5:8] != sampler_state._unitless_positions[5:8]) + assert np.all( + old_unitless_positions[5:8] + != sampler_state._unitless_positions[5:8] + ) if sampler_state.velocities is not None: - old_unitless_velocities = copy.deepcopy(sampler_state._unitless_velocities) + old_unitless_velocities = copy.deepcopy( + sampler_state._unitless_velocities + ) sampler_state.velocities[5] = [1.0, 1.0, 1.0] * vel_unit assert sampler_state.velocities.has_changed - assert np.all(old_unitless_velocities[5] != sampler_state._unitless_velocities[5]) + assert np.all( + old_unitless_velocities[5] != sampler_state._unitless_velocities[5] + ) sampler_state.velocities = copy.deepcopy(sampler_state.velocities) assert sampler_state._unitless_velocities_cache is None if isinstance(sampler_state._velocities._value, np.ndarray): - old_unitless_velocities = copy.deepcopy(sampler_state._unitless_velocities) - sampler_state.velocities[5:8] = [[2.0, 2.0, 2.0], [2.0, 2.0, 2.0], [2.0, 2.0, 2.0]] * vel_unit + old_unitless_velocities = copy.deepcopy( + sampler_state._unitless_velocities + ) + sampler_state.velocities[5:8] = [ + [2.0, 2.0, 2.0], + [2.0, 2.0, 2.0], + [2.0, 2.0, 2.0], + ] * vel_unit assert sampler_state.velocities.has_changed - assert np.all(old_unitless_velocities[5:8] != sampler_state._unitless_velocities[5:8]) + assert np.all( + old_unitless_velocities[5:8] + != sampler_state._unitless_velocities[5:8] + ) else: assert sampler_state._unitless_velocities is None @@ -1110,31 +1319,42 @@ def test_operator_getitem(self): sliced_sampler_state = sampler_state[0] assert sliced_sampler_state.n_particles == 1 assert len(sliced_sampler_state.velocities) == 1 - assert np.allclose(sliced_sampler_state.positions[0], - self.alanine_explicit_positions[0]) + assert np.allclose( + sliced_sampler_state.positions[0], self.alanine_explicit_positions[0] + ) # Modifying the sliced sampler state doesn't modify original. sliced_sampler_state.positions[0][0] += 1 * unit.angstrom - assert sliced_sampler_state.positions[0][0] == sampler_state.positions[0][0] + 1 * unit.angstrom + assert ( + sliced_sampler_state.positions[0][0] + == sampler_state.positions[0][0] + 1 * unit.angstrom + ) # SamplerState.__getitem__ should work for both slices and lists. - for sliced_sampler_state in [sampler_state[2:10], - sampler_state[list(range(2, 10))]]: + for sliced_sampler_state in [ + sampler_state[2:10], + sampler_state[list(range(2, 10))], + ]: assert sliced_sampler_state.n_particles == 8 assert len(sliced_sampler_state.velocities) == 8 - assert np.allclose(sliced_sampler_state.positions, - self.alanine_explicit_positions[2:10]) + assert np.allclose( + sliced_sampler_state.positions, self.alanine_explicit_positions[2:10] + ) sliced_sampler_state = sampler_state[2:10:2] assert sliced_sampler_state.n_particles == 4 assert len(sliced_sampler_state.velocities) == 4 - assert np.allclose(sliced_sampler_state.positions, - self.alanine_explicit_positions[2:10:2]) + assert np.allclose( + sliced_sampler_state.positions, self.alanine_explicit_positions[2:10:2] + ) # Modifying the sliced sampler state doesn't modify original. We check # this here too since the algorithm for slice objects is different. sliced_sampler_state.positions[0][0] += 1 * unit.angstrom - assert sliced_sampler_state.positions[0][0] == sampler_state.positions[2][0] + 1 * unit.angstrom + assert ( + sliced_sampler_state.positions[0][0] + == sampler_state.positions[2][0] + 1 * unit.angstrom + ) # The other attributes are copied correctly. assert sliced_sampler_state.volume == sampler_state.volume @@ -1148,7 +1368,7 @@ def test_dict_representation(self): """Setting the state of the object should work when ignoring velocities.""" alanine_vacuum_context = self.create_context(self.alanine_vacuum_state) alanine_vacuum_context.setPositions(self.alanine_vacuum_positions) - alanine_vacuum_context.setVelocitiesToTemperature(300*unit.kelvin) + alanine_vacuum_context.setVelocitiesToTemperature(300 * unit.kelvin) # Test precondition. vacuum_sampler_state = SamplerState.from_context(alanine_vacuum_context) @@ -1157,10 +1377,14 @@ def test_dict_representation(self): # Get a dictionary representation without velocities. serialization = vacuum_sampler_state.__getstate__(ignore_velocities=True) - assert serialization['velocities'] is None + assert serialization["velocities"] is None # Do not overwrite velocities when setting a state. - serialization['velocities'] = np.random.rand(*vacuum_sampler_state.positions.shape) * unit.nanometer/unit.picosecond + serialization["velocities"] = ( + np.random.rand(*vacuum_sampler_state.positions.shape) + * unit.nanometer + / unit.picosecond + ) vacuum_sampler_state.__setstate__(serialization, ignore_velocities=True) assert np.all(vacuum_sampler_state.velocities == old_velocities) @@ -1176,27 +1400,33 @@ def test_collective_variable(self): # 3 unique CV names in the Context: BondCV, AngleCVSingle, AngleCV cv_single_1 = openmm.CustomCVForce("4*BondCV") # We are going to use this name later too - cv_single_1.addCollectiveVariable('BondCV', copy.deepcopy(cv_distance)) - cv_single_2 = openmm.CustomCVForce("sin(AngleCVSingle)") # This is suppose to be unique - cv_single_2.addCollectiveVariable('AngleCVSingle', copy.deepcopy(cv_angle)) + cv_single_1.addCollectiveVariable("BondCV", copy.deepcopy(cv_distance)) + cv_single_2 = openmm.CustomCVForce( + "sin(AngleCVSingle)" + ) # This is suppose to be unique + cv_single_2.addCollectiveVariable("AngleCVSingle", copy.deepcopy(cv_angle)) cv_combined = openmm.CustomCVForce("4*BondCV + sin(AngleCV)") cv_combined.addCollectiveVariable("BondCV", cv_distance) cv_combined.addCollectiveVariable("AngleCV", cv_angle) for force in [cv_single_1, cv_single_2, cv_combined]: system_cv.addForce(force) - thermo_state = ThermodynamicState(system_cv, self.alanine_explicit_state.temperature) + thermo_state = ThermodynamicState( + system_cv, self.alanine_explicit_state.temperature + ) context = self.create_context(thermo_state) context.setPositions(self.alanine_explicit_positions) sampler_state = SamplerState.from_context(context) collective_variables = sampler_state.collective_variables - name_count = (('BondCV', 2), ('AngleCV', 1), ('AngleCVSingle', 1)) + name_count = (("BondCV", 2), ("AngleCV", 1), ("AngleCVSingle", 1)) # Ensure the CV's are all accounted for assert len(collective_variables.keys()) == 3 for name, count in name_count: # Ensure the CV's show up in the Context the number of times we expect them to assert len(collective_variables[name].keys()) == count # Ensure CVs which are the same in different forces are equal - assert len(set(collective_variables['BondCV'].values())) == 1 # Cast values of CV to set, make sure len == 1 + assert ( + len(set(collective_variables["BondCV"].values())) == 1 + ) # Cast values of CV to set, make sure len == 1 # Ensure invalidation with single replacement new_pos = copy.deepcopy(self.alanine_explicit_positions) new_pos[0] *= 2 @@ -1209,14 +1439,16 @@ def test_collective_variable(self): sampler_state.positions = new_pos assert sampler_state.collective_variables is None + # ============================================================================= # TEST COMPOUND STATE # ============================================================================= -class TestCompoundThermodynamicState(object): + +class TestCompoundThermodynamicState: """Test suite for states.CompoundThermodynamicState class.""" - class DummyState(object): + class DummyState: """A state that keeps track of a useless system parameter.""" standard_dummy_parameter = 1.0 @@ -1249,13 +1481,13 @@ def check_system_consistency(self, system): @staticmethod def is_context_compatible(context): parameters = context.getState(getParameters=True).getParameters() - if 'dummy_parameters' in parameters.keys(): + if "dummy_parameters" in parameters.keys(): return True else: return False def apply_to_context(self, context): - context.setParameter('dummy_parameter', self.dummy_parameter) + context.setParameter("dummy_parameter", self.dummy_parameter) def _on_setattr(self, standard_system, attribute_name, old_dummy_state): return False @@ -1269,8 +1501,8 @@ def _find_force_groups_to_update(self, context, current_context_state, memo): @classmethod def add_dummy_parameter(cls, system): """Add to system a CustomBondForce with a dummy parameter.""" - force = openmm.CustomBondForce('dummy_parameter') - force.addGlobalParameter('dummy_parameter', cls.standard_dummy_parameter) + force = openmm.CustomBondForce("dummy_parameter") + force.addGlobalParameter("dummy_parameter", cls.standard_dummy_parameter) max_force_group = cls._find_max_force_group(system) force.setForceGroup(max_force_group + 1) system.addForce(force) @@ -1281,7 +1513,7 @@ def _find_dummy_force(system): if isinstance(force, openmm.CustomBondForce): for parameter_id in range(force.getNumGlobalParameters()): parameter_name = force.getGlobalParameterName(parameter_id) - if parameter_name == 'dummy_parameter': + if parameter_name == "dummy_parameter": return force, parameter_id @classmethod @@ -1317,20 +1549,21 @@ def setup_class(cls): def test_dynamic_inheritance(self): """ThermodynamicState is inherited dinamically.""" - thermodynamic_state = ThermodynamicState(self.alanine_explicit, - self.std_temperature) + thermodynamic_state = ThermodynamicState( + self.alanine_explicit, self.std_temperature + ) compound_state = CompoundThermodynamicState(thermodynamic_state, []) assert isinstance(compound_state, ThermodynamicState) # Attributes are correctly read. - assert hasattr(compound_state, 'pressure') + assert hasattr(compound_state, "pressure") assert compound_state.pressure is None - assert hasattr(compound_state, 'temperature') + assert hasattr(compound_state, "temperature") assert compound_state.temperature == self.std_temperature # Properties and attributes are correctly set. - new_temperature = self.std_temperature + 1.0*unit.kelvin + new_temperature = self.std_temperature + 1.0 * unit.kelvin compound_state.pressure = self.std_pressure compound_state.temperature = new_temperature barostat = compound_state.barostat @@ -1339,16 +1572,24 @@ def test_dynamic_inheritance(self): def test_constructor_set_state(self): """IComposableState.set_state is called on construction.""" - thermodynamic_state = ThermodynamicState(self.alanine_explicit, self.std_temperature) + thermodynamic_state = ThermodynamicState( + self.alanine_explicit, self.std_temperature + ) - assert self.get_dummy_parameter(thermodynamic_state.system) != self.dummy_parameter - compound_state = CompoundThermodynamicState(thermodynamic_state, [self.dummy_state]) + assert ( + self.get_dummy_parameter(thermodynamic_state.system) != self.dummy_parameter + ) + compound_state = CompoundThermodynamicState( + thermodynamic_state, [self.dummy_state] + ) assert self.get_dummy_parameter(compound_state.system) == self.dummy_parameter def test_property_forwarding(self): """Forward properties to IComposableStates and update system.""" dummy_state = self.DummyState(self.dummy_parameter + 1) - thermodynamic_state = ThermodynamicState(self.alanine_explicit, self.std_temperature) + thermodynamic_state = ThermodynamicState( + self.alanine_explicit, self.std_temperature + ) compound_state = CompoundThermodynamicState(thermodynamic_state, [dummy_state]) # Properties are correctly read and set, and @@ -1363,19 +1604,25 @@ def test_property_forwarding(self): with nose.tools.assert_raises(AttributeError): compound_state.temp compound_state.temp = 0 - assert 'temp' in compound_state.__dict__ + assert "temp" in compound_state.__dict__ # If there are multiple composable states setting two different # values for the same attribute, an exception is raise. dummy_state2 = self.DummyState(dummy_state.dummy_parameter + 1) - compound_state = CompoundThermodynamicState(thermodynamic_state, [dummy_state, dummy_state2]) + compound_state = CompoundThermodynamicState( + thermodynamic_state, [dummy_state, dummy_state2] + ) with nose.tools.assert_raises(RuntimeError): compound_state.dummy_parameter def test_set_system(self): """CompoundThermodynamicState.system and set_system method.""" - thermodynamic_state = ThermodynamicState(self.alanine_explicit, self.std_temperature) - compound_state = CompoundThermodynamicState(thermodynamic_state, [self.dummy_state]) + thermodynamic_state = ThermodynamicState( + self.alanine_explicit, self.std_temperature + ) + compound_state = CompoundThermodynamicState( + thermodynamic_state, [self.dummy_state] + ) # Setting an inconsistent system for the dummy raises an error. system = compound_state.system @@ -1394,8 +1641,10 @@ def test_method_standardize_system(self): """CompoundThermodynamicState._standardize_system method.""" alanine_explicit = copy.deepcopy(self.alanine_explicit) thermodynamic_state = ThermodynamicState(alanine_explicit, self.std_temperature) - thermodynamic_state.pressure = self.std_pressure + 1.0*unit.bar - compound_state = CompoundThermodynamicState(thermodynamic_state, [self.dummy_state]) + thermodynamic_state.pressure = self.std_pressure + 1.0 * unit.bar + compound_state = CompoundThermodynamicState( + thermodynamic_state, [self.dummy_state] + ) # Standardizing the system fixes both ThermodynamicState and DummyState parameters. system = compound_state.system @@ -1405,7 +1654,9 @@ def test_method_standardize_system(self): compound_state._standardize_system(system) barostat = ThermodynamicState._find_barostat(system) assert barostat.getDefaultPressure() == self.std_pressure - assert self.get_dummy_parameter(system) == self.DummyState.standard_dummy_parameter + assert ( + self.get_dummy_parameter(system) == self.DummyState.standard_dummy_parameter + ) # Check that the standard system hash is correct. standard_hash = openmm.XmlSerializer.serialize(system).__hash__() @@ -1416,51 +1667,67 @@ def test_method_standardize_system(self): incompatible_state = ThermodynamicState(undummied_alanine, self.std_temperature) assert not compound_state.is_state_compatible(incompatible_state) - integrator = openmm.VerletIntegrator(2.0*unit.femtoseconds) + integrator = openmm.VerletIntegrator(2.0 * unit.femtoseconds) context = create_default_context(incompatible_state, integrator) assert not compound_state.is_context_compatible(context) def test_method_apply_to_context(self): """Test CompoundThermodynamicState.apply_to_context() method.""" dummy_parameter = self.DummyState.standard_dummy_parameter - thermodynamic_state = ThermodynamicState(self.alanine_explicit, self.std_temperature) + thermodynamic_state = ThermodynamicState( + self.alanine_explicit, self.std_temperature + ) thermodynamic_state.pressure = self.std_pressure self.DummyState.set_dummy_parameter(thermodynamic_state.system, dummy_parameter) - integrator = openmm.VerletIntegrator(2.0*unit.femtoseconds) + integrator = openmm.VerletIntegrator(2.0 * unit.femtoseconds) context = create_default_context(thermodynamic_state, integrator) barostat = ThermodynamicState._find_barostat(context.getSystem()) - assert context.getParameter('dummy_parameter') == dummy_parameter + assert context.getParameter("dummy_parameter") == dummy_parameter assert context.getParameter(barostat.Pressure()) == self.std_pressure / unit.bar - compound_state = CompoundThermodynamicState(thermodynamic_state, [self.dummy_state]) - new_pressure = thermodynamic_state.pressure + 1.0*unit.bar + compound_state = CompoundThermodynamicState( + thermodynamic_state, [self.dummy_state] + ) + new_pressure = thermodynamic_state.pressure + 1.0 * unit.bar compound_state.pressure = new_pressure compound_state.apply_to_context(context) - assert context.getParameter('dummy_parameter') == self.dummy_parameter + assert context.getParameter("dummy_parameter") == self.dummy_parameter assert context.getParameter(barostat.Pressure()) == new_pressure / unit.bar def test_method_find_force_groups_to_update(self): """Test CompoundThermodynamicState._find_force_groups_to_update() method.""" alanine_explicit = copy.deepcopy(self.alanine_explicit) thermodynamic_state = ThermodynamicState(alanine_explicit, self.std_temperature) - compound_state = CompoundThermodynamicState(thermodynamic_state, [self.dummy_state]) - context = create_default_context(compound_state, openmm.VerletIntegrator(2.0*unit.femtoseconds)) + compound_state = CompoundThermodynamicState( + thermodynamic_state, [self.dummy_state] + ) + context = create_default_context( + compound_state, openmm.VerletIntegrator(2.0 * unit.femtoseconds) + ) # No force group should be updated if the two states are identical. - assert compound_state._find_force_groups_to_update(context, compound_state, memo={}) == set() + assert ( + compound_state._find_force_groups_to_update( + context, compound_state, memo={} + ) + == set() + ) # If the dummy parameter changes, there should be 1 force group to update. compound_state2 = copy.deepcopy(compound_state) compound_state2.dummy_parameter -= 0.5 group = self.DummyState._find_max_force_group(context.getSystem()) - assert compound_state._find_force_groups_to_update(context, compound_state2, memo={}) == {group} + assert compound_state._find_force_groups_to_update( + context, compound_state2, memo={} + ) == {group} # ============================================================================= # TEST SERIALIZATION # ============================================================================= + def are_pickle_equal(state1, state2): """Check if they two ThermodynamicStates are identical.""" # Pickle internally uses __getstate__ so we are effectively @@ -1471,7 +1738,9 @@ def are_pickle_equal(state1, state2): def test_states_serialization(): """Test serialization compatibility with utils.serialize.""" test_system = testsystems.AlanineDipeptideImplicit() - thermodynamic_state = ThermodynamicState(test_system.system, temperature=300*unit.kelvin) + thermodynamic_state = ThermodynamicState( + test_system.system, temperature=300 * unit.kelvin + ) sampler_state = SamplerState(positions=test_system.positions) test_cases = [thermodynamic_state, sampler_state] @@ -1502,7 +1771,9 @@ def test_uncompressed_thermodynamic_state_serialization(): # Create uncompressed ThermodynamicState serialization. state._standardize_system(system) uncompressed_serialization = copy.deepcopy(compressed_serialization) - uncompressed_serialization['standard_system'] = openmm.XmlSerializer.serialize(system) + uncompressed_serialization["standard_system"] = openmm.XmlSerializer.serialize( + system + ) # First test serialization with cache. Copy # serialization so that we can use it again. @@ -1524,8 +1795,8 @@ def test_uncompressed_thermodynamic_state_serialization(): class ParameterStateExample(GlobalParameterState): standard_value = _GLOBAL_PARAMETER_STANDARD_VALUE - lambda_bonds = GlobalParameterState.GlobalParameter('lambda_bonds', standard_value) - gamma = GlobalParameterState.GlobalParameter('gamma', standard_value) + lambda_bonds = GlobalParameterState.GlobalParameter("lambda_bonds", standard_value) + gamma = GlobalParameterState.GlobalParameter("gamma", standard_value) def set_defined_parameters(self, value): for parameter_name, parameter_value in self._parameters.items(): @@ -1533,7 +1804,7 @@ def set_defined_parameters(self, value): self._parameters[parameter_name] = value -class TestGlobalParameterState(object): +class TestGlobalParameterState: """Test GlobalParameterState stand-alone functionality. The compatibility with CompoundThermodynamicState is tested in the next @@ -1545,37 +1816,50 @@ def setup_class(cls): """Create test systems and shared objects.""" # Define a diatomic molecule System with two custom forces # using the simple version and the suffix'ed version. - r0 = 0.15*unit.nanometers + r0 = 0.15 * unit.nanometers # Make sure that there is a force without defining a parameter. cls.parameters_default_values = { - 'lambda_bonds': 1.0, - 'gamma': 2.0, - 'lambda_bonds_mysuffix': 0.5, - 'gamma_mysuffix': None, + "lambda_bonds": 1.0, + "gamma": 2.0, + "lambda_bonds_mysuffix": 0.5, + "gamma_mysuffix": None, } r0_nanometers = r0.value_in_unit(unit.nanometers) # Shortcut in OpenMM units. system = openmm.System() - system.addParticle(40.0*unit.amu) - system.addParticle(40.0*unit.amu) + system.addParticle(40.0 * unit.amu) + system.addParticle(40.0 * unit.amu) # Add a force defining lambda_bonds and gamma global parameters. - custom_force = openmm.CustomBondForce('lambda_bonds^gamma*60000*(r-{})^2;'.format(r0_nanometers)) - custom_force.addGlobalParameter('lambda_bonds', cls.parameters_default_values['lambda_bonds']) - custom_force.addGlobalParameter('gamma', cls.parameters_default_values['gamma']) + custom_force = openmm.CustomBondForce( + f"lambda_bonds^gamma*60000*(r-{r0_nanometers})^2;" + ) + custom_force.addGlobalParameter( + "lambda_bonds", cls.parameters_default_values["lambda_bonds"] + ) + custom_force.addGlobalParameter("gamma", cls.parameters_default_values["gamma"]) custom_force.addBond(0, 1, []) system.addForce(custom_force) # Add a force defining the lambda_bonds_mysuffix global parameters. - custom_force_suffix = openmm.CustomBondForce('lambda_bonds_mysuffix*20000*(r-{})^2;'.format(r0_nanometers)) - custom_force_suffix.addGlobalParameter('lambda_bonds_mysuffix', cls.parameters_default_values['lambda_bonds_mysuffix']) + custom_force_suffix = openmm.CustomBondForce( + f"lambda_bonds_mysuffix*20000*(r-{r0_nanometers})^2;" + ) + custom_force_suffix.addGlobalParameter( + "lambda_bonds_mysuffix", + cls.parameters_default_values["lambda_bonds_mysuffix"], + ) custom_force_suffix.addBond(0, 1, []) system.addForce(custom_force_suffix) # Create a thermodynamic and sampler states. - cls.diatomic_molecule_ts = ThermodynamicState(system, temperature=300.0*unit.kelvin) + cls.diatomic_molecule_ts = ThermodynamicState( + system, temperature=300.0 * unit.kelvin + ) pos1 = [0.0, 0.0, 0.0] pos2 = [0.0, 0.0, r0_nanometers] - cls.diatomic_molecule_ss = SamplerState(positions=np.array([pos1, pos2]) * unit.nanometers) + cls.diatomic_molecule_ss = SamplerState( + positions=np.array([pos1, pos2]) * unit.nanometers + ) # Create a system with a duplicate force to test handling forces # defining the same parameters in different force groups. @@ -1583,7 +1867,9 @@ def setup_class(cls): custom_force.setForceGroup(30) system_force_groups = copy.deepcopy(system) system_force_groups.addForce(custom_force) - cls.diatomic_molecule_force_groups_ts = ThermodynamicState(system_force_groups, temperature=300.0*unit.kelvin) + cls.diatomic_molecule_force_groups_ts = ThermodynamicState( + system_force_groups, temperature=300.0 * unit.kelvin + ) # Create few incompatible systems for testing. An incompatible state # has a different set of defined global parameters. @@ -1592,7 +1878,7 @@ def setup_class(cls): # System without suffixed or non-suffixed parameters. for i in range(2): cls.incompatible_systems.append(copy.deepcopy(system)) - cls.incompatible_systems[i+1].removeForce(i) + cls.incompatible_systems[i + 1].removeForce(i) # System with the global parameters duplicated in two different force groups. cls.incompatible_systems.append(copy.deepcopy(system_force_groups)) @@ -1601,16 +1887,24 @@ def setup_class(cls): cls.incompatible_systems.append(copy.deepcopy(system)) custom_force = copy.deepcopy(cls.incompatible_systems[-1].getForce(1)) energy_function = custom_force.getEnergyFunction() - energy_function = energy_function.replace('lambda_bonds_mysuffix', 'lambda_bonds_mysuffix^gamma_mysuffix') + energy_function = energy_function.replace( + "lambda_bonds_mysuffix", "lambda_bonds_mysuffix^gamma_mysuffix" + ) custom_force.setEnergyFunction(energy_function) - custom_force.addGlobalParameter('gamma_mysuffix', cls.parameters_default_values['gamma']) + custom_force.addGlobalParameter( + "gamma_mysuffix", cls.parameters_default_values["gamma"] + ) cls.incompatible_systems[-1].addForce(custom_force) def read_system_state(self, system): states = [] - for suffix in [None, 'mysuffix']: + for suffix in [None, "mysuffix"]: try: - states.append(ParameterStateExample.from_system(system, parameters_name_suffix=suffix)) + states.append( + ParameterStateExample.from_system( + system, parameters_name_suffix=suffix + ) + ) except GlobalParameterError: continue return states @@ -1618,21 +1912,34 @@ def read_system_state(self, system): @staticmethod def test_constructor_parameters(): """Test GlobalParameterState constructor behave as expected.""" + class MyState(GlobalParameterState): - lambda_angles = GlobalParameterState.GlobalParameter('lambda_angles', standard_value=1.0) - lambda_sterics = GlobalParameterState.GlobalParameter('lambda_sterics', standard_value=1.0) + lambda_angles = GlobalParameterState.GlobalParameter( + "lambda_angles", standard_value=1.0 + ) + lambda_sterics = GlobalParameterState.GlobalParameter( + "lambda_sterics", standard_value=1.0 + ) # Raise an exception if parameter is not recognized. - with nose.tools.assert_raises_regexp(GlobalParameterError, 'Unknown parameters'): + with nose.tools.assert_raises_regexp( + GlobalParameterError, "Unknown parameters" + ): MyState(lambda_steric=1.0) # Typo. # Properties are initialized correctly. - test_cases = [{}, - {'lambda_angles': 1.0}, - {'lambda_sterics': 0.5, 'lambda_angles': 0.5}, - {'parameters_name_suffix': 'suffix'}, - {'parameters_name_suffix': 'suffix', 'lambda_angles': 1.0}, - {'parameters_name_suffix': 'suffix', 'lambda_sterics': 0.5, 'lambda_angles': 0.5}] + test_cases = [ + {}, + {"lambda_angles": 1.0}, + {"lambda_sterics": 0.5, "lambda_angles": 0.5}, + {"parameters_name_suffix": "suffix"}, + {"parameters_name_suffix": "suffix", "lambda_angles": 1.0}, + { + "parameters_name_suffix": "suffix", + "lambda_sterics": 0.5, + "lambda_angles": 0.5, + }, + ] for test_kwargs in test_cases: state = MyState(**test_kwargs) @@ -1643,41 +1950,53 @@ class MyState(GlobalParameterState): is_defined = parameter in test_kwargs # The "unsuffixed" parameter should not be controlled by the state. - if 'parameters_name_suffix' in test_kwargs: - with nose.tools.assert_raises_regexp(AttributeError, 'state does not control'): + if "parameters_name_suffix" in test_kwargs: + with nose.tools.assert_raises_regexp( + AttributeError, "state does not control" + ): getattr(state, parameter) # The state exposes a "suffixed" version of the parameter. - state_attribute = parameter + '_' + test_kwargs['parameters_name_suffix'] + state_attribute = ( + parameter + "_" + test_kwargs["parameters_name_suffix"] + ) else: state_attribute = parameter # Check if parameter should is defined or undefined (i.e. set to None) as expected. - err_msg = 'Parameter: {} (Test case: {})'.format(parameter, test_kwargs) + err_msg = f"Parameter: {parameter} (Test case: {test_kwargs})" if is_defined: - assert getattr(state, state_attribute) == test_kwargs[parameter], err_msg + assert ( + getattr(state, state_attribute) == test_kwargs[parameter] + ), err_msg else: assert getattr(state, state_attribute) is None, err_msg def test_from_system_constructor(self): """Test GlobalParameterState.from_system constructor.""" # A system exposing no global parameters controlled by the state raises an error. - with nose.tools.assert_raises_regexp(GlobalParameterError, 'no global parameters'): + with nose.tools.assert_raises_regexp( + GlobalParameterError, "no global parameters" + ): GlobalParameterState.from_system(openmm.System()) system = self.diatomic_molecule_ts.system state = ParameterStateExample.from_system(system) - state_suffix = ParameterStateExample.from_system(system, parameters_name_suffix='mysuffix') + state_suffix = ParameterStateExample.from_system( + system, parameters_name_suffix="mysuffix" + ) for parameter_name, parameter_value in self.parameters_default_values.items(): - if 'suffix' in parameter_name: + if "suffix" in parameter_name: controlling_state = state_suffix noncontrolling_state = state else: controlling_state = state noncontrolling_state = state_suffix - err_msg = '{}: {}'.format(parameter_name, parameter_value) - assert getattr(controlling_state, parameter_name) == parameter_value, err_msg + err_msg = f"{parameter_name}: {parameter_value}" + assert ( + getattr(controlling_state, parameter_name) == parameter_value + ), err_msg with nose.tools.assert_raises(AttributeError): getattr(noncontrolling_state, parameter_name), parameter_name @@ -1685,30 +2004,36 @@ def test_parameter_validator(self): """Test GlobalParameterState constructor behave as expected.""" class MyState(GlobalParameterState): - lambda_bonds = GlobalParameterState.GlobalParameter('lambda_bonds', standard_value=1.0) + lambda_bonds = GlobalParameterState.GlobalParameter( + "lambda_bonds", standard_value=1.0 + ) @lambda_bonds.validator def lambda_bonds(self, instance, new_value): if not (0.0 <= new_value <= 1.0): - raise ValueError('lambda_bonds must be between 0.0 and 1.0') + raise ValueError("lambda_bonds must be between 0.0 and 1.0") return new_value # Create system with incorrect initial parameter. system = self.diatomic_molecule_ts.system system.getForce(0).setGlobalParameterDefaultValue(0, 2.0) # lambda_bonds - system.getForce(1).setGlobalParameterDefaultValue(0, -1.0) # lambda_bonds_mysuffix + system.getForce(1).setGlobalParameterDefaultValue( + 0, -1.0 + ) # lambda_bonds_mysuffix - for suffix in [None, 'mysuffix']: + for suffix in [None, "mysuffix"]: # Raise an exception on init. - with nose.tools.assert_raises_regexp(ValueError, 'must be between'): + with nose.tools.assert_raises_regexp(ValueError, "must be between"): MyState(parameters_name_suffix=suffix, lambda_bonds=-1.0) - with nose.tools.assert_raises_regexp(ValueError, 'must be between'): + with nose.tools.assert_raises_regexp(ValueError, "must be between"): MyState.from_system(system, parameters_name_suffix=suffix) # Raise an exception when properties are set. state = MyState(parameters_name_suffix=suffix, lambda_bonds=1.0) - parameter_name = 'lambda_bonds' if suffix is None else 'lambda_bonds_' + suffix - with nose.tools.assert_raises_regexp(ValueError, 'must be between'): + parameter_name = ( + "lambda_bonds" if suffix is None else "lambda_bonds_" + suffix + ) + with nose.tools.assert_raises_regexp(ValueError, "must be between"): setattr(state, parameter_name, 5.0) def test_equality_operator(self): @@ -1717,8 +2042,12 @@ def test_equality_operator(self): state2 = ParameterStateExample(lambda_bonds=1.0) state3 = ParameterStateExample(lambda_bonds=0.9) state4 = ParameterStateExample(lambda_bonds=0.9, gamma=1.0) - state5 = ParameterStateExample(lambda_bonds=0.9, parameters_name_suffix='suffix') - state6 = ParameterStateExample(parameters_name_suffix='suffix', lambda_bonds=0.9, gamma=1.0) + state5 = ParameterStateExample( + lambda_bonds=0.9, parameters_name_suffix="suffix" + ) + state6 = ParameterStateExample( + parameters_name_suffix="suffix", lambda_bonds=0.9, gamma=1.0 + ) assert state1 == state2 assert state2 != state3 assert state3 != state4 @@ -1729,32 +2058,41 @@ def test_equality_operator(self): # States that control more variables are not equal. class MyState(ParameterStateExample): - extra_parameter = GlobalParameterState.GlobalParameter('extra_parameter', standard_value=1.0) + extra_parameter = GlobalParameterState.GlobalParameter( + "extra_parameter", standard_value=1.0 + ) + state7 = MyState(lambda_bonds=0.9) assert state3 != state7 # States defined by global parameter functions are evaluated correctly. state8 = copy.deepcopy(state1) - state8.set_function_variable('lambda1', state1.lambda_bonds*2.0) - state8.lambda_bonds = GlobalParameterFunction('lambda1 / 2') + state8.set_function_variable("lambda1", state1.lambda_bonds * 2.0) + state8.lambda_bonds = GlobalParameterFunction("lambda1 / 2") assert state1 == state8 - state8.set_function_variable('lambda1', state1.lambda_bonds) + state8.set_function_variable("lambda1", state1.lambda_bonds) assert state1 != state8 def test_apply_to_system(self): """Test method GlobalParameterState.apply_to_system().""" system = self.diatomic_molecule_ts.system state = ParameterStateExample.from_system(system) - state_suffix = ParameterStateExample.from_system(system, parameters_name_suffix='mysuffix') + state_suffix = ParameterStateExample.from_system( + system, parameters_name_suffix="mysuffix" + ) expected_system_values = copy.deepcopy(self.parameters_default_values) def check_system_values(): state, state_suffix = self.read_system_state(system) for parameter_name, parameter_value in expected_system_values.items(): - err_msg = 'parameter: {}, expected_value: {}'.format(parameter_name, parameter_value) - if 'suffix' in parameter_name: - assert getattr(state_suffix, parameter_name) == parameter_value, err_msg + err_msg = ( + f"parameter: {parameter_name}, expected_value: {parameter_value}" + ) + if "suffix" in parameter_name: + assert ( + getattr(state_suffix, parameter_name) == parameter_value + ), err_msg else: assert getattr(state, parameter_name) == parameter_value, err_msg @@ -1763,22 +2101,22 @@ def check_system_values(): # apply_to_system() modifies the state. state.lambda_bonds /= 2 - expected_system_values['lambda_bonds'] /= 2 + expected_system_values["lambda_bonds"] /= 2 state_suffix.lambda_bonds_mysuffix /= 2 - expected_system_values['lambda_bonds_mysuffix'] /= 2 + expected_system_values["lambda_bonds_mysuffix"] /= 2 for s in [state, state_suffix]: s.apply_to_system(system) check_system_values() # Raise an error if an extra parameter is defined in the system. state.gamma = None - err_msg = 'The system parameter gamma is not defined in this state.' + err_msg = "The system parameter gamma is not defined in this state." with nose.tools.assert_raises_regexp(GlobalParameterError, err_msg): state.apply_to_system(system) # Raise an error if an extra parameter is defined in the state. state_suffix.gamma_mysuffix = 2.0 - err_msg = 'Could not find global parameter gamma_mysuffix in the system.' + err_msg = "Could not find global parameter gamma_mysuffix in the system." with nose.tools.assert_raises_regexp(GlobalParameterError, err_msg): state_suffix.apply_to_system(system) @@ -1788,7 +2126,9 @@ def test_check_system_consistency(self): def check_not_consistency(states): for s in states: - with nose.tools.assert_raises_regexp(GlobalParameterError, 'Consistency check failed'): + with nose.tools.assert_raises_regexp( + GlobalParameterError, "Consistency check failed" + ): s.check_system_consistency(system) # A system is consistent with itself. @@ -1810,13 +2150,13 @@ def check_not_consistency(states): # Raise error if system has different lambda values. state, state_suffix = self.read_system_state(system) state.lambda_bonds /= 2 - state_suffix.lambda_bonds_mysuffix /=2 + state_suffix.lambda_bonds_mysuffix /= 2 check_not_consistency([state, state_suffix]) def test_apply_to_context(self): """Test method GlobalParameterState.apply_to_context.""" system = self.diatomic_molecule_ts.system - integrator = openmm.VerletIntegrator(1.0*unit.femtosecond) + integrator = openmm.VerletIntegrator(1.0 * unit.femtosecond) context = create_default_context(self.diatomic_molecule_ts, integrator) def check_not_applicable(states, error, context): @@ -1828,12 +2168,12 @@ def check_not_applicable(states, error, context): state, state_suffix = self.read_system_state(system) state.lambda_bonds = None state_suffix.lambda_bonds_mysuffix = None - check_not_applicable([state, state_suffix], 'undefined in this state', context) + check_not_applicable([state, state_suffix], "undefined in this state", context) # Raise error if the state defines global parameters that are undefined in the Context. state, state_suffix = self.read_system_state(system) state_suffix.gamma_mysuffix = 2.0 - check_not_applicable([state_suffix], 'Could not find parameter', context) + check_not_applicable([state_suffix], "Could not find parameter", context) # Test-precondition: Context parameters are different than the value we'll test. tested_value = 0.2 @@ -1859,11 +2199,15 @@ def test_standardize_system(self): def check_is_standard(states, is_standard): for s in states: - for parameter_name in s._get_controlled_parameters(s._parameters_name_suffix): + for parameter_name in s._get_controlled_parameters( + s._parameters_name_suffix + ): parameter_value = getattr(s, parameter_name) - err_msg = 'Parameter: {}; Value: {};'.format(parameter_name, parameter_value) + err_msg = f"Parameter: {parameter_name}; Value: {parameter_value};" if parameter_value is not None: - assert (parameter_value == standard_value) is is_standard, err_msg + assert ( + parameter_value == standard_value + ) is is_standard, err_msg # Test pre-condition: The system is not in the standard state. system.getForce(0).setGlobalParameterDefaultValue(0, 0.9) @@ -1879,13 +2223,9 @@ def check_is_standard(states, is_standard): def test_find_force_groups_to_update(self): """Test method GlobalParameterState._find_force_groups_to_update.""" system = self.diatomic_molecule_force_groups_ts.system - integrator = openmm.VerletIntegrator(2.0*unit.femtoseconds) + integrator = openmm.VerletIntegrator(2.0 * unit.femtoseconds) # Test cases are (force_groups, force_groups_suffix) - test_cases = [ - ([0], [0, 0]), - ([1], [5, 5]), - ([9], [4, 2]) - ] + test_cases = [([0], [0, 0]), ([1], [5, 5]), ([9], [4, 2])] for test_case in test_cases: for i, force_group in enumerate(test_case[0] + test_case[1]): @@ -1895,12 +2235,16 @@ def test_find_force_groups_to_update(self): # No force group should be updated if we don't change the global parameter. for state, force_groups in zip(states, test_case): - assert state._find_force_groups_to_update(context, state, memo={}) == set() + assert ( + state._find_force_groups_to_update(context, state, memo={}) == set() + ) # Change the lambdas one by one and check that the method # recognizes that the force group energy must be updated. current_state = copy.deepcopy(state) - for parameter_name in state._get_controlled_parameters(state._parameters_name_suffix): + for parameter_name in state._get_controlled_parameters( + state._parameters_name_suffix + ): # Check that the system defines the global variable. parameter_value = getattr(state, parameter_name) if parameter_value is None: @@ -1908,8 +2252,12 @@ def test_find_force_groups_to_update(self): # Change the current state. setattr(current_state, parameter_name, parameter_value / 2) - assert state._find_force_groups_to_update(context, current_state, memo={}) == set(force_groups) - setattr(current_state, parameter_name, parameter_value) # Reset current state. + assert state._find_force_groups_to_update( + context, current_state, memo={} + ) == set(force_groups) + setattr( + current_state, parameter_name, parameter_value + ) # Reset current state. del context def test_global_parameters_functions(self): @@ -1918,23 +2266,23 @@ def test_global_parameters_functions(self): state = ParameterStateExample.from_system(system) # Add two function variables to the state. - state.set_function_variable('lambda', 1.0) - state.set_function_variable('lambda2', 0.5) - assert state.get_function_variable('lambda') == 1.0 - assert state.get_function_variable('lambda2') == 0.5 + state.set_function_variable("lambda", 1.0) + state.set_function_variable("lambda2", 0.5) + assert state.get_function_variable("lambda") == 1.0 + assert state.get_function_variable("lambda2") == 0.5 # Cannot call an function variable as a supported parameter. with nose.tools.assert_raises(GlobalParameterError): - state.set_function_variable('lambda_bonds', 0.5) + state.set_function_variable("lambda_bonds", 0.5) # Assign string global parameter functions to parameters. - state.lambda_bonds = GlobalParameterFunction('lambda') - state.gamma = GlobalParameterFunction('(lambda + lambda2) / 2.0') + state.lambda_bonds = GlobalParameterFunction("lambda") + state.gamma = GlobalParameterFunction("(lambda + lambda2) / 2.0") assert state.lambda_bonds == 1.0 assert state.gamma == 0.75 # Setting function variables updates global parameter as well. - state.set_function_variable('lambda2', 0) + state.set_function_variable("lambda2", 0) assert state.gamma == 0.5 # --------------------------------------------------- @@ -1951,14 +2299,16 @@ def test_constructor_compound_state(self): state.set_defined_parameters(0.222) # CompoundThermodynamicState set the system state in the constructor. - compound_state = CompoundThermodynamicState(self.diatomic_molecule_ts, composable_states) + compound_state = CompoundThermodynamicState( + self.diatomic_molecule_ts, composable_states + ) new_system_states = self.read_system_state(compound_state.system) for state, new_state in zip(composable_states, new_system_states): assert state == new_state # Trying to set in the constructor undefined global parameters raise an exception. composable_states[1].gamma_mysuffix = 2.0 - err_msg = 'Could not find global parameter gamma_mysuffix in the system.' + err_msg = "Could not find global parameter gamma_mysuffix in the system." with nose.tools.assert_raises_regexp(GlobalParameterError, err_msg): CompoundThermodynamicState(self.diatomic_molecule_ts, composable_states) @@ -1967,14 +2317,16 @@ def test_global_parameters_compound_state(self): composable_states = self.read_system_state(self.diatomic_molecule_ts.system) for state in composable_states: state.set_defined_parameters(0.222) - compound_state = CompoundThermodynamicState(self.diatomic_molecule_ts, composable_states) + compound_state = CompoundThermodynamicState( + self.diatomic_molecule_ts, composable_states + ) # Defined properties can be assigned and read, unless they are undefined. for parameter_name, default_value in self.parameters_default_values.items(): if default_value is None: assert getattr(compound_state, parameter_name) is None # If undefined, setting the property should raise an error. - err_msg = 'Cannot set the parameter gamma_mysuffix in the system' + err_msg = "Cannot set the parameter gamma_mysuffix in the system" with nose.tools.assert_raises_regexp(GlobalParameterError, err_msg): setattr(compound_state, parameter_name, 2.0) continue @@ -1988,34 +2340,43 @@ def test_global_parameters_compound_state(self): system_states = self.read_system_state(compound_state.system) for state in system_states: for parameter_name in state._parameters: - assert getattr(state, parameter_name) == getattr(compound_state, parameter_name) + assert getattr(state, parameter_name) == getattr( + compound_state, parameter_name + ) # Same for global parameter function variables. - compound_state.set_function_variable('lambda', 0.25) - defined_parameters = {name for name, value in self.parameters_default_values.items() - if value is not None} + compound_state.set_function_variable("lambda", 0.25) + defined_parameters = { + name + for name, value in self.parameters_default_values.items() + if value is not None + } for parameter_name in defined_parameters: - setattr(compound_state, parameter_name, GlobalParameterFunction('lambda')) + setattr(compound_state, parameter_name, GlobalParameterFunction("lambda")) parameter_value = getattr(compound_state, parameter_name) - assert parameter_value == 0.25, '{}, {}'.format(parameter_name, parameter_value) + assert parameter_value == 0.25, f"{parameter_name}, {parameter_value}" system_states = self.read_system_state(compound_state.system) for state in system_states: for parameter_name in state._parameters: if parameter_name in defined_parameters: parameter_value = getattr(compound_state, parameter_name) - assert parameter_value == 0.25, '{}, {}'.format(parameter_name, parameter_value) + assert ( + parameter_value == 0.25 + ), f"{parameter_name}, {parameter_value}" def test_set_system_compound_state(self): """Setting inconsistent system in compound state raise errors.""" system = self.diatomic_molecule_ts.system composable_states = self.read_system_state(system) - compound_state = CompoundThermodynamicState(self.diatomic_molecule_ts, composable_states) + compound_state = CompoundThermodynamicState( + self.diatomic_molecule_ts, composable_states + ) for parameter_name, default_value in self.parameters_default_values.items(): if default_value is None: continue - elif 'suffix' in parameter_name: + elif "suffix" in parameter_name: original_state = composable_states[1] else: original_state = composable_states[0] @@ -2023,7 +2384,7 @@ def test_set_system_compound_state(self): # We create an incompatible state with the parameter set to a different value. incompatible_state = copy.deepcopy(original_state) original_value = getattr(incompatible_state, parameter_name) - setattr(incompatible_state, parameter_name, original_value/2) + setattr(incompatible_state, parameter_name, original_value / 2) incompatible_state.apply_to_system(system) # Setting an inconsistent system raise an error. @@ -2036,8 +2397,13 @@ def test_set_system_compound_state(self): # This doesn't happen if we fix the state. compound_state.set_system(system, fix_state=True) - new_state = incompatible_state.from_system(compound_state.system, original_state._parameters_name_suffix) - assert new_state == original_state, (str(new_state), str(incompatible_state)) + new_state = incompatible_state.from_system( + compound_state.system, original_state._parameters_name_suffix + ) + assert new_state == original_state, ( + str(new_state), + str(incompatible_state), + ) # Restore old value in system, and test next parameter. original_state.apply_to_system(system) @@ -2049,15 +2415,24 @@ def test_compatibility_compound_state(self): # Build all compound states. compound_states = [] for system in incompatible_systems: - thermodynamic_state = ThermodynamicState(system, temperature=300*unit.kelvin) + thermodynamic_state = ThermodynamicState( + system, temperature=300 * unit.kelvin + ) composable_states = self.read_system_state(system) - compound_states.append(CompoundThermodynamicState(thermodynamic_state, composable_states)) + compound_states.append( + CompoundThermodynamicState(thermodynamic_state, composable_states) + ) # Build all contexts for testing. - integrator = openmm.VerletIntegrator(2.0*unit.femtoseconds) - contexts = [create_default_context(s, copy.deepcopy(integrator)) for s in compound_states] + integrator = openmm.VerletIntegrator(2.0 * unit.femtoseconds) + contexts = [ + create_default_context(s, copy.deepcopy(integrator)) + for s in compound_states + ] - for state_idx, (compound_state, context) in enumerate(zip(compound_states, contexts)): + for state_idx, (compound_state, context) in enumerate( + zip(compound_states, contexts) + ): # The state is compatible with itself. assert compound_state.is_state_compatible(compound_state) assert compound_state.is_context_compatible(context) @@ -2065,7 +2440,7 @@ def test_compatibility_compound_state(self): # Changing the values of the parameters do not affect # compatibility (only defined/undefined parameters do). altered_compound_state = copy.deepcopy(compound_state) - for parameter_name in ['gamma', 'lambda_bonds_mysuffix']: + for parameter_name in ["gamma", "lambda_bonds_mysuffix"]: try: new_value = getattr(compound_state, parameter_name) / 2 setattr(altered_compound_state, parameter_name, new_value) @@ -2076,7 +2451,7 @@ def test_compatibility_compound_state(self): # All other states are incompatible. Test only those that we # haven't tested yet, but test transitivity. - for incompatible_state_idx in range(state_idx+1, len(compound_states)): + for incompatible_state_idx in range(state_idx + 1, len(compound_states)): print(state_idx, incompatible_state_idx) incompatible_state = compound_states[incompatible_state_idx] incompatible_context = contexts[incompatible_state_idx] @@ -2095,16 +2470,40 @@ def test_reduced_potential_compound_state(self): # Build a mixed collection of compatible and incompatible thermodynamic states. thermodynamic_states = [ copy.deepcopy(self.diatomic_molecule_ts), - copy.deepcopy(self.diatomic_molecule_force_groups_ts) + copy.deepcopy(self.diatomic_molecule_force_groups_ts), ] compound_states = [] for ts_idx, ts in enumerate(thermodynamic_states): - compound_state = CompoundThermodynamicState(ts, self.read_system_state(ts.system)) - for state in [dict(lambda_bonds=1.0, gamma=1.0, lambda_bonds_mysuffix=1.0, gamma_mysuffix=1.0), - dict(lambda_bonds=0.5, gamma=1.0, lambda_bonds_mysuffix=1.0, gamma_mysuffix=1.0), - dict(lambda_bonds=0.5, gamma=1.0, lambda_bonds_mysuffix=1.0, gamma_mysuffix=0.5), - dict(lambda_bonds=0.1, gamma=0.5, lambda_bonds_mysuffix=0.2, gamma_mysuffix=0.5)]: + compound_state = CompoundThermodynamicState( + ts, self.read_system_state(ts.system) + ) + for state in [ + dict( + lambda_bonds=1.0, + gamma=1.0, + lambda_bonds_mysuffix=1.0, + gamma_mysuffix=1.0, + ), + dict( + lambda_bonds=0.5, + gamma=1.0, + lambda_bonds_mysuffix=1.0, + gamma_mysuffix=1.0, + ), + dict( + lambda_bonds=0.5, + gamma=1.0, + lambda_bonds_mysuffix=1.0, + gamma_mysuffix=0.5, + ), + dict( + lambda_bonds=0.1, + gamma=0.5, + lambda_bonds_mysuffix=0.2, + gamma_mysuffix=0.5, + ), + ]: for parameter_name, parameter_value in state.items(): try: setattr(compound_state, parameter_name, parameter_value) @@ -2121,7 +2520,7 @@ def test_reduced_potential_compound_state(self): obtained_energies = [] for compatible_group in compatible_groups: # Create context. - integrator = openmm.VerletIntegrator(2.0*unit.femtoseconds) + integrator = openmm.VerletIntegrator(2.0 * unit.femtoseconds) context = create_default_context(compatible_group[0], integrator) context.setPositions(positions) @@ -2131,7 +2530,9 @@ def test_reduced_potential_compound_state(self): expected_energies.append(state.reduced_potential(context)) # Compute with multi-state method. - compatible_energies = ThermodynamicState.reduced_potential_at_states(context, compatible_group) + compatible_energies = ThermodynamicState.reduced_potential_at_states( + context, compatible_group + ) # The first and the last state must be equal. assert np.isclose(compatible_energies[0], compatible_energies[-1]) @@ -2144,8 +2545,8 @@ def test_serialization(self): composable_states = self.read_system_state(self.diatomic_molecule_ts.system) # Add a global parameter function to test if they are serialized correctly. - composable_states[0].set_function_variable('lambda', 0.5) - composable_states[0].gamma = GlobalParameterFunction('lambda**2') + composable_states[0].set_function_variable("lambda", 0.5) + composable_states[0].gamma = GlobalParameterFunction("lambda**2") # Test serialization/deserialization of GlobalParameterState. for state in composable_states: @@ -2154,7 +2555,9 @@ def test_serialization(self): are_pickle_equal(state, deserialized_state) # Test serialization/deserialization of GlobalParameterState in CompoundState. - compound_state = CompoundThermodynamicState(self.diatomic_molecule_ts, composable_states) + compound_state = CompoundThermodynamicState( + self.diatomic_molecule_ts, composable_states + ) serialization = utils.serialize(compound_state) deserialized_state = utils.deserialize(serialization) are_pickle_equal(compound_state, deserialized_state) @@ -2164,58 +2567,68 @@ def test_create_thermodynamic_state_protocol(): """Test the method for efficiently creating a list of thermoydamic states.""" system = testsystems.AlchemicalAlanineDipeptide().system - thermo_state = ThermodynamicState(system, temperature=400*unit.kelvin) + thermo_state = ThermodynamicState(system, temperature=400 * unit.kelvin) # The method raises an exception when the protocol is empty. - with nose.tools.assert_raises_regexp(ValueError, 'No protocol'): + with nose.tools.assert_raises_regexp(ValueError, "No protocol"): create_thermodynamic_state_protocol(system, protocol={}) # The method raises an exception when different parameters have different lengths. - with nose.tools.assert_raises_regexp(ValueError, 'different lengths'): - protocol = {'temperature': [1.0, 2.0], - 'pressure': [4.0]} + with nose.tools.assert_raises_regexp(ValueError, "different lengths"): + protocol = {"temperature": [1.0, 2.0], "pressure": [4.0]} create_thermodynamic_state_protocol(system, protocol=protocol) # An exception is raised if the temperature is not specified with a System. - with nose.tools.assert_raises_regexp(ValueError, 'must specify the temperature'): - protocol = {'pressure': [5.0] * unit.atmosphere} + with nose.tools.assert_raises_regexp(ValueError, "must specify the temperature"): + protocol = {"pressure": [5.0] * unit.atmosphere} create_thermodynamic_state_protocol(system, protocol=protocol) # An exception is raised if a parameter is specified both as constant and protocol. - with nose.tools.assert_raises_regexp(ValueError, 'constants and protocol'): - protocol = {'temperature': [5.0, 10.0] * unit.kelvin} - const = {'temperature': 5.0 * unit.kelvin} + with nose.tools.assert_raises_regexp(ValueError, "constants and protocol"): + protocol = {"temperature": [5.0, 10.0] * unit.kelvin} + const = {"temperature": 5.0 * unit.kelvin} create_thermodynamic_state_protocol(system, protocol=protocol, constants=const) # Method works as expected with a reference System or ThermodynamicState. - protocol = {'temperature': [290, 310, 360]*unit.kelvin} + protocol = {"temperature": [290, 310, 360] * unit.kelvin} for reference in [system, thermo_state]: states = create_thermodynamic_state_protocol(reference, protocol=protocol) - for state, temp in zip(states, protocol['temperature']): + for state, temp in zip(states, protocol["temperature"]): assert state.temperature == temp assert len(states) == 3 # Same with CompoundThermodynamicState. from openmmtools.alchemy import AlchemicalState + alchemical_state = AlchemicalState.from_system(system) - protocol = {'temperature': [290, 310, 360]*unit.kelvin, - 'lambda_sterics': [1.0, 0.5, 0.0], - 'lambda_electrostatics': [0.75, 0.5, 0.25]} + protocol = { + "temperature": [290, 310, 360] * unit.kelvin, + "lambda_sterics": [1.0, 0.5, 0.0], + "lambda_electrostatics": [0.75, 0.5, 0.25], + } for reference in [system, thermo_state]: - states = create_thermodynamic_state_protocol(reference, protocol=protocol, - composable_states=alchemical_state) - for state, temp, sterics, electro in zip(states, protocol['temperature'], - protocol['lambda_sterics'], - protocol['lambda_electrostatics']): + states = create_thermodynamic_state_protocol( + reference, protocol=protocol, composable_states=alchemical_state + ) + for state, temp, sterics, electro in zip( + states, + protocol["temperature"], + protocol["lambda_sterics"], + protocol["lambda_electrostatics"], + ): assert state.temperature == temp assert state.lambda_sterics == sterics assert state.lambda_electrostatics == electro assert len(states) == 3 # Check that constants work correctly. - del protocol['temperature'] - const = {'temperature': 500*unit.kelvin} - states = create_thermodynamic_state_protocol(thermo_state, protocol=protocol, constants=const, - composable_states=alchemical_state) + del protocol["temperature"] + const = {"temperature": 500 * unit.kelvin} + states = create_thermodynamic_state_protocol( + thermo_state, + protocol=protocol, + constants=const, + composable_states=alchemical_state, + ) for state in states: - assert state.temperature == 500*unit.kelvin + assert state.temperature == 500 * unit.kelvin diff --git a/openmmtools/tests/test_storage_interface.py b/openmmtools/tests/test_storage_interface.py index 80256694b..c67ef9f5c 100644 --- a/openmmtools/tests/test_storage_interface.py +++ b/openmmtools/tests/test_storage_interface.py @@ -13,6 +13,7 @@ # ============================================================================================= import numpy as np + try: from openmm import unit except ImportError: # OpenMM < 7.6 @@ -30,10 +31,12 @@ # TEST HELPER FUNCTIONS # ============================================================================================= + def spawn_driver(path): """Create a driver that is used to test the StorageInterface class at path location""" return NetCDFIODriver(path) + # ============================================================================================= # STORAGE INTERFACE TESTING FUNCTIONS # ============================================================================================= @@ -42,18 +45,18 @@ def spawn_driver(path): def test_storage_interface_creation(): """Test that the storage interface can create a top level file and read from it""" with temporary_directory() as tmp_dir: - test_store = tmp_dir + '/teststore.nc' + test_store = tmp_dir + "/teststore.nc" driver = spawn_driver(test_store) si = StorageInterface(driver) - si.add_metadata('name', 'data') - assert si.storage_driver.ncfile.getncattr('name') == 'data' + si.add_metadata("name", "data") + assert si.storage_driver.ncfile.getncattr("name") == "data" @tools.raises(Exception) def test_read_trap(): """Test that attempting to read a non-existent file fails""" with temporary_directory() as tmp_dir: - test_store = tmp_dir + '/teststore.nc' + test_store = tmp_dir + "/teststore.nc" driver = spawn_driver(test_store) si = StorageInterface(driver) si.var1.read() @@ -62,7 +65,7 @@ def test_read_trap(): def test_variable_write_read(): """Test that a variable can be create and written to file""" with temporary_directory() as tmp_dir: - test_store = tmp_dir + '/teststore.nc' + test_store = tmp_dir + "/teststore.nc" driver = spawn_driver(test_store) si = StorageInterface(driver) input_data = 4 @@ -74,7 +77,7 @@ def test_variable_write_read(): def test_variable_append_read(): """Test that a variable can be create and written to file""" with temporary_directory() as tmp_dir: - test_store = tmp_dir + '/teststore.nc' + test_store = tmp_dir + "/teststore.nc" driver = spawn_driver(test_store) si = StorageInterface(driver) input_data = np.eye(3) * 4.0 @@ -88,7 +91,7 @@ def test_variable_append_read(): def test_at_index_write(): """Test that writing at a specific index of appended data works""" with temporary_directory() as tmp_dir: - test_store = tmp_dir + '/teststore.nc' + test_store = tmp_dir + "/teststore.nc" driver = spawn_driver(test_store) si = StorageInterface(driver) input_data = 4 @@ -105,10 +108,10 @@ def test_at_index_write(): def test_unbound_read(): """Test that a variable can read from the file without previous binding""" with temporary_directory() as tmp_dir: - test_store = tmp_dir + '/teststore.nc' + test_store = tmp_dir + "/teststore.nc" driver = spawn_driver(test_store) si = StorageInterface(driver) - input_data = 4*unit.kelvin + input_data = 4 * unit.kelvin si.four.write(input_data) si.storage_driver.close() del si @@ -121,15 +124,15 @@ def test_unbound_read(): def test_directory_creation(): """Test that automatic directory-like objects are created on the fly""" with temporary_directory() as tmp_dir: - test_store = tmp_dir + '/teststore.nc' + test_store = tmp_dir + "/teststore.nc" driver = spawn_driver(test_store) si = StorageInterface(driver) - input_data = 'four' + input_data = "four" si.dir0.dir1.dir2.var.write(input_data) ncfile = si.storage_driver.ncfile target = ncfile for i in range(3): - my_dir = 'dir{}'.format(i) + my_dir = f"dir{i}" assert my_dir in target.groups target = target.groups[my_dir] si.storage_driver.close() @@ -138,7 +141,7 @@ def test_directory_creation(): si = StorageInterface(driver) target = si for i in range(3): - my_dir = 'dir{}'.format(i) + my_dir = f"dir{i}" target = getattr(target, my_dir) assert target.var.read() == input_data @@ -146,7 +149,7 @@ def test_directory_creation(): def test_multi_variable_creation(): """Test that multiple variables can be created in a single directory structure""" with temporary_directory() as tmp_dir: - test_store = tmp_dir + '/teststore.nc' + test_store = tmp_dir + "/teststore.nc" driver = spawn_driver(test_store) si = StorageInterface(driver) input_data = [4.0, 4.0, 4.0] @@ -166,14 +169,14 @@ def test_multi_variable_creation(): def test_metadata_creation(): """Test that metadata can be added to variables and directories""" with temporary_directory() as tmp_dir: - test_store = tmp_dir + '/teststore.nc' + test_store = tmp_dir + "/teststore.nc" driver = spawn_driver(test_store) si = StorageInterface(driver) input_data = 4 si.dir0.var1.write(input_data) - si.dir0.add_metadata('AmIAGroup', 'yes') - si.dir0.var1.add_metadata('AmIAGroup', 'no') - dir0 = si.storage_driver.ncfile.groups['dir0'] - var1 = dir0.variables['var1'] - assert dir0.getncattr('AmIAGroup') == 'yes' - assert var1.getncattr('AmIAGroup') == 'no' + si.dir0.add_metadata("AmIAGroup", "yes") + si.dir0.var1.add_metadata("AmIAGroup", "no") + dir0 = si.storage_driver.ncfile.groups["dir0"] + var1 = dir0.variables["var1"] + assert dir0.getncattr("AmIAGroup") == "yes" + assert var1.getncattr("AmIAGroup") == "no" diff --git a/openmmtools/tests/test_storage_iodrivers.py b/openmmtools/tests/test_storage_iodrivers.py index 817903888..332d91939 100644 --- a/openmmtools/tests/test_storage_iodrivers.py +++ b/openmmtools/tests/test_storage_iodrivers.py @@ -10,6 +10,7 @@ # ============================================================================================= import numpy as np + try: from openmm import unit except ImportError: # OpenMM < 7.6 @@ -26,15 +27,16 @@ # NETCDFIODRIVER TESTING FUNCTIONS # ============================================================================================= + def test_netcdf_driver_group_manipulation(): """Test that the NetCDFIODriver can create groups, rebind to groups, and that they are on the file""" with temporary_directory() as tmp_dir: - nc_io_driver = NetCDFIODriver(tmp_dir + 'test.nc') - group2 = nc_io_driver.get_directory('group1/group2') - group1 = nc_io_driver.get_directory('group1') + nc_io_driver = NetCDFIODriver(tmp_dir + "test.nc") + group2 = nc_io_driver.get_directory("group1/group2") + group1 = nc_io_driver.get_directory("group1") ncfile = nc_io_driver.ncfile - ncgroup1 = ncfile.groups['group1'] - ncgroup2 = ncfile.groups['group1'].groups['group2'] + ncgroup1 = ncfile.groups["group1"] + ncgroup2 = ncfile.groups["group1"].groups["group2"] assert group1 is ncgroup1 assert group2 is ncgroup2 @@ -42,29 +44,29 @@ def test_netcdf_driver_group_manipulation(): def test_netcdf_driver_dimension_manipulation(): """Test that the NetCDFIODriver can check and create dimensions""" with temporary_directory() as tmp_dir: - nc_io_driver = NetCDFIODriver(tmp_dir + '/test.nc') + nc_io_driver = NetCDFIODriver(tmp_dir + "/test.nc") NetCDFIODriver.check_scalar_dimension(nc_io_driver) NetCDFIODriver.check_iterable_dimension(nc_io_driver, length=4) NetCDFIODriver.check_infinite_dimension(nc_io_driver) ncfile = nc_io_driver.ncfile dims = ncfile.dimensions - assert 'scalar' in dims - assert 'iterable4' in dims - assert 'iteration' in dims + assert "scalar" in dims + assert "iterable4" in dims + assert "iteration" in dims def test_netcdf_driver_metadata_creation(): """Test that the NetCDFIODriver can create metadata on different objects""" with temporary_directory() as tmp_dir: - nc_io_driver = NetCDFIODriver(tmp_dir + '/test.nc') - group1 = nc_io_driver.get_directory('group1') - nc_io_driver.add_metadata('root_metadata', 'IAm(G)Root!') - nc_io_driver.add_metadata('group_metadata', 'group1_metadata', path='/group1') + nc_io_driver = NetCDFIODriver(tmp_dir + "/test.nc") + group1 = nc_io_driver.get_directory("group1") + nc_io_driver.add_metadata("root_metadata", "IAm(G)Root!") + nc_io_driver.add_metadata("group_metadata", "group1_metadata", path="/group1") ncfile = nc_io_driver.ncfile - nc_metadata = ncfile.getncattr('root_metadata') - group_metadata = group1.getncattr('group_metadata') - assert nc_metadata == 'IAm(G)Root!' - assert group_metadata == 'group1_metadata' + nc_metadata = ncfile.getncattr("root_metadata") + group_metadata = group1.getncattr("group_metadata") + assert nc_metadata == "IAm(G)Root!" + assert group_metadata == "group1_metadata" # ============================================================================================= @@ -75,14 +77,14 @@ def test_netcdf_driver_metadata_creation(): def generic_type_codec_check(input_data, with_append=True): """Generic type codec test to ensure all callable functions are working""" with temporary_directory() as tmp_dir: - file_path = tmp_dir + '/test.nc' + file_path = tmp_dir + "/test.nc" nc_io_driver = NetCDFIODriver(file_path) input_type = type(input_data) # Create a write and an append of the data - write_path = 'data_write' + write_path = "data_write" data_write = nc_io_driver.create_storage_variable(write_path, input_type) if with_append: - append_path = 'group1/data_append' + append_path = "group1/data_append" data_append = nc_io_driver.create_storage_variable(append_path, input_type) # Store initial data (unbound write/append) data_write.write(input_data) @@ -115,7 +117,7 @@ def generic_type_codec_check(input_data, with_append=True): if with_append: del data_append, data_append_out # Reopen and test reading actions - nc_io_driver = NetCDFIODriver(file_path, access_mode='r') + nc_io_driver = NetCDFIODriver(file_path, access_mode="r") data_write = nc_io_driver.get_storage_variable(write_path) if with_append: data_append = nc_io_driver.get_storage_variable(append_path) @@ -130,7 +132,9 @@ def generic_type_codec_check(input_data, with_append=True): assert np.all(data_write_out == input_data) if with_append: try: - for key in data_write_out.keys(): # Must act on the data_write since it has the .keys method + for key in ( + data_write_out.keys() + ): # Must act on the data_write since it has the .keys method assert np.all(data_append_out[0][key] == input_data[key]) assert np.all(data_append_out[1][key] == input_data[key]) except AttributeError: @@ -141,11 +145,11 @@ def generic_type_codec_check(input_data, with_append=True): def generic_append_to_check(input_data, overwrite_data): """Generic function to test replacing data of appended dimension""" with temporary_directory() as tmp_dir: - file_path = tmp_dir + '/test.nc' + file_path = tmp_dir + "/test.nc" nc_io_driver = NetCDFIODriver(file_path) input_type = type(input_data) # Create a write and an append of the data - append_path = 'data_append' + append_path = "data_append" data_append = nc_io_driver.create_storage_variable(append_path, input_type) # Append data 3 times for i in range(3): @@ -154,11 +158,15 @@ def generic_append_to_check(input_data, overwrite_data): data_append.write(overwrite_data, at_index=1) data_append_out = data_append.read() try: - for key in input_data.keys(): # Must act on the data_write since it has the .keys method + for key in ( + input_data.keys() + ): # Must act on the data_write since it has the .keys method assert np.all(data_append_out[0][key] == input_data[key]) assert np.all(data_append_out[2][key] == input_data[key]) assert np.all(data_append_out[1][key] == overwrite_data[key]) - assert set(input_data.keys()) == set(data_append_out[0].keys()) # Assert keys match + assert set(input_data.keys()) == set( + data_append_out[0].keys() + ) # Assert keys match assert set(input_data.keys()) == set(data_append_out[2].keys()) except AttributeError: assert np.all(data_append_out[0] == input_data) @@ -184,9 +192,9 @@ def test_netcdf_float_type_codec(): def test_netcdf_string_type_codec(): """Test that the String type codec can read/write/append""" - input_data = 'four point oh' + input_data = "four point oh" generic_type_codec_check(input_data) - overwrite_data = 'five point not' + overwrite_data = "five point not" generic_append_to_check(input_data, overwrite_data) @@ -229,21 +237,21 @@ def test_netcdf_quantity_type_codec(): def test_netcdf_dictionary_type_codec(): """Test that the dictionary type codec can read/write/append with various unit and _value types""" input_data = { - 'count': 4, - 'ratio': 0.4, - 'name': 'four', - 'repeated': [4, 4, 4], - 'temperature': 4 * unit.kelvin, - 'box_vectors': (np.eye(3) * 4.0) * unit.nanometer + "count": 4, + "ratio": 0.4, + "name": "four", + "repeated": [4, 4, 4], + "temperature": 4 * unit.kelvin, + "box_vectors": (np.eye(3) * 4.0) * unit.nanometer, } generic_type_codec_check(input_data) overwrite_data = { - 'count': 5, - 'ratio': 0.5, - 'name': 'five', - 'repeated': [5, 5, 5], - 'temperature': 5 * unit.kelvin, - 'box_vectors': (np.eye(3) * 5.0) * unit.nanometer + "count": 5, + "ratio": 0.5, + "name": "five", + "repeated": [5, 5, 5], + "temperature": 5 * unit.kelvin, + "box_vectors": (np.eye(3) * 5.0) * unit.nanometer, } generic_append_to_check(input_data, overwrite_data) @@ -252,12 +260,12 @@ def test_netcdf_dictionary_type_codec(): def test_write_at_index_must_exist(): """Ensure that the write(data, at_index) must exist first""" with temporary_directory() as tmp_dir: - file_path = tmp_dir + '/test.nc' + file_path = tmp_dir + "/test.nc" nc_io_driver = NetCDFIODriver(file_path) input_data = 4 input_type = type(input_data) # Create a write and an append of the data - append_path = 'data_append' + append_path = "data_append" data_append = nc_io_driver.create_storage_variable(append_path, input_type) data_append.write(input_data, at_index=0) @@ -266,12 +274,12 @@ def test_write_at_index_must_exist(): def test_write_at_index_is_bound(): """Ensure that the write(data, at_index) cannot write to an index beyond""" with temporary_directory() as tmp_dir: - file_path = tmp_dir + '/test.nc' + file_path = tmp_dir + "/test.nc" nc_io_driver = NetCDFIODriver(file_path) input_data = 4 input_type = type(input_data) # Create a write and an append of the data - append_path = 'data_append' + append_path = "data_append" data_append = nc_io_driver.create_storage_variable(append_path, input_type) data_append.append(input_data) # Creates the first data data_append.write(input_data, at_index=1) # should fail for out of bounds index diff --git a/openmmtools/tests/test_testsystems.py b/openmmtools/tests/test_testsystems.py index ef4fcd29b..8a057815d 100644 --- a/openmmtools/tests/test_testsystems.py +++ b/openmmtools/tests/test_testsystems.py @@ -15,18 +15,20 @@ from functools import partial + def _equiv_topology(top_1, top_2): """Compare topologies using string reps of atoms and bonds""" - for (b1, b2) in zip(top_1.bonds(), top_2.bonds()): + for b1, b2 in zip(top_1.bonds(), top_2.bonds()): if str(b1) != str(b2): return False - for (a1, a2) in zip(top_1.atoms(), top_2.atoms()): + for a1, a2 in zip(top_1.atoms(), top_2.atoms()): if str(a1) != str(a2): return False return True + def get_all_subclasses(cls): """ Return all subclasses of a specified class. @@ -51,22 +53,23 @@ def get_all_subclasses(cls): return all_subclasses + def test_get_data_filename(): - """Testing retrieval of data files shipped with distro. - """ - relative_path = 'data/alanine-dipeptide-gbsa/alanine-dipeptide.prmtop' + """Testing retrieval of data files shipped with distro.""" + relative_path = "data/alanine-dipeptide-gbsa/alanine-dipeptide.prmtop" filename = testsystems.get_data_filename(relative_path) if not os.path.exists(filename): raise Exception("Could not locate data files. Expected %s" % relative_path) + def test_subrandom_particle_positions(): - """Testing deterministic subrandom particle position assignment. - """ + """Testing deterministic subrandom particle position assignment.""" # Test halton sequence. - x = testsystems.halton_sequence(2,100) + x = testsystems.halton_sequence(2, 100) # Test Sobol. from openmmtools import sobol + x = sobol.i4_sobol_generate(3, 100, 1) # Test subrandom positions. @@ -74,19 +77,24 @@ def test_subrandom_particle_positions(): box_vectors = openmm.System().getDefaultPeriodicBoxVectors() positions = testsystems.subrandom_particle_positions(nparticles, box_vectors) + def check_properties(testsystem): class_name = testsystem.__class__.__name__ property_list = testsystem.analytical_properties - state = testsystems.ThermodynamicState(temperature=300.0*unit.kelvin, pressure=1.0*unit.atmosphere) + state = testsystems.ThermodynamicState( + temperature=300.0 * unit.kelvin, pressure=1.0 * unit.atmosphere + ) if len(property_list) > 0: for property_name in property_list: - method = getattr(testsystem, 'get_' + property_name) - logging.info("%32s . %32s : %32s" % (class_name, property_name, str(method(state)))) + method = getattr(testsystem, "get_" + property_name) + logging.info( + "%32s . %32s : %32s" % (class_name, property_name, str(method(state))) + ) return + def test_properties_all_testsystems(): - """Testing computation of analytic properties for all systems. - """ + """Testing computation of analytic properties for all systems.""" testsystem_classes = get_all_subclasses(testsystems.TestSystem) logging.info("Testing analytical property computation:") for testsystem_class in testsystem_classes: @@ -102,10 +110,17 @@ def test_properties_all_testsystems(): logging.info(f.description) yield f + fast_testsystems = [ "HarmonicOscillator", "PowerOscillator", - "Diatom", "DiatomicFluid", "UnconstrainedDiatomicFluid", "ConstrainedDiatomicFluid", "DipolarFluid", "UnconstrainedDipolarFluid", "ConstrainedDipolarFluid", + "Diatom", + "DiatomicFluid", + "UnconstrainedDiatomicFluid", + "ConstrainedDiatomicFluid", + "DipolarFluid", + "UnconstrainedDipolarFluid", + "ConstrainedDipolarFluid", "ConstraintCoupledHarmonicOscillator", "HarmonicOscillatorArray", "SodiumChlorideCrystal", @@ -113,18 +128,37 @@ def test_properties_all_testsystems(): "LennardJonesFluid", "LennardJonesGrid", "CustomLennardJonesFluidMixture", - "WCAFluid", "DoubleWellDimer_WCAFluid", "DoubleWellChain_WCAFluid", + "WCAFluid", + "DoubleWellDimer_WCAFluid", + "DoubleWellChain_WCAFluid", "IdealGas", - "WaterBox", "FlexibleWaterBox", "FourSiteWaterBox", "FiveSiteWaterBox", "DischargedWaterBox", "DischargedWaterBoxHsites", "AlchemicalWaterBox", - "AlanineDipeptideVacuum", "AlanineDipeptideImplicit", + "WaterBox", + "FlexibleWaterBox", + "FourSiteWaterBox", + "FiveSiteWaterBox", + "DischargedWaterBox", + "DischargedWaterBoxHsites", + "AlchemicalWaterBox", + "AlanineDipeptideVacuum", + "AlanineDipeptideImplicit", "MethanolBox", "MolecularIdealGas", "CustomGBForceSystem", "AlchemicalLennardJonesCluster", "LennardJonesPair", - "TolueneVacuum", "TolueneImplicit", "TolueneImplicitHCT", "TolueneImplicitOBC1", "TolueneImplicitOBC2", "TolueneImplicitGBn", "TolueneImplicitGBn2", - "HostGuestVacuum", "HostGuestImplicit", "HostGuestImplicitHCT", 'HostGuestImplicitOBC1', - ] + "TolueneVacuum", + "TolueneImplicit", + "TolueneImplicitHCT", + "TolueneImplicitOBC1", + "TolueneImplicitOBC2", + "TolueneImplicitGBn", + "TolueneImplicitGBn2", + "HostGuestVacuum", + "HostGuestImplicit", + "HostGuestImplicitHCT", + "HostGuestImplicitOBC1", +] + def check_potential_energy(system, positions): """ @@ -156,14 +190,16 @@ def check_potential_energy(system, positions): # Clean up del context, integrator + def test_energy_all_testsystems(skip_slow_tests=True): - """Testing computation of potential energy for all systems. - """ + """Testing computation of potential energy for all systems.""" testsystem_classes = get_all_subclasses(testsystems.TestSystem) for testsystem_class in testsystem_classes: class_name = testsystem_class.__name__ if skip_slow_tests and not (class_name in fast_testsystems): - logging.info("Skipping potential energy test for testsystem %s." % class_name) + logging.info( + "Skipping potential energy test for testsystem %s." % class_name + ) continue # Create test. @@ -177,9 +213,9 @@ def test_energy_all_testsystems(skip_slow_tests=True): f.description = "Testing potential energy for testsystem %s" % class_name yield f + def check_topology(system, topology): - """Check the topology object contains the correct number of atoms. - """ + """Check the topology object contains the correct number of atoms.""" # Get number of particles from topology. nparticles_topology = 0 @@ -189,11 +225,11 @@ def check_topology(system, topology): # Get number of particles from system. nparticles_system = system.getNumParticles() - assert (nparticles_topology==nparticles_system) + assert nparticles_topology == nparticles_system + def test_topology_all_testsystems(): - """Testing topology contains correct number of atoms for all systems. - """ + """Testing topology contains correct number of atoms for all systems.""" testsystem_classes = get_all_subclasses(testsystems.TestSystem) for testsystem_class in testsystem_classes: @@ -210,6 +246,7 @@ def test_topology_all_testsystems(): f.description = "Testing topology for testsystem %s" % class_name yield f + def test_dw_systems_as_wca(): # check that the double-well systems are equivalent to WCA fluid in # certain limits @@ -221,6 +258,7 @@ def test_dw_systems_as_wca(): assert _equiv_topology(chain_1.topology, wca.topology) assert _equiv_topology(chain_0.topology, wca.topology) + def test_dw_systems_1_dimer(): # check that the double-well systems are equivalent when there's only # one dimer pair @@ -228,12 +266,14 @@ def test_dw_systems_1_dimer(): chain = testsystems.DoubleWellChain_WCAFluid(nchained=2) assert _equiv_topology(dimers.topology, chain.topology) + def test_double_well_dimer_errors(): with assert_raises(ValueError) as context: testsystems.DoubleWellDimer_WCAFluid(ndimers=-1) with assert_raises(ValueError) as context: testsystems.DoubleWellDimer_WCAFluid(ndimers=6, nparticles=10) + def test_double_well_chain_errors(): with assert_raises(ValueError) as context: testsystems.DoubleWellChain_WCAFluid(nchained=-1) diff --git a/openmmtools/tests/test_utils.py b/openmmtools/tests/test_utils.py index 1ad4c927e..925bc6c0d 100644 --- a/openmmtools/tests/test_utils.py +++ b/openmmtools/tests/test_utils.py @@ -8,6 +8,7 @@ Test utility functions in utils.py. """ + import abc import copy @@ -33,81 +34,109 @@ # TEST CONTEXT UTILITIES # ============================================================================= + def test_platform_supports_precision(): """Test that platform_supports_precision works correctly.""" for platform_index in range(openmm.Platform.getNumPlatforms()): platform = openmm.Platform.getPlatform(platform_index) platform_name = platform.getName() - supported_precisions = { precision for precision in ['single', 'mixed', 'double'] if platform_supports_precision(platform, precision) } - if platform_name == 'Reference': - if supported_precisions != set(['double']): - raise Exception(f"'Reference' platform should only support 'double' precision, but platform_supports_precision reports {supported_precisions}") - if platform_name == 'CUDA': - if supported_precisions != set(['single', 'mixed', 'double']): - raise Exception(f"'CUDA' platform should support 'mixed' precision, but platform_supports_precision reports {supported_precisions}") - if platform_name == 'CPU': - if supported_precisions != set(['mixed']): - raise Exception(f"'CPU' platform should support 'mixed' precision, but platform_supports_precision reports {supported_precisions}") + supported_precisions = { + precision + for precision in ["single", "mixed", "double"] + if platform_supports_precision(platform, precision) + } + if platform_name == "Reference": + if supported_precisions != {"double"}: + raise Exception( + f"'Reference' platform should only support 'double' precision, but platform_supports_precision reports {supported_precisions}" + ) + if platform_name == "CUDA": + if supported_precisions != {"single", "mixed", "double"}: + raise Exception( + f"'CUDA' platform should support 'mixed' precision, but platform_supports_precision reports {supported_precisions}" + ) + if platform_name == "CPU": + if supported_precisions != {"mixed"}: + raise Exception( + f"'CPU' platform should support 'mixed' precision, but platform_supports_precision reports {supported_precisions}" + ) def test_string_platform_supports_precision(): """Test that if we use a string for the platform name, it works""" assert platform_supports_precision("CPU", "mixed") + # ============================================================================= # TEST STRING MATHEMATICAL EXPRESSION PARSING UTILITIES # ============================================================================= + def test_sanitize_expression(): """Test that reserved keywords are substituted correctly.""" - prefix = '_sanitized__' + prefix = "_sanitized__" # Generate a bunch of test cases for each supported reserved keyword. test_cases = {} for word in _RESERVED_WORDS_PATTERNS: s_word = prefix + word # sanitized word - test_cases[word] = [(word, s_word), - ('(' + word + ')', '(' + s_word + ')'), # parenthesis - ('( ' + word + ' )', '( ' + s_word + ' )'), # parenthesis w/ spaces - (word + '_suffix', word + '_suffix'), # w/ suffix - ('prefix_' + word, 'prefix_' + word), # w/ prefix - ('2+' + word + '-' + word + '_suffix/(' + word + ' - 3)', # expression - '2+' + s_word + '-' + word + '_suffix/(' + s_word + ' - 3)')] + test_cases[word] = [ + (word, s_word), + ("(" + word + ")", "(" + s_word + ")"), # parenthesis + ("( " + word + " )", "( " + s_word + " )"), # parenthesis w/ spaces + (word + "_suffix", word + "_suffix"), # w/ suffix + ("prefix_" + word, "prefix_" + word), # w/ prefix + ( + "2+" + word + "-" + word + "_suffix/(" + word + " - 3)", # expression + "2+" + s_word + "-" + word + "_suffix/(" + s_word + " - 3)", + ), + ] # Run test cases. for word in _RESERVED_WORDS_PATTERNS: variables = {word: 5.0} for expression, result in test_cases[word]: - sanitized_expression, sanitized_variables = sanitize_expression(expression, variables) - assert sanitized_expression == result, '{}, {}'.format(sanitized_expression, result) + sanitized_expression, sanitized_variables = sanitize_expression( + expression, variables + ) + assert sanitized_expression == result, f"{sanitized_expression}, {result}" assert word not in sanitized_variables assert sanitized_variables[prefix + word] == 5.0 def test_math_eval(): """Test math_eval method.""" - test_cases = [('1 + 3', None, 4), - ('x + y', {'x': 1.5, 'y': 2}, 3.5), - ('(x + lambda) / z * 4', {'x': 1, 'lambda': 2, 'z': 3}, 4.0), - ('-((x + y) / z * 4)**2', {'x': 1, 'y': 2, 'z': 3}, -16.0), - ('ceil(0.8) + acos(x) + step(0.5 - x) + step(0.5)', {'x': 1}, 2), - ('step_hm(x)', {'x': 0}, 0.5), - ('myset & myset2', {'myset': {1,2,3}, 'myset2': {2,3,4}}, {2, 3}), - ('myset or myset2', {'myset': {1,2,3}, 'myset2': {2,3,4}}, {1, 2, 3, 4}), - ('(myset or my2set) & myset3', {'myset': {1, 2}, 'my2set': {3, 4}, 'myset3': {2, 3}}, {2, 3})] + test_cases = [ + ("1 + 3", None, 4), + ("x + y", {"x": 1.5, "y": 2}, 3.5), + ("(x + lambda) / z * 4", {"x": 1, "lambda": 2, "z": 3}, 4.0), + ("-((x + y) / z * 4)**2", {"x": 1, "y": 2, "z": 3}, -16.0), + ("ceil(0.8) + acos(x) + step(0.5 - x) + step(0.5)", {"x": 1}, 2), + ("step_hm(x)", {"x": 0}, 0.5), + ("myset & myset2", {"myset": {1, 2, 3}, "myset2": {2, 3, 4}}, {2, 3}), + ("myset or myset2", {"myset": {1, 2, 3}, "myset2": {2, 3, 4}}, {1, 2, 3, 4}), + ( + "(myset or my2set) & myset3", + {"myset": {1, 2}, "my2set": {3, 4}, "myset3": {2, 3}}, + {2, 3}, + ), + ] for expression, variables, result in test_cases: evaluated_expression = math_eval(expression, variables) - assert evaluated_expression == result, '{}, {}, {}'.format( - expression, evaluated_expression, result) + assert evaluated_expression == result, "{}, {}, {}".format( + expression, evaluated_expression, result + ) # ============================================================================= # TEST QUANTITY UTILITIES # ============================================================================= + def test_tracked_quantity(): """Test TrackedQuantity objects.""" + def reset(q): assert tracked_quantity.has_changed is True tracked_quantity.has_changed = False @@ -134,15 +163,15 @@ def reset(q): assert len(tracked_quantity) == 2 reset(tracked_quantity) - tracked_quantity.append(10.0*u) + tracked_quantity.append(10.0 * u) assert len(tracked_quantity) == 3 reset(tracked_quantity) - tracked_quantity.extend([11.0, 12.0]*u) + tracked_quantity.extend([11.0, 12.0] * u) assert len(tracked_quantity) == 5 reset(tracked_quantity) - element = 15.0*u + element = 15.0 * u tracked_quantity.insert(1, element) assert len(tracked_quantity) == 6 reset(tracked_quantity) @@ -157,44 +186,55 @@ def reset(q): else: # Check that numpy views are handled correctly. view = tracked_quantity[:3] - view[0] = 20.0*u - assert tracked_quantity[0] == 20.0*u + view[0] = 20.0 * u + assert tracked_quantity[0] == 20.0 * u reset(tracked_quantity) view2 = view[1:] - view2[0] = 30.0*u - assert tracked_quantity[1] == 30.0*u + view2[0] = 30.0 * u + assert tracked_quantity[1] == 30.0 * u reset(tracked_quantity) def test_is_quantity_close(): """Test is_quantity_close method.""" # (quantity1, quantity2, test_result) - test_cases = [(300.0*unit.kelvin, 300.000000004*unit.kelvin, True), - (300.0*unit.kelvin, 300.00000004*unit.kelvin, False), - (1.01325*unit.bar, 1.01325000006*unit.bar, True), - (1.01325*unit.bar, 1.0132500006*unit.bar, False)] + test_cases = [ + (300.0 * unit.kelvin, 300.000000004 * unit.kelvin, True), + (300.0 * unit.kelvin, 300.00000004 * unit.kelvin, False), + (1.01325 * unit.bar, 1.01325000006 * unit.bar, True), + (1.01325 * unit.bar, 1.0132500006 * unit.bar, False), + ] - err_msg = 'obtained: {}, expected: {} (quantity1: {}, quantity2: {})' + err_msg = "obtained: {}, expected: {} (quantity1: {}, quantity2: {})" for quantity1, quantity2, test_result in test_cases: - msg = "Test failed: ({}, {}, {})".format(quantity1, quantity2, test_result) + msg = f"Test failed: ({quantity1}, {quantity2}, {test_result})" assert is_quantity_close(quantity1, quantity2) == test_result, msg # Passing quantities with different units raise an exception. with nose.tools.assert_raises(TypeError): - is_quantity_close(300*unit.kelvin, 1*unit.atmosphere) + is_quantity_close(300 * unit.kelvin, 1 * unit.atmosphere) def test_quantity_from_string(): """Test that quantities can be derived from strings""" test_strings = [ - ('3', 3.0), # Handle basic float - ('meter', unit.meter), # Handle basic unit object - ('300 * kelvin', 300 * unit.kelvin), # Handle standard Quantity - ('" 0.3 * kilojoules_per_mole / watt**3"', 0.3 * unit.kilojoules_per_mole / unit.watt ** 3), # Handle division, exponent, nested string - ('1*meter / (4*second)', 0.25 * unit.meter / unit.second), # Handle compound math and parenthesis - ('1 * watt**2 /((1* kelvin)**3 / gram)', 1 * (unit.watt ** 2) * (unit.gram) / (unit.kelvin ** 3)), # Handle everything - ('/watt', unit.watt ** -1) # Handle special "inverse unit" case + ("3", 3.0), # Handle basic float + ("meter", unit.meter), # Handle basic unit object + ("300 * kelvin", 300 * unit.kelvin), # Handle standard Quantity + ( + '" 0.3 * kilojoules_per_mole / watt**3"', + 0.3 * unit.kilojoules_per_mole / unit.watt**3, + ), # Handle division, exponent, nested string + ( + "1*meter / (4*second)", + 0.25 * unit.meter / unit.second, + ), # Handle compound math and parenthesis + ( + "1 * watt**2 /((1* kelvin)**3 / gram)", + 1 * (unit.watt**2) * (unit.gram) / (unit.kelvin**3), + ), # Handle everything + ("/watt", unit.watt**-1), # Handle special "inverse unit" case ] for test_string in test_strings: @@ -206,21 +246,23 @@ def test_quantity_from_string(): # TEST SERIALIZATION UTILITIES # ============================================================================= -class MyClass(object): + +class MyClass: """Example of serializable class used by test_serialize_deserialize.""" + def __init__(self, a, b): self.a = a self.b = b def __getstate__(self): serialization = dict() - serialization['a'] = self.a - serialization['b'] = self.b + serialization["a"] = self.a + serialization["b"] = self.b return serialization def __setstate__(self, serialization): - self.a = serialization['a'] - self.b = serialization['b'] + self.a = serialization["a"] + self.b = serialization["b"] def add(self): return self.a + self.b @@ -233,9 +275,12 @@ def test_serialize_deserialize(): # Test serialization. serialization = serialize(my_instance) - expected_serialization = {'_serialized__module_name': 'test_utils', - '_serialized__class_name': 'MyClass', - 'a': 4, 'b': 5} + expected_serialization = { + "_serialized__module_name": "test_utils", + "_serialized__class_name": "MyClass", + "a": 4, + "b": 5, + } assert serialization == expected_serialization # Test deserialization. @@ -251,35 +296,44 @@ def test_serialize_deserialize(): # TEST METACLASS UTILITIES # ============================================================================= + def test_subhooked_abcmeta(): """Test class SubhookedABCMeta.""" + # Define an interface class IInterface(SubhookedABCMeta): @abc.abstractmethod - def my_method(self): pass + def my_method(self): + pass @abc.abstractproperty - def my_property(self): pass + def my_property(self): + pass @staticmethod @abc.abstractmethod - def my_static_method(): pass + def my_static_method(): + pass # Define object implementing the interface with duck typing - class InterfaceImplementation(object): - def my_method(self): pass + class InterfaceImplementation: + def my_method(self): + pass - def my_property(self): pass + def my_property(self): + pass @staticmethod - def my_static_method(): pass + def my_static_method(): + pass implementation_instance = InterfaceImplementation() assert isinstance(implementation_instance, IInterface) # Define incomplete implementation - class WrongInterfaceImplementation(object): - def my_method(self): pass + class WrongInterfaceImplementation: + def my_method(self): + pass implementation_instance = WrongInterfaceImplementation() assert not isinstance(implementation_instance, IInterface) @@ -288,10 +342,10 @@ def my_method(self): pass def test_find_all_subclasses(): """Test find_all_subclasses() function.""" # Create Python2-3 compatible abstract classes. - ABC = abc.ABCMeta('ABC', (), {}) + ABC = abc.ABCMeta("ABC", (), {}) # Diamond inheritance. - class A(object): + class A: pass class B(A): @@ -321,23 +375,29 @@ def m(self): # RESTORABLE OPENMM OBJECT # ============================================================================= -class TestRestorableOpenMMObject(object): + +class TestRestorableOpenMMObject: """Test the RestorableOpenMMObject utility class.""" @classmethod def setup_class(cls): """Example restorable classes for tests.""" - class DummyRestorableCustomForce(RestorableOpenMMObject, openmm.CustomBondForce): + + class DummyRestorableCustomForce( + RestorableOpenMMObject, openmm.CustomBondForce + ): def __init__(self, *args, **kwargs): - super(DummyRestorableCustomForce, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) - class DummyRestorableCustomIntegrator(RestorableOpenMMObject, openmm.CustomIntegrator): + class DummyRestorableCustomIntegrator( + RestorableOpenMMObject, openmm.CustomIntegrator + ): def __init__(self, *args, **kwargs): - super(DummyRestorableCustomIntegrator, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) - cls.dummy_force = DummyRestorableCustomForce('0.0;') - cls.dummier_force = DummyRestorableCustomForce('0.0;') - cls.dummy_integrator= DummyRestorableCustomIntegrator(2.0*unit.femtoseconds) + cls.dummy_force = DummyRestorableCustomForce("0.0;") + cls.dummier_force = DummyRestorableCustomForce("0.0;") + cls.dummy_integrator = DummyRestorableCustomIntegrator(2.0 * unit.femtoseconds) def test_restorable_openmm_object(self): """Test RestorableOpenMMObject classes can be serialized and copied correctly.""" @@ -346,13 +406,15 @@ def test_restorable_openmm_object(self): test_cases = [ (copy.deepcopy(self.dummy_force), True), (copy.deepcopy(self.dummy_integrator), True), - (openmm.CustomBondForce('K'), False) + (openmm.CustomBondForce("K"), False), ] for openmm_object, is_restorable in test_cases: assert RestorableOpenMMObject.is_restorable(openmm_object) is is_restorable - err_msg = '{}: {}, {}'.format(openmm_object, RestorableOpenMMObject.restore_interface(openmm_object), is_restorable) - assert RestorableOpenMMObject.restore_interface(openmm_object) is is_restorable, err_msg + err_msg = f"{openmm_object}: {RestorableOpenMMObject.restore_interface(openmm_object)}, {is_restorable}" + assert ( + RestorableOpenMMObject.restore_interface(openmm_object) is is_restorable + ), err_msg # Serializing/deserializing restore the class correctly. serialization = openmm.XmlSerializer.serialize(openmm_object) @@ -365,7 +427,7 @@ def test_restorable_openmm_object(self): copied_object = copy.deepcopy(deserialized_object) if is_restorable: assert type(copied_object) is type(openmm_object) - assert hasattr(copied_object, '_monkey_patching') + assert hasattr(copied_object, "_monkey_patching") def test_multiple_object_context_creation(self): """Test that it is possible to create contexts with multiple restorable objects. @@ -380,7 +442,7 @@ def test_multiple_object_context_creation(self): """ system = openmm.System() for i in range(4): - system.addParticle(1.0*unit.atom_mass_units) + system.addParticle(1.0 * unit.atom_mass_units) system.addForce(copy.deepcopy(self.dummy_force)) system.addForce(copy.deepcopy(self.dummier_force)) context = openmm.Context(system, copy.deepcopy(self.dummy_integrator)) @@ -434,7 +496,7 @@ def test_context_from_restorable_with_different_globals(self): def test_restorable_openmm_object_failure(self): """An exception is raised if the class has a restorable hash but the class can't be found.""" - force = openmm.CustomBondForce('0.0') + force = openmm.CustomBondForce("0.0") force_hash_parameter_name = self.dummy_force._hash_parameter_name force.addGlobalParameter(force_hash_parameter_name, 15.0) with nose.tools.assert_raises(RestorableOpenMMObjectError): @@ -445,9 +507,11 @@ def test_restorable_openmm_object_hash_collisions(self): restorable_classes = find_all_subclasses(RestorableOpenMMObject) # Test pre-condition: make sure that our custom forces and integrators are loaded. - restorable_classes_names = {restorable_cls.__name__ for restorable_cls in restorable_classes} - assert 'ThermostatedIntegrator' in restorable_classes_names - assert 'RadiallySymmetricRestraintForce' in restorable_classes_names + restorable_classes_names = { + restorable_cls.__name__ for restorable_cls in restorable_classes + } + assert "ThermostatedIntegrator" in restorable_classes_names + assert "RadiallySymmetricRestraintForce" in restorable_classes_names # Test hash collisions. all_hashes = set() @@ -457,10 +521,11 @@ def test_restorable_openmm_object_hash_collisions(self): assert len(all_hashes) == len(restorable_classes) -class TestEquilibrationUtils(object): +class TestEquilibrationUtils: """ Class for testing equilibration utility functions in openmmtools.utils.equilibration """ + def test_gentle_equilibration_setup(self): """ Test gentle equilibration implementation using the Alanine dipeptide in explicit solvent @@ -480,58 +545,110 @@ def test_gentle_equilibration_setup(self): topology = test_system.topology stages = [ - {'EOM': 'minimize', 'n_steps': 1, 'temperature': 300 * unit.kelvin, 'ensemble': None, - 'restraint_selection': 'protein and not type H', - 'force_constant': 100 * unit.kilocalories_per_mole / unit.angstrom ** 2, - 'collision_rate': 2 / unit.picoseconds, - 'timestep': 1 * unit.femtoseconds}, - {'EOM': 'MD_interpolate', 'n_steps': 1, 'temperature': 100 * unit.kelvin, - 'temperature_end': 300 * unit.kelvin, - 'ensemble': 'NVT', 'restraint_selection': 'protein and not type H', - 'force_constant': 100 * unit.kilocalories_per_mole / unit.angstrom ** 2, - 'collision_rate': 10 / unit.picoseconds, - 'timestep': 1 * unit.femtoseconds}, - {'EOM': 'MD', 'n_steps': 1, 'temperature': 300, 'ensemble': 'NPT', - 'restraint_selection': 'protein and not type H', - 'force_constant': 100 * unit.kilocalories_per_mole / unit.angstrom ** 2, - 'collision_rate': 10 / unit.picoseconds, - 'timestep': 1 * unit.femtoseconds}, - {'EOM': 'MD', 'n_steps': 1, 'temperature': 300 * unit.kelvin, 'ensemble': 'NPT', - 'restraint_selection': 'protein and not type H', - 'force_constant': 10 * unit.kilocalories_per_mole / unit.angstrom ** 2, - 'collision_rate': 2 / unit.picoseconds, - 'timestep': 1 * unit.femtoseconds}, - {'EOM': 'minimize', 'n_steps': 1, 'temperature': 300 * unit.kelvin, 'ensemble': None, - 'restraint_selection': 'protein and backbone', - 'force_constant': 10 * unit.kilocalories_per_mole / unit.angstrom ** 2, - 'collision_rate': 2 / unit.picoseconds, - 'timestep': 1 * unit.femtoseconds}, - {'EOM': 'MD', 'n_steps': 1, 'temperature': 300 * unit.kelvin, 'ensemble': 'NPT', - 'restraint_selection': 'protein and backbone', - 'force_constant': 10 * unit.kilocalories_per_mole / unit.angstrom ** 2, - 'collision_rate': 2 / unit.picoseconds, - 'timestep': 1 * unit.femtoseconds}, - {'EOM': 'MD', 'n_steps': 1, 'temperature': 300 * unit.kelvin, 'ensemble': 'NPT', - 'restraint_selection': 'protein and backbone', - 'force_constant': 1 * unit.kilocalories_per_mole / unit.angstrom ** 2, - 'collision_rate': 2 / unit.picoseconds, - 'timestep': 1 * unit.femtoseconds}, - {'EOM': 'MD', 'n_steps': 1, 'temperature': 300 * unit.kelvin, 'ensemble': 'NPT', - 'restraint_selection': 'protein and backbone', - 'force_constant': 0.1 * unit.kilocalories_per_mole / unit.angstrom ** 2, - 'collision_rate': 2 / unit.picoseconds, - 'timestep': 1 * unit.femtoseconds}, - {'EOM': 'MD', 'n_steps': 1, 'temperature': 300 * unit.kelvin, 'ensemble': 'NPT', - 'restraint_selection': None, - 'force_constant': 0 * unit.kilocalories_per_mole / unit.angstrom ** 2, - 'collision_rate': 2 / unit.picoseconds, - 'timestep': 2 * unit.femtoseconds}, + { + "EOM": "minimize", + "n_steps": 1, + "temperature": 300 * unit.kelvin, + "ensemble": None, + "restraint_selection": "protein and not type H", + "force_constant": 100 * unit.kilocalories_per_mole / unit.angstrom**2, + "collision_rate": 2 / unit.picoseconds, + "timestep": 1 * unit.femtoseconds, + }, + { + "EOM": "MD_interpolate", + "n_steps": 1, + "temperature": 100 * unit.kelvin, + "temperature_end": 300 * unit.kelvin, + "ensemble": "NVT", + "restraint_selection": "protein and not type H", + "force_constant": 100 * unit.kilocalories_per_mole / unit.angstrom**2, + "collision_rate": 10 / unit.picoseconds, + "timestep": 1 * unit.femtoseconds, + }, + { + "EOM": "MD", + "n_steps": 1, + "temperature": 300, + "ensemble": "NPT", + "restraint_selection": "protein and not type H", + "force_constant": 100 * unit.kilocalories_per_mole / unit.angstrom**2, + "collision_rate": 10 / unit.picoseconds, + "timestep": 1 * unit.femtoseconds, + }, + { + "EOM": "MD", + "n_steps": 1, + "temperature": 300 * unit.kelvin, + "ensemble": "NPT", + "restraint_selection": "protein and not type H", + "force_constant": 10 * unit.kilocalories_per_mole / unit.angstrom**2, + "collision_rate": 2 / unit.picoseconds, + "timestep": 1 * unit.femtoseconds, + }, + { + "EOM": "minimize", + "n_steps": 1, + "temperature": 300 * unit.kelvin, + "ensemble": None, + "restraint_selection": "protein and backbone", + "force_constant": 10 * unit.kilocalories_per_mole / unit.angstrom**2, + "collision_rate": 2 / unit.picoseconds, + "timestep": 1 * unit.femtoseconds, + }, + { + "EOM": "MD", + "n_steps": 1, + "temperature": 300 * unit.kelvin, + "ensemble": "NPT", + "restraint_selection": "protein and backbone", + "force_constant": 10 * unit.kilocalories_per_mole / unit.angstrom**2, + "collision_rate": 2 / unit.picoseconds, + "timestep": 1 * unit.femtoseconds, + }, + { + "EOM": "MD", + "n_steps": 1, + "temperature": 300 * unit.kelvin, + "ensemble": "NPT", + "restraint_selection": "protein and backbone", + "force_constant": 1 * unit.kilocalories_per_mole / unit.angstrom**2, + "collision_rate": 2 / unit.picoseconds, + "timestep": 1 * unit.femtoseconds, + }, + { + "EOM": "MD", + "n_steps": 1, + "temperature": 300 * unit.kelvin, + "ensemble": "NPT", + "restraint_selection": "protein and backbone", + "force_constant": 0.1 * unit.kilocalories_per_mole / unit.angstrom**2, + "collision_rate": 2 / unit.picoseconds, + "timestep": 1 * unit.femtoseconds, + }, + { + "EOM": "MD", + "n_steps": 1, + "temperature": 300 * unit.kelvin, + "ensemble": "NPT", + "restraint_selection": None, + "force_constant": 0 * unit.kilocalories_per_mole / unit.angstrom**2, + "collision_rate": 2 / unit.picoseconds, + "timestep": 2 * unit.femtoseconds, + }, ] with temporary_directory() as tmp_path: outfile_path = f"{tmp_path}/outfile.cif" - run_gentle_equilibration(topology, positions, system, stages, outfile_path, platform_name="CPU", - save_box_vectors=False) + run_gentle_equilibration( + topology, + positions, + system, + stages, + outfile_path, + platform_name="CPU", + save_box_vectors=False, + ) # TODO: Marking as not a test until we solve our GPU CI @nottest @@ -557,53 +674,105 @@ def test_gentle_equilibration_cuda(self): topology = test_system.topology stages = [ - {'EOM': 'minimize', 'n_steps': 10000, 'temperature': 300 * unit.kelvin, 'ensemble': None, - 'restraint_selection': 'protein and not type H', - 'force_constant': 100 * unit.kilocalories_per_mole / unit.angstrom ** 2, - 'collision_rate': 2 / unit.picoseconds, - 'timestep': 1 * unit.femtoseconds}, - {'EOM': 'MD_interpolate', 'n_steps': 100000, 'temperature': 100 * unit.kelvin, - 'temperature_end': 300 * unit.kelvin, - 'ensemble': 'NVT', 'restraint_selection': 'protein and not type H', - 'force_constant': 100 * unit.kilocalories_per_mole / unit.angstrom ** 2, - 'collision_rate': 10 / unit.picoseconds, - 'timestep': 1 * unit.femtoseconds}, - {'EOM': 'MD', 'n_steps': 100000, 'temperature': 300, 'ensemble': 'NPT', - 'restraint_selection': 'protein and not type H', - 'force_constant': 100 * unit.kilocalories_per_mole / unit.angstrom ** 2, - 'collision_rate': 10 / unit.picoseconds, - 'timestep': 1 * unit.femtoseconds}, - {'EOM': 'MD', 'n_steps': 250000, 'temperature': 300 * unit.kelvin, 'ensemble': 'NPT', - 'restraint_selection': 'protein and not type H', - 'force_constant': 10 * unit.kilocalories_per_mole / unit.angstrom ** 2, - 'collision_rate': 2 / unit.picoseconds, - 'timestep': 1 * unit.femtoseconds}, - {'EOM': 'minimize', 'n_steps': 10000, 'temperature': 300 * unit.kelvin, 'ensemble': None, - 'restraint_selection': 'protein and backbone', - 'force_constant': 10 * unit.kilocalories_per_mole / unit.angstrom ** 2, - 'collision_rate': 2 / unit.picoseconds, - 'timestep': 1 * unit.femtoseconds}, - {'EOM': 'MD', 'n_steps': 100000, 'temperature': 300 * unit.kelvin, 'ensemble': 'NPT', - 'restraint_selection': 'protein and backbone', - 'force_constant': 10 * unit.kilocalories_per_mole / unit.angstrom ** 2, - 'collision_rate': 2 / unit.picoseconds, - 'timestep': 1 * unit.femtoseconds}, - {'EOM': 'MD', 'n_steps': 100000, 'temperature': 300 * unit.kelvin, 'ensemble': 'NPT', - 'restraint_selection': 'protein and backbone', - 'force_constant': 1 * unit.kilocalories_per_mole / unit.angstrom ** 2, - 'collision_rate': 2 / unit.picoseconds, - 'timestep': 1 * unit.femtoseconds}, - {'EOM': 'MD', 'n_steps': 100000, 'temperature': 300 * unit.kelvin, 'ensemble': 'NPT', - 'restraint_selection': 'protein and backbone', - 'force_constant': 0.1 * unit.kilocalories_per_mole / unit.angstrom ** 2, - 'collision_rate': 2 / unit.picoseconds, - 'timestep': 1 * unit.femtoseconds}, - {'EOM': 'MD', 'n_steps': 2500000, 'temperature': 300 * unit.kelvin, 'ensemble': 'NPT', - 'restraint_selection': None, - 'force_constant': 0 * unit.kilocalories_per_mole / unit.angstrom ** 2, - 'collision_rate': 2 / unit.picoseconds, - 'timestep': 2 * unit.femtoseconds}, + { + "EOM": "minimize", + "n_steps": 10000, + "temperature": 300 * unit.kelvin, + "ensemble": None, + "restraint_selection": "protein and not type H", + "force_constant": 100 * unit.kilocalories_per_mole / unit.angstrom**2, + "collision_rate": 2 / unit.picoseconds, + "timestep": 1 * unit.femtoseconds, + }, + { + "EOM": "MD_interpolate", + "n_steps": 100000, + "temperature": 100 * unit.kelvin, + "temperature_end": 300 * unit.kelvin, + "ensemble": "NVT", + "restraint_selection": "protein and not type H", + "force_constant": 100 * unit.kilocalories_per_mole / unit.angstrom**2, + "collision_rate": 10 / unit.picoseconds, + "timestep": 1 * unit.femtoseconds, + }, + { + "EOM": "MD", + "n_steps": 100000, + "temperature": 300, + "ensemble": "NPT", + "restraint_selection": "protein and not type H", + "force_constant": 100 * unit.kilocalories_per_mole / unit.angstrom**2, + "collision_rate": 10 / unit.picoseconds, + "timestep": 1 * unit.femtoseconds, + }, + { + "EOM": "MD", + "n_steps": 250000, + "temperature": 300 * unit.kelvin, + "ensemble": "NPT", + "restraint_selection": "protein and not type H", + "force_constant": 10 * unit.kilocalories_per_mole / unit.angstrom**2, + "collision_rate": 2 / unit.picoseconds, + "timestep": 1 * unit.femtoseconds, + }, + { + "EOM": "minimize", + "n_steps": 10000, + "temperature": 300 * unit.kelvin, + "ensemble": None, + "restraint_selection": "protein and backbone", + "force_constant": 10 * unit.kilocalories_per_mole / unit.angstrom**2, + "collision_rate": 2 / unit.picoseconds, + "timestep": 1 * unit.femtoseconds, + }, + { + "EOM": "MD", + "n_steps": 100000, + "temperature": 300 * unit.kelvin, + "ensemble": "NPT", + "restraint_selection": "protein and backbone", + "force_constant": 10 * unit.kilocalories_per_mole / unit.angstrom**2, + "collision_rate": 2 / unit.picoseconds, + "timestep": 1 * unit.femtoseconds, + }, + { + "EOM": "MD", + "n_steps": 100000, + "temperature": 300 * unit.kelvin, + "ensemble": "NPT", + "restraint_selection": "protein and backbone", + "force_constant": 1 * unit.kilocalories_per_mole / unit.angstrom**2, + "collision_rate": 2 / unit.picoseconds, + "timestep": 1 * unit.femtoseconds, + }, + { + "EOM": "MD", + "n_steps": 100000, + "temperature": 300 * unit.kelvin, + "ensemble": "NPT", + "restraint_selection": "protein and backbone", + "force_constant": 0.1 * unit.kilocalories_per_mole / unit.angstrom**2, + "collision_rate": 2 / unit.picoseconds, + "timestep": 1 * unit.femtoseconds, + }, + { + "EOM": "MD", + "n_steps": 2500000, + "temperature": 300 * unit.kelvin, + "ensemble": "NPT", + "restraint_selection": None, + "force_constant": 0 * unit.kilocalories_per_mole / unit.angstrom**2, + "collision_rate": 2 / unit.picoseconds, + "timestep": 2 * unit.femtoseconds, + }, ] - run_gentle_equilibration(topology, positions, system, stages, "outfile.cif", platform_name="CUDA", - save_box_vectors=False) + run_gentle_equilibration( + topology, + positions, + system, + stages, + "outfile.cif", + platform_name="CUDA", + save_box_vectors=False, + )