diff --git a/deepmd/jax/env.py b/deepmd/jax/env.py index 1b90433b00..393f45319a 100644 --- a/deepmd/jax/env.py +++ b/deepmd/jax/env.py @@ -1,6 +1,18 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import os +from deepmd.env import ( + get_default_nthreads, + set_default_nthreads, +) + +set_default_nthreads() +inter_nthreads, intra_nthreads = get_default_nthreads() +os.environ["XLA_FLAGS"] = os.environ.get("XLA_FLAGS", "") + ( + " --xla_cpu_multi_thread_eigen=false" + f" intra_op_parallelism_threads={inter_nthreads}" + f" inter_op_parallelism_threads={inter_nthreads}" +) os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" import jax diff --git a/deepmd/jax/jax2tf/__init__.py b/deepmd/jax/jax2tf/__init__.py index 88a928f04d..8551ca1f60 100644 --- a/deepmd/jax/jax2tf/__init__.py +++ b/deepmd/jax/jax2tf/__init__.py @@ -1,6 +1,11 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import tensorflow as tf +from deepmd.env import ( + get_default_nthreads, + set_default_nthreads, +) + if not tf.executing_eagerly(): # TF disallow temporary eager execution raise RuntimeError( @@ -9,3 +14,8 @@ "If you are converting a model between different backends, " "considering converting to the `.dp` format first." ) + +set_default_nthreads() +inter_nthreads, intra_nthreads = get_default_nthreads() +tf.config.threading.set_inter_op_parallelism_threads(inter_nthreads) +tf.config.threading.set_intra_op_parallelism_threads(intra_nthreads)