diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index 3a18f150a4..c4b2c772f8 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -342,9 +342,8 @@ def __init__( self.reinit_exclude(exclude_types) self.sel = sel - self.sec = torch.tensor( - np.append([0], np.cumsum(self.sel)), dtype=int, device=env.DEVICE - ) + # should be on CPU to avoid D2H, as it is used as slice index + self.sec = [0, *np.cumsum(self.sel).tolist()] self.split_sel = self.sel self.nnei = sum(sel) self.ndescrpt = self.nnei * 4