Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Nov 19, 2024
1 parent 1dc7fd9 commit 0afee21
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
5 changes: 4 additions & 1 deletion deepmd/dpmodel/atomic_model/polar_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ def apply_out_stat(
for kk in self.bias_keys:
ntypes = out_bias[kk].shape[0]
temp = np.zeros(ntypes, dtype=dtype)
temp = np.mean(np.diagonal(out_bias[kk].reshape(ntypes, 3, 3), axis1=1, axis2=2), axis=1)
temp = np.mean(
np.diagonal(out_bias[kk].reshape(ntypes, 3, 3), axis1=1, axis2=2),
axis=1,
)
modified_bias = temp[atype]

# (nframes, nloc, 1)
Expand Down
7 changes: 6 additions & 1 deletion deepmd/pt/model/atomic_model/polar_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,12 @@ def apply_out_stat(
for kk in self.bias_keys:
ntypes = out_bias[kk].shape[0]
temp = torch.zeros(ntypes, dtype=dtype, device=device)
temp = torch.mean(torch.diagonal(out_bias[kk].reshape(ntypes, 3, 3), dim1=-2, dim2=-1), dim=-1)
temp = torch.mean(
torch.diagonal(
out_bias[kk].reshape(ntypes, 3, 3), dim1=-2, dim2=-1
),
dim=-1,
)
modified_bias = temp[atype]

# (nframes, nloc, 1)
Expand Down

0 comments on commit 0afee21

Please sign in to comment.