diff --git a/deepmd/pt/utils/env_mat_stat.py b/deepmd/pt/utils/env_mat_stat.py index 78917c5224..2279b8ea5a 100644 --- a/deepmd/pt/utils/env_mat_stat.py +++ b/deepmd/pt/utils/env_mat_stat.py @@ -134,9 +134,16 @@ def iter( coord.shape[0] * coord.shape[1], self.descriptor.get_nsel(), 4 ) atype = atype.view(coord.shape[0] * coord.shape[1]) + # (1, nloc) eq (ntypes, 1), so broadcast is possible + # shape: (ntypes, nloc) + type_idx = torch.eq( + atype.view(1, -1), + torch.arange( + self.descriptor.get_ntypes(), device=env.DEVICE, dtype=torch.int32 + ).view(-1, 1), + ) for type_i in range(self.descriptor.get_ntypes()): - type_idx = atype == type_i - dd = env_mat[type_idx] + dd = env_mat[type_idx[type_i]] dd = dd.reshape([-1, 4]) # typen_atoms * nnei, 4 env_mats = {} env_mats[f"r_{type_i}"] = dd[:, :1]