diff --git a/deepmd/pt/model/model/pair_tab.py b/deepmd/pt/model/model/pair_tab.py index 6f0782289a..4701f26e04 100644 --- a/deepmd/pt/model/model/pair_tab.py +++ b/deepmd/pt/model/model/pair_tab.py @@ -91,6 +91,14 @@ def distinguish_types(self) -> bool: # to match DPA1 and DPA2. return False + def serialize(self) -> dict: + # place holder, implemantated in future PR + raise NotImplementedError + + def deserialize(cls): + # place holder, implemantated in future PR + raise NotImplementedError + def forward_atomic( self, extended_coord, diff --git a/source/tests/pt/test_rotation.py b/source/tests/pt/test_rotation.py index 58ec80e0d6..63d5a0b563 100644 --- a/source/tests/pt/test_rotation.py +++ b/source/tests/pt/test_rotation.py @@ -111,22 +111,18 @@ def test_rotation(self): result1 = self.model(**get_data(self.origin_batch)) result2 = self.model(**get_data(self.rotated_batch)) rotation = torch.from_numpy(self.rotation).to(env.DEVICE) - self.assertTrue(result1["energy"] == result2["energy"]) + self.assertAlmostEqual(result1["energy"], result2["energy"]) if "force" in result1: - self.assertTrue( - torch.allclose( - result2["force"][0], torch.matmul(rotation, result1["force"][0].T).T - ) + torch.testing.assert_close( + result2["force"][0], torch.matmul(rotation, result1["force"][0].T).T ) if "virial" in result1: - self.assertTrue( - torch.allclose( - result2["virial"][0].view([3, 3]), - torch.matmul( - torch.matmul(rotation, result1["virial"][0].view([3, 3]).T), - rotation.T, - ), - ) + torch.testing.assert_close( + result2["virial"][0].view([3, 3]), + torch.matmul( + torch.matmul(rotation, result1["virial"][0].view([3, 3]).T), + rotation.T, + ), )