From e7aeca024da0ca4dad30f4bcf428e730d49faf64 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 25 Sep 2024 03:44:39 -0400 Subject: [PATCH] fix reshape Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/utils/exclude_mask.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/deepmd/dpmodel/utils/exclude_mask.py b/deepmd/dpmodel/utils/exclude_mask.py index a8e8dc7ef3..e744a726f6 100644 --- a/deepmd/dpmodel/utils/exclude_mask.py +++ b/deepmd/dpmodel/utils/exclude_mask.py @@ -127,7 +127,8 @@ def build_type_exclude_mask( index = xp.reshape( xp.where(nlist == -1, xp.full_like(nlist, nall), nlist), (nf, nloc * nnei) ) - type_j = xp_take_along_axis(ae, index, axis=1).reshape(nf, nloc, nnei) + type_j = xp_take_along_axis(ae, index, axis=1) + type_j = xp.reshape(type_j, (nf, nloc, nnei)) type_ij = type_i[:, :, None] + type_j # nf x (nloc x nnei) type_ij = xp.reshape(type_ij, (nf, nloc * nnei))