From 3014e91f9753c53a421212301b1ac1d3f8b4c0cc Mon Sep 17 00:00:00 2001 From: lzy-dev Date: Wed, 4 Dec 2024 10:18:48 +0800 Subject: [PATCH] polish --- examples/aquila/conf/train/train_aquila_3b.yaml | 6 +++--- flagscale/train/arguments.py | 5 +++++ .../core/distributed/finalize_model_grads.py | 15 ++++++++------- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/examples/aquila/conf/train/train_aquila_3b.yaml b/examples/aquila/conf/train/train_aquila_3b.yaml index c2268b491..1a062459d 100644 --- a/examples/aquila/conf/train/train_aquila_3b.yaml +++ b/examples/aquila/conf/train/train_aquila_3b.yaml @@ -14,8 +14,8 @@ system: log_interval: 1 log_throughput: True tensorboard_log_interval: 1 - wandb_project: "train-aquila-1B" - wandb_exp_name: "train-test-1B" + wandb_project: "train-aquila-3B" + wandb_exp_name: "train-test-3B" checkpoint: load: outputs_llama3/checkpoint_mc ckpt_format: torch @@ -94,7 +94,7 @@ data: data_path: ${data_path:??} split: 1 tokenizer: - tokenizer_type: QwenTokenizerFS + tokenizer_type: Qwen2TokenizerFS tokenizer_path: examples/aquila/qwentokenizer vocab_size: 151936 make_vocab_size_divisible_by: 64 diff --git a/flagscale/train/arguments.py b/flagscale/train/arguments.py index 4f11cbdea..8d1447043 100644 --- a/flagscale/train/arguments.py +++ b/flagscale/train/arguments.py @@ -110,6 +110,11 @@ def pre_validate_args(self): 'pipeline_model_parallel_split_rank not supported with process_meshes set!' self.args.transformer_pipeline_model_parallel_size = self.args.pipeline_model_parallel_size + # if untie_embeddings_and_output_weights is False, the first and last stage should have the same tp degree + if self.args.untie_embeddings_and_output_weights == False: + assert all(hetero_process_meshes_tp[0] == hetero_process_meshes_tp[-1]), \ + f"if untie_embeddings_and_output_weights is False, the first and last stage should have the same tp degree!" + # Virtual parallel size. if self.args.enable_hetero: assert self.args.num_layers_per_virtual_pipeline_stage == None, \ diff --git a/megatron/megatron/core/distributed/finalize_model_grads.py b/megatron/megatron/core/distributed/finalize_model_grads.py index 19d297e62..e394dbcd2 100644 --- a/megatron/megatron/core/distributed/finalize_model_grads.py +++ b/megatron/megatron/core/distributed/finalize_model_grads.py @@ -18,13 +18,14 @@ def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: Transf sync. """ - embed_group = parallel_state.get_embedding_group() - if not isinstance(embed_group, list): - embed_group = [embed_group] - if ( - parallel_state.is_rank_in_embedding_group(ignore_virtual=True) - and torch.distributed.get_world_size(embed_group[0]) > 1 - ): + if (parallel_state.is_rank_in_embedding_group(ignore_virtual=True)): + embed_group = parallel_state.get_embedding_group() + if not isinstance(embed_group, list): + embed_group = [embed_group] + else: + return + + if (torch.distributed.get_world_size(embed_group[0]) > 1): if parallel_state.is_pipeline_first_stage(ignore_virtual=True): model_module = model[0] elif parallel_state.is_pipeline_last_stage(ignore_virtual=True):