From 52fbe03bfad5bfb12f111954a48d1cd76af1402c Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 18 Feb 2024 18:26:08 -0500 Subject: [PATCH] move torch.eq out of loop Signed-off-by: Jinzhe Zeng --- deepmd/pt/utils/env_mat_stat.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/deepmd/pt/utils/env_mat_stat.py b/deepmd/pt/utils/env_mat_stat.py index 78917c5224..2279b8ea5a 100644 --- a/deepmd/pt/utils/env_mat_stat.py +++ b/deepmd/pt/utils/env_mat_stat.py @@ -134,9 +134,16 @@ def iter( coord.shape[0] * coord.shape[1], self.descriptor.get_nsel(), 4 ) atype = atype.view(coord.shape[0] * coord.shape[1]) + # (1, nloc) eq (ntypes, 1), so broadcast is possible + # shape: (ntypes, nloc) + type_idx = torch.eq( + atype.view(1, -1), + torch.arange( + self.descriptor.get_ntypes(), device=env.DEVICE, dtype=torch.int32 + ).view(-1, 1), + ) for type_i in range(self.descriptor.get_ntypes()): - type_idx = atype == type_i - dd = env_mat[type_idx] + dd = env_mat[type_idx[type_i]] dd = dd.reshape([-1, 4]) # typen_atoms * nnei, 4 env_mats = {} env_mats[f"r_{type_i}"] = dd[:, :1]