Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Debug and implementation new features #157

Closed
wants to merge 17 commits into from
Closed
2 changes: 1 addition & 1 deletion dmff/api/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def matchTemplate(graph, template):
return False, {}, {}

def match_func(n1, n2):
return n1["element"] == n2["element"] and n1["external_bond"] == n2["external_bond"]
return n1["element"] == n2["element"] and n1["external_bond"] == n2["external_bond"] and n1["name"] == n2["name"]

def edge_match(e1, e2):
if len(e1) == 0 and len(e2) == 0:
Expand Down
66 changes: 65 additions & 1 deletion dmff/classical/inter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Iterable, Tuple, Optional
import jax
import jax.numpy as jnp
import jax.config as config
config.update("jax_enable_x64", True)
import numpy as np

from ..utils import pair_buffer_scales, regularize_pairs
from ..admp.pme import energy_pme
from ..admp.recip import generate_pme_recip
Expand Down Expand Up @@ -332,3 +333,66 @@ def get_energy(positions, box, pairs, bcc, mscales):
charges = self.init_charges + jnp.dot(self.top_mat, bcc).flatten()
return get_energy_kernel(positions, box, pairs, charges, mscales)
return get_energy


class CustomGBForce:
def __init__(
self,
map_charge,
map_radius,
map_scale,
epsilon_1=1.0,
epsilon_solv=78.3,
alpha=1,
beta=0.8,
gamma=4.85,
) -> None:
self.map_charge = map_charge
self.map_radius = map_radius
self.map_scale = map_scale
self.alpha = alpha
self.beta = beta
self.gamma = gamma
self.exp_solv = epsilon_solv
self.eps_1 = epsilon_1

def generate_get_energy(self):
# @jax.jit
def get_energy(positions, box, pairs, Ipairs, charges, radius, scales):
def calI(posList, radMap, scalMap, rhoMap, pairMap):
I = jnp.array([])
for i in range(len(radMap)):
posj = posList[Ipairs[i]]
rhoj = rhoMap[Ipairs[i]]
scalj = scalMap[Ipairs[i]]
posi = posList[i]
rhoi = rhoMap[i]
r = jnp.sqrt(jnp.sum(jnp.power(posi-posj,2),axis=1))
sr2 = rhoj * scalj
D = jnp.abs(r - sr2)
L = jnp.maximum(D, rhoi)
C = 2 * (1 / rhoi - 1 / L) * jnp.heaviside(sr2 - r - rhoi, 1)
U = r + sr2
I = jnp.append(I, jnp.sum(0.5 * jnp.heaviside(r + sr2 - rhoi, 1) * (
1 / L - 1 / U + 0.25 * (1 / U ** 2 - 1 / L ** 2) * (
r - sr2 ** 2 / r) + 0.5 * jnp.log(L / U) / r + C)))

return I

chargeMap = charges[self.map_charge]
radiusMap = radius[self.map_radius]
scalesMap = scales[self.map_scale]
rhoMap = radiusMap - 0.009

# effective radius
IList = calI(positions, radiusMap, scalesMap, rhoMap, Ipairs)
psi = IList*rhoMap
rEff = 1/(1/rhoMap-jnp.tanh(self.alpha*psi-self.beta*jnp.power(psi, 2)+self.gamma*jnp.power(psi, 3))/radiusMap)
# surface area term energy
Ese = jnp.sum(28.3919551*(radiusMap+0.14)**2*jnp.power(radiusMap/rEff, 6)-0.5*138.935456*(1/self.eps_1-1/self.exp_solv)*chargeMap**2/rEff)
dr_norm = jnp.linalg.norm(positions[pairs[:,0]] - positions[pairs[:,1]], axis=1)
chargepro = chargeMap[pairs[:, 0]] * chargeMap[pairs[:, 1]]
rEffpro = rEff[pairs[:, 0]] * rEff[pairs[:, 1]]
Egb = jnp.sum(-138.935456*(1/self.eps_1-1/self.exp_solv)*chargepro/jnp.sqrt(jnp.power(dr_norm, 2)+rEffpro*jnp.exp(-jnp.power(dr_norm,2)/(4*rEffpro))))
return Ese + Egb
return get_energy
75 changes: 75 additions & 0 deletions dmff/classical/intra.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,78 @@ def refresh_calculators(self):
"""
self.get_energy = self.generate_get_energy()
self.get_forces = value_and_grad(self.get_energy)


class Custom1_5BondJaxForce:
def __init__(self, p1idx, p2idx, prmidx):
self.p1idx = p1idx
self.p2idx = p2idx
self.prmidx = prmidx
self.refresh_calculators()

def generate_get_energy(self):
def get_energy(positions, box, pairs, k, length):
p1 = positions[self.p1idx,:]
p2 = positions[self.p2idx,:]
kprm = k[self.prmidx]
b0prm = length[self.prmidx]
dist = distance(p1, p2)
return jnp.sum(0.5 * kprm * jnp.power(dist - b0prm, 2))

return get_energy

def update_env(self, attr, val):
"""
Update the environment of the calculator
"""
setattr(self, attr, val)
self.refresh_calculators()

def refresh_calculators(self):
"""
refresh the energy and force calculators according to the current environment
"""
self.get_energy = self.generate_get_energy()
self.get_forces = value_and_grad(self.get_energy)


class CustomTorsionJaxForce:
def __init__(self, p1idx, p2idx, p3idx, p4idx, prmidx, order):
self.p1idx = p1idx
self.p2idx = p2idx
self.p3idx = p3idx
self.p4idx = p4idx
self.prmidx = prmidx
self.order = order
self.refresh_calculators()

def generate_get_energy(self):
if len(self.p1idx) == 0:
return lambda positions, box, pairs, k, psi, shift: 0.0
def get_energy(positions, box, pairs, k, psi, shift):
p1 = positions[self.p1idx, :]
p2 = positions[self.p2idx, :]
p3 = positions[self.p3idx, :]
p4 = positions[self.p4idx, :]
kp = k[self.prmidx]
psip = psi[self.prmidx]
shiftp = shift[self.prmidx]
dih = dihedral(p1, p2, p3, p4)
ener = kp * (jnp.cos(self.order * dih - psip)) + shiftp
return jnp.sum(ener)

return get_energy

def update_env(self, attr, val):
"""
Update the environment of the calculator
"""
setattr(self, attr, val)
self.refresh_calculators()

def refresh_calculators(self):
"""
refresh the energy and force calculators according to the current environment
"""
self.get_energy = self.generate_get_energy()
self.get_forces = value_and_grad(self.get_energy)
Loading