Skip to content

Commit

Permalink
move torch.eq out of loop
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Feb 18, 2024
1 parent 117acd9 commit 52fbe03
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions deepmd/pt/utils/env_mat_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,16 @@ def iter(
coord.shape[0] * coord.shape[1], self.descriptor.get_nsel(), 4
)
atype = atype.view(coord.shape[0] * coord.shape[1])

Check warning on line 136 in deepmd/pt/utils/env_mat_stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/env_mat_stat.py#L136

Added line #L136 was not covered by tests
# (1, nloc) eq (ntypes, 1), so broadcast is possible
# shape: (ntypes, nloc)
type_idx = torch.eq(

Check warning on line 139 in deepmd/pt/utils/env_mat_stat.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/env_mat_stat.py#L139

Added line #L139 was not covered by tests
atype.view(1, -1),
torch.arange(
self.descriptor.get_ntypes(), device=env.DEVICE, dtype=torch.int32
).view(-1, 1),
)
for type_i in range(self.descriptor.get_ntypes()):
type_idx = atype == type_i
dd = env_mat[type_idx]
dd = env_mat[type_idx[type_i]]
dd = dd.reshape([-1, 4]) # typen_atoms * nnei, 4
env_mats = {}
env_mats[f"r_{type_i}"] = dd[:, :1]
Expand Down

0 comments on commit 52fbe03

Please sign in to comment.