Skip to content

Commit

Permalink
feat(jax): neighbor stat (#4258)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

## Release Notes

- **New Features**
- Introduced `NeighborStat` and `NeighborStatOP` classes for enhanced
neighbor statistics computation.
- Added `AutoBatchSize` class to manage automatic batch sizing in deep
learning applications.

- **Improvements**
- Enhanced `JAXBackend` functionality with implemented properties for
neighbor statistics and serialization.
- Refactored neighbor counting logic for better clarity and modularity.

- **Tests**
- Updated unit tests for `neighbor_stat` to support multiple backends
(TensorFlow, PyTorch, NumPy, JAX).
  - Removed outdated test files to streamline testing processes.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
  • Loading branch information
njzjz and coderabbitai[bot] authored Oct 29, 2024
1 parent b647547 commit 82aaa0d
Show file tree
Hide file tree
Showing 7 changed files with 210 additions and 159 deletions.
10 changes: 7 additions & 3 deletions deepmd/backend/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ class JAXBackend(Backend):
"""The formal name of the backend."""
features: ClassVar[Backend.Feature] = (
Backend.Feature.IO
# Backend.Feature.ENTRY_POINT
| Backend.Feature.ENTRY_POINT
# | Backend.Feature.DEEP_EVAL
# | Backend.Feature.NEIGHBOR_STAT
| Backend.Feature.NEIGHBOR_STAT
)
"""The features of the backend."""
suffixes: ClassVar[list[str]] = [".jax"]
Expand Down Expand Up @@ -82,7 +82,11 @@ def neighbor_stat(self) -> type["NeighborStat"]:
type[NeighborStat]
The neighbor statistics of the backend.
"""
raise NotImplementedError
from deepmd.jax.utils.neighbor_stat import (
NeighborStat,
)

return NeighborStat

@property
def serialize_hook(self) -> Callable[[str], dict]:
Expand Down
35 changes: 18 additions & 17 deletions deepmd/dpmodel/utils/neighbor_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Optional,
)

import array_api_compat
import numpy as np

from deepmd.dpmodel.common import (
Expand Down Expand Up @@ -68,42 +69,42 @@ def call(
np.ndarray
The maximal number of neighbors
"""
xp = array_api_compat.array_namespace(coord, atype)
nframes = coord.shape[0]
coord = coord.reshape(nframes, -1, 3)
coord = xp.reshape(coord, (nframes, -1, 3))
nloc = coord.shape[1]
coord = coord.reshape(nframes, nloc * 3)
coord = xp.reshape(coord, (nframes, nloc * 3))
extend_coord, extend_atype, _ = extend_coord_with_ghosts(
coord, atype, cell, self.rcut
)

coord1 = extend_coord.reshape(nframes, -1)
coord1 = xp.reshape(extend_coord, (nframes, -1))
nall = coord1.shape[1] // 3
coord0 = coord1[:, : nloc * 3]
diff = (
coord1.reshape([nframes, -1, 3])[:, None, :, :]
- coord0.reshape([nframes, -1, 3])[:, :, None, :]
xp.reshape(coord1, [nframes, -1, 3])[:, None, :, :]
- xp.reshape(coord0, [nframes, -1, 3])[:, :, None, :]
)
assert list(diff.shape) == [nframes, nloc, nall, 3]
# remove the diagonal elements
mask = np.eye(nloc, nall, dtype=bool)
diff[:, mask] = np.inf
rr2 = np.sum(np.square(diff), axis=-1)
min_rr2 = np.min(rr2, axis=-1)
mask = xp.eye(nloc, nall, dtype=xp.bool)
mask = xp.tile(mask[None, :, :, None], (nframes, 1, 1, 3))
diff = xp.where(mask, xp.full_like(diff, xp.inf), diff)
rr2 = xp.sum(xp.square(diff), axis=-1)
min_rr2 = xp.min(rr2, axis=-1)
# count the number of neighbors
if not self.mixed_types:
mask = rr2 < self.rcut**2
nnei = np.zeros((nframes, nloc, self.ntypes), dtype=int)
nneis = []
for ii in range(self.ntypes):
nnei[:, :, ii] = np.sum(
mask & (extend_atype == ii)[:, None, :], axis=-1
)
nneis.append(xp.sum(mask & (extend_atype == ii)[:, None, :], axis=-1))
nnei = xp.stack(nneis, axis=-1)
else:
mask = rr2 < self.rcut**2
# virtual type (<0) are not counted
nnei = np.sum(mask & (extend_atype >= 0)[:, None, :], axis=-1).reshape(
nframes, nloc, 1
)
max_nnei = np.max(nnei, axis=1)
nnei = xp.sum(mask & (extend_atype >= 0)[:, None, :], axis=-1)
nnei = xp.reshape(nnei, (nframes, nloc, 1))
max_nnei = xp.max(nnei, axis=1)
return min_rr2, max_nnei


Expand Down
59 changes: 59 additions & 0 deletions deepmd/jax/utils/auto_batch_size.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# SPDX-License-Identifier: LGPL-3.0-or-later

import jaxlib

from deepmd.jax.env import (
jax,
)
from deepmd.utils.batch_size import AutoBatchSize as AutoBatchSizeBase


class AutoBatchSize(AutoBatchSizeBase):
"""Auto batch size.
Parameters
----------
initial_batch_size : int, default: 1024
initial batch size (number of total atoms) when DP_INFER_BATCH_SIZE
is not set
factor : float, default: 2.
increased factor
"""

def __init__(
self,
initial_batch_size: int = 1024,
factor: float = 2.0,
):
super().__init__(
initial_batch_size=initial_batch_size,
factor=factor,
)

def is_gpu_available(self) -> bool:
"""Check if GPU is available.
Returns
-------
bool
True if GPU is available
"""
return jax.devices()[0].platform == "gpu"

def is_oom_error(self, e: Exception) -> bool:
"""Check if the exception is an OOM error.
Parameters
----------
e : Exception
Exception
"""
# several sources think CUSOLVER_STATUS_INTERNAL_ERROR is another out-of-memory error,
# such as https://github.com/JuliaGPU/CUDA.jl/issues/1924
# (the meaningless error message should be considered as a bug in cusolver)
if isinstance(e, (jaxlib.xla_extension.XlaRuntimeError, ValueError)) and (
"RESOURCE_EXHAUSTED:" in e.args[0]
):
return True
return False
104 changes: 104 additions & 0 deletions deepmd/jax/utils/neighbor_stat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from collections.abc import (
Iterator,
)
from typing import (
Optional,
)

import numpy as np

from deepmd.dpmodel.common import (
to_numpy_array,
)
from deepmd.dpmodel.utils.neighbor_stat import (
NeighborStatOP,
)
from deepmd.jax.common import (
to_jax_array,
)
from deepmd.jax.utils.auto_batch_size import (
AutoBatchSize,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
from deepmd.utils.neighbor_stat import NeighborStat as BaseNeighborStat


class NeighborStat(BaseNeighborStat):
"""Neighbor statistics using JAX.
Parameters
----------
ntypes : int
The num of atom types
rcut : float
The cut-off radius
mixed_type : bool, optional, default=False
Treat all types as a single type.
"""

def __init__(
self,
ntypes: int,
rcut: float,
mixed_type: bool = False,
) -> None:
super().__init__(ntypes, rcut, mixed_type)
self.op = NeighborStatOP(ntypes, rcut, mixed_type)
self.auto_batch_size = AutoBatchSize()

def iterator(
self, data: DeepmdDataSystem
) -> Iterator[tuple[np.ndarray, float, str]]:
"""Iterator method for producing neighbor statistics data.
Yields
------
np.ndarray
The maximal number of neighbors
float
The squared minimal distance between two atoms
str
The directory of the data system
"""
for ii in range(len(data.system_dirs)):
for jj in data.data_systems[ii].dirs:
data_set = data.data_systems[ii]
data_set_data = data_set._load_set(jj)
minrr2, max_nnei = self.auto_batch_size.execute_all(
self._execute,
data_set_data["coord"].shape[0],
data_set.get_natoms(),
data_set_data["coord"],
data_set_data["type"],
data_set_data["box"] if data_set.pbc else None,
)
yield np.max(max_nnei, axis=0), np.min(minrr2), jj

def _execute(
self,
coord: np.ndarray,
atype: np.ndarray,
cell: Optional[np.ndarray],
):
"""Execute the operation.
Parameters
----------
coord
The coordinates of atoms.
atype
The atom types.
cell
The cell.
"""
minrr2, max_nnei = self.op(
to_jax_array(coord),
to_jax_array(atype),
to_jax_array(cell),
)
minrr2 = to_numpy_array(minrr2)
max_nnei = to_numpy_array(max_nnei)
return minrr2, max_nnei
69 changes: 0 additions & 69 deletions source/tests/common/dpmodel/test_neighbor_stat.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
from ..seed import (
GLOBAL_SEED,
)
from .common import (
INSTALLED_JAX,
INSTALLED_PT,
INSTALLED_TF,
)


def gen_sys(nframes):
Expand Down Expand Up @@ -42,7 +47,7 @@ def setUp(self):
def tearDown(self):
shutil.rmtree("system_0")

def test_neighbor_stat(self):
def run_neighbor_stat(self, backend):
for rcut in (0.0, 1.0, 2.0, 4.0):
for mixed_type in (True, False):
with self.subTest(rcut=rcut, mixed_type=mixed_type):
Expand All @@ -52,7 +57,7 @@ def test_neighbor_stat(self):
rcut=rcut,
type_map=["TYPE", "NO_THIS_TYPE"],
mixed_type=mixed_type,
backend="pytorch",
backend=backend,
)
upper = np.ceil(rcut) + 1
X, Y, Z = np.mgrid[-upper:upper, -upper:upper, -upper:upper]
Expand All @@ -67,3 +72,18 @@ def test_neighbor_stat(self):
if not mixed_type:
ret.append(0)
np.testing.assert_array_equal(max_nbor_size, ret)

@unittest.skipUnless(INSTALLED_TF, "tensorflow is not installed")
def test_neighbor_stat_tf(self):
self.run_neighbor_stat("tensorflow")

@unittest.skipUnless(INSTALLED_PT, "pytorch is not installed")
def test_neighbor_stat_pt(self):
self.run_neighbor_stat("pytorch")

def test_neighbor_stat_dp(self):
self.run_neighbor_stat("numpy")

@unittest.skipUnless(INSTALLED_JAX, "jax is not installed")
def test_neighbor_stat_jax(self):
self.run_neighbor_stat("jax")
Loading

0 comments on commit 82aaa0d

Please sign in to comment.