diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index 0e066227e..6c82b2be7 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -346,6 +346,10 @@ jax_distributed_initialization_timeout: 300 # This is the default timeout in htt # Note there are two separate initializations - the jax coordination service (aka jax.distributed.initialize) and the backend (e.g. PjRT), the timeout above refers # only to the jax coordination service. jax_debug_log_modules: "" # Set this to "jax" to enable jax verbose logging such as for the jax coordination service initialization. +skip_jax_distributed_system: False # If True we will not initialize the jax distributed system. +# Currently the jax distributed is needed on cloud TPUs for async checkpointing. +# However when run on google internal TPUs the coordination service is started automatically +# and we should set this to True so we won't try to initialize a second time manually. # We take inspiration from Llama2's learning rate (LR) schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2 # Learning rate schedule has either two or three parts: diff --git a/MaxText/max_utils.py b/MaxText/max_utils.py index 2967402a6..f6e9f54e5 100644 --- a/MaxText/max_utils.py +++ b/MaxText/max_utils.py @@ -254,6 +254,9 @@ def maybe_initialize_jax_distributed_system(raw_keys): For CPUs, we call jax.distributed.initialize() explicitly, with the specified arguments. """ + if raw_keys["skip_jax_distributed_system"]: + max_logging.log("Skipping jax distributed system due to skip_jax_distributed_system=True flag.") + return if raw_keys["inference_benchmark_test"]: # Disable initialization for inference benmark test. return