Skip to content

Commit

Permalink
fix GPU test OOM problem
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Jan 31, 2024
1 parent 664c70b commit 8e8347c
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
3 changes: 3 additions & 0 deletions deepmd/tf/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,9 @@ def _get_package_constants(

op_module = get_module("deepmd_op")
op_grads_module = get_module("op_grads")
# prevent OOM when using with other backends
# tf.config doesn't work for unclear reason
set_env_if_empty("TF_FORCE_GPU_ALLOW_GROWTH", "true", verbose=False)

# FLOAT_PREC
GLOBAL_TF_FLOAT_PRECISION = tf.dtypes.as_dtype(GLOBAL_NP_FLOAT_PRECISION)
Expand Down
9 changes: 9 additions & 0 deletions source/tests/pt/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import pytest
import torch


@pytest.fixture(scope="package", autouse=True)
def clear_cuda_memory(request):
yield
torch.cuda.empty_cache()

0 comments on commit 8e8347c

Please sign in to comment.