Skip to content

Commit

Permalink
Also set XLA_PYTHON_CLIENT_ALLOCATOR="platform"
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed Nov 21, 2024
1 parent e2a18e0 commit f1b3167
Showing 1 changed file with 18 additions and 12 deletions.
30 changes: 18 additions & 12 deletions project/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,18 +142,24 @@
@pytest.fixture(autouse=True, scope="session")
def prevent_jax_from_reserving_all_the_vram():
# note; not using monkeypatch because we want this to be session-scoped.
val_before = os.environ.get("XLA_PYTHON_CLIENT_PREALLOCATE")
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

# allocator_before = os.environ.get("XLA_PYTHON_CLIENT_ALLOCATOR")
# os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"

yield

if val_before is None:
os.environ.pop("XLA_PYTHON_CLIENT_PREALLOCATE")
else:
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = val_before
@contextmanager
def change_env(variable_name: str, value: str):
val_before = os.environ.get(variable_name)
os.environ[variable_name] = value
yield
if val_before is None:
os.environ.pop(variable_name)
else:
os.environ[variable_name] = val_before

# Set these so that we can use torch and jax during tests on the same GPU (and so that Jax lets
# go of the VRAM it doesn't need anymore.
# See https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html for more info.
with (
change_env("XLA_PYTHON_CLIENT_PREALLOCATE", "false"),
change_env("XLA_PYTHON_CLIENT_ALLOCATOR", "platform"),
):
yield


@pytest.fixture(autouse=True)
Expand Down

0 comments on commit f1b3167

Please sign in to comment.