From 1c18950d512a1e7649824c95a260c3d322d5f7fc Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Fri, 31 May 2024 17:08:35 +0800 Subject: [PATCH] fix: bugs in uts for polar and dipole fit (#3837) Fix following trivial bugs in dipole and polar fit uts: 1. `box` was not used in `extend_input_and_build_neighbor_list` (which means they were all tested in nopbc mode, if shifted coord is outside the box (sometimes) and normalized explicitly, results are not the same.) Input for fitting also used extended_atype instead of atype. (Only same when nopbc.) 2. Using of `mixed_types` is disordered, mismatched with descriptor or sometimes with nlist. Now only use `mixed_types`==False since the descriptor output is not in mixed types. ## Summary by CodeRabbit - **Tests** - Improved consistency in parameter handling for various test methods. - Updated `mixed_types` parameter to dynamically use `self.dd0.mixed_types()` across multiple test functions for better flexibility and accuracy. --------- Co-authored-by: Han Wang <92130845+wanghan-iapcm@users.noreply.github.com> Co-authored-by: Jinzhe Zeng --- source/tests/pt/model/test_dipole_fitting.py | 40 ++++++++------ .../pt/model/test_polarizability_fitting.py | 52 +++++++++++++------ 2 files changed, 61 insertions(+), 31 deletions(-) diff --git a/source/tests/pt/model/test_dipole_fitting.py b/source/tests/pt/model/test_dipole_fitting.py index fa4be9171c..db266c6c8b 100644 --- a/source/tests/pt/model/test_dipole_fitting.py +++ b/source/tests/pt/model/test_dipole_fitting.py @@ -73,8 +73,7 @@ def test_consistency( self.atype_ext[:, : self.nloc], dtype=int, device=env.DEVICE ) - for mixed_types, nfp, nap in itertools.product( - [True, False], + for nfp, nap in itertools.product( [0, 3], [0, 4], ): @@ -84,7 +83,7 @@ def test_consistency( embedding_width=self.dd0.get_dim_emb(), numb_fparam=nfp, numb_aparam=nap, - mixed_types=mixed_types, + mixed_types=self.dd0.mixed_types(), ).to(env.DEVICE) ft1 = DPDipoleFitting.deserialize(ft0.serialize()) ft2 = DipoleFittingNet.deserialize(ft1.serialize()) @@ -159,9 +158,10 @@ def test_rot(self): atype = self.atype.reshape(1, 5) rmat = torch.tensor(special_ortho_group.rvs(3), dtype=dtype, device=env.DEVICE) coord_rot = torch.matmul(self.coord, rmat) + # use larger cell to rotate only coord and shift to the center of cell + cell_rot = 10.0 * torch.eye(3, dtype=dtype, device=env.DEVICE) rng = np.random.default_rng() - for mixed_types, nfp, nap in itertools.product( - [True, False], + for nfp, nap in itertools.product( [0, 3], [0, 4], ): @@ -171,7 +171,7 @@ def test_rot(self): embedding_width=self.dd0.get_dim_emb(), numb_fparam=nfp, numb_aparam=nap, - mixed_types=mixed_types, + mixed_types=self.dd0.mixed_types(), ).to(env.DEVICE) if nfp > 0: ifp = torch.tensor( @@ -196,7 +196,12 @@ def test_rot(self): _, nlist, ) = extend_input_and_build_neighbor_list( - xyz + self.shift, atype, self.rcut, self.sel, not mixed_types + xyz + self.shift, + atype, + self.rcut, + self.sel, + self.dd0.mixed_types(), + box=cell_rot, ) rd0, gr0, _, _, _ = self.dd0( @@ -205,7 +210,7 @@ def test_rot(self): nlist, ) - ret0 = ft0(rd0, extended_atype, gr0, fparam=ifp, aparam=iap) + ret0 = ft0(rd0, atype, gr0, fparam=ifp, aparam=iap) res.append(ret0["dipole"]) np.testing.assert_allclose( @@ -220,7 +225,7 @@ def test_permu(self): embedding_width=self.dd0.get_dim_emb(), numb_fparam=0, numb_aparam=0, - mixed_types=False, + mixed_types=self.dd0.mixed_types(), ).to(env.DEVICE) res = [] for idx_perm in [[0, 1, 2, 3, 4], [1, 0, 4, 3, 2]]: @@ -231,7 +236,12 @@ def test_permu(self): _, nlist, ) = extend_input_and_build_neighbor_list( - coord[idx_perm], atype, self.rcut, self.sel, True + coord[idx_perm], + atype, + self.rcut, + self.sel, + self.dd0.mixed_types(), + box=self.cell, ) rd0, gr0, _, _, _ = self.dd0( @@ -240,7 +250,7 @@ def test_permu(self): nlist, ) - ret0 = ft0(rd0, extended_atype, gr0, fparam=0, aparam=0) + ret0 = ft0(rd0, atype, gr0, fparam=0, aparam=0) res.append(ret0["dipole"]) np.testing.assert_allclose( @@ -261,7 +271,7 @@ def test_trans(self): embedding_width=self.dd0.get_dim_emb(), numb_fparam=0, numb_aparam=0, - mixed_types=True, + mixed_types=self.dd0.mixed_types(), ).to(env.DEVICE) res = [] for xyz in [self.coord, coord_s]: @@ -271,7 +281,7 @@ def test_trans(self): _, nlist, ) = extend_input_and_build_neighbor_list( - xyz, atype, self.rcut, self.sel, False + xyz, atype, self.rcut, self.sel, self.dd0.mixed_types(), box=self.cell ) rd0, gr0, _, _, _ = self.dd0( @@ -280,7 +290,7 @@ def test_trans(self): nlist, ) - ret0 = ft0(rd0, extended_atype, gr0, fparam=0, aparam=0) + ret0 = ft0(rd0, atype, gr0, fparam=0, aparam=0) res.append(ret0["dipole"]) np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1])) @@ -305,7 +315,7 @@ def setUp(self): embedding_width=self.dd0.get_dim_emb(), numb_fparam=0, numb_aparam=0, - mixed_types=True, + mixed_types=self.dd0.mixed_types(), ).to(env.DEVICE) self.type_mapping = ["O", "H", "B"] self.model = DipoleModel(self.dd0, self.ft0, self.type_mapping) diff --git a/source/tests/pt/model/test_polarizability_fitting.py b/source/tests/pt/model/test_polarizability_fitting.py index 84d6bd91ab..6826807a45 100644 --- a/source/tests/pt/model/test_polarizability_fitting.py +++ b/source/tests/pt/model/test_polarizability_fitting.py @@ -59,8 +59,7 @@ def test_consistency( self.atype_ext[:, : self.nloc], dtype=int, device=env.DEVICE ) - for mixed_types, nfp, nap, fit_diag, scale in itertools.product( - [True, False], + for nfp, nap, fit_diag, scale in itertools.product( [0, 3], [0, 4], [True, False], @@ -72,7 +71,7 @@ def test_consistency( embedding_width=self.dd0.get_dim_emb(), numb_fparam=nfp, numb_aparam=nap, - mixed_types=mixed_types, + mixed_types=self.dd0.mixed_types(), fit_diag=fit_diag, scale=scale, ).to(env.DEVICE) @@ -166,9 +165,10 @@ def test_rot(self): atype = self.atype.reshape(1, 5) rmat = torch.tensor(special_ortho_group.rvs(3), dtype=dtype, device=env.DEVICE) coord_rot = torch.matmul(self.coord, rmat) + # use larger cell to rotate only coord and shift to the center of cell + cell_rot = 10.0 * torch.eye(3, dtype=dtype, device=env.DEVICE) - for mixed_types, nfp, nap, fit_diag, scale in itertools.product( - [True, False], + for nfp, nap, fit_diag, scale in itertools.product( [0, 3], [0, 4], [True, False], @@ -180,7 +180,7 @@ def test_rot(self): embedding_width=self.dd0.get_dim_emb(), numb_fparam=nfp, numb_aparam=nap, - mixed_types=True, + mixed_types=self.dd0.mixed_types(), fit_diag=fit_diag, scale=scale, ).to(env.DEVICE) @@ -207,7 +207,12 @@ def test_rot(self): _, nlist, ) = extend_input_and_build_neighbor_list( - xyz + self.shift, atype, self.rcut, self.sel, mixed_types + xyz + self.shift, + atype, + self.rcut, + self.sel, + self.dd0.mixed_types(), + box=cell_rot, ) rd0, gr0, _, _, _ = self.dd0( @@ -216,7 +221,7 @@ def test_rot(self): nlist, ) - ret0 = ft0(rd0, extended_atype, gr0, fparam=ifp, aparam=iap) + ret0 = ft0(rd0, atype, gr0, fparam=ifp, aparam=iap) res.append(ret0["polarizability"]) np.testing.assert_allclose( to_numpy_array(res[1]), @@ -237,7 +242,7 @@ def test_permu(self): embedding_width=self.dd0.get_dim_emb(), numb_fparam=0, numb_aparam=0, - mixed_types=True, + mixed_types=self.dd0.mixed_types(), fit_diag=fit_diag, scale=scale, ).to(env.DEVICE) @@ -250,7 +255,12 @@ def test_permu(self): _, nlist, ) = extend_input_and_build_neighbor_list( - coord[idx_perm], atype, self.rcut, self.sel, False + coord[idx_perm], + atype, + self.rcut, + self.sel, + self.dd0.mixed_types(), + box=self.cell, ) rd0, gr0, _, _, _ = self.dd0( @@ -259,7 +269,7 @@ def test_permu(self): nlist, ) - ret0 = ft0(rd0, extended_atype, gr0, fparam=None, aparam=None) + ret0 = ft0(rd0, atype, gr0, fparam=None, aparam=None) res.append(ret0["polarizability"]) np.testing.assert_allclose( @@ -269,7 +279,12 @@ def test_permu(self): def test_trans(self): atype = self.atype.reshape(1, 5) - coord_s = self.coord + self.shift + coord_s = torch.matmul( + torch.remainder( + torch.matmul(self.coord + self.shift, torch.linalg.inv(self.cell)), 1.0 + ), + self.cell, + ) for fit_diag, scale in itertools.product([True, False], [None, self.scale]): ft0 = PolarFittingNet( self.nt, @@ -277,7 +292,7 @@ def test_trans(self): embedding_width=self.dd0.get_dim_emb(), numb_fparam=0, numb_aparam=0, - mixed_types=True, + mixed_types=self.dd0.mixed_types(), fit_diag=fit_diag, scale=scale, ).to(env.DEVICE) @@ -289,7 +304,12 @@ def test_trans(self): _, nlist, ) = extend_input_and_build_neighbor_list( - xyz, atype, self.rcut, self.sel, False + xyz, + atype, + self.rcut, + self.sel, + self.dd0.mixed_types(), + box=self.cell, ) rd0, gr0, _, _, _ = self.dd0( @@ -298,7 +318,7 @@ def test_trans(self): nlist, ) - ret0 = ft0(rd0, extended_atype, gr0, fparam=0, aparam=0) + ret0 = ft0(rd0, atype, gr0, fparam=0, aparam=0) res.append(ret0["polarizability"]) np.testing.assert_allclose(to_numpy_array(res[0]), to_numpy_array(res[1])) @@ -323,7 +343,7 @@ def setUp(self): embedding_width=self.dd0.get_dim_emb(), numb_fparam=0, numb_aparam=0, - mixed_types=True, + mixed_types=self.dd0.mixed_types(), ).to(env.DEVICE) self.type_mapping = ["O", "H", "B"] self.model = PolarModel(self.dd0, self.ft0, self.type_mapping)