diff --git a/deepmd/dpmodel/atomic_model/base_atomic_model.py b/deepmd/dpmodel/atomic_model/base_atomic_model.py index 990847c1de..42d1e67138 100644 --- a/deepmd/dpmodel/atomic_model/base_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/base_atomic_model.py @@ -56,22 +56,19 @@ def reinit_pair_exclude( def atomic_output_def(self) -> FittingOutputDef: old_def = self.fitting_output_def() - if self.atom_excl is None: - return old_def - else: - old_list = list(old_def.get_data().values()) - return FittingOutputDef( - old_list # noqa:RUF005 - + [ - OutputVariableDef( - name="mask", - shape=[1], - reduciable=False, - r_differentiable=False, - c_differentiable=False, - ) - ] - ) + old_list = list(old_def.get_data().values()) + return FittingOutputDef( + old_list # noqa:RUF005 + + [ + OutputVariableDef( + name="mask", + shape=[1], + reduciable=False, + r_differentiable=False, + c_differentiable=False, + ) + ] + ) def forward_common_atomic( self, @@ -82,6 +79,37 @@ def forward_common_atomic( fparam: Optional[np.ndarray] = None, aparam: Optional[np.ndarray] = None, ) -> Dict[str, np.ndarray]: + """Common interface for atomic inference. + + This method accept extended coordinates, extended atom typs, neighbor list, + and predict the atomic contribution of the fit property. + + Parameters + ---------- + extended_coord + extended coodinates, shape: nf x (nall x 3) + extended_atype + extended atom typs, shape: nf x nall + for a type < 0 indicating the atomic is virtual. + nlist + neighbor list, shape: nf x nloc x nsel + mapping + extended to local index mapping, shape: nf x nall + fparam + frame parameters, shape: nf x dim_fparam + aparam + atomic parameter, shape: nf x nloc x dim_aparam + + Returns + ------- + ret_dict + dict of output atomic properties. + should implement the definition of `fitting_output_def`. + ret_dict["mask"] of shape nf x nloc will be provided. + ret_dict["mask"][ff,ii] == 1 indicating the ii-th atom of the ff-th frame is real. + ret_dict["mask"][ff,ii] == 0 indicating the ii-th atom of the ff-th frame is virtual. + + """ _, nloc, _ = nlist.shape atype = extended_atype[:, :nloc] if self.pair_excl is not None: @@ -89,24 +117,28 @@ def forward_common_atomic( # exclude neighbors in the nlist nlist = np.where(pair_mask == 1, nlist, -1) + ext_atom_mask = self.make_atom_mask(extended_atype) ret_dict = self.forward_atomic( extended_coord, - extended_atype, + np.where(ext_atom_mask, extended_atype, 0), nlist, mapping=mapping, fparam=fparam, aparam=aparam, ) + # nf x nloc + atom_mask = ext_atom_mask[:, :nloc].astype(np.int32) if self.atom_excl is not None: - atom_mask = self.atom_excl.build_type_exclude_mask(atype) - for kk in ret_dict.keys(): - out_shape = ret_dict[kk].shape - ret_dict[kk] = ( - ret_dict[kk].reshape([out_shape[0], out_shape[1], -1]) - * atom_mask[:, :, None] - ).reshape(out_shape) - ret_dict["mask"] = atom_mask + atom_mask *= self.atom_excl.build_type_exclude_mask(atype) + + for kk in ret_dict.keys(): + out_shape = ret_dict[kk].shape + ret_dict[kk] = ( + ret_dict[kk].reshape([out_shape[0], out_shape[1], -1]) + * atom_mask[:, :, None] + ).reshape(out_shape) + ret_dict["mask"] = atom_mask return ret_dict diff --git a/deepmd/dpmodel/atomic_model/make_base_atomic_model.py b/deepmd/dpmodel/atomic_model/make_base_atomic_model.py index dfbb4e435a..936c2b0943 100644 --- a/deepmd/dpmodel/atomic_model/make_base_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/make_base_atomic_model.py @@ -136,6 +136,28 @@ def serialize(self) -> dict: def deserialize(cls, data: dict): pass + def make_atom_mask( + self, + atype: t_tensor, + ) -> t_tensor: + """The atoms with type < 0 are treated as virutal atoms, + which serves as place-holders for multi-frame calculations + with different number of atoms in different frames. + + Parameters + ---------- + atype + Atom types. >= 0 for real atoms <0 for virtual atoms. + + Returns + ------- + mask + True for real atoms and False for virutal atoms. + + """ + # supposed to be supported by all backends + return atype >= 0 + def do_grad_r( self, var_name: Optional[str] = None, diff --git a/deepmd/dpmodel/utils/nlist.py b/deepmd/dpmodel/utils/nlist.py index e5631bf2e3..ca8b18023b 100644 --- a/deepmd/dpmodel/utils/nlist.py +++ b/deepmd/dpmodel/utils/nlist.py @@ -15,7 +15,7 @@ ## translated from torch implemantation by chatgpt def build_neighbor_list( - coord1: np.ndarray, + coord: np.ndarray, atype: np.ndarray, nloc: int, rcut: float, @@ -26,10 +26,11 @@ def build_neighbor_list( Parameters ---------- - coord1 : np.ndarray + coord : np.ndarray exptended coordinates of shape [batch_size, nall x 3] atype : np.ndarray extended atomic types of shape [batch_size, nall] + type < 0 the atom is treat as virtual atoms. nloc : int number of local atoms. rcut : float @@ -54,11 +55,20 @@ def build_neighbor_list( if distinguish_types==True and we have two types |---- nsel[0] -----| |---- nsel[1] -----| xx xx xx xx -1 -1 -1 xx xx xx -1 -1 -1 -1 + For virtual atoms all neighboring positions are filled with -1. """ - batch_size = coord1.shape[0] - coord1 = coord1.reshape(batch_size, -1) - nall = coord1.shape[1] // 3 + batch_size = coord.shape[0] + coord = coord.reshape(batch_size, -1) + nall = coord.shape[1] // 3 + # fill virtual atoms with large coords so they are not neighbors of any + # real atom. + xmax = np.max(coord) + 2.0 * rcut + # nf x nall + is_vir = atype < 0 + coord1 = np.where(is_vir[:, :, None], xmax, coord.reshape(-1, nall, 3)).reshape( + -1, nall * 3 + ) if isinstance(sel, int): sel = [sel] nsel = sum(sel) @@ -88,7 +98,7 @@ def build_neighbor_list( axis=-1, ) assert list(nlist.shape) == [batch_size, nloc, nsel] - nlist = np.where((rr > rcut), -1, nlist) + nlist = np.where(np.logical_or((rr > rcut), is_vir[:, :nloc, None]), -1, nlist) if distinguish_types: return nlist_distinguish_types(nlist, atype, sel) diff --git a/deepmd/pt/model/atomic_model/base_atomic_model.py b/deepmd/pt/model/atomic_model/base_atomic_model.py index d3a1cfb459..c921538203 100644 --- a/deepmd/pt/model/atomic_model/base_atomic_model.py +++ b/deepmd/pt/model/atomic_model/base_atomic_model.py @@ -58,24 +58,44 @@ def reinit_pair_exclude( else: self.pair_excl = PairExcludeMask(self.get_ntypes(), self.pair_exclude_types) + # to make jit happy... + def make_atom_mask( + self, + atype: torch.Tensor, + ) -> torch.Tensor: + """The atoms with type < 0 are treated as virutal atoms, + which serves as place-holders for multi-frame calculations + with different number of atoms in different frames. + + Parameters + ---------- + atype + Atom types. >= 0 for real atoms <0 for virtual atoms. + + Returns + ------- + mask + True for real atoms and False for virutal atoms. + + """ + # supposed to be supported by all backends + return atype >= 0 + def atomic_output_def(self) -> FittingOutputDef: old_def = self.fitting_output_def() - if self.atom_excl is None: - return old_def - else: - old_list = list(old_def.get_data().values()) - return FittingOutputDef( - old_list # noqa:RUF005 - + [ - OutputVariableDef( - name="mask", - shape=[1], - reduciable=False, - r_differentiable=False, - c_differentiable=False, - ) - ] - ) + old_list = list(old_def.get_data().values()) + return FittingOutputDef( + old_list # noqa:RUF005 + + [ + OutputVariableDef( + name="mask", + shape=[1], + reduciable=False, + r_differentiable=False, + c_differentiable=False, + ) + ] + ) def forward_common_atomic( self, @@ -86,6 +106,37 @@ def forward_common_atomic( fparam: Optional[torch.Tensor] = None, aparam: Optional[torch.Tensor] = None, ) -> Dict[str, torch.Tensor]: + """Common interface for atomic inference. + + This method accept extended coordinates, extended atom typs, neighbor list, + and predict the atomic contribution of the fit property. + + Parameters + ---------- + extended_coord + extended coodinates, shape: nf x (nall x 3) + extended_atype + extended atom typs, shape: nf x nall + for a type < 0 indicating the atomic is virtual. + nlist + neighbor list, shape: nf x nloc x nsel + mapping + extended to local index mapping, shape: nf x nall + fparam + frame parameters, shape: nf x dim_fparam + aparam + atomic parameter, shape: nf x nloc x dim_aparam + + Returns + ------- + ret_dict + dict of output atomic properties. + should implement the definition of `fitting_output_def`. + ret_dict["mask"] of shape nf x nloc will be provided. + ret_dict["mask"][ff,ii] == 1 indicating the ii-th atom of the ff-th frame is real. + ret_dict["mask"][ff,ii] == 0 indicating the ii-th atom of the ff-th frame is virtual. + + """ _, nloc, _ = nlist.shape atype = extended_atype[:, :nloc] @@ -94,24 +145,28 @@ def forward_common_atomic( # exclude neighbors in the nlist nlist = torch.where(pair_mask == 1, nlist, -1) + ext_atom_mask = self.make_atom_mask(extended_atype) ret_dict = self.forward_atomic( extended_coord, - extended_atype, + torch.where(ext_atom_mask, extended_atype, 0), nlist, mapping=mapping, fparam=fparam, aparam=aparam, ) + # nf x nloc + atom_mask = ext_atom_mask[:, :nloc].to(torch.int32) if self.atom_excl is not None: - atom_mask = self.atom_excl(atype) - for kk in ret_dict.keys(): - out_shape = ret_dict[kk].shape - ret_dict[kk] = ( - ret_dict[kk].reshape([out_shape[0], out_shape[1], -1]) - * atom_mask[:, :, None] - ).reshape(out_shape) - ret_dict["mask"] = atom_mask + atom_mask *= self.atom_excl(atype) + + for kk in ret_dict.keys(): + out_shape = ret_dict[kk].shape + ret_dict[kk] = ( + ret_dict[kk].reshape([out_shape[0], out_shape[1], -1]) + * atom_mask[:, :, None] + ).view(out_shape) + ret_dict["mask"] = atom_mask return ret_dict diff --git a/deepmd/pt/utils/nlist.py b/deepmd/pt/utils/nlist.py index 7e92f44e8d..cdee6e3722 100644 --- a/deepmd/pt/utils/nlist.py +++ b/deepmd/pt/utils/nlist.py @@ -51,7 +51,7 @@ def extend_input_and_build_neighbor_list( def build_neighbor_list( - coord1: torch.Tensor, + coord: torch.Tensor, atype: torch.Tensor, nloc: int, rcut: float, @@ -62,10 +62,11 @@ def build_neighbor_list( Parameters ---------- - coord1 : torch.Tensor + coord : torch.Tensor exptended coordinates of shape [batch_size, nall x 3] atype : torch.Tensor extended atomic types of shape [batch_size, nall] + if type < 0 the atom is treat as virtual atoms. nloc : int number of local atoms. rcut : float @@ -90,11 +91,20 @@ def build_neighbor_list( if distinguish_types==True and we have two types |---- nsel[0] -----| |---- nsel[1] -----| xx xx xx xx -1 -1 -1 xx xx xx -1 -1 -1 -1 + For virtual atoms all neighboring positions are filled with -1. """ - batch_size = coord1.shape[0] - coord1 = coord1.view(batch_size, -1) - nall = coord1.shape[1] // 3 + batch_size = coord.shape[0] + coord = coord.view(batch_size, -1) + nall = coord.shape[1] // 3 + # fill virtual atoms with large coords so they are not neighbors of any + # real atom. + xmax = torch.max(coord) + 2.0 * rcut + # nf x nall + is_vir = atype < 0 + coord1 = torch.where(is_vir[:, :, None], xmax, coord.view(-1, nall, 3)).view( + -1, nall * 3 + ) if isinstance(sel, int): sel = [sel] nsel = sum(sel) @@ -133,7 +143,9 @@ def build_neighbor_list( dim=-1, ) assert list(nlist.shape) == [batch_size, nloc, nsel] - nlist = nlist.masked_fill((rr > rcut), -1) + nlist = torch.where( + torch.logical_or((rr > rcut), is_vir[:, :nloc, None]), -1, nlist + ) if distinguish_types: return nlist_distinguish_types(nlist, atype, sel) diff --git a/source/tests/common/dpmodel/case_single_frame_with_nlist.py b/source/tests/common/dpmodel/case_single_frame_with_nlist.py index c260a18527..828e090cad 100644 --- a/source/tests/common/dpmodel/case_single_frame_with_nlist.py +++ b/source/tests/common/dpmodel/case_single_frame_with_nlist.py @@ -72,3 +72,53 @@ def setUp(self): nlist1 = inv_perm[nlist1] nlist1 = np.where(mask, -1, nlist1) self.nlist = np.concatenate([self.nlist, nlist1], axis=0) + + +class TestCaseSingleFrameWithNlistWithVirtual: + def setUp(self): + # nloc == 3, nall == 4 + self.nloc = 4 + self.nall = 5 + self.nf, self.nt = 2, 2 + self.coord_ext = np.array( + [ + [0, 0, 0], + [0, 0, 0], + [0, 1, 0], + [0, 0, 1], + [0, -2, 0], + ], + dtype=np.float64, + ).reshape([1, self.nall, 3]) + self.atype_ext = np.array([0, -1, 0, 1, 0], dtype=int).reshape([1, self.nall]) + # sel = [5, 2] + self.sel = [5, 2] + self.nlist = np.array( + [ + [2, 4, -1, -1, -1, 3, -1], + [-1, -1, -1, -1, -1, -1, -1], + [0, -1, -1, -1, -1, 3, -1], + [0, 2, -1, -1, -1, -1, -1], + ], + dtype=int, + ).reshape([1, self.nloc, sum(self.sel)]) + self.rcut = 2.2 + self.rcut_smth = 0.4 + # permutations + self.perm = np.array([3, 0, 1, 2, 4], dtype=np.int32) + inv_perm = np.argsort(self.perm) + # permute the coord and atype + self.coord_ext = np.concatenate( + [self.coord_ext, self.coord_ext[:, self.perm, :]], axis=0 + ).reshape(self.nf, self.nall * 3) + self.atype_ext = np.concatenate( + [self.atype_ext, self.atype_ext[:, self.perm]], axis=0 + ) + # permute the nlist + nlist1 = self.nlist[:, self.perm[: self.nloc], :] + mask = nlist1 == -1 + nlist1 = inv_perm[nlist1] + nlist1 = np.where(mask, -1, nlist1) + self.nlist = np.concatenate([self.nlist, nlist1], axis=0) + self.get_real_mapping = np.array([[0, 2, 3], [0, 1, 3]], dtype=np.int32) + self.atol = 1e-12 diff --git a/source/tests/common/dpmodel/test_dp_atomic_model.py b/source/tests/common/dpmodel/test_dp_atomic_model.py index ac49280b82..c69de6161d 100644 --- a/source/tests/common/dpmodel/test_dp_atomic_model.py +++ b/source/tests/common/dpmodel/test_dp_atomic_model.py @@ -16,6 +16,7 @@ from .case_single_frame_with_nlist import ( TestCaseSingleFrameWithNlist, + TestCaseSingleFrameWithNlistWithVirtual, ) @@ -92,10 +93,8 @@ def test_excl_consistency(self): # check output def out_names = [vv.name for vv in md0.atomic_output_def().get_data().values()] - if atom_excl == []: - self.assertEqual(out_names, ["energy"]) - else: - self.assertEqual(out_names, ["energy", "mask"]) + self.assertEqual(out_names, ["energy", "mask"]) + if atom_excl != []: for ii in md0.atomic_output_def().get_data().values(): if ii.name == "mask": self.assertEqual(ii.shape, [1]) @@ -115,3 +114,49 @@ def test_excl_consistency(self): np.testing.assert_array_equal(ret0["mask"], expected) else: raise ValueError(f"not expected atom_excl {atom_excl}") + + +class TestDPAtomicModelVirtualConsistency(unittest.TestCase): + def setUp(self): + self.case0 = TestCaseSingleFrameWithNlist() + self.case1 = TestCaseSingleFrameWithNlistWithVirtual() + self.case0.setUp() + self.case1.setUp() + + def test_virtual_consistency(self): + nf, _, _ = self.case0.nlist.shape + ds = DescrptSeA( + self.case0.rcut, + self.case0.rcut_smth, + self.case0.sel, + ) + ft = InvarFitting( + "energy", + self.case0.nt, + ds.get_dim_out(), + 1, + mixed_types=ds.mixed_types(), + ) + type_map = ["foo", "bar"] + md1 = DPAtomicModel(ds, ft, type_map=type_map) + + args0 = [self.case0.coord_ext, self.case0.atype_ext, self.case0.nlist] + # args0 = [np.array(ii) for ii in args0] + args1 = [self.case1.coord_ext, self.case1.atype_ext, self.case1.nlist] + # args1 = [np.array(ii) for ii in args1] + + ret0 = md1.forward_common_atomic(*args0) + ret1 = md1.forward_common_atomic(*args1) + + for dd in range(self.case0.nf): + np.testing.assert_allclose( + ret0["energy"][dd], + ret1["energy"][dd, self.case1.get_real_mapping[dd], :], + ) + expected_mask = np.array( + [ + [1, 0, 1, 1], + [1, 1, 0, 1], + ] + ) + np.testing.assert_equal(ret1["mask"], expected_mask) diff --git a/source/tests/common/dpmodel/test_dp_model.py b/source/tests/common/dpmodel/test_dp_model.py index c3de1f4cdf..9121c7cd07 100644 --- a/source/tests/common/dpmodel/test_dp_model.py +++ b/source/tests/common/dpmodel/test_dp_model.py @@ -87,7 +87,10 @@ def test_prec_consistency(self): self.assertEqual(model_l_ret_32[ii].dtype, np.float64) else: self.assertEqual(model_l_ret_32[ii].dtype, np.float32) - self.assertEqual(model_l_ret_64[ii].dtype, np.float64) + if ii != "mask": + self.assertEqual(model_l_ret_64[ii].dtype, np.float64) + else: + self.assertEqual(model_l_ret_64[ii].dtype, np.int32) np.testing.assert_allclose( model_l_ret_32[ii], model_l_ret_64[ii], @@ -138,8 +141,10 @@ def test_prec_consistency(self): self.assertEqual(model_l_ret_32[ii].dtype, np.float64) else: self.assertEqual(model_l_ret_32[ii].dtype, np.float32) - self.assertEqual(model_l_ret_64[ii].dtype, np.float64) - self.assertEqual(model_l_ret_64[ii].dtype, np.float64) + if ii != "mask": + self.assertEqual(model_l_ret_64[ii].dtype, np.float64) + else: + self.assertEqual(model_l_ret_64[ii].dtype, np.int32) np.testing.assert_allclose( model_l_ret_32[ii], model_l_ret_64[ii], diff --git a/source/tests/common/dpmodel/test_nlist.py b/source/tests/common/dpmodel/test_nlist.py index 35145cde39..ee8a7139e7 100644 --- a/source/tests/common/dpmodel/test_nlist.py +++ b/source/tests/common/dpmodel/test_nlist.py @@ -125,12 +125,12 @@ def test_nlist_lt(self): class TestNeighList(unittest.TestCase): def setUp(self): self.nf = 3 - self.nloc = 2 + self.nloc = 3 self.ns = 5 * 5 * 3 self.nall = self.ns * self.nloc self.cell = np.array([[1, 0, 0], [0.4, 0.8, 0], [0.1, 0.3, 2.1]], dtype=dtype) - self.icoord = np.array([[0, 0, 0], [0.5, 0.5, 0.1]], dtype=dtype) - self.atype = np.array([0, 1], dtype=np.int32) + self.icoord = np.array([[0, 0, 0], [0, 0, 0], [0.5, 0.5, 0.1]], dtype=dtype) + self.atype = np.array([-1, 0, 1], dtype=np.int32) [self.cell, self.icoord, self.atype] = [ np.expand_dims(ii, 0) for ii in [self.cell, self.icoord, self.atype] ] @@ -144,8 +144,9 @@ def setUp(self): self.nsel = [10, 10] self.ref_nlist = np.array( [ - [0, 0, 0, 0, 0, 0, -1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1], - [0, 0, 0, 0, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1], + [-1] * sum(self.nsel), + [1, 1, 1, 1, 1, 1, -1, -1, -1, -1, 2, 2, 2, 2, -1, -1, -1, -1, -1, -1], + [1, 1, 1, 1, -1, -1, -1, -1, -1, -1, 2, 2, 2, 2, 2, 2, -1, -1, -1, -1], ] ) @@ -269,7 +270,7 @@ def test_extend_coord(self): ) np.testing.assert_allclose( cc, - np.array([30, 30, 30, 30, 30], dtype=np.int32), + np.array([self.ns * self.nloc // 5] * 5, dtype=np.int32), rtol=self.prec, atol=self.prec, ) @@ -282,7 +283,7 @@ def test_extend_coord(self): ) np.testing.assert_allclose( cc, - np.array([30, 30, 30, 30, 30], dtype=np.int32), + np.array([self.ns * self.nloc // 5] * 5, dtype=np.int32), rtol=self.prec, atol=self.prec, ) @@ -295,7 +296,7 @@ def test_extend_coord(self): ) np.testing.assert_allclose( cc, - np.array([50, 50, 50], dtype=np.int32), + np.array([self.ns * self.nloc // 3] * 3, dtype=np.int32), rtol=self.prec, atol=self.prec, ) diff --git a/source/tests/pt/model/test_dp_atomic_model.py b/source/tests/pt/model/test_dp_atomic_model.py index 6daaeef2ef..4a35b4676a 100644 --- a/source/tests/pt/model/test_dp_atomic_model.py +++ b/source/tests/pt/model/test_dp_atomic_model.py @@ -27,6 +27,7 @@ from .test_env_mat import ( TestCaseSingleFrameWithNlist, + TestCaseSingleFrameWithNlistWithVirtual, ) dtype = env.GLOBAL_PT_FLOAT_PRECISION @@ -166,10 +167,8 @@ def test_excl_consistency(self): # check output def out_names = [vv.name for vv in md0.atomic_output_def().get_data().values()] - if atom_excl == []: - self.assertEqual(out_names, ["energy"]) - else: - self.assertEqual(out_names, ["energy", "mask"]) + self.assertEqual(out_names, ["energy", "mask"]) + if atom_excl != []: for ii in md0.atomic_output_def().get_data().values(): if ii.name == "mask": self.assertEqual(ii.shape, [1]) @@ -189,3 +188,49 @@ def test_excl_consistency(self): np.testing.assert_array_equal(to_numpy_array(ret0["mask"]), expected) else: raise ValueError(f"not expected atom_excl {atom_excl}") + + +class TestDPAtomicModelVirtualConsistency(unittest.TestCase): + def setUp(self): + self.case0 = TestCaseSingleFrameWithNlist() + self.case1 = TestCaseSingleFrameWithNlistWithVirtual() + self.case0.setUp() + self.case1.setUp() + + def test_virtual_consistency(self): + nf, _, _ = self.case0.nlist.shape + ds = DescrptSeA( + self.case0.rcut, + self.case0.rcut_smth, + self.case0.sel, + ) + ft = InvarFitting( + "energy", + self.case0.nt, + ds.get_dim_out(), + 1, + mixed_types=ds.mixed_types(), + ) + type_map = ["foo", "bar"] + md1 = DPAtomicModel(ds, ft, type_map=type_map).to(env.DEVICE) + + args0 = [self.case0.coord_ext, self.case0.atype_ext, self.case0.nlist] + args0 = [to_torch_tensor(ii) for ii in args0] + args1 = [self.case1.coord_ext, self.case1.atype_ext, self.case1.nlist] + args1 = [to_torch_tensor(ii) for ii in args1] + + ret0 = md1.forward_common_atomic(*args0) + ret1 = md1.forward_common_atomic(*args1) + + for dd in range(self.case0.nf): + np.testing.assert_allclose( + to_numpy_array(ret0["energy"])[dd], + to_numpy_array(ret1["energy"])[dd, self.case1.get_real_mapping[dd], :], + ) + expected_mask = np.array( + [ + [1, 0, 1, 1], + [1, 1, 0, 1], + ] + ) + np.testing.assert_equal(to_numpy_array(ret1["mask"]), expected_mask) diff --git a/source/tests/pt/model/test_dp_model.py b/source/tests/pt/model/test_dp_model.py index c0b152b3d3..7470cf96d0 100644 --- a/source/tests/pt/model/test_dp_model.py +++ b/source/tests/pt/model/test_dp_model.py @@ -237,7 +237,10 @@ def test_prec_consistency(self): self.assertEqual(model_l_ret_32[ii].dtype, torch.float64) else: self.assertEqual(model_l_ret_32[ii].dtype, torch.float32) - self.assertEqual(model_l_ret_64[ii].dtype, torch.float64) + if ii != "mask": + self.assertEqual(model_l_ret_64[ii].dtype, torch.float64) + else: + self.assertEqual(model_l_ret_64[ii].dtype, torch.int32) np.testing.assert_allclose( to_numpy_array(model_l_ret_32[ii]), to_numpy_array(model_l_ret_64[ii]), @@ -377,7 +380,10 @@ def test_prec_consistency(self): self.assertEqual(model_l_ret_32[ii].dtype, torch.float64) else: self.assertEqual(model_l_ret_32[ii].dtype, torch.float32) - self.assertEqual(model_l_ret_64[ii].dtype, torch.float64) + if ii != "mask": + self.assertEqual(model_l_ret_64[ii].dtype, torch.float64) + else: + self.assertEqual(model_l_ret_64[ii].dtype, torch.int32) np.testing.assert_allclose( to_numpy_array(model_l_ret_32[ii]), to_numpy_array(model_l_ret_64[ii]), diff --git a/source/tests/pt/model/test_env_mat.py b/source/tests/pt/model/test_env_mat.py index 615e7c6230..e18093b2f1 100644 --- a/source/tests/pt/model/test_env_mat.py +++ b/source/tests/pt/model/test_env_mat.py @@ -64,6 +64,56 @@ def setUp(self): self.atol = 1e-12 +class TestCaseSingleFrameWithNlistWithVirtual: + def setUp(self): + # nloc == 3, nall == 4 + self.nloc = 4 + self.nall = 5 + self.nf, self.nt = 2, 2 + self.coord_ext = np.array( + [ + [0, 0, 0], + [0, 0, 0], + [0, 1, 0], + [0, 0, 1], + [0, -2, 0], + ], + dtype=np.float64, + ).reshape([1, self.nall, 3]) + self.atype_ext = np.array([0, -1, 0, 1, 0], dtype=int).reshape([1, self.nall]) + # sel = [5, 2] + self.sel = [5, 2] + self.nlist = np.array( + [ + [2, 4, -1, -1, -1, 3, -1], + [-1, -1, -1, -1, -1, -1, -1], + [0, -1, -1, -1, -1, 3, -1], + [0, 2, -1, -1, -1, -1, -1], + ], + dtype=int, + ).reshape([1, self.nloc, sum(self.sel)]) + self.rcut = 2.2 + self.rcut_smth = 0.4 + # permutations + self.perm = np.array([3, 0, 1, 2, 4], dtype=np.int32) + inv_perm = np.argsort(self.perm) + # permute the coord and atype + self.coord_ext = np.concatenate( + [self.coord_ext, self.coord_ext[:, self.perm, :]], axis=0 + ).reshape(self.nf, self.nall * 3) + self.atype_ext = np.concatenate( + [self.atype_ext, self.atype_ext[:, self.perm]], axis=0 + ) + # permute the nlist + nlist1 = self.nlist[:, self.perm[: self.nloc], :] + mask = nlist1 == -1 + nlist1 = inv_perm[nlist1] + nlist1 = np.where(mask, -1, nlist1) + self.nlist = np.concatenate([self.nlist, nlist1], axis=0) + self.get_real_mapping = np.array([[0, 2, 3], [0, 1, 3]], dtype=np.int32) + self.atol = 1e-12 + + class TestCaseSingleFrameWithoutNlist: def setUp(self): # nloc == 3, nall == 4 diff --git a/source/tests/pt/model/test_nlist.py b/source/tests/pt/model/test_nlist.py index 616af93081..244b3804c8 100644 --- a/source/tests/pt/model/test_nlist.py +++ b/source/tests/pt/model/test_nlist.py @@ -22,16 +22,16 @@ class TestNeighList(unittest.TestCase): def setUp(self): self.nf = 3 - self.nloc = 2 + self.nloc = 3 self.ns = 5 * 5 * 3 self.nall = self.ns * self.nloc self.cell = torch.tensor( [[1, 0, 0], [0.4, 0.8, 0], [0.1, 0.3, 2.1]], dtype=dtype, device=env.DEVICE ) self.icoord = torch.tensor( - [[0, 0, 0], [0.5, 0.5, 0.1]], dtype=dtype, device=env.DEVICE + [[0, 0, 0], [0, 0, 0], [0.5, 0.5, 0.1]], dtype=dtype, device=env.DEVICE ) - self.atype = torch.tensor([0, 1], dtype=torch.int, device=env.DEVICE) + self.atype = torch.tensor([-1, 0, 1], dtype=torch.int, device=env.DEVICE) [self.cell, self.icoord, self.atype] = [ ii.unsqueeze(0) for ii in [self.cell, self.icoord, self.atype] ] @@ -51,8 +51,9 @@ def setUp(self): # mapping[0], type_split=True, ) self.ref_nlist = torch.tensor( [ - [0, 0, 0, 0, 0, 0, -1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1], - [0, 0, 0, 0, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1], + [-1] * sum(self.nsel), + [1, 1, 1, 1, 1, 1, -1, -1, -1, -1, 2, 2, 2, 2, -1, -1, -1, -1, -1, -1], + [1, 1, 1, 1, -1, -1, -1, -1, -1, -1, 2, 2, 2, 2, 2, 2, -1, -1, -1, -1], ], device=env.DEVICE, ) @@ -181,7 +182,9 @@ def test_extend_coord(self): ) torch.testing.assert_close( cc, - torch.tensor([30, 30, 30, 30, 30], dtype=torch.long, device=env.DEVICE), + torch.tensor( + [self.ns * self.nloc // 5] * 5, dtype=torch.long, device=env.DEVICE + ), rtol=self.prec, atol=self.prec, ) @@ -194,7 +197,9 @@ def test_extend_coord(self): ) torch.testing.assert_close( cc, - torch.tensor([30, 30, 30, 30, 30], dtype=torch.long, device=env.DEVICE), + torch.tensor( + [self.ns * self.nloc // 5] * 5, dtype=torch.long, device=env.DEVICE + ), rtol=self.prec, atol=self.prec, ) @@ -207,7 +212,9 @@ def test_extend_coord(self): ) torch.testing.assert_close( cc, - torch.tensor([50, 50, 50], dtype=torch.long, device=env.DEVICE), + torch.tensor( + [self.ns * self.nloc // 3] * 3, dtype=torch.long, device=env.DEVICE + ), rtol=self.prec, atol=self.prec, )