From 0d7c740807d641f7be68d942fd8dd87bd2cf260a Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 3 Nov 2024 04:55:02 -0500 Subject: [PATCH] revert make_model.py Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/model/make_model.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/deepmd/dpmodel/model/make_model.py b/deepmd/dpmodel/model/make_model.py index 98a93c7500..95d97262df 100644 --- a/deepmd/dpmodel/model/make_model.py +++ b/deepmd/dpmodel/model/make_model.py @@ -503,14 +503,11 @@ def _format_nlist( index = ret.reshape(n_nf, n_nloc * n_nnei, 1).repeat(3, axis=2) coord1 = xp.take_along_axis(extended_coord, index, axis=1) coord1 = coord1.reshape(n_nf, n_nloc, n_nnei, 3) - # jax raises NaN error using norm - # but note: we don't actually need to sqrt here; the squared value is enough - # rr = xp.linalg.norm(coord0[:, :, None, :] - coord1, axis=-1) - rr2 = xp.sum(xp.square(coord0[:, :, None, :] - coord1), axis=-1) - rr2 = xp.where(m_real_nei, rr2, float("inf")) - rr2, ret_mapping = xp.sort(rr2, axis=-1), xp.argsort(rr2, axis=-1) + rr = xp.linalg.norm(coord0[:, :, None, :] - coord1, axis=-1) + rr = xp.where(m_real_nei, rr, float("inf")) + rr, ret_mapping = xp.sort(rr, axis=-1), xp.argsort(rr, axis=-1) ret = xp.take_along_axis(ret, ret_mapping, axis=2) - ret = xp.where(rr2 > rcut * rcut, -1, ret) + ret = xp.where(rr > rcut, -1, ret) ret = ret[..., :nnei] # not extra_nlist_sort and n_nnei <= nnei: elif n_nnei == nnei: