From 050c20c6044839795c4c56f1562a31b2bbb78bd0 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Fri, 8 Nov 2024 15:23:34 -0500 Subject: [PATCH] set xla flags before any imports Signed-off-by: Jinzhe Zeng --- source/tests/__init__.py | 10 ++++++++++ source/tests/jax/__init__.py | 9 --------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/source/tests/__init__.py b/source/tests/__init__.py index 6ceb116d85..5ca68af64d 100644 --- a/source/tests/__init__.py +++ b/source/tests/__init__.py @@ -1 +1,11 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import os + +# set XLA FLAGS before any jax import +os.environ["XLA_FLAGS"] = " ".join( + ( + "--xla_cpu_multi_thread_eigen=false", + "intra_op_parallelism_threads=1", + "inter_op_parallelism_threads=1", + ) +) diff --git a/source/tests/jax/__init__.py b/source/tests/jax/__init__.py index 52e6a17be2..6ceb116d85 100644 --- a/source/tests/jax/__init__.py +++ b/source/tests/jax/__init__.py @@ -1,10 +1 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import os - -os.environ["XLA_FLAGS"] = " ".join( - ( - "--xla_cpu_multi_thread_eigen=false", - "intra_op_parallelism_threads=1", - "inter_op_parallelism_threads=1", - ) -)