diff --git a/source/tests/__init__.py b/source/tests/__init__.py index 6ceb116d85..5ca68af64d 100644 --- a/source/tests/__init__.py +++ b/source/tests/__init__.py @@ -1 +1,11 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import os + +# set XLA FLAGS before any jax import +os.environ["XLA_FLAGS"] = " ".join( + ( + "--xla_cpu_multi_thread_eigen=false", + "intra_op_parallelism_threads=1", + "inter_op_parallelism_threads=1", + ) +) diff --git a/source/tests/jax/__init__.py b/source/tests/jax/__init__.py index 52e6a17be2..6ceb116d85 100644 --- a/source/tests/jax/__init__.py +++ b/source/tests/jax/__init__.py @@ -1,10 +1 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import os - -os.environ["XLA_FLAGS"] = " ".join( - ( - "--xla_cpu_multi_thread_eigen=false", - "intra_op_parallelism_threads=1", - "inter_op_parallelism_threads=1", - ) -)