Skip to content

Commit

Permalink
pyupgrade
Browse files Browse the repository at this point in the history
ruff-ed
  • Loading branch information
mikemhenry committed Aug 27, 2024
1 parent c710d68 commit 562e0d3
Show file tree
Hide file tree
Showing 15 changed files with 5,045 additions and 2,317 deletions.
2,036 changes: 1,449 additions & 587 deletions openmmtools/tests/test_alchemy.py

Large diffs are not rendered by default.

200 changes: 116 additions & 84 deletions openmmtools/tests/test_cache.py

Large diffs are not rendered by default.

89 changes: 67 additions & 22 deletions openmmtools/tests/test_forcefactories.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
# TESTING UTILITIES
# =============================================================================


def create_context(system, integrator, platform=None):
"""Create a Context.
Expand All @@ -51,6 +52,7 @@ def create_context(system, integrator, platform=None):
# UTILITY FUNCTIONS
# =============================================================================


def compute_forces(system, positions, platform=None, force_group=-1):
"""Compute forces of the system in the given positions.
Expand All @@ -72,7 +74,9 @@ def compute_forces(system, positions, platform=None, force_group=-1):
return forces


def compare_system_forces(reference_system, alchemical_system, positions, name="", platform=None):
def compare_system_forces(
reference_system, alchemical_system, positions, name="", platform=None
):
"""Check that the forces of reference and modified systems are close.
Parameters
Expand All @@ -90,19 +94,36 @@ def compare_system_forces(reference_system, alchemical_system, positions, name="
"""
# Compute forces
reference_force = compute_forces(reference_system, positions, platform=platform) / GLOBAL_FORCE_UNIT
alchemical_force = compute_forces(alchemical_system, positions, platform=platform) / GLOBAL_FORCE_UNIT
reference_force = (
compute_forces(reference_system, positions, platform=platform)
/ GLOBAL_FORCE_UNIT
)
alchemical_force = (
compute_forces(alchemical_system, positions, platform=platform)
/ GLOBAL_FORCE_UNIT
)

# Check that error is small.
def magnitude(vec):
return np.sqrt(np.mean(np.sum(vec**2, axis=1)))

relative_error = magnitude(alchemical_force - reference_force) / magnitude(reference_force)
relative_error = magnitude(alchemical_force - reference_force) / magnitude(
reference_force
)
if np.any(np.abs(relative_error) > MAX_FORCE_RELATIVE_ERROR):
err_msg = ("Maximum allowable relative force error exceeded (was {:.8f}; allowed {:.8f}).\n"
"alchemical_force = {:.8f}, reference_force = {:.8f}, difference = {:.8f}")
raise Exception(err_msg.format(relative_error, MAX_FORCE_RELATIVE_ERROR, magnitude(alchemical_force),
magnitude(reference_force), magnitude(alchemical_force-reference_force)))
err_msg = (
"Maximum allowable relative force error exceeded (was {:.8f}; allowed {:.8f}).\n"
"alchemical_force = {:.8f}, reference_force = {:.8f}, difference = {:.8f}"
)
raise Exception(
err_msg.format(
relative_error,
MAX_FORCE_RELATIVE_ERROR,
magnitude(alchemical_force),
magnitude(reference_force),
magnitude(alchemical_force - reference_force),
)
)


def generate_new_positions(system, positions, platform=None, nsteps=50):
Expand Down Expand Up @@ -137,23 +158,29 @@ def generate_new_positions(system, positions, platform=None, nsteps=50):
# TEST FORCE FACTORIES FUNCTIONS
# =============================================================================


def test_restrain_atoms():
"""Check that the restrained molecule's centroid is in the origin."""
host_guest = testsystems.HostGuestExplicit()
topology = mdtraj.Topology.from_openmm(host_guest.topology)
sampler_state = states.SamplerState(positions=host_guest.positions)
thermodynamic_state = states.ThermodynamicState(host_guest.system, temperature=300*unit.kelvin,
pressure=1.0*unit.atmosphere)
thermodynamic_state = states.ThermodynamicState(
host_guest.system, temperature=300 * unit.kelvin, pressure=1.0 * unit.atmosphere
)

# Restrain all the host carbon atoms.
restrained_atoms = [atom.index for atom in topology.atoms
if atom.element.symbol is 'C' and atom.index <= 125]
restrained_atoms = [
atom.index
for atom in topology.atoms
if atom.element.symbol == "C" and atom.index <= 125
]
restrain_atoms(thermodynamic_state, sampler_state, restrained_atoms)

# Compute host center_of_geometry.
centroid = np.mean(sampler_state.positions[:126], axis=0)
assert np.allclose(centroid, np.zeros(3))


def test_replace_reaction_field():
"""Check that replacing reaction-field electrostatics with Custom*Force
yields minimal force differences with original system.
Expand All @@ -164,35 +191,53 @@ def test_replace_reaction_field():
"""
test_cases = [
testsystems.AlanineDipeptideExplicit(nonbondedMethod=openmm.app.CutoffPeriodic),
testsystems.HostGuestExplicit(nonbondedMethod=openmm.app.CutoffPeriodic)
testsystems.HostGuestExplicit(nonbondedMethod=openmm.app.CutoffPeriodic),
]
platform = openmm.Platform.getPlatformByName('Reference')
platform = openmm.Platform.getPlatformByName("Reference")
for test_system in test_cases:
test_name = test_system.__class__.__name__

# Replace reaction field.
modified_rf_system = replace_reaction_field(test_system.system, switch_width=None)
modified_rf_system = replace_reaction_field(
test_system.system, switch_width=None
)

# Make sure positions are not at minimum.
positions = generate_new_positions(test_system.system, test_system.positions)

# Test forces.
f = partial(compare_system_forces, test_system.system, modified_rf_system, positions,
name=test_name, platform=platform)
f.description = "Testing replace_reaction_field on system {}".format(test_name)
f = partial(
compare_system_forces,
test_system.system,
modified_rf_system,
positions,
name=test_name,
platform=platform,
)
f.description = f"Testing replace_reaction_field on system {test_name}"
yield f

for test_system in test_cases:
test_name = test_system.__class__.__name__

# Replace reaction field.
modified_rf_system = replace_reaction_field(test_system.system, switch_width=None, shifted=True)
modified_rf_system = replace_reaction_field(
test_system.system, switch_width=None, shifted=True
)

# Make sure positions are not at minimum.
positions = generate_new_positions(test_system.system, test_system.positions)

# Test forces.
f = partial(compare_system_forces, test_system.system, modified_rf_system, positions,
name=test_name, platform=platform)
f.description = "Testing replace_reaction_field on system {} with shifted=True".format(test_name)
f = partial(
compare_system_forces,
test_system.system,
modified_rf_system,
positions,
name=test_name,
platform=platform,
)
f.description = (
f"Testing replace_reaction_field on system {test_name} with shifted=True"
)
yield f
Loading

0 comments on commit 562e0d3

Please sign in to comment.