Skip to content

Commit

Permalink
fix(jax): set threads for XLA and TF
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Nov 8, 2024
1 parent 3701566 commit 00e3c10
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
12 changes: 12 additions & 0 deletions deepmd/jax/env.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import os

from deepmd.env import (
get_default_nthreads,
set_default_nthreads,
)

set_default_nthreads()
inter_nthreads, intra_nthreads = get_default_nthreads()
os.environ["XLA_FLAGS"] = os.environ.get("XLA_FLAGS", "") + (
" --xla_cpu_multi_thread_eigen=false"
f" intra_op_parallelism_threads={inter_nthreads}"
f" inter_op_parallelism_threads={inter_nthreads}"
)
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

import jax
Expand Down
10 changes: 10 additions & 0 deletions deepmd/jax/jax2tf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import tensorflow as tf

from deepmd.env import (
get_default_nthreads,
set_default_nthreads,
)

if not tf.executing_eagerly():
# TF disallow temporary eager execution
raise RuntimeError(
Expand All @@ -9,3 +14,8 @@
"If you are converting a model between different backends, "
"considering converting to the `.dp` format first."
)

set_default_nthreads()
inter_nthreads, intra_nthreads = get_default_nthreads()
tf.config.threading.set_inter_op_parallelism_threads(inter_nthreads)
tf.config.threading.set_intra_op_parallelism_threads(intra_nthreads)

0 comments on commit 00e3c10

Please sign in to comment.