diff --git a/deepmd/model_format/make_model.py b/deepmd/model_format/make_model.py index dd23b6ea0e..4e0996995c 100644 --- a/deepmd/model_format/make_model.py +++ b/deepmd/model_format/make_model.py @@ -9,6 +9,7 @@ from .nlist import ( build_neighbor_list, extend_coord_with_ghosts, + nlist_distinguish_types, ) from .output_def import ( ModelOutputDef, @@ -133,9 +134,9 @@ def call( def call_lower( self, - extended_coord, - extended_atype, - nlist, + extended_coord: np.ndarray, + extended_atype: np.ndarray, + nlist: np.ndarray, mapping: Optional[np.ndarray] = None, fparam: Optional[np.ndarray] = None, aparam: Optional[np.ndarray] = None, @@ -171,6 +172,7 @@ def call_lower( """ nframes, nall = extended_atype.shape[:2] extended_coord = extended_coord.reshape(nframes, -1, 3) + nlist = self.format_nlist(extended_coord, extended_atype, nlist) atomic_ret = self.forward_atomic( extended_coord, extended_atype, @@ -187,4 +189,88 @@ def call_lower( ) return model_predict + def format_nlist( + self, + extended_coord: np.ndarray, + extended_atype: np.ndarray, + nlist: np.ndarray, + ): + """Format the neighbor list. + + 1. If the number of neighbors in the `nlist` is equal to sum(self.sel), + it does nothong + + 2. If the number of neighbors in the `nlist` is smaller than sum(self.sel), + the `nlist` is pad with -1. + + 3. If the number of neighbors in the `nlist` is larger than sum(self.sel), + the nearest sum(sel) neighbors will be preseved. + + Known limitations: + + In the case of self.distinguish_types, the nlist is always formatted. + May have side effact on the efficiency. + + Parameters + ---------- + extended_coord + coodinates in extended region. nf x nall x 3 + extended_atype + atomic type in extended region. nf x nall + nlist + neighbor list. nf x nloc x nsel + + Returns + ------- + formated_nlist + the formated nlist. + + """ + n_nf, n_nloc, n_nnei = nlist.shape + distinguish_types = self.distinguish_types() + ret = self._format_nlist(extended_coord, nlist, sum(self.get_sel())) + if distinguish_types: + ret = nlist_distinguish_types(ret, extended_atype, self.get_sel()) + return ret + + def _format_nlist( + self, + extended_coord: np.ndarray, + nlist: np.ndarray, + nnei: int, + ): + n_nf, n_nloc, n_nnei = nlist.shape + extended_coord = extended_coord.reshape([n_nf, -1, 3]) + nall = extended_coord.shape[1] + rcut = self.get_rcut() + + if n_nnei < nnei: + # make a copy before revise + ret = np.concatenate( + [ + nlist, + -1 * np.ones([n_nf, n_nloc, nnei - n_nnei], dtype=nlist.dtype), + ], + axis=-1, + ) + elif n_nnei > nnei: + # make a copy before revise + m_real_nei = nlist >= 0 + ret = np.where(m_real_nei, nlist, 0) + coord0 = extended_coord[:, :n_nloc, :] + index = ret.reshape(n_nf, n_nloc * n_nnei, 1).repeat(3, axis=2) + coord1 = np.take_along_axis(extended_coord, index, axis=1) + coord1 = coord1.reshape(n_nf, n_nloc, n_nnei, 3) + rr = np.linalg.norm(coord0[:, :, None, :] - coord1, axis=-1) + rr = np.where(m_real_nei, rr, float("inf")) + rr, ret_mapping = np.sort(rr, axis=-1), np.argsort(rr, axis=-1) + ret = np.take_along_axis(ret, ret_mapping, axis=2) + ret = np.where(rr > rcut, -1, ret) + ret = ret[..., :nnei] + else: # n_nnei == nnei: + # copy anyway... + ret = nlist + assert ret.shape[-1] == nnei + return ret + return CM diff --git a/deepmd/model_format/nlist.py b/deepmd/model_format/nlist.py index 1083e1ab48..bc6592d52b 100644 --- a/deepmd/model_format/nlist.py +++ b/deepmd/model_format/nlist.py @@ -26,9 +26,9 @@ def build_neighbor_list( Parameters ---------- - coord1 : torch.Tensor + coord1 : np.ndarray exptended coordinates of shape [batch_size, nall x 3] - atype : torch.Tensor + atype : np.ndarray extended atomic types of shape [batch_size, nall] nloc : int number of local atoms. @@ -44,7 +44,7 @@ def build_neighbor_list( Returns ------- - neighbor_list : torch.Tensor + neighbor_list : np.ndarray Neighbor list of shape [batch_size, nloc, nsel], the neighbors are stored in an ascending order. If the number of neighbors is less than nsel, the positions are masked @@ -88,26 +88,39 @@ def build_neighbor_list( assert list(nlist.shape) == [batch_size, nloc, nsel] nlist = np.where((rr > rcut), -1, nlist) - if not distinguish_types: - return nlist + if distinguish_types: + return nlist_distinguish_types(nlist, atype, sel) else: - ret_nlist = [] - tmp_atype = np.tile(atype[:, None], [1, nloc, 1]) - mask = nlist == -1 - tnlist_0 = nlist - tnlist_0[mask] = 0 - tnlist = np.take_along_axis(tmp_atype, tnlist_0, axis=2).squeeze() - tnlist = np.where(mask, -1, tnlist) - snsel = tnlist.shape[2] - for ii, ss in enumerate(sel): - pick_mask = (tnlist == ii).astype(np.int32) - sorted_indices = np.argsort(-pick_mask, kind="stable", axis=-1) - pick_mask_sorted = -np.sort(-pick_mask, axis=-1) - inlist = np.take_along_axis(nlist, sorted_indices, axis=2) - inlist = np.where(~pick_mask_sorted.astype(bool), -1, inlist) - ret_nlist.append(np.split(inlist, [ss, snsel - ss], axis=-1)[0]) - ret = np.concatenate(ret_nlist, axis=-1) - return ret + return nlist + + +def nlist_distinguish_types( + nlist: np.ndarray, + atype: np.ndarray, + sel: List[int], +): + """Given a nlist that does not distinguish atom types, return a nlist that + distinguish atom types. + + """ + nf, nloc, _ = nlist.shape + ret_nlist = [] + tmp_atype = np.tile(atype[:, None], [1, nloc, 1]) + mask = nlist == -1 + tnlist_0 = nlist.copy() + tnlist_0[mask] = 0 + tnlist = np.take_along_axis(tmp_atype, tnlist_0, axis=2).squeeze() + tnlist = np.where(mask, -1, tnlist) + snsel = tnlist.shape[2] + for ii, ss in enumerate(sel): + pick_mask = (tnlist == ii).astype(np.int32) + sorted_indices = np.argsort(-pick_mask, kind="stable", axis=-1) + pick_mask_sorted = -np.sort(-pick_mask, axis=-1) + inlist = np.take_along_axis(nlist, sorted_indices, axis=2) + inlist = np.where(~pick_mask_sorted.astype(bool), -1, inlist) + ret_nlist.append(np.split(inlist, [ss, snsel - ss], axis=-1)[0]) + ret = np.concatenate(ret_nlist, axis=-1) + return ret def get_multiple_nlist_key(rcut: float, nsel: int) -> str: @@ -127,9 +140,9 @@ def build_multiple_neighbor_list( Parameters ---------- - coord : torch.Tensor + coord : np.ndarray exptended coordinates of shape [batch_size, nall x 3] - nlist : torch.Tensor + nlist : np.ndarray Neighbor list of shape [batch_size, nloc, nsel], the neighbors should be stored in an ascending order. rcuts : List[float] @@ -139,7 +152,7 @@ def build_multiple_neighbor_list( Returns ------- - nlist_dict : Dict[str, torch.Tensor] + nlist_dict : Dict[str, np.ndarray] A dict of nlists, key given by get_multiple_nlist_key(rc, nsel) value being the corresponding nlist. @@ -185,20 +198,22 @@ def extend_coord_with_ghosts( Parameters ---------- - coord : torch.Tensor + coord : np.ndarray original coordinates of shape [-1, nloc*3]. - atype : torch.Tensor + atype : np.ndarray atom type of shape [-1, nloc]. - cell : torch.Tensor + cell : np.ndarray simulation cell tensor of shape [-1, 9]. + rcut : float + the cutoff radius Returns ------- - extended_coord: torch.Tensor + extended_coord: np.ndarray extended coordinates of shape [-1, nall*3]. - extended_atype: torch.Tensor + extended_atype: np.ndarray extended atom type of shape [-1, nall]. - index_mapping: torch.Tensor + index_mapping: np.ndarray maping extended index to the local index """ diff --git a/deepmd/pt/model/model/make_model.py b/deepmd/pt/model/model/make_model.py index 6ba5ea23d3..b7c1c7edcf 100644 --- a/deepmd/pt/model/model/make_model.py +++ b/deepmd/pt/model/model/make_model.py @@ -16,6 +16,7 @@ from deepmd.pt.utils.nlist import ( build_neighbor_list, extend_coord_with_ghosts, + nlist_distinguish_types, ) from deepmd.pt.utils.region import ( normalize_coord, @@ -164,6 +165,7 @@ def forward_common_lower( """ nframes, nall = extended_atype.shape[:2] extended_coord = extended_coord.view(nframes, -1, 3) + nlist = self.format_nlist(extended_coord, extended_atype, nlist) atomic_ret = self.forward_atomic( extended_coord, extended_atype, @@ -180,4 +182,93 @@ def forward_common_lower( ) return model_predict + def format_nlist( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + ): + """Format the neighbor list. + + 1. If the number of neighbors in the `nlist` is equal to sum(self.sel), + it does nothong + + 2. If the number of neighbors in the `nlist` is smaller than sum(self.sel), + the `nlist` is pad with -1. + + 3. If the number of neighbors in the `nlist` is larger than sum(self.sel), + the nearest sum(sel) neighbors will be preseved. + + Known limitations: + + In the case of self.distinguish_types, the nlist is always formatted. + May have side effact on the efficiency. + + Parameters + ---------- + extended_coord + coodinates in extended region. nf x nall x 3 + extended_atype + atomic type in extended region. nf x nall + nlist + neighbor list. nf x nloc x nsel + + Returns + ------- + formated_nlist + the formated nlist. + + """ + n_nf, n_nloc, n_nnei = nlist.shape + distinguish_types = self.distinguish_types() + nlist = self._format_nlist(extended_coord, nlist, sum(self.get_sel())) + if distinguish_types: + nlist = nlist_distinguish_types(nlist, extended_atype, self.get_sel()) + return nlist + + def _format_nlist( + self, + extended_coord: torch.Tensor, + nlist: torch.Tensor, + nnei: int, + ): + n_nf, n_nloc, n_nnei = nlist.shape + # nf x nall x 3 + extended_coord = extended_coord.view([n_nf, -1, 3]) + nall = extended_coord.shape[1] + rcut = self.get_rcut() + + if n_nnei < nnei: + nlist = torch.cat( + [ + nlist, + -1 + * torch.ones( + [n_nf, n_nloc, nnei - n_nnei], dtype=nlist.dtype + ).to(nlist.device), + ], + dim=-1, + ) + elif n_nnei > nnei: + m_real_nei = nlist >= 0 + nlist = torch.where(m_real_nei, nlist, 0) + # nf x nloc x 3 + coord0 = extended_coord[:, :n_nloc, :] + # nf x (nloc x nnei) x 3 + index = nlist.view(n_nf, n_nloc * n_nnei, 1).expand(-1, -1, 3) + coord1 = torch.gather(extended_coord, 1, index) + # nf x nloc x nnei x 3 + coord1 = coord1.view(n_nf, n_nloc, n_nnei, 3) + # nf x nloc x nnei + rr = torch.linalg.norm(coord0[:, :, None, :] - coord1, dim=-1) + rr = torch.where(m_real_nei, rr, float("inf")) + rr, nlist_mapping = torch.sort(rr, dim=-1) + nlist = torch.gather(nlist, 2, nlist_mapping) + nlist = torch.where(rr > rcut, -1, nlist) + nlist = nlist[..., :nnei] + else: # n_nnei == nnei: + pass # great! + assert nlist.shape[-1] == nnei + return nlist + return CM diff --git a/deepmd/pt/utils/nlist.py b/deepmd/pt/utils/nlist.py index a63e94ba5f..fdb2627f04 100644 --- a/deepmd/pt/utils/nlist.py +++ b/deepmd/pt/utils/nlist.py @@ -90,7 +90,7 @@ def build_neighbor_list( nlist = torch.cat( [ nlist, - torch.ones([batch_size, nloc, nsel - nnei], dtype=torch.long).to( + torch.ones([batch_size, nloc, nsel - nnei], dtype=nlist.dtype).to( rr.device ), ], @@ -99,35 +99,46 @@ def build_neighbor_list( assert list(nlist.shape) == [batch_size, nloc, nsel] nlist = nlist.masked_fill((rr > rcut), -1) - if not distinguish_types: - return nlist + if distinguish_types: + return nlist_distinguish_types(nlist, atype, sel) else: - ret_nlist = [] - # nloc x nall - tmp_atype = torch.tile(atype.unsqueeze(1), [1, nloc, 1]) - mask = nlist == -1 + return nlist + + +def nlist_distinguish_types( + nlist: torch.Tensor, + atype: torch.Tensor, + sel: List[int], +): + """Given a nlist that does not distinguish atom types, return a nlist that + distinguish atom types. + + """ + nf, nloc, nnei = nlist.shape + ret_nlist = [] + # nloc x nall + tmp_atype = torch.tile(atype.unsqueeze(1), [1, nloc, 1]) + mask = nlist == -1 + # nloc x s(nsel) + tnlist = torch.gather( + tmp_atype, + 2, + nlist.masked_fill(mask, 0), + ) + tnlist = tnlist.masked_fill(mask, -1) + snsel = tnlist.shape[2] + for ii, ss in enumerate(sel): # nloc x s(nsel) - tnlist = torch.gather( - tmp_atype, - 2, - nlist.masked_fill(mask, 0), - ) - tnlist = tnlist.masked_fill(mask, -1) - snsel = tnlist.shape[2] - for ii, ss in enumerate(sel): - # nloc x s(nsel) - # to int because bool cannot be sort on GPU - pick_mask = (tnlist == ii).to(torch.int32) - # nloc x s(nsel), stable sort, nearer neighbors first - pick_mask, imap = torch.sort( - pick_mask, dim=-1, descending=True, stable=True - ) - # nloc x s(nsel) - inlist = torch.gather(nlist, 2, imap) - inlist = inlist.masked_fill(~(pick_mask.to(torch.bool)), -1) - # nloc x nsel[ii] - ret_nlist.append(torch.split(inlist, [ss, snsel - ss], dim=-1)[0]) - return torch.concat(ret_nlist, dim=-1) + # to int because bool cannot be sort on GPU + pick_mask = (tnlist == ii).to(torch.int32) + # nloc x s(nsel), stable sort, nearer neighbors first + pick_mask, imap = torch.sort(pick_mask, dim=-1, descending=True, stable=True) + # nloc x s(nsel) + inlist = torch.gather(nlist, 2, imap) + inlist = inlist.masked_fill(~(pick_mask.to(torch.bool)), -1) + # nloc x nsel[ii] + ret_nlist.append(torch.split(inlist, [ss, snsel - ss], dim=-1)[0]) + return torch.concat(ret_nlist, dim=-1) # build_neighbor_list = torch.vmap( @@ -232,6 +243,8 @@ def extend_coord_with_ghosts( atom type of shape [-1, nloc]. cell : torch.Tensor simulation cell tensor of shape [-1, 9]. + rcut : float + the cutoff radius Returns ------- diff --git a/source/tests/common/test_model_format_utils.py b/source/tests/common/test_model_format_utils.py index 0a457e0802..59423fb8da 100644 --- a/source/tests/common/test_model_format_utils.py +++ b/source/tests/common/test_model_format_utils.py @@ -333,7 +333,7 @@ def setUp(self): [ [1, 3, -1, -1, -1, 2, -1], [0, -1, -1, -1, -1, 2, -1], - [0, 1, -1, -1, -1, 0, -1], + [0, 1, -1, -1, -1, -1, -1], ], dtype=int, ).reshape([1, self.nloc, sum(self.sel)]) @@ -567,6 +567,104 @@ def test_self_consistency( np.testing.assert_allclose(ret0["energy_redu"], ret1["energy_redu"]) +class TestDPModelFormatNlist(unittest.TestCase): + def setUp(self): + # nloc == 3, nall == 4 + self.nloc = 3 + self.nall = 5 + self.nf, self.nt = 1, 2 + self.coord_ext = np.array( + [ + [0, 0, 0], + [0, 1, 0], + [0, 0, 1], + [0, -2, 0], + [2.3, 0, 0], + ], + dtype=np.float64, + ).reshape([1, self.nall * 3]) + # sel = [5, 2] + self.sel = [5, 2] + self.expected_nlist = np.array( + [ + [1, 3, -1, -1, -1, 2, -1], + [0, -1, -1, -1, -1, 2, -1], + [0, 1, -1, -1, -1, -1, -1], + ], + dtype=int, + ).reshape([1, self.nloc, sum(self.sel)]) + self.atype_ext = np.array([0, 0, 1, 0, 1], dtype=int).reshape([1, self.nall]) + self.rcut_smth = 0.4 + self.rcut = 2.1 + + nf, nloc, nnei = self.expected_nlist.shape + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ) + ft = InvarFitting( + "energy", + self.nt, + ds.get_dim_out(), + 1, + distinguish_types=ds.distinguish_types(), + ) + type_map = ["foo", "bar"] + self.md = DPModel(ds, ft, type_map=type_map) + + def test_nlist_eq(self): + # n_nnei == nnei + nlist = np.array( + [ + [1, 3, -1, -1, -1, 2, -1], + [0, -1, -1, -1, -1, 2, -1], + [0, 1, -1, -1, -1, -1, -1], + ], + dtype=np.int64, + ).reshape([1, self.nloc, -1]) + nlist1 = self.md.format_nlist( + self.coord_ext, + self.atype_ext, + nlist, + ) + np.testing.assert_allclose(self.expected_nlist, nlist1) + + def test_nlist_st(self): + # n_nnei < nnei + nlist = np.array( + [ + [1, 3, -1, 2], + [0, -1, -1, 2], + [0, 1, -1, -1], + ], + dtype=np.int64, + ).reshape([1, self.nloc, -1]) + nlist1 = self.md.format_nlist( + self.coord_ext, + self.atype_ext, + nlist, + ) + np.testing.assert_allclose(self.expected_nlist, nlist1) + + def test_nlist_lt(self): + # n_nnei > nnei + nlist = np.array( + [ + [1, 3, -1, -1, -1, 2, -1, -1, 4], + [0, -1, 4, -1, -1, 2, -1, 3, -1], + [0, 1, -1, -1, -1, 4, -1, -1, 3], + ], + dtype=np.int64, + ).reshape([1, self.nloc, -1]) + nlist1 = self.md.format_nlist( + self.coord_ext, + self.atype_ext, + nlist, + ) + np.testing.assert_allclose(self.expected_nlist, nlist1) + + class TestRegion(unittest.TestCase): def setUp(self): self.cell = np.array( diff --git a/source/tests/pt/test_dp_mode.py b/source/tests/pt/test_dp_mode.py index 08b4d6cf03..d81b6ea13b 100644 --- a/source/tests/pt/test_dp_mode.py +++ b/source/tests/pt/test_dp_mode.py @@ -256,3 +256,102 @@ def test_jit(self): # TODO: dirty hack to avoid data stat!!! md0 = DPModel(ds, ft, type_map=type_map, resuming=True).to(env.DEVICE) torch.jit.script(md0) + + +class TestDPModelFormatNlist(unittest.TestCase): + def setUp(self): + # nloc == 3, nall == 4 + self.nloc = 3 + self.nall = 5 + self.nf, self.nt = 1, 2 + self.coord_ext = np.array( + [ + [0, 0, 0], + [0, 1, 0], + [0, 0, 1], + [0, -2, 0], + [2.3, 0, 0], + ], + dtype=np.float64, + ).reshape([1, self.nall * 3]) + # sel = [5, 2] + self.sel = [5, 2] + self.expected_nlist = np.array( + [ + [1, 3, -1, -1, -1, 2, -1], + [0, -1, -1, -1, -1, 2, -1], + [0, 1, -1, -1, -1, -1, -1], + ], + dtype=int, + ).reshape([1, self.nloc, sum(self.sel)]) + self.atype_ext = np.array([0, 0, 1, 0, 1], dtype=int).reshape([1, self.nall]) + self.rcut_smth = 0.4 + self.rcut = 2.0 + + nf, nloc, nnei = self.expected_nlist.shape + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(env.DEVICE) + ft = InvarFitting( + "energy", + self.nt, + ds.get_dim_out(), + 1, + distinguish_types=ds.distinguish_types(), + ).to(env.DEVICE) + type_map = ["foo", "bar"] + # TODO: dirty hack to avoid data stat!!! + self.md = DPModel(ds, ft, type_map=type_map, resuming=True).to(env.DEVICE) + + def test_nlist_eq(self): + # n_nnei == nnei + nlist = np.array( + [ + [1, 3, -1, -1, -1, 2, -1], + [0, -1, -1, -1, -1, 2, -1], + [0, 1, -1, -1, -1, -1, -1], + ], + dtype=np.int64, + ).reshape([1, self.nloc, -1]) + nlist1 = self.md.format_nlist( + to_torch_tensor(self.coord_ext), + to_torch_tensor(self.atype_ext), + to_torch_tensor(nlist), + ) + np.testing.assert_allclose(self.expected_nlist, to_numpy_array(nlist1)) + + def test_nlist_st(self): + # n_nnei < nnei + nlist = np.array( + [ + [1, 3, -1, 2], + [0, -1, -1, 2], + [0, 1, -1, -1], + ], + dtype=np.int64, + ).reshape([1, self.nloc, -1]) + nlist1 = self.md.format_nlist( + to_torch_tensor(self.coord_ext), + to_torch_tensor(self.atype_ext), + to_torch_tensor(nlist), + ) + np.testing.assert_allclose(self.expected_nlist, to_numpy_array(nlist1)) + + def test_nlist_lt(self): + # n_nnei > nnei + nlist = np.array( + [ + [1, 3, -1, -1, -1, 2, -1, -1, 4], + [0, -1, 4, -1, -1, 2, -1, 3, -1], + [0, 1, -1, -1, -1, 4, -1, -1, 3], + ], + dtype=np.int64, + ).reshape([1, self.nloc, -1]) + nlist1 = self.md.format_nlist( + to_torch_tensor(self.coord_ext), + to_torch_tensor(self.atype_ext), + to_torch_tensor(nlist), + ) + np.testing.assert_allclose(self.expected_nlist, to_numpy_array(nlist1)) diff --git a/source/tests/pt/test_env_mat.py b/source/tests/pt/test_env_mat.py index 73707a3099..d9f8d59b1c 100644 --- a/source/tests/pt/test_env_mat.py +++ b/source/tests/pt/test_env_mat.py @@ -47,7 +47,7 @@ def setUp(self): [ [1, 3, -1, -1, -1, 2, -1], [0, -1, -1, -1, -1, 2, -1], - [0, 1, -1, -1, -1, 0, -1], + [0, 1, -1, -1, -1, -1, -1], ], dtype=int, ).reshape([1, self.nloc, sum(self.sel)])