Skip to content

Commit

Permalink
feat: directional nlist (deepmodeling#4052)
Browse files Browse the repository at this point in the history
this pr implement the directional nlist, with which the central and
neighboring atoms are treated as different atoms.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit


- **New Features**
- Introduced a new function for building a directional neighbor list,
enhancing detailed neighbor analysis.
- **Bug Fixes**
- Improved handling of neighbor lists, ensuring proper management of
virtual atoms and accurate sizing.
- **Tests**
- Added new tests for the directional neighbor list functionality and
modified existing tests to cover additional scenarios.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Han Wang <[email protected]>
Co-authored-by: Han Wang <[email protected]>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Co-authored-by: Duo <[email protected]>
  • Loading branch information
4 people authored Aug 14, 2024
1 parent f1e5dbc commit 05323f3
Show file tree
Hide file tree
Showing 2 changed files with 192 additions and 6 deletions.
128 changes: 123 additions & 5 deletions deepmd/pt/utils/nlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ def build_neighbor_list(
).view(batch_size, nall * 3)
if isinstance(sel, int):
sel = [sel]
nsel = sum(sel)
# nloc x 3
coord0 = coord1[:, : nloc * 3]
# nloc x nall x 3
Expand All @@ -126,8 +125,26 @@ def build_neighbor_list(
# nloc x (nall-1)
rr = rr[:, :, 1:]
nlist = nlist[:, :, 1:]

return _trim_mask_distinguish_nlist(
is_vir, atype, rr, nlist, rcut, sel, distinguish_types
)


def _trim_mask_distinguish_nlist(
is_vir_cntl: torch.Tensor,
atype_neig: torch.Tensor,
rr: torch.Tensor,
nlist: torch.Tensor,
rcut: float,
sel: List[int],
distinguish_types: bool,
) -> torch.Tensor:
"""Trim the size of nlist, mask if any central atom is virtual, distinguish types if necessary."""
nsel = sum(sel)
# nloc x nsel
nnei = rr.shape[2]
batch_size, nloc, nnei = rr.shape
assert batch_size == is_vir_cntl.shape[0]
if nsel <= nnei:
rr = rr[:, :, :nsel]
nlist = nlist[:, :, :nsel]
Expand All @@ -147,15 +164,116 @@ def build_neighbor_list(
)
assert list(nlist.shape) == [batch_size, nloc, nsel]
nlist = torch.where(
torch.logical_or((rr > rcut), is_vir[:, :nloc, None]), -1, nlist
torch.logical_or((rr > rcut), is_vir_cntl[:, :nloc, None]), -1, nlist
)

if distinguish_types:
return nlist_distinguish_types(nlist, atype, sel)
return nlist_distinguish_types(nlist, atype_neig, sel)
else:
return nlist


def build_directional_neighbor_list(
coord_cntl: torch.Tensor,
atype_cntl: torch.Tensor,
coord_neig: torch.Tensor,
atype_neig: torch.Tensor,
rcut: float,
sel: Union[int, List[int]],
distinguish_types: bool = True,
) -> torch.Tensor:
"""Build directional neighbor list.
With each central atom, all the neighbor atoms in the cut-off radius will
be recorded in the neighbor list. The maximum neighbors is nsel. If the real
number of neighbors is larger than nsel, the neighbors will be sorted with the
distance and the first nsel neighbors are kept.
Important: the central and neighboring atoms are assume to be different atoms.
Parameters
----------
coord_central : torch.Tensor
coordinates of central atoms. assumed to be local atoms.
shape [batch_size, nloc_central x 3]
atype_central : torch.Tensor
atomic types of central atoms. shape [batch_size, nloc_central]
if type < 0 the atom is treated as virtual atoms.
coord_neighbor : torch.Tensor
extended coordinates of neighbors atoms. shape [batch_size, nall_neighbor x 3]
atype_central : torch.Tensor
extended atomic types of neighbors atoms. shape [batch_size, nall_neighbor]
if type < 0 the atom is treated as virtual atoms.
rcut : float
cut-off radius
sel : int or List[int]
maximal number of neighbors (of each type).
if distinguish_types==True, nsel should be list and
the length of nsel should be equal to number of
types.
distinguish_types : bool
distinguish different types.
Returns
-------
neighbor_list : torch.Tensor
Neighbor list of shape [batch_size, nloc_central, nsel], the neighbors
are stored in an ascending order. If the number of neighbors is less than nsel,
the positions are masked with -1. The neighbor list of an atom looks like
|------ nsel ------|
xx xx xx xx -1 -1 -1
if distinguish_types==True and we have two types
|---- nsel[0] -----| |---- nsel[1] -----|
xx xx xx xx -1 -1 -1 xx xx xx -1 -1 -1 -1
For virtual atoms all neighboring positions are filled with -1.
"""
batch_size = coord_cntl.shape[0]
coord_cntl = coord_cntl.view(batch_size, -1)
nloc_cntl = coord_cntl.shape[1] // 3
coord_neig = coord_neig.view(batch_size, -1)
nall_neig = coord_neig.shape[1] // 3
# fill virtual atoms with large coords so they are not neighbors of any
# real atom.
if coord_neig.numel() > 0:
xmax = torch.max(coord_cntl) + 2.0 * rcut
else:
xmax = (
torch.zeros(1, dtype=coord_neig.dtype, device=coord_neig.device)
+ 2.0 * rcut
)
# nf x nloc
is_vir_cntl = atype_cntl < 0
# nf x nall
is_vir_neig = atype_neig < 0
# nf x nloc x 3
coord_cntl = coord_cntl.view(batch_size, nloc_cntl, 3)
# nf x nall x 3
coord_neig = torch.where(
is_vir_neig[:, :, None], xmax, coord_neig.view(batch_size, nall_neig, 3)
).view(batch_size, nall_neig, 3)
# nsel
if isinstance(sel, int):
sel = [sel]
# nloc x nall x 3
diff = coord_neig[:, None, :, :] - coord_cntl[:, :, None, :]
assert list(diff.shape) == [batch_size, nloc_cntl, nall_neig, 3]
# nloc x nall
rr = torch.linalg.norm(diff, dim=-1)
rr, nlist = torch.sort(rr, dim=-1)

# We assume that the central and neighbor atoms are diffferent,
# thus we do not need to exclude self-neighbors.
# # if central atom has two zero distances, sorting sometimes can not exclude itself
# rr -= torch.eye(nloc_cntl, nall_neig, dtype=rr.dtype, device=rr.device).unsqueeze(0)
# rr, nlist = torch.sort(rr, dim=-1)
# # nloc x (nall-1)
# rr = rr[:, :, 1:]
# nlist = nlist[:, :, 1:]

return _trim_mask_distinguish_nlist(
is_vir_cntl, atype_neig, rr, nlist, rcut, sel, distinguish_types
)


def nlist_distinguish_types(
nlist: torch.Tensor,
atype: torch.Tensor,
Expand Down
70 changes: 69 additions & 1 deletion source/tests/pt/model/test_nlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
env,
)
from deepmd.pt.utils.nlist import (
build_directional_neighbor_list,
build_multiple_neighbor_list,
build_neighbor_list,
extend_coord_with_ghosts,
Expand Down Expand Up @@ -62,6 +63,7 @@ def test_build_notype(self):
ecoord, eatype, mapping = extend_coord_with_ghosts(
self.coord, self.atype, self.cell, self.rcut
)
# test normal sel
nlist = build_neighbor_list(
ecoord,
eatype,
Expand All @@ -70,14 +72,29 @@ def test_build_notype(self):
sum(self.nsel),
distinguish_types=False,
)
torch.testing.assert_close(nlist[0], nlist[1])
nlist_mask = nlist[0] == -1
nlist_loc = mapping[0][nlist[0]]
nlist_loc[nlist_mask] = -1
torch.testing.assert_close(
torch.sort(nlist_loc, dim=-1)[0],
torch.sort(self.ref_nlist, dim=-1)[0],
)
# test a very large sel
nlist = build_neighbor_list(
ecoord,
eatype,
self.nloc,
self.rcut,
sum(self.nsel) + 300, # +300, real nnei==224
distinguish_types=False,
)
nlist_mask = nlist[0] == -1
nlist_loc = mapping[0][nlist[0]]
nlist_loc[nlist_mask] = -1
torch.testing.assert_close(
torch.sort(nlist_loc, descending=True, dim=-1)[0][:, : sum(self.nsel)],
torch.sort(self.ref_nlist, descending=True, dim=-1)[0],
)

def test_build_type(self):
ecoord, eatype, mapping = extend_coord_with_ghosts(
Expand Down Expand Up @@ -218,3 +235,54 @@ def test_extend_coord(self):
rtol=self.prec,
atol=self.prec,
)

def test_build_directional_nlist(self):
"""Directional nlist is tested against the standard nlist implementation."""
ecoord, eatype, mapping = extend_coord_with_ghosts(
self.coord, self.atype, self.cell, self.rcut
)
for distinguish_types, mysel in zip([True, False], [sum(self.nsel), 300]):
# full neighbor list
nlist_full = build_neighbor_list(
ecoord,
eatype,
self.nloc,
self.rcut,
sum(self.nsel),
distinguish_types=distinguish_types,
)
# central as part of the system
nlist = build_directional_neighbor_list(
ecoord[:, 3:6],
eatype[:, 1:2],
torch.concat(
[
ecoord[:, 0:3],
torch.zeros(
[self.nf, 3], dtype=dtype, device=env.DEVICE
), # placeholder
ecoord[:, 6:],
],
dim=1,
),
torch.concat(
[
eatype[:, 0:1],
-1
* torch.ones(
[self.nf, 1], dtype=int, device=env.DEVICE
), # placeholder
eatype[:, 2:],
],
dim=1,
),
self.rcut,
mysel,
distinguish_types=distinguish_types,
)
torch.testing.assert_close(nlist[0], nlist[1])
torch.testing.assert_close(nlist[0], nlist[2])
torch.testing.assert_close(
torch.sort(nlist[0], descending=True, dim=-1)[0][:, : sum(self.nsel)],
torch.sort(nlist_full[0][1:2], descending=True, dim=-1)[0],
)

0 comments on commit 05323f3

Please sign in to comment.