diff --git a/source/tests/pt/model/test_polar_stat.py b/source/tests/pt/model/test_polar_stat.py index 401c368d35..898b51e514 100644 --- a/source/tests/pt/model/test_polar_stat.py +++ b/source/tests/pt/model/test_polar_stat.py @@ -1,29 +1,41 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later import unittest + import numpy as np import torch + +from deepmd.pt.model.task.polarizability import ( + PolarFittingNet, +) from deepmd.pt.utils import ( env, ) -from deepmd.pt.model.task.polarizability import PolarFittingNet -from deepmd.tf.fit.polar import PolarFittingSeA from deepmd.pt.utils.utils import ( to_numpy_array, ) +from deepmd.tf.fit.polar import ( + PolarFittingSeA, +) + class TestConsistency(unittest.TestCase): def setUp(self) -> None: - types = torch.randint(0,4,(1,5), device=env.DEVICE) - types = torch.cat((types,types,types), dim=0) + types = torch.randint(0, 4, (1, 5), device=env.DEVICE) + types = torch.cat((types, types, types), dim=0) ntypes = 4 - atomic_polarizability = torch.rand(3,5,9) - polarizability = torch.rand(3,9) + atomic_polarizability = torch.rand(3, 5, 9) + polarizability = torch.rand(3, 9) find_polarizability = torch.rand(1) find_atomic_polarizability = torch.rand(1) - self.sampled = [{"type": types, - "find_atomic_polarizability": find_atomic_polarizability, - "atomic_polarizability": atomic_polarizability, - "polarizability": polarizability, - "find_polarizability": find_polarizability}] + self.sampled = [ + { + "type": types, + "find_atomic_polarizability": find_atomic_polarizability, + "atomic_polarizability": atomic_polarizability, + "polarizability": polarizability, + "find_polarizability": find_polarizability, + } + ] self.all_stat = {k: [v.numpy()] for d in self.sampled for k, v in d.items()} self.tfpolar = PolarFittingSeA( ntypes=ntypes, @@ -36,7 +48,6 @@ def setUp(self) -> None: dim_descrpt=1, embedding_width=1, ) - def test_atomic_consistency(self): self.tfpolar.compute_output_stats(self.all_stat) @@ -48,12 +59,16 @@ def test_atomic_consistency(self): def test_global_consistency(self): self.sampled[0]["find_atomic_polarizability"] = -1 - self.sampled[0]["polarizability"] = self.sampled[0]["atomic_polarizability"].sum(dim=1) + self.sampled[0]["polarizability"] = self.sampled[0][ + "atomic_polarizability" + ].sum(dim=1) self.all_stat["find_atomic_polarizability"] = [-1] - self.all_stat["polarizability"] = list(self.all_stat["atomic_polarizability"][0].sum(axis=1)) + self.all_stat["polarizability"] = list( + self.all_stat["atomic_polarizability"][0].sum(axis=1) + ) self.tfpolar.compute_output_stats(self.all_stat) tfbias = self.tfpolar.constant_matrix self.ptpolar.compute_output_stats(self.sampled) ptbias = self.ptpolar.constant_matrix print(tfbias, to_numpy_array(ptbias)) - np.testing.assert_allclose(tfbias, to_numpy_array(ptbias)) \ No newline at end of file + np.testing.assert_allclose(tfbias, to_numpy_array(ptbias))