diff --git a/examples/aquila/conf/config.yaml b/examples/aquila/conf/config.yaml index 2070a15f0..d6ec0c544 100644 --- a/examples/aquila/conf/config.yaml +++ b/examples/aquila/conf/config.yaml @@ -1,5 +1,5 @@ defaults: - - train: demo + - train: demo - _self_ experiment: @@ -9,6 +9,8 @@ experiment: type: train backend: megatron entrypoint: ./flagscale/train/train_aquila.py + # cmds: + # before_start: source /root/miniconda3/bin/activate flagscale runner: backend: torchrun nnodes: 1 diff --git a/examples/aquila/conf/train/train_aquila_3b.yaml b/examples/aquila/conf/train/train_aquila_3b.yaml new file mode 100644 index 000000000..dc7c1f5a3 --- /dev/null +++ b/examples/aquila/conf/train/train_aquila_3b.yaml @@ -0,0 +1,89 @@ +system: + reset_position_ids: True + reset_attention_mask: True + add_qkv_bias: True + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 2 + disable_bias_linear: True + use_flash_attn: True + use_distributed_optimizer: True + precision: + bf16: True + initial_loss_scale: 522893 + min_loss_scale: 1.0 + attention_softmax_in_fp32: True + accumulate_allreduce_grads_in_fp32: True + logging: + log_interval: 1 + log_throughput: True + tensorboard_log_interval: 1 + wandb-log-model: False + wandb-log-model-interval: 1 + wandb_project: "train-aquila-3B" + wandb_exp_name: "train-test-3B" + checkpoint: + load: outputs_llama3/checkpoint_mc + ckpt_format: torch + save_interval: 2385 + + # hetero: + # enable_hetero: True + # hetero_use_cpu_communication: False + # use_partial_reduce_for_shared_embedding: True + # # mesh format [tp1,cp1,ep1,dp1,pp1,(tp2,cp2...)] + + # hetero_pipeline_layer_split: [12,12] + # hetero_process_meshes: [1,1,1,4,1, 1,1,1,4,1] + # hetero_device_types: ["A800", "A800"] + + # standalone_embedding_stage: False + # hetero_current_device_type: "A800" +model: + transformer_impl: transformer_engine + num_layers: 24 + hidden_size: 1024 + num_attention_heads: 16 + group_query_attention: True + num_query_groups: 2 + seq_length: 4096 + max_position_embeddings: 4096 # only for adding position embeddings + norm_epsilon: 1e-6 + use_rotary_position_embeddings: true + no_position_embedding: true + rotary_base: 1000000 + swiglu: true + multiple_of: 256 + hidden_dim_multiplier: 2 # ffn_hidden_size 11008 + normalization: RMSNorm + position_embedding_type: rope + untie_embeddings_and_output_weights: False + init_method_std: 0.02 + attention_dropout: 0.0 + hidden_dropout: 0.0 + weight_decay: 0.1 + clip_grad: 1.0 + train_samples: 29297664 #120B tokens + eval_iters: 0 + micro_batch_size: 2 + global_batch_size: 1024 + seed: 42 + + optimizer: + weight_decay: 0.1 + adam_beta1: 0.9 + adam_beta2: 0.95 + lr_scheduler: + lr: 5.0e-3 + min_lr: 5.0e-4 + lr_warmup_samples: 10 + lr_decay_style: cosine + +data: + data_path: {data_path:??} + split: 1 + no_mmap_bin_files: true + tokenizer: + tokenizer_type: QwenTokenizerFS + tokenizer_path: examples/aquila/qwentokenizer + vocab_size: 151851 + make_vocab_size_divisible_by: 64 \ No newline at end of file diff --git a/flagscale/runner/runner_train.py b/flagscale/runner/runner_train.py index e459e5ecb..7a8dadc51 100644 --- a/flagscale/runner/runner_train.py +++ b/flagscale/runner/runner_train.py @@ -225,7 +225,7 @@ def _generate_run_script_train( f.write(f"\n") f.write(f"cd {root_dir}\n") f.write(f"\n") - f.write(f"export PYTHONPATH={megatron_dir}:{root_dir}\n") + f.write(f"export PYTHONPATH={megatron_dir}:{root_dir}:${{PYTHONPATH}}\n") f.write(f"\n") f.write(f'cmd="{cmd}"\n') f.write(f"\n") diff --git a/flagscale/train/arguments.py b/flagscale/train/arguments.py index 4f11cbdea..55ef080b4 100644 --- a/flagscale/train/arguments.py +++ b/flagscale/train/arguments.py @@ -110,6 +110,18 @@ def pre_validate_args(self): 'pipeline_model_parallel_split_rank not supported with process_meshes set!' self.args.transformer_pipeline_model_parallel_size = self.args.pipeline_model_parallel_size + # if untie_embeddings_and_output_weights is False, the first and last stage should have the same tp degree + if self.args.untie_embeddings_and_output_weights == False: + assert hetero_process_meshes_tp[0] == hetero_process_meshes_tp[-1], \ + f"if untie_embeddings_and_output_weights is False, the first and last stage should have the same tp degree!" + assert self.args.hetero_use_cpu_communication == False, \ + f"if untie_embeddings_and_output_weights is False, the hetero_use_cpu_communication should be False currently!" + if hetero_process_meshes_dp[0] != hetero_process_meshes_dp[-1]: + assert self.args.use_partial_reduce_for_shared_embedding == True, \ + f"if untie_embeddings_and_output_weights is False and hetero_process_meshes_dp[0] and hetero_process_meshes_dp[-1] are different, "\ + "the use_partial_reduce_for_shared_embedding should be True currently!" + + # Virtual parallel size. if self.args.enable_hetero: assert self.args.num_layers_per_virtual_pipeline_stage == None, \ diff --git a/flagscale/train/hetero/p2p_communication.py b/flagscale/train/hetero/p2p_communication.py index f89c9ccc6..df34c0362 100644 --- a/flagscale/train/hetero/p2p_communication.py +++ b/flagscale/train/hetero/p2p_communication.py @@ -240,7 +240,11 @@ def recv_backward_hetero(tensor_shape: Shape, config: ModelParallelConfig) -> to config=config, group=group, ) - output_tensor_grad.data[sp_start:sp_end, dp_start:dp_end, :] = output_tensor_grad_sliced + if dp_end - dp_start != tensor_shape[1]: + dp_coef = float((dp_end - dp_start)) / float(tensor_shape[1]) + output_tensor_grad.data[sp_start:sp_end, dp_start:dp_end, :] = output_tensor_grad_sliced * dp_coef + else: + output_tensor_grad.data[sp_start:sp_end, dp_start:dp_end, :] = output_tensor_grad_sliced if config.timers is not None: config.timers('backward-recv').stop() @@ -330,6 +334,10 @@ def send_backward_hetero(input_tensor_grad: torch.Tensor, config: ModelParallelC for tensor_slice in tensor_slices: dst_rank, (dp_start, dp_end), (sp_start, sp_end), local_hidden_size = tensor_slice input_tensor_grad_sliced = input_tensor_grad[sp_start:sp_end, dp_start:dp_end, :] + dp_coef = 1.0 + if dp_end - dp_start != input_tensor_grad.shape[1]: + dp_coef = float(input_tensor_grad.shape[1]) / float((dp_end - dp_start)) + group = None pp_groups = para_ctx.get_pipeline_model_parallel_group() for pp_group in pp_groups: @@ -339,7 +347,7 @@ def send_backward_hetero(input_tensor_grad: torch.Tensor, config: ModelParallelC break _communicate( tensor_send_next=None, - tensor_send_prev=input_tensor_grad_sliced.contiguous() if "gloo" not in group.name() else input_tensor_grad_sliced.cpu(), + tensor_send_prev=input_tensor_grad_sliced.contiguous() * dp_coef if "gloo" not in group.name() else input_tensor_grad_sliced.cpu() * dp_coef, recv_prev=False, recv_next=False, tensor_shape=None, @@ -405,7 +413,11 @@ def send_forward_recv_backward_hetero( config=config, group=group, ) - output_tensor_grad.data[sp_start:sp_end, dp_start:dp_end, :] = output_tensor_grad_sliced + if dp_end - dp_start != tensor_shape[1]: + dp_coef = float((dp_end - dp_start)) / float(tensor_shape[1]) + output_tensor_grad.data[sp_start:sp_end, dp_start:dp_end, :] = output_tensor_grad_sliced * dp_coef + else: + output_tensor_grad.data[sp_start:sp_end, dp_start:dp_end, :] = output_tensor_grad_sliced if config.timers is not None: config.timers('forward-send-backward-recv').stop() if output_tensor_grad is not None and output_tensor_grad.device == torch.device("cpu"): @@ -453,6 +465,9 @@ def send_backward_recv_forward_hetero( dst_rank, (dp_start, dp_end), (sp_start, sp_end), local_hidden_size = tensor_slice input_tensor_grad_sliced = input_tensor_grad[sp_start:sp_end, dp_start:dp_end, :] tensor_shape_sliced = (sp_end - sp_start, dp_end - dp_start, local_hidden_size) + dp_coef = 1.0 + if dp_end - dp_start != input_tensor_grad.shape[1]: + dp_coef = float(input_tensor_grad.shape[1]) / float((dp_end - dp_start)) group = None for pp_group in pp_groups: pp_group_ranks = torch.distributed.get_process_group_ranks(pp_group) @@ -461,7 +476,7 @@ def send_backward_recv_forward_hetero( break input_tensor_sliced, _, _ = _communicate( tensor_send_next=None, - tensor_send_prev=input_tensor_grad_sliced.contiguous() if "gloo" not in group.name() else input_tensor_grad_sliced.cpu(), + tensor_send_prev=input_tensor_grad_sliced.contiguous() * dp_coef if "gloo" not in group.name() else input_tensor_grad_sliced.cpu() * dp_coef, recv_prev=True, recv_next=False, tensor_shape=tensor_shape_sliced, diff --git a/flagscale/train/hetero/parallel_context.py b/flagscale/train/hetero/parallel_context.py index 0d73b0abe..14a70c20e 100644 --- a/flagscale/train/hetero/parallel_context.py +++ b/flagscale/train/hetero/parallel_context.py @@ -1003,11 +1003,7 @@ def get_embedding_group(self): """Get the embedding group the caller rank belongs to.""" groups = self._global_process_groups.get("embd", None) assert groups is not None, 'embedding group is not initialized' - for group in groups: - if self._rank in self._global_process_group_to_ranks[group]: - embd_group = group - break - return embd_group + return groups def get_position_embedding_group(self): """Get the position embedding group the caller rank belongs to.""" diff --git a/megatron/megatron/core/distributed/finalize_model_grads.py b/megatron/megatron/core/distributed/finalize_model_grads.py index 8c0dad8ab..caddcf070 100644 --- a/megatron/megatron/core/distributed/finalize_model_grads.py +++ b/megatron/megatron/core/distributed/finalize_model_grads.py @@ -115,10 +115,14 @@ def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: Transf sync. """ - if ( - parallel_state.is_rank_in_embedding_group(ignore_virtual=True) - and torch.distributed.get_world_size(parallel_state.get_embedding_group()) > 1 - ): + if (parallel_state.is_rank_in_embedding_group(ignore_virtual=True)): + embed_group = parallel_state.get_embedding_group() + if not isinstance(embed_group, list): + embed_group = [embed_group] + else: + return + + if (torch.distributed.get_world_size(embed_group[0]) > 1): if parallel_state.is_pipeline_first_stage(ignore_virtual=True): model_module = model[0] elif parallel_state.is_pipeline_last_stage(ignore_virtual=True): @@ -126,13 +130,41 @@ def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: Transf else: # We do not support an interleaved schedule for models with encoders yet. model_module = model[0] + use_dist_opt = False + if hasattr(model_module, "ddp_config"): + use_dist_opt = model_module.ddp_config.use_distributed_optimizer model_module = get_attr_wrapped_model(model_module, 'pre_process', return_model_obj=True) if model_module.share_embeddings_and_output_weights: weight = model_module.shared_embedding_or_output_weight() grad_attr = "main_grad" if hasattr(weight, "main_grad") else "grad" orig_grad = getattr(weight, grad_attr) grad = _unshard_if_dtensor(orig_grad) - torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group()) + if use_dist_opt: + if config.use_partial_reduce_for_shared_embedding: + dp_world_size = parallel_state.get_data_parallel_world_size() + dp_rank = parallel_state.get_data_parallel_rank() + assert grad.shape[0] % dp_world_size == 0, f"grad shape: {grad.shape[0]}, dp_world_size: {dp_world_size}" + per_partion_size = grad.shape[0] // dp_world_size + if len(embed_group) == 1: + offset = per_partion_size * dp_rank + torch.distributed.all_reduce(grad[offset:offset+per_partion_size, :], group=embed_group[0]) + else: + group_idx = 0 + per_partion_size = per_partion_size // len(embed_group) + for group in embed_group: + offset = per_partion_size * (dp_rank * len(embed_group) + group_idx) + torch.distributed.all_reduce(grad[offset : offset + per_partion_size, :], group=group) + group_idx += 1 + else: # megartron default method + torch.distributed.all_reduce(grad, group=embed_group[0]) + else: + if len(embed_group) == 1: # megartron default method + torch.distributed.all_reduce(grad, group=embed_group[0]) + else: + original_grad_data = grad.clone().detach().data + for group in embed_group: + grad.data.copy_(original_grad_data) + torch.distributed.all_reduce(grad, group=group) setattr(weight, grad_attr, _reshard_if_dtensor(grad, orig_grad)) diff --git a/megatron/megatron/core/models/common/language_module/language_module.py b/megatron/megatron/core/models/common/language_module/language_module.py index 7075e57f9..715442d80 100644 --- a/megatron/megatron/core/models/common/language_module/language_module.py +++ b/megatron/megatron/core/models/common/language_module/language_module.py @@ -98,9 +98,16 @@ def setup_embeddings_and_output_layer(self) -> None: if parallel_state.is_rank_in_embedding_group(): weight = self.shared_embedding_or_output_weight() weight.data = weight.data.cuda() - torch.distributed.all_reduce( - weight.data, group=parallel_state.get_embedding_group() - ) + embedding_group = parallel_state.get_embedding_group() + if not isinstance(embedding_group, list): + torch.distributed.all_reduce( + weight.data, group=parallel_state.get_embedding_group() + ) + else: + original_weight = weight.clone().detach().data + for group in embedding_group: + weight.data.copy_(original_weight) + torch.distributed.all_reduce(weight.data, group=group) elif not getattr(LanguageModule, "embedding_warning_printed", False): logging.getLogger(__name__).warning( diff --git a/megatron/megatron/core/optimizer/clip_grads.py b/megatron/megatron/core/optimizer/clip_grads.py index 210191b30..0fef2c885 100644 --- a/megatron/megatron/core/optimizer/clip_grads.py +++ b/megatron/megatron/core/optimizer/clip_grads.py @@ -141,9 +141,9 @@ def get_grad_norm_fp32( # For cpu comminication tensor_device = get_device_type_for_comm(grad_stats_parallel_group) if isinstance(grad_stats_parallel_group, list): - original_total_norm = total_norm - for group in grad_stats_parallel_group: - total_norm = original_total_norm + original_total_norm = total_norm.clone().detach() + for mp_group in grad_stats_parallel_group: + total_norm.data = original_total_norm.data.clone() total_norm = total_norm.to(tensor_device) torch.distributed.all_reduce( total_norm, op=torch.distributed.ReduceOp.SUM, group=group @@ -236,9 +236,9 @@ def count_zeros_fp32( ) # Sum across all model-parallel GPUs. if isinstance(grad_stats_parallel_group, list): - original_total_num_zeros = total_num_zeros + original_total_num_zeros = total_num_zeros.clone().detach() for group in grad_stats_parallel_group: - total_num_zeros = original_total_num_zeros + total_num_zeros.data = original_total_num_zeros.data.clone() torch.distributed.all_reduce( total_num_zeros, op=torch.distributed.ReduceOp.SUM, group=group ) diff --git a/megatron/megatron/core/transformer/attention.py b/megatron/megatron/core/transformer/attention.py index 583e3c1e6..7a1d7cd4d 100644 --- a/megatron/megatron/core/transformer/attention.py +++ b/megatron/megatron/core/transformer/attention.py @@ -508,22 +508,44 @@ def __init__( ) if submodules.q_layernorm is not None: - self.q_layernorm = build_module( - submodules.q_layernorm, - hidden_size=self.hidden_size_per_attention_head, - config=self.config, - eps=self.config.layernorm_epsilon, - ) + if not self.config.qk_layernorm_hidden_dim: + self.q_layernorm = build_module( + submodules.q_layernorm, + hidden_size=self.hidden_size_per_attention_head, + config=self.config, + eps=self.config.layernorm_epsilon, + ) + else: + tp_world_size = get_tensor_model_parallel_world_size() + assert tp_world_size <= 1, "TP world size must be less than 1 for qk_layernorm_hidden_dim" + # nums_head_cur_rank = divide(self.config.num_attention_heads, tp_world_size) + self.q_layernorm = build_module( + submodules.q_layernorm, + hidden_size=self.query_projection_size, + config=self.config, + eps=self.config.layernorm_epsilon, + ) else: self.q_layernorm = None if submodules.k_layernorm is not None: - self.k_layernorm = build_module( - submodules.k_layernorm, - hidden_size=self.hidden_size_per_attention_head, - config=self.config, - eps=self.config.layernorm_epsilon, - ) + if not self.config.qk_layernorm_hidden_dim: + self.k_layernorm = build_module( + submodules.k_layernorm, + hidden_size=self.hidden_size_per_attention_head, + config=self.config, + eps=self.config.layernorm_epsilon, + ) + else: + tp_world_size = get_tensor_model_parallel_world_size() + assert tp_world_size <= 1, "TP world size must be less than 1 for qk_layernorm_hidden_dim" + # nums_head_cur_rank = divide(self.config.num_attention_heads, tp_world_size) + self.k_layernorm = build_module( + submodules.k_layernorm, + hidden_size=self.kv_projection_size, + config=self.config, + eps=self.config.layernorm_epsilon, + ) else: self.k_layernorm = None @@ -640,10 +662,24 @@ def get_query_key_value_tensors(self, hidden_states, key_value_states=None): query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) if self.q_layernorm is not None: - query = self.q_layernorm(query) + if not self.config.qk_layernorm_hidden_dim: + query = self.q_layernorm(query) + else: + # [sq, b, np, hn] -> [sq, b, 1, np * hn] + query_shape = list(query.shape) + query = query.reshape(query.size(0), query.size(1), 1, -1) + query = self.q_layernorm(query) + query = query.reshape(*query_shape) if self.k_layernorm is not None: - key = self.k_layernorm(key) + if not self.config.qk_layernorm_hidden_dim: + key = self.k_layernorm(key) + else: + # [sq, b, ng, hn] -> [sq, b, 1, ng * hn] + key_shape = list(key.shape) + key = key.reshape(key.size(0), key.size(1), 1, -1) + key = self.k_layernorm(key) + key = key.reshape(*key_shape) if self.config.test_mode: self.run_realtime_tests() diff --git a/megatron/megatron/core/transformer/transformer_block.py b/megatron/megatron/core/transformer/transformer_block.py index 246a9fcf7..a62eec281 100755 --- a/megatron/megatron/core/transformer/transformer_block.py +++ b/megatron/megatron/core/transformer/transformer_block.py @@ -83,7 +83,7 @@ def get_num_layers_to_build(config: TransformerConfig) -> int: pipeline_ranks = config.pipeline_model_parallel_size num_layers_per_pipeline_rank = config.num_layers // pipeline_ranks - if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None: + if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None and parallel_state.get_virtual_pipeline_model_parallel_world_size() > 1: # Interleaved pipeline parallelism: # Number of layers in each model chunk is the number of layers in the stage, # divided by the number of model chunks in a stage. diff --git a/megatron/megatron/core/transformer/transformer_config.py b/megatron/megatron/core/transformer/transformer_config.py index e734e9387..e6092903e 100644 --- a/megatron/megatron/core/transformer/transformer_config.py +++ b/megatron/megatron/core/transformer/transformer_config.py @@ -350,7 +350,13 @@ class TransformerConfig(ModelParallelConfig): config_logger_dir: str = "" """When non-empty, dumps entry-point configs to config_logger_dir""" + + qk_layernorm_hidden_dim: bool = False + """Whether to apply LayerNorm to the query and key embeddings on the hidden dimension rather than head dimension.""" + use_partial_reduce_for_shared_embedding: bool = False + """Whether to use partional reduce for shared embedding.""" + flash_decode: bool = False """ Use the optimized flash decoding kernel during inference. """ diff --git a/megatron/megatron/training/arguments.py b/megatron/megatron/training/arguments.py index e400f8bc6..d4430d307 100644 --- a/megatron/megatron/training/arguments.py +++ b/megatron/megatron/training/arguments.py @@ -1784,6 +1784,8 @@ def _add_distributed_args(parser): 'affects the encoder embedding.)') group.add_argument('--use-distributed-optimizer', action='store_true', help='Use distributed optimizer.') + group.add_argument('--use-partial-reduce-for-shared-embedding', action='store_true', + help='Use partial reduce for shared word embedding.') group.add_argument('--no-shared-fs', action='store_true', help='Indicate whether not running on a shared file system.') group.add_argument('--ulysses-sp-parallel-size', type=int, default=1, @@ -2118,6 +2120,8 @@ def _add_vision_args(parser): # regularization arguments group.add_argument('--qk-layernorm', action='store_true', help='Whether to layer normalize the q and k attention embeddings.') + group.add_argument('--qk-layernorm-hidden-dim', action='store_true', + help='Whether to layer normalize the q and k attention embeddings on hidden dimension rather than head dimension') return parser diff --git a/megatron/megatron/training/global_vars.py b/megatron/megatron/training/global_vars.py index 3cd53a451..70f78c5aa 100644 --- a/megatron/megatron/training/global_vars.py +++ b/megatron/megatron/training/global_vars.py @@ -129,11 +129,11 @@ def set_global_writers(args): ranks_tensor = torch.tensor([0 for _ in range(size)], dtype=torch.int, device=comm_device) orig_ranks = torch.tensor([i for i in range(size)], dtype=torch.int, device=comm_device) if is_last_rank(): - ranks_tensor = orig_ranks ranks_list = torch.distributed.get_process_group_ranks(mp_groups[-1]) - ranks_tensor = torch.tensor(ranks_list, dtype=torch.int, device=comm_device) + ranks_tensor = torch.tensor(ranks_list, dtype=torch.int, device=comm_device) + orig_ranks = ranks_tensor.clone().detach() for group in mp_groups: - ranks_tensor = orig_ranks + ranks_tensor = orig_ranks.clone() torch.distributed.all_reduce(ranks_tensor, group=group) if torch.distributed.get_rank() in ranks_tensor.tolist(): _set_wandb_writer(args) diff --git a/megatron/megatron/training/utils.py b/megatron/megatron/training/utils.py index 32e2bf17a..806b3d9d3 100644 --- a/megatron/megatron/training/utils.py +++ b/megatron/megatron/training/utils.py @@ -105,11 +105,23 @@ def calc_params_l2_norm(model): group=data_parallel_group) # Sum across all model-parallel GPUs(tensor + pipeline). - torch.distributed.all_reduce( - norm_2, - op=torch.distributed.ReduceOp.SUM, - group=mpu.get_model_parallel_group() - ) + mp_groups = mpu.get_model_parallel_group() + if isinstance(mp_groups, list): + if len(mp_groups) > 1: + assert mpu.get_expert_model_parallel_world_size() <= 1, f"Expert model parallelism is not supported with heterogeneous model parallelism" + original_norm_2 = norm_2.clone().detach() + for mp_group in mp_groups: + norm_2 = original_norm_2.clone() + torch.distributed.all_reduce(norm_2, + op=torch.distributed.ReduceOp.SUM, + group=mp_group) + else: + # Sum across all model-parallel GPUs(tensor + pipeline). + torch.distributed.all_reduce( + norm_2, + op=torch.distributed.ReduceOp.SUM, + group=mpu.get_model_parallel_group() + ) # Calculate moe norm if len(moe_params_data) > 0: moe_norm, _ = multi_tensor_applier( diff --git a/tests/functional_tests/test_cases/hetero_train/aquila/conf/dp2dp4_shared_embedding.yaml b/tests/functional_tests/test_cases/hetero_train/aquila/conf/dp2dp4_shared_embedding.yaml new file mode 100644 index 000000000..2c49401c2 --- /dev/null +++ b/tests/functional_tests/test_cases/hetero_train/aquila/conf/dp2dp4_shared_embedding.yaml @@ -0,0 +1,34 @@ +defaults: + - _self_ + - train: dp2dp4_shared_embedding + +experiment: + exp_name: dp2dp4_shared_embedding + exp_dir: tests/functional_tests/test_cases/hetero_train/aquila/results_test/dp2dp4_shared_embedding + task: + type: train + backend: megatron + entrypoint: flagscale/train/train_aquila.py + runner: + backend: torchrun + ssh_port: null + shell_cmds: null + envs: + HYDRA_FULL_ERROR: 1 + CUDA_VISIBLE_DEVICES: "0,1,2,3,4,5" + CUDA_DEVICE_MAX_CONNECTIONS: 1 + CUBLAS_WORKSPACE_CONFIG: ":4096:8" + NCCL_ALGO: "Tree" + NVTE_APPLY_QK_LAYER_SCALING: 0 + NVTE_ALLOW_NONDETERMINISTIC_ALGO: 0 + NVTE_FLASH_ATTN: 0 + NVTE_FUSED_ATTN: 0 + CUDNN_BENCHMARK: "false" + CUDNN_DETERMINISTIC: "true" + # cmds: + # before_start: source /root/miniconda3/bin/activate flagscale +action: run + +hydra: + run: + dir: ${experiment.exp_dir}/hydra diff --git a/tests/functional_tests/test_cases/hetero_train/aquila/conf/train/dp2dp4_shared_embedding.yaml b/tests/functional_tests/test_cases/hetero_train/aquila/conf/train/dp2dp4_shared_embedding.yaml new file mode 100644 index 000000000..c723bcc28 --- /dev/null +++ b/tests/functional_tests/test_cases/hetero_train/aquila/conf/train/dp2dp4_shared_embedding.yaml @@ -0,0 +1,81 @@ +system: + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 2 + reset_position_ids: True + reset_attention_mask: True + add_qkv_bias: True + sequence_parallel: True + disable_bias_linear: True + use_distributed_optimizer: True + hetero: + enable_hetero: True + use_partial_reduce_for_shared_embedding: True + hetero_pipeline_layer_split: [4, 4] + hetero_process_meshes: [1,1,1,2,1, 1,1,1,4,1] + hetero_device_types: ["A800", "A800"] + + standalone_embedding_stage: False + hetero_current_device_type: "A800" + precision: + fp16: True + initial_loss_scale: 522893 + min_loss_scale: 1.0 + attention_softmax_in_fp32: True + accumulate_allreduce_grads_in_fp32: True + logging: + log_interval: 1 + no_log_loss_scale_to_tensorboard: true + checkpoint: + no_save_optim: true + no_save_rng: true + save_interval: 100000 + tensorboard_log_interval: 999999 + +model: + deterministic_mode: true + use_mcore_models: true + transformer_impl: transformer_engine + num_layers: 8 + hidden_size: 512 + num_attention_heads: 8 + seq_length: 1024 + max_position_embeddings: 1024 + norm_epsilon: 1e-5 + use_rotary_position_embeddings: true + no_position_embedding: true + swiglu: true + multiple_of: 256 + normalization: RMSNorm + # rotary_interleaved_patch: true + untie_embeddings_and_output_weights: false + init_method_std: 0.02 + attention_dropout: 0.0 + hidden_dropout: 0.0 + weight_decay: 0.1 + clip_grad: 1.0 + train_iters: 10 + micro_batch_size: 4 + global_batch_size: 1024 + seed: 42 + + optimizer: + weight_decay: 0.1 + adam_beta1: 0.9 + adam_beta2: 0.95 + lr_scheduler: + lr: 2.0e-5 + min_lr: 2.0e-6 + lr_warmup_samples: 0 + lr_warmup_fraction: 0.01 + lr_decay_style: cosine + +data: + data_path: /home/gitlab-runner/data/pile_wikipedia_demo/pile_wikipedia_demo + # data_path: /share/project/lizhiyu/FlagScale/build/data/pile_wikipedia_demo + split: 1 + tokenizer: + tokenizer_type: AquilaTokenizerFS + vocab_file: ./examples/aquila/tokenizer/vocab.json + merge_file: ./examples/aquila/tokenizer/merges.txt + special_tokens_file: ./examples/aquila/tokenizer/special_tokens.txt + vocab_size: 100008 \ No newline at end of file diff --git a/tests/functional_tests/test_cases/hetero_train/aquila/results_gold/dp2dp4_shared_embedding.json b/tests/functional_tests/test_cases/hetero_train/aquila/results_gold/dp2dp4_shared_embedding.json new file mode 100644 index 000000000..15e0d87dc --- /dev/null +++ b/tests/functional_tests/test_cases/hetero_train/aquila/results_gold/dp2dp4_shared_embedding.json @@ -0,0 +1,2 @@ +{"lm loss:": {"values": [[11.5554, 11.55801, 11.34487, 11.19957, 11.07778, 10.98522, 10.92735, 10.88371, 10.86305, 10.84247]]}} + diff --git a/tests/functional_tests/test_cases/hetero_train/aquila/results_gold/tp2dp1pp1_tp2dp2pp1_tp1dp2pp1.json b/tests/functional_tests/test_cases/hetero_train/aquila/results_gold/tp2dp1pp1_tp2dp2pp1_tp1dp2pp1.json index 05985eef7..19cd8a6c0 100644 --- a/tests/functional_tests/test_cases/hetero_train/aquila/results_gold/tp2dp1pp1_tp2dp2pp1_tp1dp2pp1.json +++ b/tests/functional_tests/test_cases/hetero_train/aquila/results_gold/tp2dp1pp1_tp2dp2pp1_tp1dp2pp1.json @@ -1,2 +1,2 @@ -{"lm loss:": {"values": [[11.62049, 11.61899, 11.41394, 11.27374, 11.15945, 11.07574, 11.01715, 10.97403, 10.95042, 10.93264]]}} +{"lm loss:": {"values": [[11.62049, 11.61899, 11.41389, 11.27375, 11.15958, 11.07644, 11.01809, 10.97522, 10.95196, 10.93447]]}} diff --git a/tests/scripts/functional_tests/config.yml b/tests/scripts/functional_tests/config.yml index 504e48620..b5548d362 100644 --- a/tests/scripts/functional_tests/config.yml +++ b/tests/scripts/functional_tests/config.yml @@ -17,6 +17,7 @@ hetero_train: test_cases: -tp2pp1_tp4pp1_tp2pp1 -tp2dp1pp1_tp2dp2pp1_tp1dp2pp1 + -dp2dp4_shared_embedding # Add in the feature # inference: diff --git a/tools/checkpoint/aquila/args.py b/tools/checkpoint/aquila/args.py new file mode 100644 index 000000000..e62c42b00 --- /dev/null +++ b/tools/checkpoint/aquila/args.py @@ -0,0 +1,93 @@ +import os +import json + + +def load_args_hf2mg(args): + + # Read llama args. + llama_args_path = os.path.join(args.load, "config.json") + with open(llama_args_path) as f: + llama_args = json.load(f) + + # Update Megatron args. + args.attention_dropout = llama_args["attention_dropout"] + args.hidden_dropout = llama_args["attention_dropout"] + args.hidden_size = llama_args["hidden_size"] + args.swiglu = llama_args["hidden_act"] == "silu" + args.init_method_std = llama_args["initializer_range"] + args.ffn_hidden_size = llama_args["intermediate_size"] + args.max_position_embeddings = llama_args["max_position_embeddings"] + args.model_type = llama_args["model_type"] + args.num_attention_heads = llama_args["num_attention_heads"] + args.num_layers = llama_args["num_hidden_layers"] + args.num_query_groups = llama_args["num_key_value_heads"] + args.norm_epsilon = llama_args["rms_norm_eps"] + args.rotary_seq_len_interpolation_factor = ( + None if llama_args["rope_scaling"] == "null" else llama_args["rope_scaling"] + ) + args.rotary_base = llama_args["rope_theta"] + args.untie_embeddings_and_output_weights = not llama_args["tie_word_embeddings"] + args.bf16 = llama_args["torch_dtype"] == "bfloat16" + args.fp16 = llama_args["torch_dtype"] == "float16" + args.vocab_size = llama_args["vocab_size"] + args.padded_vocab_size = llama_args["vocab_size"] + + args.seq_length = 4096 + args.global_batch_size = 1024 + args.iteration = 1 # '0', 'release' don't work + args.add_position_embedding = False + args.group_query_attention = True + args.normalization = "RMSNorm" + args.use_rotary_position_embeddings = True + args.add_bias_linear = False + args.add_qkv_bias = False + args.make_vocab_size_divisible_by = 64 + args.consumed_train_samples = 0 + args.consumed_valid_samples = 0 + args.norm_has_bias = False + + +def save_args_mg2hf(args): + from .llama_model.configuration_llama import LlamaConfig + + config = LlamaConfig( + vocab_size=args.vocab_size, + hidden_size=args.hidden_size, + intermediate_size=args.ffn_hidden_size, + num_hidden_layers=args.encoder_num_layers, + num_attention_heads=args.num_attention_heads, + num_key_value_heads=args.num_query_groups, + hidden_act="silu" if args.swiglu else False, + max_position_embeddings=args.max_position_embeddings, + initializer_range=args.init_method_std, + rms_norm_eps=args.norm_epsilon, + use_cache=True, + tie_word_embeddings=not args.untie_embeddings_and_output_weights, + rope_theta=args.rotary_base, + rope_scaling=args.rotary_seq_len_interpolation_factor, + attention_bias=args.add_qkv_bias, + attention_dropout=args.attention_dropout, + torch_dtype=args.params_dtype, + bias_dropout_fusion=args.bias_dropout_fusion, + end_weight_decay=args.end_weight_decay, + global_batch_size=args.global_batch_size, + hidden_dropout=args.hidden_dropout, + lr=args.lr, + lr_decay_style=args.lr_decay_style, + make_vocab_size_divisible_by=args.make_vocab_size_divisible_by, + masked_softmax_fusion=args.masked_softmax_fusion, + min_lr=args.min_lr, + norm_init_weight=args.norm_init_weight, + perform_initialization=args.perform_initialization, + reset_attention_mask=args.reset_attention_mask, + reset_position_ids=args.reset_position_ids, + rotary_base=args.rotary_base, + seed=args.seed, + split=args.split, + start_weight_decay=args.start_weight_decay, + use_flash_attn=args.use_flash_attn, + weight_decay_incr_style=args.weight_decay_incr_style, + ) + config.save_pretrained(args.save) + + return config diff --git a/tools/checkpoint/aquila/ckpt.py b/tools/checkpoint/aquila/ckpt.py new file mode 100644 index 000000000..8183f87b4 --- /dev/null +++ b/tools/checkpoint/aquila/ckpt.py @@ -0,0 +1,203 @@ +import torch + +import sys +sys.path.append("..") +from mixtral.ckpt import ( + get_hf_attn_ckpt, + set_hf_attn_ckpt, + get_embedding_ckpt, + get_final_norm_ckpt, + get_output_layer_ckpt, + set_embedding_ckpt, + set_final_norm_ckpt, + set_output_layer_ckpt, +) + + +def get_hf_mlp_ckpt(message, model, layer_id, args): + assert args.swiglu is True + + tf_layer = model.model.layers[layer_id] + message["mlp l0 weight W"] = tf_layer.mlp.gate_proj.weight.data + message["mlp l0 weight V"] = tf_layer.mlp.up_proj.weight.data + message["mlp l1 weight"] = tf_layer.mlp.down_proj.weight.data + + if args.add_bias_linear: + message["mlp l0 bias W"] = tf_layer.mlp.gate_proj.bias.data + message["mlp l0 bias V"] = tf_layer.mlp.up_proj.bias.data + message["mlp l1 bias"] = tf_layer.mlp.down_proj.bias.data + + +def set_hf_mlp_ckpt(message, model, layer_id, md, args): + assert args.swiglu is True + + tf_layer = model.model.layers[layer_id] + tf_layer.mlp.gate_proj.weight.data.copy_(message.pop("mlp l0 weight W")) + tf_layer.mlp.up_proj.weight.data.copy_(message.pop("mlp l0 weight V")) + tf_layer.mlp.down_proj.weight.data.copy_(message.pop("mlp l1 weight")) + + if md.add_bias_linear: + tf_layer.mlp.gate_proj.bias.data.copy_(message.pop("mlp l0 bias W")) + tf_layer.mlp.up_proj.bias.data.copy_(message.pop("mlp l0 bias V")) + tf_layer.mlp.down_proj.bias.data.copy_(message.pop("mlp l1 bias")) + + +def _get_parallel_size(args): + assert args.expert_model_parallel_size == 1 + return args.tensor_model_parallel_size, \ + args.pipeline_model_parallel_size, \ + args.expert_model_parallel_size, \ + args.virtual_pipeline_model_parallel_size or 1 + + +def get_attn_ckpt(message, models, layer_id, args): + tp_size, _, _, _ = _get_parallel_size(args) + + # parallel tensor + qkv_weight = [] + qkv_bias = [] + proj_weight = [] + # non-parallel tensor + proj_bias = None + input_norm_weight = None + input_norm_bias = None + post_norm_weight = None + post_norm_bias = None + + assert len(models) == tp_size + for model in models: + tf_layer = model.decoder.layers[layer_id] + # weight + qkv_weight.append(tf_layer.self_attention.linear_qkv.weight.data) + proj_weight.append(tf_layer.self_attention.linear_proj.weight.data) + input_norm_weight = tf_layer.self_attention.linear_qkv.layer_norm_weight.data + post_norm_weight = tf_layer.mlp.linear_fc1.layer_norm_weight.data + # bias + if args.norm_has_bias: + input_norm_bias = tf_layer.self_attention.linear_qkv.layer_norm_bias.data + post_norm_bias = tf_layer.mlp.linear_fc1.layer_norm_bias.data + if args.add_qkv_bias or args.add_bias_linear: + qkv_bias.append(tf_layer.self_attention.linear_qkv.bias.data) + if args.add_bias_linear: + proj_bias = tf_layer.self_attention.linear_proj.bias.data + + # weight + message["qkv weight"] = torch.cat(qkv_weight, dim=0) + message["proj weight"] = torch.cat(proj_weight, dim=1) + message["input norm weight"] = input_norm_weight + message["post norm weight"] = post_norm_weight + # bias + if args.norm_has_bias: + message["input norm bias"] = input_norm_bias + message["post norm bias"] = post_norm_bias + if args.add_qkv_bias or args.add_bias_linear: + message["qkv bias"] = torch.cat(qkv_bias, dim=0) + if args.add_bias_linear: + message["proj bias"] = proj_bias + + +def get_mlp_ckpt(message, models, layer_id, args): + tp_size, _, _, _ = _get_parallel_size(args) + + # parallel tensor + l0_weight = [] + l0_bias = [] + l1_weight = [] + # non-parallel tensor + l1_bias = None + + assert len(models) == tp_size + for model in models: + tf_layer = model.decoder.layers[layer_id] + # weight + l0_weight.append(tf_layer.mlp.linear_fc1.weight.data) + l1_weight.append(tf_layer.mlp.linear_fc2.weight.data) + # bias + if args.add_bias_linear: + l0_bias.append(tf_layer.mlp.linear_fc1.bias.data) + l1_bias = tf_layer.mlp.linear_fc2.bias.data + + # weight + message["mlp l1 weight"] = torch.cat(l1_weight, dim=1) + if args.swiglu: + for tp_rank in range(tp_size): + l0_weight[tp_rank] = torch.chunk(l0_weight[tp_rank], 2, dim=0) + message["mlp l0 weight W"] = torch.cat([w[0] for w in l0_weight], dim=0) + message["mlp l0 weight V"] = torch.cat([w[1] for w in l0_weight], dim=0) + else: + message["mlp l0 weight"] = torch.cat(l0_weight, dim=0) + # bias + if args.add_bias_linear: + message["mlp l1 bias"] = l1_bias + if args.swiglu: + for tp_rank in range(tp_size): + l0_bias[tp_rank] = torch.chunk(l0_bias[tp_rank], 2, dim=0) + message["mlp l0 bias W"] = torch.cat([b[0] for b in l0_bias],dim=0) + message["mlp l0 bias V"] = torch.cat([b[1] for b in l0_bias],dim=0) + else: + message["mlp l0 bias"] = torch.cat(l0_bias, dim=0) + + +def set_attn_ckpt(message, models, layer_id, md, args): + tp_size, _, _, _ = _get_parallel_size(args) + + # weight + qkv_weight = torch.chunk(message.pop("qkv weight"), tp_size, dim=0) + proj_weight = torch.chunk(message.pop("proj weight"), tp_size, dim=1) + input_norm_weight = message.pop("input norm weight") + post_norm_weight = message.pop("post norm weight") + # bias + if md.norm_has_bias: + input_norm_bias = message.pop("input norm bias") + post_norm_bias = message.pop("post norm bias") + if md.add_qkv_bias or md.add_bias_linear: + qkv_bias = torch.chunk(message.pop("qkv bias"), tp_size, dim=0) + if md.add_bias_linear: + proj_bias = message.pop("proj bias") + + # set data to transformer layer's self-attention + for tp_rank, model in enumerate(models): + tf_layer = model.decoder.layers[layer_id] + tf_layer.self_attention.linear_qkv.weight.data.copy_(qkv_weight[tp_rank]) + tf_layer.self_attention.linear_proj.weight.data.copy_(proj_weight[tp_rank]) + tf_layer.self_attention.linear_qkv.layer_norm_weight.data.copy_(input_norm_weight) + tf_layer.mlp.linear_fc1.layer_norm_weight.data.copy_(post_norm_weight) + if md.norm_has_bias: + tf_layer.self_attention.linear_qkv.layer_norm_bias.data.copy_(input_norm_bias) + tf_layer.mlp.linear_fc1.layer_norm_bias.data.copy(post_norm_bias) + if md.add_qkv_bias or md.add_bias_linear: + tf_layer.self_attention.linear_qkv.bias.data.copy_(qkv_bias[tp_rank]) + if md.add_bias_linear: + tf_layer.self_attention.linear_proj.bias.data.copy_(proj_bias) + + +def set_mlp_ckpt(message, models, layer_id, md, args): + tp_size, _, _, _ = _get_parallel_size(args) + + # weight + l1_weight = torch.chunk(message.pop("mlp l1 weight"), tp_size, dim=1) + if md.swiglu: + l0_weight_W = torch.chunk(message.pop("mlp l0 weight W"), tp_size, dim=0) + l0_weight_V = torch.chunk(message.pop("mlp l0 weight V"), tp_size, dim=0) + l0_weight = [torch.cat(weights, dim=0) for weights in zip(l0_weight_W, l0_weight_V)] + else: + l0_weight = torch.chunk(message.pop("mlp l0 weight"), tp_size, dim=0) + # bias + if md.add_bias_linear: + l1_bias = message.pop("mlp l1 bias") + if md.swiglu: + l0_bias_W = torch.chunk(message.pop("mlp l0 bias W"), tp_size, dim=0) + l0_bias_V = torch.chunk(message.pop("mlp l0 bias V"), tp_size, dim=0) + l0_bias = [torch.cat(bias, dim=0) for bias in zip(l0_bias_W, l0_bias_V)] + else: + l0_bias = torch.chunk(message.pop("mlp l0 bias"), tp_size, dim=0) + + # set data to transformer layer for mlp + for tp_rank, model in enumerate(models): + tf_layer = model.decoder.layers[layer_id] + tf_layer.mlp.linear_fc1.weight.data.copy_(l0_weight[tp_rank]) + tf_layer.mlp.linear_fc2.weight.data.copy_(l1_weight[tp_rank]) + + if md.add_bias_linear: + tf_layer.mlp.linear_fc1.bias.data.copy_(l0_bias[tp_rank]) + tf_layer.mlp.linear_fc2.bias.data.copy_(l1_bias) diff --git a/tools/checkpoint/aquila/model.py b/tools/checkpoint/aquila/model.py new file mode 100644 index 000000000..05b7a9a23 --- /dev/null +++ b/tools/checkpoint/aquila/model.py @@ -0,0 +1,41 @@ +import time +from megatron.core.enums import ModelType + +model_type = ModelType.encoder_or_decoder # Megatron's model_type + + +def get_hf_model(dtype, model_path=None, config=None): + try: + from .llama_model.modeling_llama import LlamaForCausalLM + except ImportError: + print("Failed to import LlamaForCausalLM from modeling_llama, please add the model of huggingface style.") + s_time = time.time() + if model_path and not config: + model = LlamaForCausalLM.from_pretrained( + model_path, device_map="cpu", trust_remote_code=True, torch_dtype=dtype + ) + elif not model_path and config: + import torch + from accelerate import init_empty_weights + from accelerate.utils import set_module_tensor_to_device + + with init_empty_weights(): + model = LlamaForCausalLM._from_config( + config=config, torch_dtype=dtype + ) + for name, param in model.named_parameters(): + set_module_tensor_to_device( + model, name, "cpu", torch.empty(*param.size(), dtype=dtype) + ) + else: + raise ValueError("Need one args, model_path or config, to build HF model.") + print("> build huggingface model elapsed time:", time.time() - s_time) + return model + + +def get_mg_model(dtype, pre_process, post_process): + from pretrain_gpt import model_provider + s_time = time.time() + model = model_provider(pre_process, post_process).to(dtype) + print("> build megatron model elapsed time:", time.time() - s_time) + return model diff --git a/tools/checkpoint/loader_mcore.py b/tools/checkpoint/loader_mcore.py index cee5d28ac..adf9e39d8 100644 --- a/tools/checkpoint/loader_mcore.py +++ b/tools/checkpoint/loader_mcore.py @@ -93,6 +93,17 @@ def _set_arg(arg_name): _set_arg("expert_model_parallel_size") _set_arg("num_experts") _set_arg("sequence_parallel") + + # for hetero + _set_arg("enable_hetero") + _set_arg("hetero_process_meshes") + _set_arg("hetero_pipeline_layer_split") + + # for hetero + if margs.hetero_process_meshes is not None: + margs.pipeline_model_parallel_size = sum(row[-1] for row in margs.hetero_process_meshes) + margs.data_parallel_size = 1 + margs.micro_batch_size = 1 # Arguments do sanity checks on the world size, but we don't care, # so trick it into thinking we are plenty of processes