From 82aaa0db8b2e484d2179112b509bb8bcadc6ab1f Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 29 Oct 2024 14:51:23 -0400 Subject: [PATCH] feat(jax): neighbor stat (#4258) ## Summary by CodeRabbit ## Release Notes - **New Features** - Introduced `NeighborStat` and `NeighborStatOP` classes for enhanced neighbor statistics computation. - Added `AutoBatchSize` class to manage automatic batch sizing in deep learning applications. - **Improvements** - Enhanced `JAXBackend` functionality with implemented properties for neighbor statistics and serialization. - Refactored neighbor counting logic for better clarity and modularity. - **Tests** - Updated unit tests for `neighbor_stat` to support multiple backends (TensorFlow, PyTorch, NumPy, JAX). - Removed outdated test files to streamline testing processes. --------- Signed-off-by: Jinzhe Zeng Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> --- deepmd/backend/jax.py | 10 +- deepmd/dpmodel/utils/neighbor_stat.py | 35 +++--- deepmd/jax/utils/auto_batch_size.py | 59 ++++++++++ deepmd/jax/utils/neighbor_stat.py | 104 ++++++++++++++++++ .../common/dpmodel/test_neighbor_stat.py | 69 ------------ .../{pt => consistent}/test_neighbor_stat.py | 24 +++- source/tests/tf/test_neighbor_stat.py | 68 ------------ 7 files changed, 210 insertions(+), 159 deletions(-) create mode 100644 deepmd/jax/utils/auto_batch_size.py create mode 100644 deepmd/jax/utils/neighbor_stat.py delete mode 100644 source/tests/common/dpmodel/test_neighbor_stat.py rename source/tests/{pt => consistent}/test_neighbor_stat.py (77%) delete mode 100644 source/tests/tf/test_neighbor_stat.py diff --git a/deepmd/backend/jax.py b/deepmd/backend/jax.py index bb2fba5a7c..7131f4d534 100644 --- a/deepmd/backend/jax.py +++ b/deepmd/backend/jax.py @@ -33,9 +33,9 @@ class JAXBackend(Backend): """The formal name of the backend.""" features: ClassVar[Backend.Feature] = ( Backend.Feature.IO - # Backend.Feature.ENTRY_POINT + | Backend.Feature.ENTRY_POINT # | Backend.Feature.DEEP_EVAL - # | Backend.Feature.NEIGHBOR_STAT + | Backend.Feature.NEIGHBOR_STAT ) """The features of the backend.""" suffixes: ClassVar[list[str]] = [".jax"] @@ -82,7 +82,11 @@ def neighbor_stat(self) -> type["NeighborStat"]: type[NeighborStat] The neighbor statistics of the backend. """ - raise NotImplementedError + from deepmd.jax.utils.neighbor_stat import ( + NeighborStat, + ) + + return NeighborStat @property def serialize_hook(self) -> Callable[[str], dict]: diff --git a/deepmd/dpmodel/utils/neighbor_stat.py b/deepmd/dpmodel/utils/neighbor_stat.py index 43ca2cadd1..3aea8ceeb9 100644 --- a/deepmd/dpmodel/utils/neighbor_stat.py +++ b/deepmd/dpmodel/utils/neighbor_stat.py @@ -6,6 +6,7 @@ Optional, ) +import array_api_compat import numpy as np from deepmd.dpmodel.common import ( @@ -68,42 +69,42 @@ def call( np.ndarray The maximal number of neighbors """ + xp = array_api_compat.array_namespace(coord, atype) nframes = coord.shape[0] - coord = coord.reshape(nframes, -1, 3) + coord = xp.reshape(coord, (nframes, -1, 3)) nloc = coord.shape[1] - coord = coord.reshape(nframes, nloc * 3) + coord = xp.reshape(coord, (nframes, nloc * 3)) extend_coord, extend_atype, _ = extend_coord_with_ghosts( coord, atype, cell, self.rcut ) - coord1 = extend_coord.reshape(nframes, -1) + coord1 = xp.reshape(extend_coord, (nframes, -1)) nall = coord1.shape[1] // 3 coord0 = coord1[:, : nloc * 3] diff = ( - coord1.reshape([nframes, -1, 3])[:, None, :, :] - - coord0.reshape([nframes, -1, 3])[:, :, None, :] + xp.reshape(coord1, [nframes, -1, 3])[:, None, :, :] + - xp.reshape(coord0, [nframes, -1, 3])[:, :, None, :] ) assert list(diff.shape) == [nframes, nloc, nall, 3] # remove the diagonal elements - mask = np.eye(nloc, nall, dtype=bool) - diff[:, mask] = np.inf - rr2 = np.sum(np.square(diff), axis=-1) - min_rr2 = np.min(rr2, axis=-1) + mask = xp.eye(nloc, nall, dtype=xp.bool) + mask = xp.tile(mask[None, :, :, None], (nframes, 1, 1, 3)) + diff = xp.where(mask, xp.full_like(diff, xp.inf), diff) + rr2 = xp.sum(xp.square(diff), axis=-1) + min_rr2 = xp.min(rr2, axis=-1) # count the number of neighbors if not self.mixed_types: mask = rr2 < self.rcut**2 - nnei = np.zeros((nframes, nloc, self.ntypes), dtype=int) + nneis = [] for ii in range(self.ntypes): - nnei[:, :, ii] = np.sum( - mask & (extend_atype == ii)[:, None, :], axis=-1 - ) + nneis.append(xp.sum(mask & (extend_atype == ii)[:, None, :], axis=-1)) + nnei = xp.stack(nneis, axis=-1) else: mask = rr2 < self.rcut**2 # virtual type (<0) are not counted - nnei = np.sum(mask & (extend_atype >= 0)[:, None, :], axis=-1).reshape( - nframes, nloc, 1 - ) - max_nnei = np.max(nnei, axis=1) + nnei = xp.sum(mask & (extend_atype >= 0)[:, None, :], axis=-1) + nnei = xp.reshape(nnei, (nframes, nloc, 1)) + max_nnei = xp.max(nnei, axis=1) return min_rr2, max_nnei diff --git a/deepmd/jax/utils/auto_batch_size.py b/deepmd/jax/utils/auto_batch_size.py new file mode 100644 index 0000000000..eec6766ae2 --- /dev/null +++ b/deepmd/jax/utils/auto_batch_size.py @@ -0,0 +1,59 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + +import jaxlib + +from deepmd.jax.env import ( + jax, +) +from deepmd.utils.batch_size import AutoBatchSize as AutoBatchSizeBase + + +class AutoBatchSize(AutoBatchSizeBase): + """Auto batch size. + + Parameters + ---------- + initial_batch_size : int, default: 1024 + initial batch size (number of total atoms) when DP_INFER_BATCH_SIZE + is not set + factor : float, default: 2. + increased factor + + """ + + def __init__( + self, + initial_batch_size: int = 1024, + factor: float = 2.0, + ): + super().__init__( + initial_batch_size=initial_batch_size, + factor=factor, + ) + + def is_gpu_available(self) -> bool: + """Check if GPU is available. + + Returns + ------- + bool + True if GPU is available + """ + return jax.devices()[0].platform == "gpu" + + def is_oom_error(self, e: Exception) -> bool: + """Check if the exception is an OOM error. + + Parameters + ---------- + e : Exception + Exception + """ + # several sources think CUSOLVER_STATUS_INTERNAL_ERROR is another out-of-memory error, + # such as https://github.com/JuliaGPU/CUDA.jl/issues/1924 + # (the meaningless error message should be considered as a bug in cusolver) + if isinstance(e, (jaxlib.xla_extension.XlaRuntimeError, ValueError)) and ( + "RESOURCE_EXHAUSTED:" in e.args[0] + ): + return True + return False diff --git a/deepmd/jax/utils/neighbor_stat.py b/deepmd/jax/utils/neighbor_stat.py new file mode 100644 index 0000000000..6d9bc872e8 --- /dev/null +++ b/deepmd/jax/utils/neighbor_stat.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from collections.abc import ( + Iterator, +) +from typing import ( + Optional, +) + +import numpy as np + +from deepmd.dpmodel.common import ( + to_numpy_array, +) +from deepmd.dpmodel.utils.neighbor_stat import ( + NeighborStatOP, +) +from deepmd.jax.common import ( + to_jax_array, +) +from deepmd.jax.utils.auto_batch_size import ( + AutoBatchSize, +) +from deepmd.utils.data_system import ( + DeepmdDataSystem, +) +from deepmd.utils.neighbor_stat import NeighborStat as BaseNeighborStat + + +class NeighborStat(BaseNeighborStat): + """Neighbor statistics using JAX. + + Parameters + ---------- + ntypes : int + The num of atom types + rcut : float + The cut-off radius + mixed_type : bool, optional, default=False + Treat all types as a single type. + """ + + def __init__( + self, + ntypes: int, + rcut: float, + mixed_type: bool = False, + ) -> None: + super().__init__(ntypes, rcut, mixed_type) + self.op = NeighborStatOP(ntypes, rcut, mixed_type) + self.auto_batch_size = AutoBatchSize() + + def iterator( + self, data: DeepmdDataSystem + ) -> Iterator[tuple[np.ndarray, float, str]]: + """Iterator method for producing neighbor statistics data. + + Yields + ------ + np.ndarray + The maximal number of neighbors + float + The squared minimal distance between two atoms + str + The directory of the data system + """ + for ii in range(len(data.system_dirs)): + for jj in data.data_systems[ii].dirs: + data_set = data.data_systems[ii] + data_set_data = data_set._load_set(jj) + minrr2, max_nnei = self.auto_batch_size.execute_all( + self._execute, + data_set_data["coord"].shape[0], + data_set.get_natoms(), + data_set_data["coord"], + data_set_data["type"], + data_set_data["box"] if data_set.pbc else None, + ) + yield np.max(max_nnei, axis=0), np.min(minrr2), jj + + def _execute( + self, + coord: np.ndarray, + atype: np.ndarray, + cell: Optional[np.ndarray], + ): + """Execute the operation. + + Parameters + ---------- + coord + The coordinates of atoms. + atype + The atom types. + cell + The cell. + """ + minrr2, max_nnei = self.op( + to_jax_array(coord), + to_jax_array(atype), + to_jax_array(cell), + ) + minrr2 = to_numpy_array(minrr2) + max_nnei = to_numpy_array(max_nnei) + return minrr2, max_nnei diff --git a/source/tests/common/dpmodel/test_neighbor_stat.py b/source/tests/common/dpmodel/test_neighbor_stat.py deleted file mode 100644 index 8dd700f608..0000000000 --- a/source/tests/common/dpmodel/test_neighbor_stat.py +++ /dev/null @@ -1,69 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -import shutil -import unittest - -import dpdata -import numpy as np - -from deepmd.entrypoints.neighbor_stat import ( - neighbor_stat, -) - -from ...seed import ( - GLOBAL_SEED, -) - - -def gen_sys(nframes): - rng = np.random.default_rng(GLOBAL_SEED) - natoms = 1000 - data = {} - X, Y, Z = np.mgrid[0:2:3j, 0:2:3j, 0:2:3j] - positions = np.vstack([X.ravel(), Y.ravel(), Z.ravel()]).T # + 0.1 - data["coords"] = np.repeat(positions[np.newaxis, :, :], nframes, axis=0) - data["forces"] = rng.random([nframes, natoms, 3]) - data["cells"] = np.array([3.0, 0.0, 0.0, 0.0, 3.0, 0.0, 0.0, 0.0, 3.0]).reshape( - 1, 3, 3 - ) - data["energies"] = rng.random([nframes, 1]) - data["atom_names"] = ["TYPE"] - data["atom_numbs"] = [27] - data["atom_types"] = np.repeat(0, 27) - return data - - -class TestNeighborStat(unittest.TestCase): - def setUp(self): - data0 = gen_sys(1) - sys0 = dpdata.LabeledSystem() - sys0.data = data0 - sys0.to_deepmd_npy("system_0", set_size=1) - - def tearDown(self): - shutil.rmtree("system_0") - - def test_neighbor_stat(self): - for rcut in (0.0, 1.0, 2.0, 4.0): - for mixed_type in (True, False): - with self.subTest(rcut=rcut, mixed_type=mixed_type): - rcut += 1e-3 # prevent numerical errors - min_nbor_dist, max_nbor_size = neighbor_stat( - system="system_0", - rcut=rcut, - type_map=["TYPE", "NO_THIS_TYPE"], - mixed_type=mixed_type, - backend="numpy", - ) - upper = np.ceil(rcut) + 1 - X, Y, Z = np.mgrid[-upper:upper, -upper:upper, -upper:upper] - positions = np.vstack([X.ravel(), Y.ravel(), Z.ravel()]).T - # distance to (0,0,0) - distance = np.linalg.norm(positions, axis=1) - expected_neighbors = np.count_nonzero( - np.logical_and(distance > 0, distance <= rcut) - ) - self.assertAlmostEqual(min_nbor_dist, 1.0, 6) - ret = [expected_neighbors] - if not mixed_type: - ret.append(0) - np.testing.assert_array_equal(max_nbor_size, ret) diff --git a/source/tests/pt/test_neighbor_stat.py b/source/tests/consistent/test_neighbor_stat.py similarity index 77% rename from source/tests/pt/test_neighbor_stat.py rename to source/tests/consistent/test_neighbor_stat.py index 08ba453d74..55181a6903 100644 --- a/source/tests/pt/test_neighbor_stat.py +++ b/source/tests/consistent/test_neighbor_stat.py @@ -12,6 +12,11 @@ from ..seed import ( GLOBAL_SEED, ) +from .common import ( + INSTALLED_JAX, + INSTALLED_PT, + INSTALLED_TF, +) def gen_sys(nframes): @@ -42,7 +47,7 @@ def setUp(self): def tearDown(self): shutil.rmtree("system_0") - def test_neighbor_stat(self): + def run_neighbor_stat(self, backend): for rcut in (0.0, 1.0, 2.0, 4.0): for mixed_type in (True, False): with self.subTest(rcut=rcut, mixed_type=mixed_type): @@ -52,7 +57,7 @@ def test_neighbor_stat(self): rcut=rcut, type_map=["TYPE", "NO_THIS_TYPE"], mixed_type=mixed_type, - backend="pytorch", + backend=backend, ) upper = np.ceil(rcut) + 1 X, Y, Z = np.mgrid[-upper:upper, -upper:upper, -upper:upper] @@ -67,3 +72,18 @@ def test_neighbor_stat(self): if not mixed_type: ret.append(0) np.testing.assert_array_equal(max_nbor_size, ret) + + @unittest.skipUnless(INSTALLED_TF, "tensorflow is not installed") + def test_neighbor_stat_tf(self): + self.run_neighbor_stat("tensorflow") + + @unittest.skipUnless(INSTALLED_PT, "pytorch is not installed") + def test_neighbor_stat_pt(self): + self.run_neighbor_stat("pytorch") + + def test_neighbor_stat_dp(self): + self.run_neighbor_stat("numpy") + + @unittest.skipUnless(INSTALLED_JAX, "jax is not installed") + def test_neighbor_stat_jax(self): + self.run_neighbor_stat("jax") diff --git a/source/tests/tf/test_neighbor_stat.py b/source/tests/tf/test_neighbor_stat.py deleted file mode 100644 index 22b7790958..0000000000 --- a/source/tests/tf/test_neighbor_stat.py +++ /dev/null @@ -1,68 +0,0 @@ -# SPDX-License-Identifier: LGPL-3.0-or-later -import shutil -import unittest - -import dpdata -import numpy as np - -from deepmd.tf.entrypoints.neighbor_stat import ( - neighbor_stat, -) - -from ..seed import ( - GLOBAL_SEED, -) - - -def gen_sys(nframes): - rng = np.random.default_rng(GLOBAL_SEED) - natoms = 1000 - data = {} - X, Y, Z = np.mgrid[0:2:3j, 0:2:3j, 0:2:3j] - positions = np.vstack([X.ravel(), Y.ravel(), Z.ravel()]).T # + 0.1 - data["coords"] = np.repeat(positions[np.newaxis, :, :], nframes, axis=0) - data["forces"] = rng.random([nframes, natoms, 3]) - data["cells"] = np.array([3.0, 0.0, 0.0, 0.0, 3.0, 0.0, 0.0, 0.0, 3.0]).reshape( - 1, 3, 3 - ) - data["energies"] = rng.random([nframes, 1]) - data["atom_names"] = ["TYPE"] - data["atom_numbs"] = [27] - data["atom_types"] = np.repeat(0, 27) - return data - - -class TestNeighborStat(unittest.TestCase): - def setUp(self): - data0 = gen_sys(1) - sys0 = dpdata.LabeledSystem() - sys0.data = data0 - sys0.to_deepmd_npy("system_0", set_size=1) - - def tearDown(self): - shutil.rmtree("system_0") - - def test_neighbor_stat(self): - for rcut in (0.0, 1.0, 2.0, 4.0): - for mixed_type in (True, False): - with self.subTest(rcut=rcut, mixed_type=mixed_type): - rcut += 1e-3 # prevent numerical errors - min_nbor_dist, max_nbor_size = neighbor_stat( - system="system_0", - rcut=rcut, - type_map=["TYPE", "NO_THIS_TYPE"], - mixed_type=mixed_type, - ) - upper = np.ceil(rcut) + 1 - X, Y, Z = np.mgrid[-upper:upper, -upper:upper, -upper:upper] - positions = np.vstack([X.ravel(), Y.ravel(), Z.ravel()]).T - # distance to (0,0,0) - distance = np.linalg.norm(positions, axis=1) - expected_neighbors = np.count_nonzero( - np.logical_and(distance > 0, distance <= rcut) - ) - self.assertAlmostEqual(min_nbor_dist, 1.0, 6) - ret = [expected_neighbors] - if not mixed_type: - ret.append(0) - np.testing.assert_array_equal(max_nbor_size, ret)