Skip to content

Commit

Permalink
Clean the way of aux_data implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
WangXinyan940 committed Oct 23, 2023
1 parent 337a8cf commit db28782
Show file tree
Hide file tree
Showing 7 changed files with 589 additions and 239 deletions.
59 changes: 21 additions & 38 deletions dmff/admp/pme.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,20 +103,12 @@ def __init__(self, box, axis_type, axis_indices, rc, ethresh, lmax, lpol=False,
def generate_get_energy(self):
# if the force field is not polarizable
if not self.lpol:
if self.has_aux:
def get_energy(positions, box, pairs, Q_local, mScales, aux):
return energy_pme(positions, box, pairs,
Q_local, None, None, None,
mScales, None, None,
self.construct_local_frames, self.pme_recip,
self.kappa, self.K1, self.K2, self.K3, self.lmax, False, lpme=self.lpme), aux
else:
def get_energy(positions, box, pairs, Q_local, mScales):
return energy_pme(positions, box, pairs,
Q_local, None, None, None,
mScales, None, None,
self.construct_local_frames, self.pme_recip,
self.kappa, self.K1, self.K2, self.K3, self.lmax, False, lpme=self.lpme)
def get_energy(positions, box, pairs, Q_local, mScales):
return energy_pme(positions, box, pairs,
Q_local, None, None, None,
mScales, None, None,
self.construct_local_frames, self.pme_recip,
self.kappa, self.K1, self.K2, self.K3, self.lmax, False, lpme=self.lpme)
return get_energy
else:
# this is the bare energy calculator, with Uind as explicit input
Expand All @@ -131,31 +123,22 @@ def energy_fn(positions, box, pairs, Q_local, Uind_global, pol, tholes, mScales,
self.grad_pos_fn = grad(self.energy_fn, argnums=(0))
self.U_ind = jnp.zeros((self.n_atoms, 3))
# this is the wrapper that include a Uind optimizer
if self.has_aux:
def get_energy(
positions, box, pairs,
Q_local, pol, tholes, mScales, pScales, dScales,
aux):
self.U_ind, self.lconverg, self.n_cycle = self.optimize_Uind(
positions, box, pairs, Q_local, pol, tholes,
mScales, pScales, dScales,
U_init=aux["U_ind"], steps_pol=self.steps_pol)
# here we rely on Feynman-Hellman theorem, drop the term dV/dU*dU/dr !
# self.U_ind = jax.lax.stop_gradient(U_ind)
def get_energy(
positions, box, pairs,
Q_local, pol, tholes, mScales, pScales, dScales,
U_init = self.U_ind, aux = None):
self.U_ind, self.lconverg, self.n_cycle = self.optimize_Uind(
positions, box, pairs, Q_local, pol, tholes,
mScales, pScales, dScales,
U_init=U_init, steps_pol=self.steps_pol)
# here we rely on Feynman-Hellman theorem, drop the term dV/dU*dU/dr !
# self.U_ind = jax.lax.stop_gradient(U_ind)
energy = self.energy_fn(positions, box, pairs, Q_local, self.U_ind, pol, tholes, mScales, pScales, dScales)
if aux is not None:
aux["U_ind"] = self.U_ind
return self.energy_fn(positions, box, pairs, Q_local, self.U_ind, pol, tholes, mScales, pScales, dScales), aux
else:
def get_energy(
positions, box, pairs,
Q_local, pol, tholes, mScales, pScales, dScales,
U_init=self.U_ind):
self.U_ind, self.lconverg, self.n_cycle = self.optimize_Uind(
positions, box, pairs, Q_local, pol, tholes,
mScales, pScales, dScales,
U_init=U_init, steps_pol=self.steps_pol)
# here we rely on Feynman-Hellman theorem, drop the term dV/dU*dU/dr !
# self.U_ind = jax.lax.stop_gradient(U_ind)
return self.energy_fn(positions, box, pairs, Q_local, self.U_ind, pol, tholes, mScales, pScales, dScales)
return energy, aux
else:
return energy
return get_energy


Expand Down
Loading

0 comments on commit db28782

Please sign in to comment.