From 89dff77b116e76eaa711545864f5d4ea67ea5c80 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sat, 2 Mar 2024 23:07:04 +0800 Subject: [PATCH] ut: add null test --- source/tests/pt/model/test_null_input.py | 145 +++++++++++++++++++++++ 1 file changed, 145 insertions(+) create mode 100644 source/tests/pt/model/test_null_input.py diff --git a/source/tests/pt/model/test_null_input.py b/source/tests/pt/model/test_null_input.py new file mode 100644 index 0000000000..4d52d4e27a --- /dev/null +++ b/source/tests/pt/model/test_null_input.py @@ -0,0 +1,145 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import copy +import unittest + +import numpy as np +import torch + +from deepmd.pt.infer.deep_eval import ( + eval_model, +) +from deepmd.pt.model.model import ( + get_model, + get_zbl_model, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, +) + +from .test_permutation import ( + model_dpa1, + model_dpa2, + model_hybrid, + model_se_e2_a, + model_zbl, +) + +dtype = torch.float64 + + +class NullTest: + def test_nloc_1( + self, + ): + natoms = 1 + # torch.manual_seed(1000) + cell = torch.rand([3, 3], dtype=dtype, device=env.DEVICE) + # large box to exclude images + cell = (cell + cell.T) + 100.0 * torch.eye(3, device=env.DEVICE) + coord = torch.rand([natoms, 3], dtype=dtype, device=env.DEVICE) + atype = torch.Tensor([0], device=env.DEVICE) + e0, f0, v0 = eval_model( + self.model, coord.unsqueeze(0), cell.unsqueeze(0), atype + ) + ret0 = { + "energy": e0.squeeze(0), + "force": f0.squeeze(0), + "virial": v0.squeeze(0), + } + prec = 1e-10 + expect_e_shape = [1] + expect_f = torch.zeros([natoms, 3], dtype=dtype, device=env.DEVICE) + expect_v = torch.zeros([9], dtype=dtype, device=env.DEVICE) + self.assertEqual(list(ret0["energy"].shape), expect_e_shape) + self.assertFalse(np.isnan(to_numpy_array(ret0["energy"])[0])) + torch.testing.assert_close(ret0["force"], expect_f, rtol=prec, atol=prec) + if not hasattr(self, "test_virial") or self.test_virial: + torch.testing.assert_close(ret0["virial"], expect_v, rtol=prec, atol=prec) + + def test_nloc_2_far( + self, + ): + natoms = 2 + cell = torch.rand([3, 3], dtype=dtype, device=env.DEVICE) + # large box to exclude images + cell = (cell + cell.T) + 3000.0 * torch.eye(3, device=env.DEVICE) + coord = torch.rand([1, 3], dtype=dtype, device=env.DEVICE) + # 2 far-away atoms + coord = torch.cat([coord, coord + 100.0], dim=0) + atype = torch.Tensor([0, 2], device=env.DEVICE) + e0, f0, v0 = eval_model( + self.model, coord.unsqueeze(0), cell.unsqueeze(0), atype + ) + ret0 = { + "energy": e0.squeeze(0), + "force": f0.squeeze(0), + "virial": v0.squeeze(0), + } + prec = 1e-10 + expect_e_shape = [1] + expect_f = torch.zeros([natoms, 3], dtype=dtype, device=env.DEVICE) + expect_v = torch.zeros([9], dtype=dtype, device=env.DEVICE) + self.assertEqual(list(ret0["energy"].shape), expect_e_shape) + self.assertFalse(np.isnan(to_numpy_array(ret0["energy"])[0])) + torch.testing.assert_close(ret0["force"], expect_f, rtol=prec, atol=prec) + if not hasattr(self, "test_virial") or self.test_virial: + torch.testing.assert_close(ret0["virial"], expect_v, rtol=prec, atol=prec) + + +class TestEnergyModelSeA(unittest.TestCase, NullTest): + def setUp(self): + model_params = copy.deepcopy(model_se_e2_a) + self.type_split = False + self.model = get_model(model_params).to(env.DEVICE) + + +class TestEnergyModelDPA1(unittest.TestCase, NullTest): + def setUp(self): + model_params = copy.deepcopy(model_dpa1) + self.type_split = True + self.model = get_model(model_params).to(env.DEVICE) + + +class TestEnergyModelDPA2(unittest.TestCase, NullTest): + def setUp(self): + model_params = copy.deepcopy(model_dpa2) + self.type_split = True + self.model = get_model(model_params).to(env.DEVICE) + + +class TestForceModelDPA2(unittest.TestCase, NullTest): + def setUp(self): + model_params = copy.deepcopy(model_dpa2) + model_params["fitting_net"]["type"] = "direct_force_ener" + self.type_split = True + self.test_virial = False + self.model = get_model(model_params).to(env.DEVICE) + + +@unittest.skip("hybrid not supported at the moment") +class TestEnergyModelHybrid(unittest.TestCase, NullTest): + def setUp(self): + model_params = copy.deepcopy(model_hybrid) + self.type_split = True + self.model = get_model(model_params).to(env.DEVICE) + + +@unittest.skip("hybrid not supported at the moment") +class TestForceModelHybrid(unittest.TestCase, NullTest): + def setUp(self): + model_params = copy.deepcopy(model_hybrid) + model_params["fitting_net"]["type"] = "direct_force_ener" + self.type_split = True + self.test_virial = False + self.model = get_model(model_params).to(env.DEVICE) + + +@unittest.skip("FAILED at the moment") +class TestEnergyModelZBL(unittest.TestCase, NullTest): + def setUp(self): + model_params = copy.deepcopy(model_zbl) + self.type_split = False + self.model = get_zbl_model(model_params).to(env.DEVICE)