diff --git a/deepmd/jax/jax2tf/format_nlist.py b/deepmd/jax/jax2tf/format_nlist.py new file mode 100644 index 0000000000..f0c630206f --- /dev/null +++ b/deepmd/jax/jax2tf/format_nlist.py @@ -0,0 +1,71 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import tensorflow as tf +import tensorflow.experimental.numpy as tnp + + +@tf.function(autograph=True) +def format_nlist( + extended_coord: tnp.ndarray, + nlist: tnp.ndarray, + nsel: int, + rcut: float, +): + """Format neighbor list. + + If nnei == nsel, do nothing; + If nnei < nsel, pad -1; + If nnei > nsel, sort by distance and truncate. + + Parameters + ---------- + extended_coord + The extended coordinates of the atoms. + shape: nf x nall x 3 + nlist + The neighbor list. + shape: nf x nloc x nnei + nsel + The number of selected neighbors. + rcut + The cutoff radius. + + Returns + ------- + nlist + The formatted neighbor list. + shape: nf x nloc x nsel + """ + nlist_shape = tf.shape(nlist) + n_nf, n_nloc, n_nsel = nlist_shape[0], nlist_shape[1], nlist_shape[2] + extended_coord = extended_coord.reshape([n_nf, -1, 3]) + + if n_nsel < nsel: + # make a copy before revise + ret = tnp.concatenate( + [ + nlist, + tnp.full([n_nf, n_nloc, nsel - n_nsel], -1, dtype=nlist.dtype), + ], + axis=-1, + ) + + elif n_nsel > nsel: + # make a copy before revise + m_real_nei = nlist >= 0 + ret = tnp.where(m_real_nei, nlist, 0) + coord0 = extended_coord[:, :n_nloc, :] + index = ret.reshape(n_nf, n_nloc * n_nsel, 1) + index = tnp.repeat(index, 3, axis=2) + coord1 = tnp.take_along_axis(extended_coord, index, axis=1) + coord1 = coord1.reshape(n_nf, n_nloc, n_nsel, 3) + rr2 = tnp.sum(tnp.square(coord0[:, :, None, :] - coord1), axis=-1) + rr2 = tnp.where(m_real_nei, rr2, float("inf")) + rr2, ret_mapping = tnp.sort(rr2, axis=-1), tnp.argsort(rr2, axis=-1) + ret = tnp.take_along_axis(ret, ret_mapping, axis=2) + ret = tnp.where(rr2 > rcut * rcut, -1, ret) + ret = ret[..., :nsel] + else: # n_nsel == nsel: + ret = nlist + # do a reshape any way; this will tell the xla the shape without any dynamic shape + ret = tnp.reshape(ret, [n_nf, n_nloc, nsel]) + return ret diff --git a/deepmd/jax/jax2tf/serialization.py b/deepmd/jax/jax2tf/serialization.py index 7e560f6008..2af8656572 100644 --- a/deepmd/jax/jax2tf/serialization.py +++ b/deepmd/jax/jax2tf/serialization.py @@ -10,6 +10,9 @@ jax2tf, ) +from deepmd.jax.jax2tf.format_nlist import ( + format_nlist, +) from deepmd.jax.jax2tf.make_model import ( model_call_from_call_lower, ) @@ -76,7 +79,7 @@ def call_lower_with_fixed_do_atomic_virial( input_signature=[ tf.TensorSpec([None, None, 3], tf.float64), tf.TensorSpec([None, None], tf.int32), - tf.TensorSpec([None, None, model.get_nnei()], tf.int64), + tf.TensorSpec([None, None, None], tf.int64), tf.TensorSpec([None, None], tf.int64), tf.TensorSpec([None, model.get_dim_fparam()], tf.float64), tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64), @@ -85,6 +88,7 @@ def call_lower_with_fixed_do_atomic_virial( def call_lower_without_atomic_virial( coord, atype, nlist, mapping, fparam, aparam ): + nlist = format_nlist(coord, nlist, model.get_nnei(), model.get_rcut()) return tf.cond( tf.shape(coord)[1] == tf.shape(nlist)[1], lambda: exported_whether_do_atomic_virial( @@ -102,13 +106,14 @@ def call_lower_without_atomic_virial( input_signature=[ tf.TensorSpec([None, None, 3], tf.float64), tf.TensorSpec([None, None], tf.int32), - tf.TensorSpec([None, None, model.get_nnei()], tf.int64), + tf.TensorSpec([None, None, None], tf.int64), tf.TensorSpec([None, None], tf.int64), tf.TensorSpec([None, model.get_dim_fparam()], tf.float64), tf.TensorSpec([None, None, model.get_dim_aparam()], tf.float64), ], ) def call_lower_with_atomic_virial(coord, atype, nlist, mapping, fparam, aparam): + nlist = format_nlist(coord, nlist, model.get_nnei(), model.get_rcut()) return tf.cond( tf.shape(coord)[1] == tf.shape(nlist)[1], lambda: exported_whether_do_atomic_virial( diff --git a/source/jax2tf_tests/test_format_nlist.py b/source/jax2tf_tests/test_format_nlist.py new file mode 100644 index 0000000000..9b95db2cf1 --- /dev/null +++ b/source/jax2tf_tests/test_format_nlist.py @@ -0,0 +1,91 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import tensorflow as tf +import tensorflow.experimental.numpy as tnp + +from deepmd.jax.jax2tf.format_nlist import ( + format_nlist, +) +from deepmd.jax.jax2tf.nlist import ( + build_neighbor_list, + extend_coord_with_ghosts, +) + +GLOBAL_SEED = 20241110 + + +class TestFormatNlist(tf.test.TestCase): + def setUp(self): + self.nf = 3 + self.nloc = 3 + self.ns = 5 * 5 * 3 + self.nall = self.ns * self.nloc + self.cell = tnp.array( + [[[1, 0, 0], [0.4, 0.8, 0], [0.1, 0.3, 2.1]]], dtype=tnp.float64 + ) + self.icoord = tnp.array( + [[[0.035, 0.062, 0.064], [0.085, 0.058, 0.021], [0.537, 0.553, 0.124]]], + dtype=tnp.float64, + ) + self.atype = tnp.array([[1, 0, 1]], dtype=tnp.int32) + self.nsel = [10, 10] + self.rcut = 1.01 + + self.ecoord, self.eatype, mapping = extend_coord_with_ghosts( + self.icoord, self.atype, self.cell, self.rcut + ) + self.nlist = build_neighbor_list( + self.ecoord, + self.eatype, + self.nloc, + self.rcut, + sum(self.nsel), + distinguish_types=False, + ) + + def test_format_nlist_equal(self): + nlist = format_nlist(self.ecoord, self.nlist, sum(self.nsel), self.rcut) + self.assertAllEqual(nlist, self.nlist) + + def test_format_nlist_less(self): + nlist = build_neighbor_list( + self.ecoord, + self.eatype, + self.nloc, + self.rcut, + sum(self.nsel) - 5, + distinguish_types=False, + ) + nlist = format_nlist(self.ecoord, nlist, sum(self.nsel), self.rcut) + self.assertAllEqual(nlist, self.nlist) + + def test_format_nlist_large(self): + nlist = build_neighbor_list( + self.ecoord, + self.eatype, + self.nloc, + self.rcut, + sum(self.nsel) + 5, + distinguish_types=False, + ) + # random shuffle + shuffle_idx = tf.random.shuffle(tf.range(nlist.shape[2])) + nlist = tnp.take(nlist, shuffle_idx, axis=2) + nlist = format_nlist(self.ecoord, nlist, sum(self.nsel), self.rcut) + # we only need to ensure the result is correct, no need to check the order + self.assertAllEqual(tnp.sort(nlist, axis=-1), tnp.sort(self.nlist, axis=-1)) + + def test_format_nlist_larger_rcut(self): + nlist = build_neighbor_list( + self.ecoord, + self.eatype, + self.nloc, + self.rcut * 2, + 40, + distinguish_types=False, + ) + # random shuffle + shuffle_idx = tf.random.shuffle(tf.range(nlist.shape[2])) + nlist = tnp.take(nlist, shuffle_idx, axis=2) + nlist = format_nlist(self.ecoord, nlist, sum(self.nsel), self.rcut) + # we only need to ensure the result is correct, no need to check the order + self.assertAllEqual(tnp.sort(nlist, axis=-1), tnp.sort(self.nlist, axis=-1))