Skip to content

Commit

Permalink
use jax auto batch size
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Oct 26, 2024
1 parent c9fee92 commit 2256c6b
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 3 deletions.
6 changes: 3 additions & 3 deletions deepmd/jax/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@
OutputVariableCategory,
OutputVariableDef,
)
from deepmd.dpmodel.utils.batch_size import (
AutoBatchSize,
)
from deepmd.dpmodel.utils.serialization import (
load_dp_model,
)
Expand Down Expand Up @@ -52,6 +49,9 @@
from deepmd.jax.model.hlo import (
HLO,
)
from deepmd.jax.utils.auto_batch_size import (
AutoBatchSize,
)

if TYPE_CHECKING:
import ase.neighborlist
Expand Down
59 changes: 59 additions & 0 deletions deepmd/jax/utils/auto_batch_size.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# SPDX-License-Identifier: LGPL-3.0-or-later

import jaxlib

from deepmd.jax.env import (
jax,
)
from deepmd.utils.batch_size import AutoBatchSize as AutoBatchSizeBase


class AutoBatchSize(AutoBatchSizeBase):
"""Auto batch size.
Parameters
----------
initial_batch_size : int, default: 1024
initial batch size (number of total atoms) when DP_INFER_BATCH_SIZE
is not set
factor : float, default: 2.
increased factor
"""

def __init__(
self,
initial_batch_size: int = 1024,
factor: float = 2.0,
):
super().__init__(
initial_batch_size=initial_batch_size,
factor=factor,
)

def is_gpu_available(self) -> bool:
"""Check if GPU is available.
Returns
-------
bool
True if GPU is available
"""
return jax.devices()[0].platform == "gpu"

def is_oom_error(self, e: Exception) -> bool:
"""Check if the exception is an OOM error.
Parameters
----------
e : Exception
Exception
"""
# several sources think CUSOLVER_STATUS_INTERNAL_ERROR is another out-of-memory error,
# such as https://github.com/JuliaGPU/CUDA.jl/issues/1924
# (the meaningless error message should be considered as a bug in cusolver)
if isinstance(e, (jaxlib.xla_extension.XlaRuntimeError, ValueError)) and (
"RESOURCE_EXHAUSTED:" in e.args[0]
):
return True
return False

0 comments on commit 2256c6b

Please sign in to comment.