Skip to content

Commit

Permalink
Merge branch 'reformat-nlist' into tmp
Browse files Browse the repository at this point in the history
  • Loading branch information
njzjz committed Nov 11, 2024
2 parents 37c8739 + 1eecf10 commit 0ccad5b
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 2 deletions.
71 changes: 71 additions & 0 deletions deepmd/jax/jax2tf/format_nlist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import tensorflow as tf
import tensorflow.experimental.numpy as tnp


@tf.function(autograph=True)
def format_nlist(
extended_coord: tnp.ndarray,
nlist: tnp.ndarray,
nsel: int,
rcut: float,
):
"""Format neighbor list.
If nnei == nsel, do nothing;
If nnei < nsel, pad -1;
If nnei > nsel, sort by distance and truncate.
Parameters
----------
extended_coord
The extended coordinates of the atoms.
shape: nf x nall x 3
nlist
The neighbor list.
shape: nf x nloc x nnei
nsel
The number of selected neighbors.
rcut
The cutoff radius.
Returns
-------
nlist
The formatted neighbor list.
shape: nf x nloc x nsel
"""
nlist_shape = tf.shape(nlist)
n_nf, n_nloc, n_nsel = nlist_shape[0], nlist_shape[1], nlist_shape[2]
extended_coord = extended_coord.reshape([n_nf, -1, 3])

if n_nsel < nsel:
# make a copy before revise
ret = tnp.concatenate(
[
nlist,
tnp.full([n_nf, n_nloc, nsel - n_nsel], -1, dtype=nlist.dtype),
],
axis=-1,
)

elif n_nsel > nsel:
# make a copy before revise
m_real_nei = nlist >= 0
ret = tnp.where(m_real_nei, nlist, 0)
coord0 = extended_coord[:, :n_nloc, :]
index = ret.reshape(n_nf, n_nloc * n_nsel, 1)
index = tnp.repeat(index, 3, axis=2)
coord1 = tnp.take_along_axis(extended_coord, index, axis=1)
coord1 = coord1.reshape(n_nf, n_nloc, n_nsel, 3)
rr2 = tnp.sum(tnp.square(coord0[:, :, None, :] - coord1), axis=-1)
rr2 = tnp.where(m_real_nei, rr2, float("inf"))
rr2, ret_mapping = tnp.sort(rr2, axis=-1), tnp.argsort(rr2, axis=-1)
ret = tnp.take_along_axis(ret, ret_mapping, axis=2)
ret = tnp.where(rr2 > rcut * rcut, -1, ret)
ret = ret[..., :nsel]
else: # n_nsel == nsel:
ret = nlist
# do a reshape any way; this will tell the xla the shape without any dynamic shape
ret = tnp.reshape(ret, [n_nf, n_nloc, nsel])
return ret
9 changes: 7 additions & 2 deletions deepmd/jax/jax2tf/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
jax2tf,
)

