Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add cross-platform AutoBatchSize #3143

Merged
merged 1 commit into from
Jan 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 17 additions & 191 deletions deepmd/utils/batch_size.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,4 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
import os
from typing import (
Callable,
Tuple,
)

import numpy as np
from packaging.version import (
Version,
)
Expand All @@ -18,197 +10,31 @@
from deepmd.utils.errors import (
OutOfMemoryError,
)
from deepmd_utils.utils.batch_size import AutoBatchSize as AutoBatchSizeBase

log = logging.getLogger(__name__)


class AutoBatchSize:
"""This class allows DeePMD-kit to automatically decide the maximum
batch size that will not cause an OOM error.

Notes
-----
In some CPU environments, the program may be directly killed when OOM. In
this case, by default the batch size will not be increased for CPUs. The
environment variable `DP_INFER_BATCH_SIZE` can be set as the batch size.

In other cases, we assume all OOM error will raise :class:`OutOfMemoryError`.

Parameters
----------
initial_batch_size : int, default: 1024
initial batch size (number of total atoms) when DP_INFER_BATCH_SIZE
is not set
factor : float, default: 2.
increased factor

Attributes
----------
current_batch_size : int
current batch size (number of total atoms)
maximum_working_batch_size : int
maximum working batch size
minimal_not_working_batch_size : int
minimal not working batch size
"""

def __init__(self, initial_batch_size: int = 1024, factor: float = 2.0) -> None:
# See also PyTorchLightning/pytorch-lightning#1638
# TODO: discuss a proper initial batch size
self.current_batch_size = initial_batch_size
DP_INFER_BATCH_SIZE = int(os.environ.get("DP_INFER_BATCH_SIZE", 0))
if DP_INFER_BATCH_SIZE > 0:
self.current_batch_size = DP_INFER_BATCH_SIZE
self.maximum_working_batch_size = DP_INFER_BATCH_SIZE
self.minimal_not_working_batch_size = self.maximum_working_batch_size + 1
else:
self.maximum_working_batch_size = initial_batch_size
if (
Version(TF_VERSION) >= Version("1.14")
and tf.config.experimental.get_visible_devices("GPU")
) or tf.test.is_gpu_available():
self.minimal_not_working_batch_size = 2**31
else:
self.minimal_not_working_batch_size = (
self.maximum_working_batch_size + 1
)
log.warning(
"You can use the environment variable DP_INFER_BATCH_SIZE to"
"control the inference batch size (nframes * natoms). "
"The default value is %d." % initial_batch_size
)

self.factor = factor

def execute(
self, callable: Callable, start_index: int, natoms: int
) -> Tuple[int, tuple]:
"""Excuate a method with given batch size.

Parameters
----------
callable : Callable
The method should accept the batch size and start_index as parameters,
and returns executed batch size and data.
start_index : int
start index
natoms : int
natoms
class AutoBatchSize(AutoBatchSizeBase):
def is_gpu_available(self) -> bool:
"""Check if GPU is available.

Returns
-------
int
executed batch size * number of atoms
tuple
result from callable, None if failing to execute

Raises
------
OutOfMemoryError
OOM when batch size is 1
bool
True if GPU is available
"""
if natoms > 0:
batch_nframes = self.current_batch_size // natoms
else:
batch_nframes = self.current_batch_size
try:
n_batch, result = callable(max(batch_nframes, 1), start_index)
except OutOfMemoryError as e:
# TODO: it's very slow to catch OOM error; I don't know what TF is doing here
# but luckily we only need to catch once
self.minimal_not_working_batch_size = min(
self.minimal_not_working_batch_size, self.current_batch_size
)
if self.maximum_working_batch_size >= self.minimal_not_working_batch_size:
self.maximum_working_batch_size = int(
self.minimal_not_working_batch_size / self.factor
)
if self.minimal_not_working_batch_size <= natoms:
raise OutOfMemoryError(
"The callable still throws an out-of-memory (OOM) error even when batch size is 1!"
) from e
# adjust the next batch size
self._adjust_batch_size(1.0 / self.factor)
return 0, None
else:
n_tot = n_batch * natoms
self.maximum_working_batch_size = max(
self.maximum_working_batch_size, n_tot
)
# adjust the next batch size
if (
n_tot + natoms > self.current_batch_size
and self.current_batch_size * self.factor
< self.minimal_not_working_batch_size
):
self._adjust_batch_size(self.factor)
return n_batch, result
return (
Version(TF_VERSION) >= Version("1.14")
and tf.config.experimental.get_visible_devices("GPU")
) or tf.test.is_gpu_available()

def _adjust_batch_size(self, factor: float):
old_batch_size = self.current_batch_size
self.current_batch_size = int(self.current_batch_size * factor)
log.info(
"Adjust batch size from %d to %d"
% (old_batch_size, self.current_batch_size)
)

def execute_all(
self, callable: Callable, total_size: int, natoms: int, *args, **kwargs
) -> Tuple[np.ndarray]:
"""Excuate a method with all given data.
def is_oom_error(self, e: Exception) -> bool:
"""Check if the exception is an OOM error.

Parameters
----------
callable : Callable
The method should accept *args and **kwargs as input and return the similiar array.
total_size : int
Total size
natoms : int
The number of atoms
*args
Variable length argument list.
**kwargs
If 2D np.ndarray, assume the first axis is batch; otherwise do nothing.
e : Exception
Exception
"""

def execute_with_batch_size(
batch_size: int, start_index: int
) -> Tuple[int, Tuple[np.ndarray]]:
end_index = start_index + batch_size
end_index = min(end_index, total_size)
return (end_index - start_index), callable(
*[
(
vv[start_index:end_index]
if isinstance(vv, np.ndarray) and vv.ndim > 1
else vv
)
for vv in args
],
**{
kk: (
vv[start_index:end_index]
if isinstance(vv, np.ndarray) and vv.ndim > 1
else vv
)
for kk, vv in kwargs.items()
},
)

index = 0
results = []
while index < total_size:
n_batch, result = self.execute(execute_with_batch_size, index, natoms)
if not isinstance(result, tuple):
result = (result,)
index += n_batch
if n_batch:
for rr in result:
rr.reshape((n_batch, -1))
results.append(result)

r = tuple([np.concatenate(r, axis=0) for r in zip(*results)])
if len(r) == 1:
# avoid returning tuple if callable doesn't return tuple
r = r[0]
return r
# TODO: it's very slow to catch OOM error; I don't know what TF is doing here
# but luckily we only need to catch once
return isinstance(e, (tf.errors.ResourceExhaustedError, OutOfMemoryError))
Loading
Loading