Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Mar 8, 2024
1 parent cd3ac47 commit f416c99
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
15 changes: 10 additions & 5 deletions deepmd/pt/model/task/polarizability.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,9 @@ def __init__(
self.scale, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
).view(ntypes, 1)
self.shift_diag = shift_diag
self.constant_matrix = torch.zeros(ntypes, dtype = env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)
self.constant_matrix = torch.zeros(
ntypes, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)
super().__init__(
var_name=kwargs.pop("var_name", "polar"),
ntypes=ntypes,
Expand Down Expand Up @@ -215,7 +217,9 @@ def compute_output_stats(
else:
if not sampled[sys]["find_polarizability"] > 0.0:
continue

Check warning on line 219 in deepmd/pt/model/task/polarizability.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/polarizability.py#L219

Added line #L219 was not covered by tests
sys_type_count = np.zeros((nframs, self.ntypes), dtype=env.GLOBAL_NP_FLOAT_PRECISION)
sys_type_count = np.zeros(
(nframs, self.ntypes), dtype=env.GLOBAL_NP_FLOAT_PRECISION
)
for itype in range(self.ntypes):
type_mask = sampled[sys]["type"] == itype
sys_type_count[:, itype] = type_mask.sum(dim=1).numpy(
Expand All @@ -227,8 +231,10 @@ def compute_output_stats(
sys_atom_polar = compute_stats_from_redu(
sys_bias_redu, sys_type_count, rcond=self.rcond
)[0]
cur_constant_matrix = np.zeros(self.ntypes, dtype= env.GLOBAL_NP_FLOAT_PRECISION)

cur_constant_matrix = np.zeros(
self.ntypes, dtype=env.GLOBAL_NP_FLOAT_PRECISION
)

for itype in range(self.ntypes):
cur_constant_matrix[itype] = np.mean(
np.diagonal(sys_atom_polar[itype].reshape(3, 3))
Expand All @@ -242,7 +248,6 @@ def compute_output_stats(
stat_file_path.save_numpy(self.constant_matrix)

Check warning on line 248 in deepmd/pt/model/task/polarizability.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/task/polarizability.py#L248

Added line #L248 was not covered by tests
self.constant_matrix = torch.tensor(constant_matrix, device=env.DEVICE)


def forward(
self,
descriptor: torch.Tensor,
Expand Down
2 changes: 1 addition & 1 deletion source/tests/pt/model/test_polar_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class TestConsistency(unittest.TestCase):
def setUp(self) -> None:
types = torch.randint(0, 4, (1, 5), device=env.DEVICE)
types = torch.cat((types, types, types), dim=0)
types[:,-1] = 3
types[:, -1] = 3
ntypes = 4
atomic_polarizability = torch.rand((3, 5, 9), device=env.DEVICE)
polarizability = torch.rand((3, 9), device=env.DEVICE)
Expand Down

0 comments on commit f416c99

Please sign in to comment.