Skip to content

Commit

Permalink
add cross-platform AutoBatchSize
Browse files Browse the repository at this point in the history
See #3118 and dptech-corp/deepmd-pytorch#137.
Subclass needs to implement `is_gpu_available` and `is_oom_error`.

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Jan 13, 2024
1 parent 2096b80 commit fe0d4ad
Show file tree
Hide file tree
Showing 2 changed files with 250 additions and 191 deletions.
208 changes: 17 additions & 191 deletions deepmd/utils/batch_size.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,4 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
import os
from typing import (
Callable,
Tuple,
)

import numpy as np
from packaging.version import (
Version,
)
Expand All @@ -18,197 +10,31 @@
from deepmd.utils.errors import (
OutOfMemoryError,
)
from deepmd_utils.utils.batch_size import AutoBatchSize as AutoBatchSizeBase

log = logging.getLogger(__name__)


class AutoBatchSize:
"""This class allows DeePMD-kit to automatically decide the maximum
batch size that will not cause an OOM error.
Notes
-----
In some CPU environments, the program may be directly killed when OOM. In
this case, by default the batch size will not be increased for CPUs. The
environment variable `DP_INFER_BATCH_SIZE` can be set as the batch size.
In other cases, we assume all OOM error will raise :class:`OutOfMemoryError`.
Parameters
----------
initial_batch_size : int, default: 1024
initial batch size (number of total atoms) when DP_INFER_BATCH_SIZE
is not set
factor : float, default: 2.
increased factor
Attributes
----------
current_batch_size : int
current batch size (number of total atoms)
maximum_working_batch_size : int
maximum working batch size
minimal_not_working_batch_size : int
minimal not working batch size
"""

def __init__(self, initial_batch_size: int = 1024, factor: float = 2.0) -> None:
# See also PyTorchLightning/pytorch-lightning#1638
# TODO: discuss a proper initial batch size
self.current_batch_size = initial_batch_size
DP_INFER_BATCH_SIZE = int(os.environ.get("DP_INFER_BATCH_SIZE", 0))
if DP_INFER_BATCH_SIZE > 0:
self.current_batch_size = DP_INFER_BATCH_SIZE
self.maximum_working_batch_size = DP_INFER_BATCH_SIZE
self.minimal_not_working_batch_size = self.maximum_working_batch_size + 1
else:
self.maximum_working_batch_size = initial_batch_size
if (
Version(TF_VERSION) >= Version("1.14")
and tf.config.experimental.get_visible_devices("GPU")
) or tf.test.is_gpu_available():
self.minimal_not_working_batch_size = 2**31
else:
self.minimal_not_working_batch_size = (
self.maximum_working_batch_size + 1
)
log.warning(
"You can use the environment variable DP_INFER_BATCH_SIZE to"
"control the inference batch size (nframes * natoms). "
"The default value is %d." % initial_batch_size
)

self.factor = factor

def execute(
self, callable: Callable, start_index: int, natoms: int
) -> Tuple[int, tuple]:
"""Excuate a method with given batch size.
Parameters
----------
callable : Callable
The method should accept the batch size and start_index as parameters,
and returns executed batch size and data.
start_index : int
start index
natoms : int
natoms
class AutoBatchSize(AutoBatchSizeBase):
def is_gpu_available(self) -> bool:
"""Check if GPU is available.
Returns
-------
int
executed batch size * number of atoms
tuple
result from callable, None if failing to execute
Raises
------
OutOfMemoryError
OOM when batch size is 1
bool
True if GPU is available
"""
if natoms > 0:
batch_nframes = self.current_batch_size // natoms
else:
batch_nframes = self.current_batch_size
try:
n_batch, result = callable(max(batch_nframes, 1), start_index)
except OutOfMemoryError as e:
# TODO: it's very slow to catch OOM error; I don't know what TF is doing here
# but luckily we only need to catch once
self.minimal_not_working_batch_size = min(
self.minimal_not_working_batch_size, self.current_batch_size
)
if self.maximum_working_batch_size >= self.minimal_not_working_batch_size:
self.maximum_working_batch_size = int(
self.minimal_not_working_batch_size / self.factor
)
if self.minimal_not_working_batch_size <= natoms:
raise OutOfMemoryError(
"The callable still throws an out-of-memory (OOM) error even when batch size is 1!"
) from e
# adjust the next batch size
self._adjust_batch_size(1.0 / self.factor)
return 0, None
else:
n_tot = n_batch * natoms
self.maximum_working_batch_size = max(
self.maximum_working_batch_size, n_tot
)
# adjust the next batch size
if (
n_tot + natoms > self.current_batch_size
and self.current_batch_size * self.factor
< self.minimal_not_working_batch_size
):
self._adjust_batch_size(self.factor)
return n_batch, result
return (
Version(TF_VERSION) >= Version("1.14")
and tf.config.experimental.get_visible_devices("GPU")
) or tf.test.is_gpu_available()

def _adjust_batch_size(self, factor: float):
old_batch_size = self.current_batch_size
self.current_batch_size = int(self.current_batch_size * factor)
log.info(
"Adjust batch size from %d to %d"
% (old_batch_size, self.current_batch_size)
)

def execute_all(
self, callable: Callable, total_size: int, natoms: int, *args, **kwargs
) -> Tuple[np.ndarray]:
"""Excuate a method with all given data.
def is_oom_error(self, e: Exception) -> bool:
"""Check if the exception is an OOM error.
Parameters
----------
callable : Callable
The method should accept *args and **kwargs as input and return the similiar array.
total_size : int
Total size
natoms : int
The number of atoms
*args
Variable length argument list.
**kwargs
If 2D np.ndarray, assume the first axis is batch; otherwise do nothing.
e : Exception
Exception
"""

def execute_with_batch_size(
batch_size: int, start_index: int
) -> Tuple[int, Tuple[np.ndarray]]:
end_index = start_index + batch_size
end_index = min(end_index, total_size)
return (end_index - start_index), callable(
*[
(
vv[start_index:end_index]
if isinstance(vv, np.ndarray) and vv.ndim > 1
else vv
)
for vv in args
],
**{
kk: (
vv[start_index:end_index]
if isinstance(vv, np.ndarray) and vv.ndim > 1
else vv
)
for kk, vv in kwargs.items()
},
)

index = 0
results = []
while index < total_size:
n_batch, result = self.execute(execute_with_batch_size, index, natoms)
if not isinstance(result, tuple):
result = (result,)
index += n_batch
if n_batch:
for rr in result:
rr.reshape((n_batch, -1))
results.append(result)

r = tuple([np.concatenate(r, axis=0) for r in zip(*results)])
if len(r) == 1:
# avoid returning tuple if callable doesn't return tuple
r = r[0]
return r
# TODO: it's very slow to catch OOM error; I don't know what TF is doing here
# but luckily we only need to catch once
return isinstance(e, (tf.errors.ResourceExhaustedError, OutOfMemoryError))
Loading

0 comments on commit fe0d4ad

Please sign in to comment.