Skip to content

Commit

Permalink
feat(jax): reformat nlist in the TF model (#4336)
Browse files Browse the repository at this point in the history
Reformat the neighbor list in the TF model to convert the dynamic shape
to the determined shape so the TF model can accept the neighbor list
with a dynamic shape.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **New Features**
- Introduced a new function to format neighbor lists based on selected
neighbors and cutoff radius.
- Enhanced deserialization process to incorporate the new formatting
function for improved neighbor list handling.

- **Tests**
- Added a new test suite for the neighbor list formatting function,
ensuring its functionality under various scenarios.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Nov 12, 2024
1 parent 4a9ed88 commit 560d82e
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 2 deletions.
71 changes: 71 additions & 0 deletions deepmd/jax/jax2tf/format_nlist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import tensorflow as tf
import tensorflow.experimental.numpy as tnp


@tf.function(autograph=True)
def format_nlist(
extended_coord: tnp.ndarray,
nlist: tnp.ndarray,
nsel: int,
rcut: float,
):
"""Format neighbor list.
If nnei == nsel, do nothing;
If nnei < nsel, pad -1;
If nnei > nsel, sort by distance and truncate.
Parameters
----------
extended_coord
The extended coordinates of the atoms.
shape: nf x nall x 3
nlist
The neighbor list.
shape: nf x nloc x nnei
nsel
The number of selected neighbors.
rcut
The cutoff radius.
Returns
-------
nlist
The formatted neighbor list.
shape: nf x nloc x nsel
"""
nlist_shape = tf.shape(nlist)
n_nf, n_nloc, n_nsel = nlist_shape[0], nlist_shape[1], nlist_shape[2]
extended_coord = extended_coord.reshape([n_nf, -1, 3])

if n_nsel < nsel:
# make a copy before revise
ret = tnp.concatenate(
[
nlist,
tnp.full([n_nf, n_nloc, nsel - n_nsel], -1, dtype=nlist.dtype),
],
axis=-1,
)

elif n_nsel > nsel:
# make a copy before revise
m_real_nei = nlist >= 0
ret = tnp.where(m_real_nei, nlist, 0)
coord0 = extended_coord[:, :n_nloc, :]
index = ret.reshape(n_nf, n_nloc * n_nsel, 1)
index = tnp.repeat(index, 3, axis=2)
coord1 = tnp.take_along_axis(extended_coord, index, axis=1)
coord1 = coord1.reshape(n_nf, n_nloc, n_nsel, 3)
rr2 = tnp.sum(tnp.square(coord0[:, :, None, :] - coord1), axis=-1)
rr2 = tnp.where(m_real_nei, rr2, float("inf"))
rr2, ret_mapping = tnp.sort(rr2, axis=-1), tnp.argsort(rr2, axis=-1)
ret = tnp.take_along_axis(ret, ret_mapping, axis=2)
ret = tnp.where(rr2 > rcut * rcut, -1, ret)
ret = ret[..., :nsel]
else: # n_nsel == nsel:
ret = nlist
# do a reshape any way; this will tell the xla the shape without any dynamic shape
ret = tnp.reshape(ret, [n_nf, n_nloc, nsel])
return ret
9 changes: 7 additions & 2 deletions deepmd/jax/jax2tf/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
jax2tf,
)

