diff --git a/deepmd/pt/utils/env_mat_stat.py b/deepmd/pt/utils/env_mat_stat.py index 5247ce08ba..78917c5224 100644 --- a/deepmd/pt/utils/env_mat_stat.py +++ b/deepmd/pt/utils/env_mat_stat.py @@ -128,38 +128,20 @@ def iter( # TODO: export rcut_smth from DescriptorBlock self.descriptor.rcut_smth, ) + # reshape to nframes * nloc at the atom level, + # so nframes/mixed_type do not matter env_mat = env_mat.view( - coord.shape[0], coord.shape[1], self.descriptor.get_nsel(), 4 + coord.shape[0] * coord.shape[1], self.descriptor.get_nsel(), 4 ) - - if "real_natoms_vec" not in system: - end_indexes = torch.cumsum(natoms[0, 2:], 0) - start_indexes = torch.cat( - [ - torch.zeros(1, dtype=torch.int32, device=env.DEVICE), - end_indexes[:-1], - ] - ) - for type_i in range(self.descriptor.get_ntypes()): - dd = env_mat[ - :, start_indexes[type_i] : end_indexes[type_i], :, : - ] # all descriptors for this element - env_mats = {} - env_mats[f"r_{type_i}"] = dd[:, :, :, :1] - env_mats[f"a_{type_i}"] = dd[:, :, :, 1:] - yield self.compute_stat(env_mats) - else: - for frame_item in range(env_mat.shape[0]): - dd_ff = env_mat[frame_item] - atype_frame = atype[frame_item] - for type_i in range(self.descriptor.get_ntypes()): - type_idx = atype_frame == type_i - dd = dd_ff[type_idx] - dd = dd.reshape([-1, 4]) # typen_atoms * nnei, 4 - env_mats = {} - env_mats[f"r_{type_i}"] = dd[:, :1] - env_mats[f"a_{type_i}"] = dd[:, 1:] - yield self.compute_stat(env_mats) + atype = atype.view(coord.shape[0] * coord.shape[1]) + for type_i in range(self.descriptor.get_ntypes()): + type_idx = atype == type_i + dd = env_mat[type_idx] + dd = dd.reshape([-1, 4]) # typen_atoms * nnei, 4 + env_mats = {} + env_mats[f"r_{type_i}"] = dd[:, :1] + env_mats[f"a_{type_i}"] = dd[:, 1:] + yield self.compute_stat(env_mats) def get_hash(self) -> str: """Get the hash of the environment matrix.