Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to skip initializing the jax distributed system #1125

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,10 @@ jax_distributed_initialization_timeout: 300 # This is the default timeout in htt
# Note there are two separate initializations - the jax coordination service (aka jax.distributed.initialize) and the backend (e.g. PjRT), the timeout above refers
# only to the jax coordination service.
jax_debug_log_modules: "" # Set this to "jax" to enable jax verbose logging such as for the jax coordination service initialization.
skip_jax_distributed_system: False # If True we will not initialize the jax distributed system.
# Currently the jax distributed is needed on cloud TPUs for async checkpointing.
# However when run on google internal TPUs the coordination service is started automatically
# and we should set this to True.

# We take inspiration from Llama2's learning rate (LR) schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
# Learning rate schedule has either two or three parts:
Expand Down
3 changes: 3 additions & 0 deletions MaxText/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,9 @@ def maybe_initialize_jax_distributed_system(raw_keys):

For CPUs, we call jax.distributed.initialize() explicitly, with the specified arguments.
"""
if raw_keys["skip_jax_distributed_system"]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Thanks for the log! So we know the reason to skip the initialization. Could you help add accordingly logs to following 2 cases? inference_benchmark_test & compile_topology?

max_logging.log("Skipping jax distributed system due to skip_jax_distributed_system=True flag.")
return
if raw_keys["inference_benchmark_test"]:
# Disable initialization for inference benmark test.
return
Expand Down
Loading