From f416c994fc4b267e1e7c2d8a1c59102516da879c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 8 Mar 2024 02:11:22 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt/model/task/polarizability.py | 15 ++++++++++----- source/tests/pt/model/test_polar_stat.py | 2 +- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/deepmd/pt/model/task/polarizability.py b/deepmd/pt/model/task/polarizability.py index cdcb6e8d6c..5a81789284 100644 --- a/deepmd/pt/model/task/polarizability.py +++ b/deepmd/pt/model/task/polarizability.py @@ -119,7 +119,9 @@ def __init__( self.scale, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE ).view(ntypes, 1) self.shift_diag = shift_diag - self.constant_matrix = torch.zeros(ntypes, dtype = env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE) + self.constant_matrix = torch.zeros( + ntypes, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE + ) super().__init__( var_name=kwargs.pop("var_name", "polar"), ntypes=ntypes, @@ -215,7 +217,9 @@ def compute_output_stats( else: if not sampled[sys]["find_polarizability"] > 0.0: continue - sys_type_count = np.zeros((nframs, self.ntypes), dtype=env.GLOBAL_NP_FLOAT_PRECISION) + sys_type_count = np.zeros( + (nframs, self.ntypes), dtype=env.GLOBAL_NP_FLOAT_PRECISION + ) for itype in range(self.ntypes): type_mask = sampled[sys]["type"] == itype sys_type_count[:, itype] = type_mask.sum(dim=1).numpy( @@ -227,8 +231,10 @@ def compute_output_stats( sys_atom_polar = compute_stats_from_redu( sys_bias_redu, sys_type_count, rcond=self.rcond )[0] - cur_constant_matrix = np.zeros(self.ntypes, dtype= env.GLOBAL_NP_FLOAT_PRECISION) - + cur_constant_matrix = np.zeros( + self.ntypes, dtype=env.GLOBAL_NP_FLOAT_PRECISION + ) + for itype in range(self.ntypes): cur_constant_matrix[itype] = np.mean( np.diagonal(sys_atom_polar[itype].reshape(3, 3)) @@ -242,7 +248,6 @@ def compute_output_stats( stat_file_path.save_numpy(self.constant_matrix) self.constant_matrix = torch.tensor(constant_matrix, device=env.DEVICE) - def forward( self, descriptor: torch.Tensor, diff --git a/source/tests/pt/model/test_polar_stat.py b/source/tests/pt/model/test_polar_stat.py index 8f8ac78bc1..ca3b037011 100644 --- a/source/tests/pt/model/test_polar_stat.py +++ b/source/tests/pt/model/test_polar_stat.py @@ -22,7 +22,7 @@ class TestConsistency(unittest.TestCase): def setUp(self) -> None: types = torch.randint(0, 4, (1, 5), device=env.DEVICE) types = torch.cat((types, types, types), dim=0) - types[:,-1] = 3 + types[:, -1] = 3 ntypes = 4 atomic_polarizability = torch.rand((3, 5, 9), device=env.DEVICE) polarizability = torch.rand((3, 9), device=env.DEVICE)