Skip to content

Commit

Permalink
use espfit.analysis.BaseDataLoader to load trajectories
Browse files Browse the repository at this point in the history
  • Loading branch information
kntkb committed Apr 7, 2024
1 parent 1591436 commit 3455a6f
Showing 1 changed file with 22 additions and 16 deletions.
38 changes: 22 additions & 16 deletions espfit/utils/sampler/reweight.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class SetupSamplerReweight(object):
def __init__(self):
self.samplers = None
self.weights_neff_dict = dict() # {'target_name': {'weights': w_i}, {'neff': neff}}
self.force_group_names = ['HarmonicBondForce', 'HarmonicAngleForce', 'PeriodicTorsionForce']
self.force_group_names = ['HarmonicBondForce', 'HarmonicAngleForce', 'PeriodicTorsionForce', 'NonbondedForce']
self.exclude_n_frames = 0.1


Expand All @@ -44,7 +44,7 @@ def run(self):
None
"""
for sampler in self.samplers:
_logger.info(f'Running simulation for {sampler.target_name} for {sampler.nsteps} steps...')
_logger.info(f'Re-run simulation for {sampler.target_name}')
sampler.minimize()
sampler.run()

Expand All @@ -67,24 +67,26 @@ def get_effective_sample_size(self, temporary_samplers):
from espfit.utils.units import KB_T_KCALPERMOL
from espfit.app.analysis import BaseDataLoader

_logger.info(f'Compute effective sample size and sampling weights')

if self.samplers is None:
_logger.info('No samplers found. Return effective sample size -1.')
return -1

for sampler, temporary_sampler in zip(self.samplers, temporary_samplers):
_logger.info(f'Compute effective sample size and sampling weights for {sampler.target_name}')
_logger.info(f'Analyzing {sampler.target_name}...')

# Get temperature
temp0 = sampler.temperature._value
temp1 = temporary_sampler.temperature._value
assert temp0 == temp1, f'Temperature should be equivalent but got sampler {temp0} K and temporary sampler {temp1} K'
beta = 1 / (KB_T_KCALPERMOL * temp0)
_logger.info(f'beta temperature in kcal/mol: {beta}')
#_logger.debug(f'beta temperature in kcal/mol: {beta}')

# Get position from trajectory
#traj = mdtraj.load(sampler.output_directory_path + '/traj.nc', top=sampler.output_directory_path + '/solvated.pdb')
baseloader = BaseDataLoader()
baseloader = BaseDataLoader(atomSubset=sampler.atomSubset)
baseloader.load_traj(input_directory_path=sampler.output_directory_path, exclude_n_frames=self.exclude_n_frames)
_logger.info(f'Found {baseloader.traj.n_frames} frames in trajectory')
_logger.info(f'Found {baseloader.traj.n_frames} frames from {sampler.output_directory_path}')

# Compute weights and effective sample size
w_arr = []
Expand All @@ -93,7 +95,6 @@ def get_effective_sample_size(self, temporary_samplers):
sampler.simulation.context.setPositions(baseloader.traj.openmm_positions(i))
for gid, force in enumerate(sampler.simulation.system.getForces()):
if force.getName() in self.force_group_names:
print(f'{force.getName()}')
try:
potential_energy += sampler.simulation.context.getState(getEnergy=True, groups={gid}).getPotentialEnergy()
except:
Expand All @@ -102,11 +103,10 @@ def get_effective_sample_size(self, temporary_samplers):
temporary_sampler.simulation.context.setPositions(baseloader.traj.openmm_positions(i))
for gid, force in enumerate(temporary_sampler.simulation.system.getForces()):
if force.getName() in self.force_group_names:
print(f'{force.getName()}')
try:
reduced_potential_energy += sampler.simulation.context.getState(getEnergy=True, groups={gid}).getPotentialEnergy()
reduced_potential_energy += temporary_sampler.simulation.context.getState(getEnergy=True, groups={gid}).getPotentialEnergy()
except:
reduced_potential_energy = sampler.simulation.context.getState(getEnergy=True, groups={gid}).getPotentialEnergy()
reduced_potential_energy = temporary_sampler.simulation.context.getState(getEnergy=True, groups={gid}).getPotentialEnergy()
# deltaU = U(x0, theta1) - U(x0, theta0)
delta = (reduced_potential_energy - potential_energy).value_in_unit(kcalpermol)
# w = ln(exp(-beta * delta))
Expand All @@ -116,16 +116,22 @@ def get_effective_sample_size(self, temporary_samplers):
_logger.info(f'U(x0, theta0): {potential_energy.value_in_unit(kcalpermol):10.3f} kcal/mol')
_logger.info(f'U(x0, theta1): {reduced_potential_energy.value_in_unit(kcalpermol):10.3f} kcal/mol')
_logger.info(f'deltaU: {delta:10.3f} kcal/mol')
_logger.info(f'ln_w: {w:10.3f}')
_logger.info(f'w: {w:10.3f}')

# Compute weights and effective sample size (ratio: 0 to 1)
# Prevent RuntimeWarning: overflow encountered in exp
w_arr = np.float128(w_arr)
w_i = np.exp(w_arr) / np.sum(np.exp(w_arr))
neff = np.sum(w_i) ** 2 / np.sum(w_i ** 2) / len(w_i)
_logger.info(f'w_i_sum: {np.sum(w_i):10.3f}')
_logger.info(f'neff: {neff:10.3f}')
_logger.debug(f'w_i_sum: {np.sum(w_i):10.3f}')
_logger.debug(f'neff: {neff:10.3f}')

# Check if sum of weights is 1
sum_w = abs(np.sum(w_i))
assert abs(1 - sum_w) <= 0.001, f"Weight sum {sum_w} is greater than tolerance 0.001"

self.weights_neff_dict[f'{sampler.target_name}'] = {'neff': neff, 'weights': w_i}
_logger.info(f'{self.weights_neff_dict}')
_logger.debug(f'{self.weights_neff_dict}')
neffs = [self.weights_neff_dict[key]['neff'] for key in self.weights_neff_dict.keys()]

return min(neffs)
Expand All @@ -141,7 +147,7 @@ def compute_loss(self):
"""
loss_list = []
for sampler in self.samplers:
_logger.info(f'Compute loss for {sampler.target_name}')
_logger.debug(f'Compute loss for {sampler.target_name}')
loss = self._compute_loss_per_system(sampler) # torch.tensor
loss_list.append(loss)

Expand Down

0 comments on commit 3455a6f

Please sign in to comment.