-
Notifications
You must be signed in to change notification settings - Fork 520
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes - **New Features** - Introduced `NeighborStat` and `NeighborStatOP` classes for enhanced neighbor statistics computation. - Added `AutoBatchSize` class to manage automatic batch sizing in deep learning applications. - **Improvements** - Enhanced `JAXBackend` functionality with implemented properties for neighbor statistics and serialization. - Refactored neighbor counting logic for better clarity and modularity. - **Tests** - Updated unit tests for `neighbor_stat` to support multiple backends (TensorFlow, PyTorch, NumPy, JAX). - Removed outdated test files to streamline testing processes. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <[email protected]> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
- Loading branch information
1 parent
b647547
commit 82aaa0d
Showing
7 changed files
with
210 additions
and
159 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
|
||
import jaxlib | ||
|
||
from deepmd.jax.env import ( | ||
jax, | ||
) | ||
from deepmd.utils.batch_size import AutoBatchSize as AutoBatchSizeBase | ||
|
||
|
||
class AutoBatchSize(AutoBatchSizeBase): | ||
"""Auto batch size. | ||
Parameters | ||
---------- | ||
initial_batch_size : int, default: 1024 | ||
initial batch size (number of total atoms) when DP_INFER_BATCH_SIZE | ||
is not set | ||
factor : float, default: 2. | ||
increased factor | ||
""" | ||
|
||
def __init__( | ||
self, | ||
initial_batch_size: int = 1024, | ||
factor: float = 2.0, | ||
): | ||
super().__init__( | ||
initial_batch_size=initial_batch_size, | ||
factor=factor, | ||
) | ||
|
||
def is_gpu_available(self) -> bool: | ||
"""Check if GPU is available. | ||
Returns | ||
------- | ||
bool | ||
True if GPU is available | ||
""" | ||
return jax.devices()[0].platform == "gpu" | ||
|
||
def is_oom_error(self, e: Exception) -> bool: | ||
"""Check if the exception is an OOM error. | ||
Parameters | ||
---------- | ||
e : Exception | ||
Exception | ||
""" | ||
# several sources think CUSOLVER_STATUS_INTERNAL_ERROR is another out-of-memory error, | ||
# such as https://github.com/JuliaGPU/CUDA.jl/issues/1924 | ||
# (the meaningless error message should be considered as a bug in cusolver) | ||
if isinstance(e, (jaxlib.xla_extension.XlaRuntimeError, ValueError)) and ( | ||
"RESOURCE_EXHAUSTED:" in e.args[0] | ||
): | ||
return True | ||
return False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
from collections.abc import ( | ||
Iterator, | ||
) | ||
from typing import ( | ||
Optional, | ||
) | ||
|
||
import numpy as np | ||
|
||
from deepmd.dpmodel.common import ( | ||
to_numpy_array, | ||
) | ||
from deepmd.dpmodel.utils.neighbor_stat import ( | ||
NeighborStatOP, | ||
) | ||
from deepmd.jax.common import ( | ||
to_jax_array, | ||
) | ||
from deepmd.jax.utils.auto_batch_size import ( | ||
AutoBatchSize, | ||
) | ||
from deepmd.utils.data_system import ( | ||
DeepmdDataSystem, | ||
) | ||
from deepmd.utils.neighbor_stat import NeighborStat as BaseNeighborStat | ||
|
||
|
||
class NeighborStat(BaseNeighborStat): | ||
"""Neighbor statistics using JAX. | ||
Parameters | ||
---------- | ||
ntypes : int | ||
The num of atom types | ||
rcut : float | ||
The cut-off radius | ||
mixed_type : bool, optional, default=False | ||
Treat all types as a single type. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
ntypes: int, | ||
rcut: float, | ||
mixed_type: bool = False, | ||
) -> None: | ||
super().__init__(ntypes, rcut, mixed_type) | ||
self.op = NeighborStatOP(ntypes, rcut, mixed_type) | ||
self.auto_batch_size = AutoBatchSize() | ||
|
||
def iterator( | ||
self, data: DeepmdDataSystem | ||
) -> Iterator[tuple[np.ndarray, float, str]]: | ||
"""Iterator method for producing neighbor statistics data. | ||
Yields | ||
------ | ||
np.ndarray | ||
The maximal number of neighbors | ||
float | ||
The squared minimal distance between two atoms | ||
str | ||
The directory of the data system | ||
""" | ||
for ii in range(len(data.system_dirs)): | ||
for jj in data.data_systems[ii].dirs: | ||
data_set = data.data_systems[ii] | ||
data_set_data = data_set._load_set(jj) | ||
minrr2, max_nnei = self.auto_batch_size.execute_all( | ||
self._execute, | ||
data_set_data["coord"].shape[0], | ||
data_set.get_natoms(), | ||
data_set_data["coord"], | ||
data_set_data["type"], | ||
data_set_data["box"] if data_set.pbc else None, | ||
) | ||
yield np.max(max_nnei, axis=0), np.min(minrr2), jj | ||
|
||
def _execute( | ||
self, | ||
coord: np.ndarray, | ||
atype: np.ndarray, | ||
cell: Optional[np.ndarray], | ||
): | ||
"""Execute the operation. | ||
Parameters | ||
---------- | ||
coord | ||
The coordinates of atoms. | ||
atype | ||
The atom types. | ||
cell | ||
The cell. | ||
""" | ||
minrr2, max_nnei = self.op( | ||
to_jax_array(coord), | ||
to_jax_array(atype), | ||
to_jax_array(cell), | ||
) | ||
minrr2 = to_numpy_array(minrr2) | ||
max_nnei = to_numpy_array(max_nnei) | ||
return minrr2, max_nnei |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.