Skip to content

Commit

Permalink
Upload
Browse files Browse the repository at this point in the history
  • Loading branch information
WangXinyan940 committed Oct 26, 2023
1 parent 23a1238 commit 7e94269
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 23 deletions.
1 change: 1 addition & 0 deletions dmff/admp/multipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import jax.numpy as jnp
from ..utils import jit_condition
from jax import vmap
import jax

# This module deals with the transformations and rotations of multipoles

Expand Down
30 changes: 18 additions & 12 deletions dmff/admp/pme.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(self, box, axis_type, axis_indices, rc, ethresh, lmax, lpol=False,
# turn off pme if lpme is False, this is useful when doing cluster calculations
self.lpme = lpme
if self.lpme is False:
self.kappa = 0
self.kappa = 0.0
self.K1 = 0
self.K2 = 0
self.K3 = 0
Expand Down Expand Up @@ -119,23 +119,25 @@ def energy_fn(positions, box, pairs, Q_local, Uind_global, pol, tholes, mScales,
self.construct_local_frames, self.pme_recip,
self.kappa, self.K1, self.K2, self.K3, self.lmax, True, lpme=self.lpme)
self.energy_fn = energy_fn
self.grad_U_fn = grad(self.energy_fn, argnums=(4))
self.grad_pos_fn = grad(self.energy_fn, argnums=(0))
self.U_ind = jnp.zeros((self.n_atoms, 3))
self.grad_U_fn = grad(energy_fn, argnums=(4))
self.grad_pos_fn = grad(energy_fn, argnums=(0))
self.U_ind = U_ind = jnp.zeros((self.n_atoms, 3))
# this is the wrapper that include a Uind optimizer
def get_energy(
positions, box, pairs,
Q_local, pol, tholes, mScales, pScales, dScales,
U_init = self.U_ind, aux = None):
self.U_ind, self.lconverg, self.n_cycle = self.optimize_Uind(
U_init = U_ind, aux = None):
U_ind, lconverg, n_cycle = self.optimize_Uind(
positions, box, pairs, Q_local, pol, tholes,
mScales, pScales, dScales,
U_init=U_init, steps_pol=self.steps_pol)
# here we rely on Feynman-Hellman theorem, drop the term dV/dU*dU/dr !
# self.U_ind = jax.lax.stop_gradient(U_ind)
energy = self.energy_fn(positions, box, pairs, Q_local, self.U_ind, pol, tholes, mScales, pScales, dScales)
energy = energy_fn(positions, box, pairs, Q_local, U_ind, pol, tholes, mScales, pScales, dScales)
if aux is not None:
aux["U_ind"] = self.U_ind
aux["U_ind"] = U_ind
aux["lconverg"] = lconverg
aux["n_cycle"] = n_cycle
return energy, aux
else:
return energy
Expand Down Expand Up @@ -658,7 +660,7 @@ def pme_real_kernel(dr, qiQI, qiQJ, qiUindI, qiUindJ, thole1, thole2, dmp, mscal
Vji0 = Vji0 + cd*qiQJ[1]
# D-D m0
Vij1 += dd_m0 * qiQI[1]
Vji1 += dd_m0 * qiQJ[1]
Vji1 += dd_m0 * qiQJ[1]
# D-D m1
Vij2 = dd_m1*qiQI[2]
Vji2 = dd_m1*qiQJ[2]
Expand Down Expand Up @@ -740,6 +742,7 @@ def pme_real_kernel(dr, qiQI, qiQJ, qiUindI, qiUindJ, thole1, thole2, dmp, mscal
raise ValueError(f"Invalid lmax {lmax}. Valid values are 0, 1, 2")

if lpol:
# return jnp.array(0.5) * (jnp.sum(qiQJ*Vij) + jnp.sum(qiQI*Vji)) + jnp.array(0.5) * (jnp.sum(qiUindJ*Vijdd) + jnp.sum(qiUindI*Vjidd))
return jnp.array(0.5) * (jnp.sum(qiQJ*Vij) + jnp.sum(qiQI*Vji)) + jnp.array(0.5) * (jnp.sum(qiUindJ*Vijdd) + jnp.sum(qiUindI*Vjidd))
else:
return jnp.array(0.5) * (jnp.sum(qiQJ*Vij) + jnp.sum(qiQI*Vji))
Expand Down Expand Up @@ -835,8 +838,7 @@ def pme_real(positions, box, pairs,
qiUindJ = None

# everything should be pair-specific now
ene = jnp.sum(
pme_real_kernel(
elist = pme_real_kernel(
norm_dr,
qiQI,
qiQJ,
Expand All @@ -852,7 +854,10 @@ def pme_real(positions, box, pairs,
lmax,
lpol
) * buffer_scales
ene = jnp.sum(
elist
)
jax.debug.print("elist: {}", elist)

return ene

Expand Down Expand Up @@ -898,5 +903,6 @@ def pol_penalty(U_ind, pol):
'''
# this is to remove the singularity when pol=0
pol_pi = trim_val_0(pol)
Uind_norm = jnp.linalg.norm(U_ind + 1e-10, axis=1)
# pol_pi = pol/(jnp.exp((-pol+1e-08)*1e10)+1) + 1e-08/(jnp.exp((pol-1e-08)*1e10)+1)
return jnp.sum(0.5/pol_pi*(U_ind**2).T) * DIELECTRIC
return jnp.sum(0.5/pol_pi*(Uind_norm**2)) * DIELECTRIC
5 changes: 5 additions & 0 deletions dmff/admp/recip.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ def generate_pme_recip(Ck_fn, kappa, gamma, pme_order, K1, K2, K3, lmax):
bspline_range = jnp.arange(-pme_order//2, pme_order//2)
n_mesh = pme_order**3
shifts = jnp.array(jnp.meshgrid(bspline_range, bspline_range, bspline_range)).T.reshape((1, n_mesh, 3))

if K1 == K2 == K3 == 0:
def pme_recip(positions, box, Q):
return jnp.zeros((1, ))
return pme_recip

def pme_recip(positions, box, Q):
'''
Expand Down
17 changes: 6 additions & 11 deletions dmff/generators/admp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1142,19 +1142,14 @@ def createPotential(
# here box is only used to setup ewald parameters, no need to be differentiable
box = topdata.getPeriodicBoxVectors()
if box is not None:
noPBC = False
box = jnp.array(box) * 10.0
else:
noPBC = True
box = jnp.eye(3) * 1.0e6
box = jnp.eye(3)
# get the admp calculator
if noPBC:
rc = 10000.0
if unit.is_quantity(nonbondedCutoff):
rc = nonbondedCutoff.value_in_unit(unit.angstrom)
else:
if unit.is_quantity(nonbondedCutoff):
rc = nonbondedCutoff.value_in_unit(unit.angstrom)
else:
rc = nonbondedCutoff * 10.0
rc = nonbondedCutoff * 10.0

# build covalent map
covalent_map = topdata.buildCovMat()
Expand Down Expand Up @@ -1353,8 +1348,8 @@ def createPotential(
has_aux = False

def potential_fn(positions, box, pairs, params, aux=None):
positions = positions * 10
box = box * 10
positions = positions * 10.0
box = box * 10.0
Q_local = params["ADMPPmeForce"]["Q_local"][map_atomtype]
if self.lpol:
pol = params["ADMPPmeForce"]["pol"][map_poltype]
Expand Down

0 comments on commit 7e94269

Please sign in to comment.