Skip to content

Commit

Permalink
Keep the order of atoms when they are equivalent and have the same na…
Browse files Browse the repository at this point in the history
…mes in topology and template
  • Loading branch information
WangXinyan940 committed Dec 7, 2023
1 parent 80eee20 commit 4244b59
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 5 deletions.
18 changes: 16 additions & 2 deletions dmff/api/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,28 @@
import warnings
warnings.warn("RDKit is not installed. SMIRKS pattern matching cannot be used.")

def is_same_list(l1, l2):
if len(l1) != len(l2):
return False
for nn in range(len(l1)):
if l1[nn] != l2[nn]:
return False
return True

def matchTemplate(graph, template):
if graph.number_of_nodes() != template.number_of_nodes():
# print("Node with different number of nodes.")
return False, {}, {}

def match_func(n1, n2):
return n1["element"] == n2["element"] and n1["external_bond"] == n2["external_bond"]
name_graph = sorted([i[1]['name'] for i in graph.nodes.data()])
name_template = sorted([i[1]['name'] for i in template.nodes.data()])

if is_same_list(name_graph, name_template):
def match_func(n1, n2):
return n1["element"] == n2["element"] and n1["external_bond"] == n2["external_bond"] and n1['name'] == n2['name']
else:
def match_func(n1, n2):
return n1["element"] == n2["element"] and n1["external_bond"] == n2["external_bond"]

def edge_match(e1, e2):
if len(e1) == 0 and len(e2) == 0:
Expand Down
14 changes: 13 additions & 1 deletion dmff/classical/intra.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import jax
import jax.numpy as jnp
from jax import grad, value_and_grad, vmap
from ..admp.spatial import v_pbc_shift


def distance(p1v, p2v):
Expand All @@ -14,6 +16,13 @@ def angle(p1v, p2v, p3v):
vzz = v1[:, 2] * v2[:, 2]
return jnp.arccos(vxx + vyy + vzz)

@jax.vmap
def angle_v(v1, v2):
# compute the angle between v1 and v2
v1n = v1 / jnp.linalg.norm(v1)
v2n = v2 / jnp.linalg.norm(v2)
return jnp.arccos(jnp.dot(v1n, v2n))


def dihedral(i, j, k, l):
b1, b2, b3 = j - i, k - j, l - k
Expand Down Expand Up @@ -72,12 +81,15 @@ def __init__(self, p1idx, p2idx, p3idx, prmidx):

def generate_get_energy(self):
def get_energy(positions, box, pairs, k, theta0):
box_inv = jnp.linalg.inv(box + jnp.eye(3) * 1e-36)
p1 = positions[self.p1idx,:]
p2 = positions[self.p2idx,:]
p3 = positions[self.p3idx,:]
v1 = v_pbc_shift(p1 - p2, box, box_inv)
v2 = v_pbc_shift(p3 - p2, box, box_inv)
kprm = k[self.prmidx]
theta0prm = theta0[self.prmidx]
ang = angle(p1, p2, p3)
ang = angle_v(v1, v2)
return jnp.sum(0.5 * kprm * jnp.power(ang - theta0prm, 2))

return get_energy
Expand Down
2 changes: 0 additions & 2 deletions dmff/generators/classical.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,6 @@ def createPotential(self, topdata: DMFFTopology, nonbondedMethod,
angle_a3 = jnp.array(angle_a3)
angle_indices = jnp.array(angle_indices)

# 创建势函数
harmonic_angle_force = HarmonicAngleJaxForce(
angle_a1, angle_a2, angle_a3, angle_indices)
harmonic_angle_energy = harmonic_angle_force.generate_get_energy()
Expand All @@ -427,7 +426,6 @@ def createPotential(self, topdata: DMFFTopology, nonbondedMethod,
if "has_aux" in kwargs and kwargs["has_aux"]:
has_aux = True

# 包装成统一的potential_function函数形式,传入四个参数:positions, box, pairs, parameters。
def potential_fn(positions: jnp.ndarray, box: jnp.ndarray, pairs: jnp.ndarray, params: ParamSet, aux=None):
isinstance_jnp(positions, box, params)
energy = harmonic_angle_energy(
Expand Down

0 comments on commit 4244b59

Please sign in to comment.