From 1eecf104564e9d7f850fc3e2b469036908d9d807 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 11 Nov 2024 00:01:37 -0500 Subject: [PATCH] feat(jax): reformat nlist in the TF model Format the neighbor list in the TF model to convert the dynamic shape to the determined shape, so the TF model can accept the neighbor list with a dynamic shape. Signed-off-by: Jinzhe Zeng --- deepmd/jax/jax2tf/format_nlist.py | 71 ++++++++++++++++++ deepmd/jax/jax2tf/serialization.py | 9 ++- source/jax2tf_tests/test_format_nlist.py | 91 ++++++++++++++++++++++++ 3 files changed, 169 insertions(+), 2 deletions(-) create mode 100644 deepmd/jax/jax2tf/format_nlist.py create mode 100644 source/jax2tf_tests/test_format_nlist.py diff --git a/deepmd/jax/jax2tf/format_nlist.py b/deepmd/jax/jax2tf/format_nlist.py new file mode 100644 index 0000000000..f0c630206f --- /dev/null +++ b/deepmd/jax/jax2tf/format_nlist.py @@ -0,0 +1,71 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import tensorflow as tf +import tensorflow.experimental.numpy as tnp + + +@tf.function(autograph=True) +def format_nlist( + extended_coord: tnp.ndarray, + nlist: tnp.ndarray, + nsel: int, + rcut: float, +): + """Format neighbor list. + + If nnei == nsel, do nothing; + If nnei < nsel, pad -1; + If nnei > nsel, sort by distance and truncate. + + Parameters + ---------- + extended_coord + The extended coordinates of the atoms. + shape: nf x nall x 3 + nlist + The neighbor list. + shape: nf x nloc x nnei + nsel + The number of selected neighbors. + rcut + The cutoff radius. + + Returns + ------- + nlist + The formatted neighbor list. + shape: nf x nloc x nsel + """ + nlist_shape = tf.shape(nlist) + n_nf, n_nloc, n_nsel = nlist_shape[0], nlist_shape[1], nlist_shape[2] + extended_coord = extended_coord.reshape([n_nf, -1, 3]) + + if n_nsel < nsel: + # make a copy before revise + ret = tnp.concatenate( + [ + nlist, + tnp.full([n_nf, n_nloc, nsel - n_nsel], -1, dtype=nlist.dtype), + ], + axis=-1, + ) + + elif n_nsel > nsel: + # make a copy before revise + m_real_nei = nlist >= 0 + ret = tnp.where(m_real_nei, nlist, 0) + coord0 = extended_coord[:, :n_nloc, :] + index = ret.reshape(n_nf, n_nloc * n_nsel, 1) + index = tnp.repeat(index, 3, axis=2) + coord1 = tnp.take_along_axis(extended_coord, index, axis=1) + coord1 = coord1.reshape(n_nf, n_nloc, n_nsel, 3) + rr2 = tnp.sum(tnp.square(coord0[:, :, None, :] - coord1), axis=-1) + rr2 = tnp.where(m_real_nei, rr2, float("inf")) + rr2, ret_mapping = tnp.sort(rr2, axis=-1), tnp.argsort(rr2, axis=-1) + ret = tnp.take_along_axis(ret, ret_mapping, axis=2) + ret = tnp.where(rr2 > rcut * rcut, -1, ret) + ret = ret[..., :nsel] + else: # n_nsel == nsel: + ret = nlist + # do a reshape any way; this will tell the xla the shape without any dynamic shape + ret = tnp.reshape(ret, [n_nf, n_nloc, nsel]) + return ret diff --git a/deepmd/jax/jax2tf/serialization.py b/deepmd/jax/jax2tf/serialization.py index 7e560f6008..2af8656572 100644 --- a/deepmd/jax/jax2tf/serialization.py +++ b/deepmd/jax/jax2tf/serialization.py @@ -10,6 +10,9 @@ jax2tf, ) +from deepmd.jax.jax2tf.format_nlist import ( + format_nlist, +) from deepmd.jax.jax2tf.make_model import ( model_call_from_call_lower, ) @@ -76,7 +79,7 @@ def call_lower_with_fixed_do_atomic_virial( input_signature=[ tf.TensorSpec([None, None, 3], tf.float64), tf.TensorSpec([None, None], tf.int32), - tf.TensorSpec([None, None, model.get_nnei()], tf.int64), + tf.TensorSpec([None, None, None], tf.int64), tf.TensorSpec([None, None], tf.int64), tf.TensorSpec([None, model.get_dim_fparam()], tf.float64), tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64), @@ -85,6 +88,7 @@ def call_lower_with_fixed_do_atomic_virial( def call_lower_without_atomic_virial( coord, atype, nlist, mapping, fparam, aparam ): + nlist = format_nlist(coord, nlist, model.get_nnei(), model.get_rcut()) return tf.cond( tf.shape(coord)[1] == tf.shape(nlist)[1], lambda: exported_whether_do_atomic_virial( @@ -102,13 +106,14 @@ def call_lower_without_atomic_virial( input_signature=[ tf.TensorSpec([None, None, 3], tf.float64), tf.TensorSpec([None, None], tf.int32), - tf.TensorSpec([None, None, model.get_nnei()], tf.int64), + tf.TensorSpec([None, None, None], tf.int64), tf.TensorSpec([None, None], tf.int64), tf.TensorSpec([None, model.get_dim_fparam()], tf.float64), tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64), ], ) def call_lower_with_atomic_virial(coord, atype, nlist, mapping, fparam, aparam): + nlist = format_nlist(coord, nlist, model.get_nnei(), model.get_rcut()) return tf.cond( tf.shape(coord)[1] == tf.shape(nlist)[1], lambda: exported_whether_do_atomic_virial( diff --git a/source/jax2tf_tests/test_format_nlist.py b/source/jax2tf_tests/test_format_nlist.py new file mode 100644 index 0000000000..9b95db2cf1 --- /dev/null +++ b/source/jax2tf_tests/test_format_nlist.py @@ -0,0 +1,91 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import tensorflow as tf +import tensorflow.experimental.numpy as tnp + +from deepmd.jax.jax2tf.format_nlist import ( + format_nlist, +) +from deepmd.jax.jax2tf.nlist import ( + build_neighbor_list, + extend_coord_with_ghosts, +) + +GLOBAL_SEED = 20241110 + + +class TestFormatNlist(tf.test.TestCase): + def setUp(self): + self.nf = 3 + self.nloc = 3 + self.ns = 5 * 5 * 3 + self.nall = self.ns * self.nloc + self.cell = tnp.array( + [[[1, 0, 0], [0.4, 0.8, 0], [0.1, 0.3, 2.1]]], dtype=tnp.float64 + ) + self.icoord = tnp.array( + [[[0.035, 0.062, 0.064], [0.085, 0.058, 0.021], [0.537, 0.553, 0.124]]], + dtype=tnp.float64, + ) + self.atype = tnp.array([[1, 0, 1]], dtype=tnp.int32) + self.nsel = [10, 10] + self.rcut = 1.01 + + self.ecoord, self.eatype, mapping = extend_coord_with_ghosts( + self.icoord, self.atype, self.cell, self.rcut + ) + self.nlist = build_neighbor_list( + self.ecoord, + self.eatype, + self.nloc, + self.rcut, + sum(self.nsel), + distinguish_types=False, + ) + + def test_format_nlist_equal(self): + nlist = format_nlist(self.ecoord, self.nlist, sum(self.nsel), self.rcut) + self.assertAllEqual(nlist, self.nlist) + + def test_format_nlist_less(self): + nlist = build_neighbor_list( + self.ecoord, + self.eatype, + self.nloc, + self.rcut, + sum(self.nsel) - 5, + distinguish_types=False, + ) + nlist = format_nlist(self.ecoord, nlist, sum(self.nsel), self.rcut) + self.assertAllEqual(nlist, self.nlist) + + def test_format_nlist_large(self): + nlist = build_neighbor_list( + self.ecoord, + self.eatype, + self.nloc, + self.rcut, + sum(self.nsel) + 5, + distinguish_types=False, + ) + # random shuffle + shuffle_idx = tf.random.shuffle(tf.range(nlist.shape[2])) + nlist = tnp.take(nlist, shuffle_idx, axis=2) + nlist = format_nlist(self.ecoord, nlist, sum(self.nsel), self.rcut) + # we only need to ensure the result is correct, no need to check the order + self.assertAllEqual(tnp.sort(nlist, axis=-1), tnp.sort(self.nlist, axis=-1)) + + def test_format_nlist_larger_rcut(self): + nlist = build_neighbor_list( + self.ecoord, + self.eatype, + self.nloc, + self.rcut * 2, + 40, + distinguish_types=False, + ) + # random shuffle + shuffle_idx = tf.random.shuffle(tf.range(nlist.shape[2])) + nlist = tnp.take(nlist, shuffle_idx, axis=2) + nlist = format_nlist(self.ecoord, nlist, sum(self.nsel), self.rcut) + # we only need to ensure the result is correct, no need to check the order + self.assertAllEqual(tnp.sort(nlist, axis=-1), tnp.sort(self.nlist, axis=-1))