From 6f0bda58f8fb74151217e7fcb5f8228bf2d2d5e7 Mon Sep 17 00:00:00 2001 From: Han Wang <92130845+wanghan-iapcm@users.noreply.github.com> Date: Mon, 8 Apr 2024 16:32:43 +0800 Subject: [PATCH] test: ut for the smoothness when pair exclusion presents (#3650) Co-authored-by: Han Wang --- source/tests/pt/model/test_smooth.py | 29 ++++++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/source/tests/pt/model/test_smooth.py b/source/tests/pt/model/test_smooth.py index 4f5be912cf..1a75caebdc 100644 --- a/source/tests/pt/model/test_smooth.py +++ b/source/tests/pt/model/test_smooth.py @@ -39,7 +39,9 @@ def test( natoms = 10 cell = 8.6 * torch.eye(3, dtype=dtype, device=env.DEVICE) - atype = torch.randint(0, 3, [natoms], device=env.DEVICE) + atype0 = torch.arange(3, dtype=dtype, device=env.DEVICE) + atype1 = torch.randint(0, 3, [natoms - 3], device=env.DEVICE) + atype = torch.cat([atype0, atype1]).view([natoms]) coord0 = torch.tensor( [ 0.0, @@ -148,7 +150,6 @@ def setUp(self): self.epsilon, self.aprec = None, None -# @unittest.skip("dpa-1 not smooth at the moment") class TestEnergyModelDPA1(unittest.TestCase, SmoothTest): def setUp(self): model_params = copy.deepcopy(model_dpa1) @@ -160,6 +161,30 @@ def setUp(self): self.aprec = 1e-5 +class TestEnergyModelDPA1Excl1(unittest.TestCase, SmoothTest): + def setUp(self): + model_params = copy.deepcopy(model_dpa1) + model_params["pair_exclude_types"] = [[0, 1]] + self.type_split = True + self.model = get_model(model_params).to(env.DEVICE) + # less degree of smoothness, + # error can be systematically removed by reducing epsilon + self.epsilon = 1e-5 + self.aprec = 1e-5 + + +class TestEnergyModelDPA1Excl12(unittest.TestCase, SmoothTest): + def setUp(self): + model_params = copy.deepcopy(model_dpa1) + model_params["pair_exclude_types"] = [[0, 1], [0, 2]] + self.type_split = True + self.model = get_model(model_params).to(env.DEVICE) + # less degree of smoothness, + # error can be systematically removed by reducing epsilon + self.epsilon = 1e-5 + self.aprec = 1e-5 + + class TestEnergyModelDPA2(unittest.TestCase, SmoothTest): def setUp(self): model_params = copy.deepcopy(model_dpa2)