diff --git a/deepmd/pt/model/task/polarizability.py b/deepmd/pt/model/task/polarizability.py index 94073cb217..2aef51569b 100644 --- a/deepmd/pt/model/task/polarizability.py +++ b/deepmd/pt/model/task/polarizability.py @@ -226,7 +226,12 @@ def compute_output_stats( cur_constant_matrix = np.zeros(self.ntypes) if sys_atom_polar.shape[0] < self.ntypes: # pad zeros in case atype.max() + 1 != ntypes - sys_atom_polar = np.concatenate([sys_atom_polar, np.zeros((self.ntypes - sys_atom_polar.shape[0], 9))]) + sys_atom_polar = np.concatenate( + [ + sys_atom_polar, + np.zeros((self.ntypes - sys_atom_polar.shape[0], 9)), + ] + ) for itype in range(self.ntypes): cur_constant_matrix[itype] = np.mean( np.diagonal(sys_atom_polar[itype].reshape(3, 3)) @@ -279,11 +284,7 @@ def forward( eye = torch.eye(3, device=env.DEVICE) eye = eye.repeat(nframes, nloc, 1, 1) # (nframes, nloc, 3, 3) - bias = bias * eye - out = ( - out - + bias - * self.scale[atype] - ) + bias = bias * eye + out = out + bias * self.scale[atype] return {self.var_name: out.to(env.GLOBAL_PT_FLOAT_PRECISION)} diff --git a/deepmd/tf/fit/polar.py b/deepmd/tf/fit/polar.py index 7de241f1ce..08b0740459 100644 --- a/deepmd/tf/fit/polar.py +++ b/deepmd/tf/fit/polar.py @@ -202,8 +202,9 @@ def compute_output_stats(self, all_stat): polar_bias.append( np.sum( - all_stat["atomic_polarizability"][ss][:, index_lis, :]/nframes, - axis=(0,1), + all_stat["atomic_polarizability"][ss][:, index_lis, :] + / nframes, + axis=(0, 1), ).reshape((1, 9)) ) else: # No atomic polar in this system, so it should have global polar @@ -227,7 +228,9 @@ def compute_output_stats(self, all_stat): sys_matrix[-1][0, itype] = len(index_lis) # add polar_bias - polar_bias.append(np.mean(all_stat["polarizability"][ss],axis=0).reshape((1, 9))) + polar_bias.append( + np.mean(all_stat["polarizability"][ss], axis=0).reshape((1, 9)) + ) matrix, bias = ( np.concatenate(sys_matrix, axis=0), diff --git a/source/tests/consistent/model/test_polar_stat.py b/source/tests/consistent/model/test_polar_stat.py index b5d7ec3b93..bfe11ad709 100644 --- a/source/tests/consistent/model/test_polar_stat.py +++ b/source/tests/consistent/model/test_polar_stat.py @@ -1,29 +1,41 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later import unittest + import numpy as np import torch + +from deepmd.pt.model.task.polarizability import ( + PolarFittingNet, +) from deepmd.pt.utils import ( env, ) -from deepmd.pt.model.task.polarizability import PolarFittingNet -from deepmd.tf.fit.polar import PolarFittingSeA from deepmd.pt.utils.utils import ( to_numpy_array, ) +from deepmd.tf.fit.polar import ( + PolarFittingSeA, +) + class TestConsistency(unittest.TestCase): def setUp(self) -> None: - types = torch.randint(0,4,(1,5), device=env.DEVICE) - types = torch.cat((types,types,types), dim=0) + types = torch.randint(0, 4, (1, 5), device=env.DEVICE) + types = torch.cat((types, types, types), dim=0) ntypes = 4 - atomic_polarizability = torch.rand(3,5,9) - polarizability = torch.rand(3,9) + atomic_polarizability = torch.rand(3, 5, 9) + polarizability = torch.rand(3, 9) find_polarizability = torch.rand(1) find_atomic_polarizability = torch.rand(1) - self.sampled = [{"type": types, - "find_atomic_polarizability": find_atomic_polarizability, - "atomic_polarizability": atomic_polarizability, - "polarizability": polarizability, - "find_polarizability": find_polarizability}] + self.sampled = [ + { + "type": types, + "find_atomic_polarizability": find_atomic_polarizability, + "atomic_polarizability": atomic_polarizability, + "polarizability": polarizability, + "find_polarizability": find_polarizability, + } + ] self.all_stat = {k: [v.numpy()] for d in self.sampled for k, v in d.items()} self.tfpolar = PolarFittingSeA( ntypes=ntypes, @@ -36,7 +48,6 @@ def setUp(self) -> None: dim_descrpt=1, embedding_width=1, ) - def test_atomic_consistency(self): self.tfpolar.compute_output_stats(self.all_stat) @@ -48,12 +59,16 @@ def test_atomic_consistency(self): def test_global_consistency(self): self.sampled[0]["find_atomic_polarizability"] = -1 - self.sampled[0]["polarizability"] = self.sampled[0]["atomic_polarizability"].sum(dim=1) + self.sampled[0]["polarizability"] = self.sampled[0][ + "atomic_polarizability" + ].sum(dim=1) self.all_stat["find_atomic_polarizability"] = [-1] - self.all_stat["polarizability"] = [(self.all_stat["atomic_polarizability"][0]).sum(axis=1)] + self.all_stat["polarizability"] = [ + (self.all_stat["atomic_polarizability"][0]).sum(axis=1) + ] self.tfpolar.compute_output_stats(self.all_stat) tfbias = self.tfpolar.constant_matrix self.ptpolar.compute_output_stats(self.sampled) ptbias = self.ptpolar.constant_matrix print(tfbias, to_numpy_array(ptbias)) - np.testing.assert_allclose(tfbias, to_numpy_array(ptbias)) \ No newline at end of file + np.testing.assert_allclose(tfbias, to_numpy_array(ptbias))