From b810945efbbad9369949e1d6ef626c1117008b9c Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 6 Mar 2024 00:09:36 -0500 Subject: [PATCH] pt: avoid D2H in se_e2_a sec is used as slice index, so it should not stored on the GPU, otherwise D2H will happen to create the tensor with the shape. Signed-off-by: Jinzhe Zeng --- deepmd/pt/model/descriptor/se_a.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index 3a18f150a4..c4b2c772f8 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -342,9 +342,8 @@ def __init__( self.reinit_exclude(exclude_types) self.sel = sel - self.sec = torch.tensor( - np.append([0], np.cumsum(self.sel)), dtype=int, device=env.DEVICE - ) + # should be on CPU to avoid D2H, as it is used as slice index + self.sec = [0, *np.cumsum(self.sel).tolist()] self.split_sel = self.sel self.nnei = sum(sel) self.ndescrpt = self.nnei * 4