From 6d7a086d5d99f6f5ed7cf624674e887b48442514 Mon Sep 17 00:00:00 2001 From: gobbleturk Date: Mon, 23 Dec 2024 23:10:18 +0000 Subject: [PATCH] Add option to skip initializing the jax distributed system --- MaxText/configs/base.yml | 4 ++++ MaxText/max_utils.py | 3 +++ 2 files changed, 7 insertions(+) diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index 8c667cbc3..7b221a007 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -346,6 +346,10 @@ jax_distributed_initialization_timeout: 300 # This is the default timeout in htt # Note there are two separate initializations - the jax coordination service (aka jax.distributed.initialize) and the backend (e.g. PjRT), the timeout above refers # only to the jax coordination service. jax_debug_log_modules: "" # Set this to "jax" to enable jax verbose logging such as for the jax coordination service initialization. +skip_jax_distributed_system: False # If True we will not initialize the jax distributed system. +# Currently the jax distributed is needed on cloud TPUs for async checkpointing. +# However when run on google internal TPUs the coordination service is started automatically +# and we should set this to True. # We take inspiration from Llama2's learning rate (LR) schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2 # Learning rate schedule has either two or three parts: diff --git a/MaxText/max_utils.py b/MaxText/max_utils.py index 06ab3f6a4..edc4b395c 100644 --- a/MaxText/max_utils.py +++ b/MaxText/max_utils.py @@ -219,6 +219,9 @@ def maybe_initialize_jax_distributed_system(raw_keys): For CPUs, we call jax.distributed.initialize() explicitly, with the specified arguments. """ + if raw_keys["skip_jax_distributed_system"]: + max_logging.log("Skipping jax distributed system due to skip_jax_distributed_system=True flag.") + return if raw_keys["inference_benchmark_test"]: # Disable initialization for inference benmark test. return