Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Mar 6, 2024
1 parent 97548d3 commit 1ae79a9
Showing 1 changed file with 14 additions and 11 deletions.
25 changes: 14 additions & 11 deletions deepmd/pt/model/task/polarizability.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def compute_output_stats(
sampled = merged()
else:
sampled = merged

sys_matrix, polar_bias = [], []
for sys in len(sampled):

Check failure

Code scanning / CodeQL

Non-iterable used in for loop Error

This for-loop may attempt to iterate over a
non-iterable instance
of class
int
.
if sampled[sys]["find_atomic_polarizability"] > 0.0:
Expand All @@ -206,21 +206,28 @@ def compute_output_stats(
sys_matrix.append(torch.zeros((1, self.ntypes)))
# this gives the number of atoms of type itype in the system
sys_matrix[-1][0, itype] = type_mask.sum().item()
expanded_mask = type_mask.unsqueeze(-1).expand((*type_mask.shape, 9))
expanded_mask = type_mask.unsqueeze(-1).expand(
(*type_mask.shape, 9)
)
polar_bias.append(
torch.sum(
(sampled[sys]["atomic_polarizability"]*expanded_mask).reshape(-1,9),
(
sampled[sys]["atomic_polarizability"]
* expanded_mask
).reshape(-1, 9),
dim=0,
).reshape((1, 9))
)
else:
if (not sampled[sys]["find_polarizability"] > 0.0):
if not sampled[sys]["find_polarizability"] > 0.0:
continue
sys_matrix.append(torch.zeros((1, self.ntypes)))
for itype in range(self.ntypes):
type_mask = sampled[sys]["type"] == itype
sys_matrix[-1][0, itype] = type_mask.sum().item()
polar_bias.append(sampled[sys]["polarizability"].reshape((1, 9)))
polar_bias.append(
sampled[sys]["polarizability"].reshape((1, 9))
)
matrix, bias = (
torch.cat(sys_matrix, dim=0),
torch.cat(polar_bias, dim=0),
Expand All @@ -229,17 +236,13 @@ def compute_output_stats(
constant_matrix = []
for itype in range(self.ntypes):
constant_matrix.append(
torch.mean(torch.diagonal(atom_polar[itype].reshape(3,3)))
torch.mean(torch.diagonal(atom_polar[itype].reshape(3, 3)))
)

self.constant_matrix = torch.tensor(constant_matrix, device=env.DEVICE)
if stat_file_path is not None:
stat_file_path.save_numpy(self.constant_matrix.detach().cpu().numpy())





def forward(
self,
descriptor: torch.Tensor,
Expand Down

0 comments on commit 1ae79a9

Please sign in to comment.