-
Notifications
You must be signed in to change notification settings - Fork 520
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add neighbor stat support with NumPy and PyTorch implementation (#3271)
I also tested `examples/water`,`examples/nopbc`, and the ANI-1x dataset (compared to the screenshot in #1624) to confirm consistent results. Besides, as the OP supports multiple frames, the PT implementation only takes 9 s on ANI-1x, which is much faster than the TF implementation, which took over 10 min as shown in #1624. ![image](https://github.com/deepmodeling/deepmd-kit/assets/9496702/c3cc1950-33a7-435c-90f4-c18d196b202d) Signed-off-by: Jinzhe Zeng <[email protected]>
- Loading branch information
Showing
10 changed files
with
698 additions
and
131 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
from typing import ( | ||
Iterator, | ||
Optional, | ||
Tuple, | ||
) | ||
|
||
import numpy as np | ||
|
||
from deepmd.dpmodel.common import ( | ||
NativeOP, | ||
) | ||
from deepmd.dpmodel.utils.nlist import ( | ||
extend_coord_with_ghosts, | ||
) | ||
from deepmd.utils.data_system import ( | ||
DeepmdDataSystem, | ||
) | ||
from deepmd.utils.neighbor_stat import NeighborStat as BaseNeighborStat | ||
|
||
|
||
class NeighborStatOP(NativeOP): | ||
"""Class for getting neighbor statics data information. | ||
Parameters | ||
---------- | ||
ntypes | ||
The num of atom types | ||
rcut | ||
The cut-off radius | ||
distinguish_types : bool, optional | ||
If False, treat all types as a single type. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
ntypes: int, | ||
rcut: float, | ||
distinguish_types: bool, | ||
) -> None: | ||
self.rcut = rcut | ||
self.ntypes = ntypes | ||
self.distinguish_types = distinguish_types | ||
|
||
def call( | ||
self, | ||
coord: np.ndarray, | ||
atype: np.ndarray, | ||
cell: Optional[np.ndarray], | ||
) -> Tuple[float, np.ndarray]: | ||
"""Calculate the neareest neighbor distance between atoms, maximum nbor size of | ||
atoms and the output data range of the environment matrix. | ||
Parameters | ||
---------- | ||
coord | ||
The coordinates of atoms. | ||
atype | ||
The atom types. | ||
cell | ||
The cell. | ||
Returns | ||
------- | ||
float | ||
The minimal squared distance between two atoms | ||
np.ndarray | ||
The maximal number of neighbors | ||
""" | ||
nframes = coord.shape[0] | ||
coord = coord.reshape(nframes, -1, 3) | ||
nloc = coord.shape[1] | ||
coord = coord.reshape(nframes, nloc * 3) | ||
extend_coord, extend_atype, _ = extend_coord_with_ghosts( | ||
coord, atype, cell, self.rcut | ||
) | ||
|
||
coord1 = extend_coord.reshape(nframes, -1) | ||
nall = coord1.shape[1] // 3 | ||
coord0 = coord1[:, : nloc * 3] | ||
diff = ( | ||
coord1.reshape([nframes, -1, 3])[:, None, :, :] | ||
- coord0.reshape([nframes, -1, 3])[:, :, None, :] | ||
) | ||
assert list(diff.shape) == [nframes, nloc, nall, 3] | ||
# remove the diagonal elements | ||
mask = np.eye(nloc, nall, dtype=bool) | ||
diff[:, mask] = np.inf | ||
rr2 = np.sum(np.square(diff), axis=-1) | ||
min_rr2 = np.min(rr2, axis=-1) | ||
# count the number of neighbors | ||
if self.distinguish_types: | ||
mask = rr2 < self.rcut**2 | ||
nnei = np.zeros((nframes, nloc, self.ntypes), dtype=int) | ||
for ii in range(self.ntypes): | ||
nnei[:, :, ii] = np.sum( | ||
mask & (extend_atype == ii)[:, None, :], axis=-1 | ||
) | ||
else: | ||
mask = rr2 < self.rcut**2 | ||
# virtual type (<0) are not counted | ||
nnei = np.sum(mask & (extend_atype >= 0)[:, None, :], axis=-1).reshape( | ||
nframes, nloc, 1 | ||
) | ||
max_nnei = np.max(nnei, axis=1) | ||
return min_rr2, max_nnei | ||
|
||
|
||
class NeighborStat(BaseNeighborStat): | ||
"""Neighbor statistics using pure NumPy. | ||
Parameters | ||
---------- | ||
ntypes : int | ||
The num of atom types | ||
rcut : float | ||
The cut-off radius | ||
one_type : bool, optional, default=False | ||
Treat all types as a single type. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
ntypes: int, | ||
rcut: float, | ||
one_type: bool = False, | ||
) -> None: | ||
super().__init__(ntypes, rcut, one_type) | ||
self.op = NeighborStatOP(ntypes, rcut, not one_type) | ||
|
||
def iterator( | ||
self, data: DeepmdDataSystem | ||
) -> Iterator[Tuple[np.ndarray, float, str]]: | ||
"""Abstract method for producing data. | ||
Yields | ||
------ | ||
np.ndarray | ||
The maximal number of neighbors | ||
float | ||
The squared minimal distance between two atoms | ||
str | ||
The directory of the data system | ||
""" | ||
for ii in range(len(data.system_dirs)): | ||
for jj in data.data_systems[ii].dirs: | ||
data_set = data.data_systems[ii] | ||
data_set_data = data_set._load_set(jj) | ||
minrr2, max_nnei = self.op( | ||
data_set_data["coord"], | ||
data_set_data["type"], | ||
data_set_data["box"] if data_set.pbc else None, | ||
) | ||
yield np.max(max_nnei, axis=0), np.min(minrr2), jj |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
import logging | ||
from typing import ( | ||
List, | ||
) | ||
|
||
from deepmd.common import ( | ||
expand_sys_str, | ||
) | ||
from deepmd.utils.data_system import ( | ||
DeepmdDataSystem, | ||
) | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
def neighbor_stat( | ||
*, | ||
system: str, | ||
rcut: float, | ||
type_map: List[str], | ||
one_type: bool = False, | ||
backend: str = "tensorflow", | ||
**kwargs, | ||
): | ||
"""Calculate neighbor statistics. | ||
Parameters | ||
---------- | ||
system : str | ||
system to stat | ||
rcut : float | ||
cutoff radius | ||
type_map : list[str] | ||
type map | ||
one_type : bool, optional, default=False | ||
treat all types as a single type | ||
backend : str, optional, default="tensorflow" | ||
backend to use | ||
**kwargs | ||
additional arguments | ||
Examples | ||
-------- | ||
>>> neighbor_stat( | ||
... system=".", | ||
... rcut=6.0, | ||
... type_map=[ | ||
... "C", | ||
... "H", | ||
... "O", | ||
... "N", | ||
... "P", | ||
... "S", | ||
... "Mg", | ||
... "Na", | ||
... "HW", | ||
... "OW", | ||
... "mNa", | ||
... "mCl", | ||
... "mC", | ||
... "mH", | ||
... "mMg", | ||
... "mN", | ||
... "mO", | ||
... "mP", | ||
... ], | ||
... ) | ||
min_nbor_dist: 0.6599510670195264 | ||
max_nbor_size: [23, 26, 19, 16, 2, 2, 1, 1, 72, 37, 5, 0, 31, 29, 1, 21, 20, 5] | ||
""" | ||
if backend == "tensorflow": | ||
from deepmd.tf.utils.neighbor_stat import ( | ||
NeighborStat, | ||
) | ||
elif backend == "pytorch": | ||
from deepmd.pt.utils.neighbor_stat import ( | ||
NeighborStat, | ||
) | ||
elif backend == "numpy": | ||
from deepmd.dpmodel.utils.neighbor_stat import ( | ||
NeighborStat, | ||
) | ||
else: | ||
raise ValueError(f"Invalid backend {backend}") | ||
all_sys = expand_sys_str(system) | ||
if not len(all_sys): | ||
raise RuntimeError("Did not find valid system") | ||
data = DeepmdDataSystem( | ||
systems=all_sys, | ||
batch_size=1, | ||
test_size=1, | ||
rcut=rcut, | ||
type_map=type_map, | ||
) | ||
data.get_batch() | ||
nei = NeighborStat(data.get_ntypes(), rcut, one_type=one_type) | ||
min_nbor_dist, max_nbor_size = nei.get_stat(data) | ||
log.info("min_nbor_dist: %f" % min_nbor_dist) | ||
log.info("max_nbor_size: %s" % str(max_nbor_size)) | ||
return min_nbor_dist, max_nbor_size |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.