Skip to content

Commit

Permalink
polish
Browse files Browse the repository at this point in the history
  • Loading branch information
lzy-dev committed Dec 4, 2024
1 parent 613b035 commit 3014e91
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 10 deletions.
6 changes: 3 additions & 3 deletions examples/aquila/conf/train/train_aquila_3b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ system:
log_interval: 1
log_throughput: True
tensorboard_log_interval: 1
wandb_project: "train-aquila-1B"
wandb_exp_name: "train-test-1B"
wandb_project: "train-aquila-3B"
wandb_exp_name: "train-test-3B"
checkpoint:
load: outputs_llama3/checkpoint_mc
ckpt_format: torch
Expand Down Expand Up @@ -94,7 +94,7 @@ data:
data_path: ${data_path:??}
split: 1
tokenizer:
tokenizer_type: QwenTokenizerFS
tokenizer_type: Qwen2TokenizerFS
tokenizer_path: examples/aquila/qwentokenizer
vocab_size: 151936
make_vocab_size_divisible_by: 64
5 changes: 5 additions & 0 deletions flagscale/train/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ def pre_validate_args(self):
'pipeline_model_parallel_split_rank not supported with process_meshes set!'
self.args.transformer_pipeline_model_parallel_size = self.args.pipeline_model_parallel_size

# if untie_embeddings_and_output_weights is False, the first and last stage should have the same tp degree
if self.args.untie_embeddings_and_output_weights == False:
assert all(hetero_process_meshes_tp[0] == hetero_process_meshes_tp[-1]), \
f"if untie_embeddings_and_output_weights is False, the first and last stage should have the same tp degree!"

# Virtual parallel size.
if self.args.enable_hetero:
assert self.args.num_layers_per_virtual_pipeline_stage == None, \
Expand Down
15 changes: 8 additions & 7 deletions megatron/megatron/core/distributed/finalize_model_grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@ def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: Transf
sync.
"""

embed_group = parallel_state.get_embedding_group()
if not isinstance(embed_group, list):
embed_group = [embed_group]
if (
parallel_state.is_rank_in_embedding_group(ignore_virtual=True)
and torch.distributed.get_world_size(embed_group[0]) > 1
):
if (parallel_state.is_rank_in_embedding_group(ignore_virtual=True)):
embed_group = parallel_state.get_embedding_group()
if not isinstance(embed_group, list):
embed_group = [embed_group]
else:
return

if (torch.distributed.get_world_size(embed_group[0]) > 1):
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
model_module = model[0]
elif parallel_state.is_pipeline_last_stage(ignore_virtual=True):
Expand Down

0 comments on commit 3014e91

Please sign in to comment.