from deepmd.jax.jax2tf.format_nlist import (
format_nlist,
)
from deepmd.jax.jax2tf.make_model import (
model_call_from_call_lower,
)
Expand Down Expand Up @@ -76,7 +79,7 @@ def call_lower_with_fixed_do_atomic_virial(
input_signature=[
tf.TensorSpec([None, None, 3], tf.float64),
tf.TensorSpec([None, None], tf.int32),
tf.TensorSpec([None, None, model.get_nnei()], tf.int64),
tf.TensorSpec([None, None, None], tf.int64),
tf.TensorSpec([None, None], tf.int64),
tf.TensorSpec([None, model.get_dim_fparam()], tf.float64),
tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64),
Expand All @@ -85,6 +88,7 @@ def call_lower_with_fixed_do_atomic_virial(
def call_lower_without_atomic_virial(
coord, atype, nlist, mapping, fparam, aparam
):
nlist = format_nlist(coord, nlist, model.get_nnei(), model.get_rcut())
return tf.cond(
tf.shape(coord)[1] == tf.shape(nlist)[1],
lambda: exported_whether_do_atomic_virial(
Expand All @@ -102,13 +106,14 @@ def call_lower_without_atomic_virial(
input_signature=[
tf.TensorSpec([None, None, 3], tf.float64),
tf.TensorSpec([None, None], tf.int32),
tf.TensorSpec([None, None, model.get_nnei()], tf.int64),
tf.TensorSpec([None, None, None], tf.int64),
tf.TensorSpec([None, None], tf.int64),
tf.TensorSpec([None, model.get_dim_fparam()], tf.float64),
tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64),
],
)
def call_lower_with_atomic_virial(coord, atype, nlist, mapping, fparam, aparam):
nlist = format_nlist(coord, nlist, model.get_nnei(), model.get_rcut())
return tf.cond(
tf.shape(coord)[1] == tf.shape(nlist)[1],
lambda: exported_whether_do_atomic_virial(
Expand Down
91 changes: 91 additions & 0 deletions source/jax2tf_tests/test_format_nlist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import tensorflow as tf
import tensorflow.experimental.numpy as tnp

from deepmd.jax.jax2tf.format_nlist import (
format_nlist,
)
from deepmd.jax.jax2tf.nlist import (
build_neighbor_list,
extend_coord_with_ghosts,
)

GLOBAL_SEED = 20241110


class TestFormatNlist(tf.test.TestCase):
def setUp(self):
self.nf = 3
self.nloc = 3
self.ns = 5 * 5 * 3
self.nall = self.ns * self.nloc
self.cell = tnp.array(
[[[1, 0, 0], [0.4, 0.8, 0], [0.1, 0.3, 2.1]]], dtype=tnp.float64
)
self.icoord = tnp.array(
[[[0.035, 0.062, 0.064], [0.085, 0.058, 0.021], [0.537, 0.553, 0.124]]],
dtype=tnp.float64,
)
self.atype = tnp.array([[1, 0, 1]], dtype=tnp.int32)
self.nsel = [10, 10]
self.rcut = 1.01

self.ecoord, self.eatype, mapping = extend_coord_with_ghosts(
self.icoord, self.atype, self.cell, self.rcut
)
self.nlist = build_neighbor_list(
self.ecoord,
self.eatype,
self.nloc,
self.rcut,
sum(self.nsel),
distinguish_types=False,
)

def test_format_nlist_equal(self):
nlist = format_nlist(self.ecoord, self.nlist, sum(self.nsel), self.rcut)
self.assertAllEqual(nlist, self.nlist)

def test_format_nlist_less(self):
nlist = build_neighbor_list(
self.ecoord,
self.eatype,
self.nloc,
self.rcut,
sum(self.nsel) - 5,
distinguish_types=False,
)
nlist = format_nlist(self.ecoord, nlist, sum(self.nsel), self.rcut)
self.assertAllEqual(nlist, self.nlist)

def test_format_nlist_large(self):
nlist = build_neighbor_list(
self.ecoord,
self.eatype,
self.nloc,
self.rcut,
sum(self.nsel) + 5,
distinguish_types=False,
)
# random shuffle
shuffle_idx = tf.random.shuffle(tf.range(nlist.shape[2]))
nlist = tnp.take(nlist, shuffle_idx, axis=2)
nlist = format_nlist(self.ecoord, nlist, sum(self.nsel), self.rcut)
# we only need to ensure the result is correct, no need to check the order
self.assertAllEqual(tnp.sort(nlist, axis=-1), tnp.sort(self.nlist, axis=-1))

def test_format_nlist_larger_rcut(self):
nlist = build_neighbor_list(
self.ecoord,
self.eatype,
self.nloc,
self.rcut * 2,
40,
distinguish_types=False,
)
# random shuffle
shuffle_idx = tf.random.shuffle(tf.range(nlist.shape[2]))
nlist = tnp.take(nlist, shuffle_idx, axis=2)
nlist = format_nlist(self.ecoord, nlist, sum(self.nsel), self.rcut)
# we only need to ensure the result is correct, no need to check the order
self.assertAllEqual(tnp.sort(nlist, axis=-1), tnp.sort(self.nlist, axis=-1))

0 comments on commit 0ccad5b

Please sign in to comment.