Skip to content

Commit

Permalink
pt: process frames in parallel for env mat stat (#3293)
Browse files Browse the repository at this point in the history
Resolves
#3285 (comment).

We don't even need to consider whether the data uses mixed type in this
piece of code. The data can be reshaped to the atom level, and type
masks can be used to get an environmental matrix with certain types. No
frame-level things are involved here.

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Feb 19, 2024
1 parent 9f6ff1e commit ab35468
Showing 1 changed file with 19 additions and 30 deletions.
49 changes: 19 additions & 30 deletions deepmd/pt/utils/env_mat_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,38 +128,27 @@ def iter(
# TODO: export rcut_smth from DescriptorBlock
self.descriptor.rcut_smth,
)
# reshape to nframes * nloc at the atom level,
# so nframes/mixed_type do not matter
env_mat = env_mat.view(
coord.shape[0], coord.shape[1], self.descriptor.get_nsel(), 4
coord.shape[0] * coord.shape[1], self.descriptor.get_nsel(), 4
)

if "real_natoms_vec" not in system:
end_indexes = torch.cumsum(natoms[0, 2:], 0)
start_indexes = torch.cat(
[
torch.zeros(1, dtype=torch.int32, device=env.DEVICE),
end_indexes[:-1],
]
)
for type_i in range(self.descriptor.get_ntypes()):
dd = env_mat[
:, start_indexes[type_i] : end_indexes[type_i], :, :
] # all descriptors for this element
env_mats = {}
env_mats[f"r_{type_i}"] = dd[:, :, :, :1]
env_mats[f"a_{type_i}"] = dd[:, :, :, 1:]
yield self.compute_stat(env_mats)
else:
for frame_item in range(env_mat.shape[0]):
dd_ff = env_mat[frame_item]
atype_frame = atype[frame_item]
for type_i in range(self.descriptor.get_ntypes()):
type_idx = atype_frame == type_i
dd = dd_ff[type_idx]
dd = dd.reshape([-1, 4]) # typen_atoms * nnei, 4
env_mats = {}
env_mats[f"r_{type_i}"] = dd[:, :1]
env_mats[f"a_{type_i}"] = dd[:, 1:]
yield self.compute_stat(env_mats)
atype = atype.view(coord.shape[0] * coord.shape[1])
# (1, nloc) eq (ntypes, 1), so broadcast is possible
# shape: (ntypes, nloc)
type_idx = torch.eq(
atype.view(1, -1),
torch.arange(
self.descriptor.get_ntypes(), device=env.DEVICE, dtype=torch.int32
).view(-1, 1),
)
for type_i in range(self.descriptor.get_ntypes()):
dd = env_mat[type_idx[type_i]]
dd = dd.reshape([-1, 4]) # typen_atoms * nnei, 4
env_mats = {}
env_mats[f"r_{type_i}"] = dd[:, :1]
env_mats[f"a_{type_i}"] = dd[:, 1:]
yield self.compute_stat(env_mats)

def get_hash(self) -> str:
"""Get the hash of the environment matrix.
Expand Down

0 comments on commit ab35468

Please sign in to comment.