From 48c88180356d881cff00203e590cb62063fb25ef Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 28 Feb 2024 21:32:51 -0500 Subject: [PATCH] pt: fix se_a type_one_side performance degradation (#3361) The code in this PR is ugly, but applying a mask is causing performance degradation for ~3 ms/step. When applying a mask, `aten::nonzero` has a high host time, as it causes host-device synchronization: ![image](https://github.com/deepmodeling/deepmd-kit/assets/9496702/86b3518c-206d-410d-928e-2f605746147c) After fixing: ![image](https://github.com/deepmodeling/deepmd-kit/assets/9496702/af9e86fa-7908-4bbb-ace7-58b4602e167f) See https://github.com/pytorch/pytorch/issues/12461 for more information. Signed-off-by: Jinzhe Zeng --- deepmd/pt/model/descriptor/se_a.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index 6c29636d6d..8a211c977d 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -472,23 +472,34 @@ def forward( if self.type_one_side: ii = embedding_idx # torch.jit is not happy with slice(None) - ti_mask = torch.ones(nfnl, dtype=torch.bool, device=dmatrix.device) + # ti_mask = torch.ones(nfnl, dtype=torch.bool, device=dmatrix.device) + # applying a mask seems to cause performance degradation + ti_mask = None else: # ti: center atom type, ii: neighbor type... ii = embedding_idx // self.ntypes ti = embedding_idx % self.ntypes ti_mask = atype.ravel().eq(ti) # nfnl x nt - mm = exclude_mask[ti_mask, self.sec[ii] : self.sec[ii + 1]] + if ti_mask is not None: + mm = exclude_mask[ti_mask, self.sec[ii] : self.sec[ii + 1]] + else: + mm = exclude_mask[:, self.sec[ii] : self.sec[ii + 1]] # nfnl x nt x 4 - rr = dmatrix[ti_mask, self.sec[ii] : self.sec[ii + 1], :] + if ti_mask is not None: + rr = dmatrix[ti_mask, self.sec[ii] : self.sec[ii + 1], :] + else: + rr = dmatrix[:, self.sec[ii] : self.sec[ii + 1], :] rr = rr * mm[:, :, None] ss = rr[:, :, :1] # nfnl x nt x ng gg = ll.forward(ss) # nfnl x 4 x ng gr = torch.matmul(rr.permute(0, 2, 1), gg) - xyz_scatter[ti_mask] += gr + if ti_mask is not None: + xyz_scatter[ti_mask] += gr + else: + xyz_scatter += gr xyz_scatter /= self.nnei xyz_scatter_1 = xyz_scatter.permute(0, 2, 1)