diff --git a/dmff/admp/multipole.py b/dmff/admp/multipole.py index 1fb8e197f..0a9adbfb6 100644 --- a/dmff/admp/multipole.py +++ b/dmff/admp/multipole.py @@ -3,6 +3,7 @@ import jax.numpy as jnp from ..utils import jit_condition from jax import vmap +import jax # This module deals with the transformations and rotations of multipoles diff --git a/dmff/admp/pme.py b/dmff/admp/pme.py index 646122e3e..3a05e9d02 100755 --- a/dmff/admp/pme.py +++ b/dmff/admp/pme.py @@ -78,7 +78,7 @@ def __init__(self, box, axis_type, axis_indices, rc, ethresh, lmax, lpol=False, # turn off pme if lpme is False, this is useful when doing cluster calculations self.lpme = lpme if self.lpme is False: - self.kappa = 0 + self.kappa = 0.0 self.K1 = 0 self.K2 = 0 self.K3 = 0 @@ -119,23 +119,25 @@ def energy_fn(positions, box, pairs, Q_local, Uind_global, pol, tholes, mScales, self.construct_local_frames, self.pme_recip, self.kappa, self.K1, self.K2, self.K3, self.lmax, True, lpme=self.lpme) self.energy_fn = energy_fn - self.grad_U_fn = grad(self.energy_fn, argnums=(4)) - self.grad_pos_fn = grad(self.energy_fn, argnums=(0)) - self.U_ind = jnp.zeros((self.n_atoms, 3)) + self.grad_U_fn = grad(energy_fn, argnums=(4)) + self.grad_pos_fn = grad(energy_fn, argnums=(0)) + self.U_ind = U_ind = jnp.zeros((self.n_atoms, 3)) # this is the wrapper that include a Uind optimizer def get_energy( positions, box, pairs, Q_local, pol, tholes, mScales, pScales, dScales, - U_init = self.U_ind, aux = None): - self.U_ind, self.lconverg, self.n_cycle = self.optimize_Uind( + U_init = U_ind, aux = None): + U_ind, lconverg, n_cycle = self.optimize_Uind( positions, box, pairs, Q_local, pol, tholes, mScales, pScales, dScales, U_init=U_init, steps_pol=self.steps_pol) # here we rely on Feynman-Hellman theorem, drop the term dV/dU*dU/dr ! # self.U_ind = jax.lax.stop_gradient(U_ind) - energy = self.energy_fn(positions, box, pairs, Q_local, self.U_ind, pol, tholes, mScales, pScales, dScales) + energy = energy_fn(positions, box, pairs, Q_local, U_ind, pol, tholes, mScales, pScales, dScales) if aux is not None: - aux["U_ind"] = self.U_ind + aux["U_ind"] = U_ind + aux["lconverg"] = lconverg + aux["n_cycle"] = n_cycle return energy, aux else: return energy @@ -658,7 +660,7 @@ def pme_real_kernel(dr, qiQI, qiQJ, qiUindI, qiUindJ, thole1, thole2, dmp, mscal Vji0 = Vji0 + cd*qiQJ[1] # D-D m0 Vij1 += dd_m0 * qiQI[1] - Vji1 += dd_m0 * qiQJ[1] + Vji1 += dd_m0 * qiQJ[1] # D-D m1 Vij2 = dd_m1*qiQI[2] Vji2 = dd_m1*qiQJ[2] @@ -740,6 +742,7 @@ def pme_real_kernel(dr, qiQI, qiQJ, qiUindI, qiUindJ, thole1, thole2, dmp, mscal raise ValueError(f"Invalid lmax {lmax}. Valid values are 0, 1, 2") if lpol: + # return jnp.array(0.5) * (jnp.sum(qiQJ*Vij) + jnp.sum(qiQI*Vji)) + jnp.array(0.5) * (jnp.sum(qiUindJ*Vijdd) + jnp.sum(qiUindI*Vjidd)) return jnp.array(0.5) * (jnp.sum(qiQJ*Vij) + jnp.sum(qiQI*Vji)) + jnp.array(0.5) * (jnp.sum(qiUindJ*Vijdd) + jnp.sum(qiUindI*Vjidd)) else: return jnp.array(0.5) * (jnp.sum(qiQJ*Vij) + jnp.sum(qiQI*Vji)) @@ -835,8 +838,7 @@ def pme_real(positions, box, pairs, qiUindJ = None # everything should be pair-specific now - ene = jnp.sum( - pme_real_kernel( + elist = pme_real_kernel( norm_dr, qiQI, qiQJ, @@ -852,7 +854,10 @@ def pme_real(positions, box, pairs, lmax, lpol ) * buffer_scales + ene = jnp.sum( + elist ) + jax.debug.print("elist: {}", elist) return ene @@ -898,5 +903,6 @@ def pol_penalty(U_ind, pol): ''' # this is to remove the singularity when pol=0 pol_pi = trim_val_0(pol) + Uind_norm = jnp.linalg.norm(U_ind + 1e-10, axis=1) # pol_pi = pol/(jnp.exp((-pol+1e-08)*1e10)+1) + 1e-08/(jnp.exp((pol-1e-08)*1e10)+1) - return jnp.sum(0.5/pol_pi*(U_ind**2).T) * DIELECTRIC + return jnp.sum(0.5/pol_pi*(Uind_norm**2)) * DIELECTRIC diff --git a/dmff/admp/recip.py b/dmff/admp/recip.py index 5557eb026..9aa1202e0 100755 --- a/dmff/admp/recip.py +++ b/dmff/admp/recip.py @@ -15,6 +15,11 @@ def generate_pme_recip(Ck_fn, kappa, gamma, pme_order, K1, K2, K3, lmax): bspline_range = jnp.arange(-pme_order//2, pme_order//2) n_mesh = pme_order**3 shifts = jnp.array(jnp.meshgrid(bspline_range, bspline_range, bspline_range)).T.reshape((1, n_mesh, 3)) + + if K1 == K2 == K3 == 0: + def pme_recip(positions, box, Q): + return jnp.zeros((1, )) + return pme_recip def pme_recip(positions, box, Q): ''' diff --git a/dmff/generators/admp.py b/dmff/generators/admp.py index 442ccefb7..c63ba57eb 100644 --- a/dmff/generators/admp.py +++ b/dmff/generators/admp.py @@ -1142,19 +1142,14 @@ def createPotential( # here box is only used to setup ewald parameters, no need to be differentiable box = topdata.getPeriodicBoxVectors() if box is not None: - noPBC = False box = jnp.array(box) * 10.0 else: - noPBC = True - box = jnp.eye(3) * 1.0e6 + box = jnp.eye(3) # get the admp calculator - if noPBC: - rc = 10000.0 + if unit.is_quantity(nonbondedCutoff): + rc = nonbondedCutoff.value_in_unit(unit.angstrom) else: - if unit.is_quantity(nonbondedCutoff): - rc = nonbondedCutoff.value_in_unit(unit.angstrom) - else: - rc = nonbondedCutoff * 10.0 + rc = nonbondedCutoff * 10.0 # build covalent map covalent_map = topdata.buildCovMat() @@ -1353,8 +1348,8 @@ def createPotential( has_aux = False def potential_fn(positions, box, pairs, params, aux=None): - positions = positions * 10 - box = box * 10 + positions = positions * 10.0 + box = box * 10.0 Q_local = params["ADMPPmeForce"]["Q_local"][map_atomtype] if self.lpol: pol = params["ADMPPmeForce"]["pol"][map_poltype]