Skip to content

Commit

Permalink
Merge pull request #1072 from AI-Hypercomputer:mattdavidow-assert-dcn
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 702067169
  • Loading branch information
maxtext authors committed Dec 2, 2024
2 parents 5b960b1 + 5f43228 commit 5cdbabb
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions MaxText/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def validate_keys(keys):
" use_replicator_service and replicator_backup_interval_minutes"
)

validate_multiple_slices(keys)
if keys["num_experts"] > 1:
validate_megablox_parallelism(keys)

Expand Down Expand Up @@ -488,6 +489,27 @@ def update_model_vars(base_config_path, raw_keys, config_name: str):
return updated_keys


def validate_multiple_slices(raw_keys):
if (
math.fabs(
math.prod(
[
raw_keys["dcn_data_parallelism"],
raw_keys["dcn_pipeline_parallelism"],
raw_keys["dcn_fsdp_parallelism"],
raw_keys["dcn_fsdp_transpose_parallelism"],
raw_keys["dcn_sequence_parallelism"],
raw_keys["dcn_tensor_parallelism"],
raw_keys["dcn_expert_parallelism"],
raw_keys["dcn_autoregressive_parallelism"],
]
)
)
> 1
):
assert raw_keys["num_slices"] > 1, "DCN parallelism requested but only one slice available."


def validate_megablox_parallelism(raw_keys):
if raw_keys["megablox"] and (
using_sequence_parallelism(raw_keys) or using_pipeline_parallelism(raw_keys) or using_expert_parallelism(raw_keys)
Expand Down

0 comments on commit 5cdbabb

Please sign in to comment.