from deepmd.jax.jax2tf.format_nlist import (
format_nlist,
)
from deepmd.jax.jax2tf.make_model import (
model_call_from_call_lower,
)
Expand Down Expand Up @@ -76,7 +79,7 @@ def call_lower_with_fixed_do_atomic_virial(
input_signature=[
tf.TensorSpec([None, None, 3], tf.float64),
tf.TensorSpec([None, None], tf.int32),
tf.TensorSpec([None, None, model.get_nnei()], tf.int64),
tf.TensorSpec([None, None, None], tf.int64),
tf.TensorSpec([None, None], tf.int64),
tf.TensorSpec([None, model.get_dim_fparam()], tf.float64),
tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64),
Expand All @@ -85,6 +88,7 @@ def call_lower_with_fixed_do_atomic_virial(
def call_lower_without_atomic_virial(
coord, atype, nlist, mapping, fparam, aparam
):
nlist = format_nlist(coord, nlist, model.get_nnei(), model.get_rcut())
return tf.cond(
tf.shape(coord)[1] == tf.shape(nlist)[1],
lambda: exported_whether_do_atomic_virial(
Expand All @@ -102,13 +106,14 @@ def call_lower_without_atomic_virial(
input_signature=[
tf.TensorSpec([None, None, 3], tf.float64),
tf.TensorSpec([None, None], tf.int32),
tf.TensorSpec([None, None, model.get_nnei()], tf.int64),
tf.TensorSpec([None, None, None], tf.int64),
tf.TensorSpec([None, None], tf.int64),
tf.TensorSpec([None, model.get_dim_fparam()], tf.float64),
tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64),
],
)
def call_lower_with_atomic_virial(coord, atype, nlist, mapping, fparam, aparam):
nlist = format_nlist(coord, nlist, model.get_nnei(), model.get_rcut())
return tf.cond(
tf.shape(coord)[1] == tf.shape(nlist)[1],
lambda: exported_whether_do_atomic_virial(
Expand Down
91 changes: 91 additions & 0 deletions source/jax2tf_tests/test_format_nlist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import tensorflow as tf
import tensorflow.experimental.numpy as tnp

from deepmd.jax.jax2tf.format_nlist import (
format_nlist,
)
from deepmd.jax.jax2tf.nlist import (
build_neighbor_list,
extend_coord_with_ghosts,
)

GLOBAL_SEED = 20241110


class TestFormatNlist(tf.test.TestCase):
def setUp(self):
self.nf = 3
self.nloc = 3
self.ns = 5 * 5 * 3
self.nall = self.ns * self.nloc
self.cell = tnp.array(
[[[1, 0, 0], [0.4, 0.8, 0], [0.1, 0.3, 2.1]]], dtype=tnp.float64
)
self.icoord = tnp.array(
[[[0.035, 0.062, 0.064], [0.085, 0.058, 0.021], [0.537, 0.553, 0.124]]],
dtype=tnp.float64,
)
self.atype = tnp.array([[1, 0, 1]], dtype=tnp.int32)
self.nsel = [10, 10]
self.rcut = 1.01

self.ecoord, self.eatype, mapping = extend_coord_with_ghosts(
self.icoord, self.atype, self.cell, self.rcut
)
self.nlist = build_neighbor_list(
self.ecoord,
self.eatype,
self.nloc,
self.rcut,
sum(self.nsel),
distinguish_types=False,
)

def test_format_nlist_equal(self):
nlist = format_nlist(self.ecoord, self.nlist, sum(self.nsel), self.rcut)
self.assertAllEqual(nlist, self.nlist)

def test_format_nlist_less(self):
nlist = build_neighbor_list(
self.ecoord,
self.eatype,
self.nloc,
self.rcut,
sum(self.nsel) - 5,
distinguish_types=False,
)
nlist = format_nlist(self.ecoord, nlist, sum(self.nsel), self.rcut)
self.assertAllEqual(nlist, self.nlist)

def test_format_nlist_large(self):
nlist = build_neighbor_list(
self.ecoord,
self.eatype,
self.nloc,
self.rcut,
sum(self.nsel) + 5,
distinguish_types=False,
)
# random shuffle
shuffle_idx = tf.random.shuffle(tf.range(nlist.shape[2]))
nlist = tnp.take(nlist, shuffle_idx, axis=2)
nlist = format_nlist(self.ecoord, nlist, sum(self.nsel), self.rcut)
# we only need to ensure the result is correct, no need to check the order
self.assertAllEqual(tnp.sort(nlist, axis=-1), tnp.sort(self.nlist, axis=-1))

def test_format_nlist_larger_rcut(self):
nlist = build_neighbor_list(
self.ecoord,
self.eatype,
self.nloc,
self.rcut * 2,
40,
distinguish_types=False,
)
# random shuffle
shuffle_idx = tf.random.shuffle(tf.range(nlist.shape[2]))
nlist = tnp.take(nlist, shuffle_idx, axis=2)
nlist = format_nlist(self.ecoord, nlist, sum(self.nsel), self.rcut)
# we only need to ensure the result is correct, no need to check the order
self.assertAllEqual(tnp.sort(nlist, axis=-1), tnp.sort(self.nlist, axis=-1))

0 comments on commit 560d82e

Please sign in to comment.