Skip to content

Commit

Permalink
[Model] Add aquila 3b configuration file (#281)
Browse files Browse the repository at this point in the history
1. Add Aquila 3b configuration file
2. Add qk_norm on the hidden-dim, not head-dim
3. Support shared embedding weight when dp-sizes are different
4. Fix bugs when calculating gradient normalization if dp-sizes are
different.
5. Add `export PYTHONPATH=...:${PYTHONPATH}`
6. Support to convert checkpoint for hetero-train

---------

Co-authored-by: lzy-dev <[email protected]>
  • Loading branch information
heavyrain-lzy and lzy-dev authored Dec 30, 2024
1 parent f2bc020 commit 3d81a6e
Show file tree
Hide file tree
Showing 24 changed files with 725 additions and 48 deletions.
4 changes: 3 additions & 1 deletion examples/aquila/conf/config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
defaults:
- train: demo
- train: demo
- _self_

experiment:
Expand All @@ -9,6 +9,8 @@ experiment:
type: train
backend: megatron
entrypoint: ./flagscale/train/train_aquila.py
# cmds:
# before_start: source /root/miniconda3/bin/activate flagscale
runner:
backend: torchrun
nnodes: 1
Expand Down
89 changes: 89 additions & 0 deletions examples/aquila/conf/train/train_aquila_3b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
system:
reset_position_ids: True
reset_attention_mask: True
add_qkv_bias: True
tensor_model_parallel_size: 1
pipeline_model_parallel_size: 2
disable_bias_linear: True
use_flash_attn: True
use_distributed_optimizer: True
precision:
bf16: True
initial_loss_scale: 522893
min_loss_scale: 1.0
attention_softmax_in_fp32: True
accumulate_allreduce_grads_in_fp32: True
logging:
log_interval: 1
log_throughput: True
tensorboard_log_interval: 1
wandb-log-model: False
wandb-log-model-interval: 1
wandb_project: "train-aquila-3B"
wandb_exp_name: "train-test-3B"
checkpoint:
load: outputs_llama3/checkpoint_mc
ckpt_format: torch
save_interval: 2385

# hetero:
# enable_hetero: True
# hetero_use_cpu_communication: False
# use_partial_reduce_for_shared_embedding: True
# # mesh format [tp1,cp1,ep1,dp1,pp1,(tp2,cp2...)]

# hetero_pipeline_layer_split: [12,12]
# hetero_process_meshes: [1,1,1,4,1, 1,1,1,4,1]
# hetero_device_types: ["A800", "A800"]

# standalone_embedding_stage: False
# hetero_current_device_type: "A800"
model:
transformer_impl: transformer_engine
num_layers: 24
hidden_size: 1024
num_attention_heads: 16
group_query_attention: True
num_query_groups: 2
seq_length: 4096
max_position_embeddings: 4096 # only for adding position embeddings
norm_epsilon: 1e-6
use_rotary_position_embeddings: true
no_position_embedding: true
rotary_base: 1000000
swiglu: true
multiple_of: 256
hidden_dim_multiplier: 2 # ffn_hidden_size 11008
normalization: RMSNorm
position_embedding_type: rope
untie_embeddings_and_output_weights: False
init_method_std: 0.02
attention_dropout: 0.0
hidden_dropout: 0.0
weight_decay: 0.1
clip_grad: 1.0
train_samples: 29297664 #120B tokens
eval_iters: 0
micro_batch_size: 2
global_batch_size: 1024
seed: 42

optimizer:
weight_decay: 0.1
adam_beta1: 0.9
adam_beta2: 0.95
lr_scheduler:
lr: 5.0e-3
min_lr: 5.0e-4
lr_warmup_samples: 10
lr_decay_style: cosine

data:
data_path: {data_path:??}
split: 1
no_mmap_bin_files: true
tokenizer:
tokenizer_type: QwenTokenizerFS
tokenizer_path: examples/aquila/qwentokenizer
vocab_size: 151851
make_vocab_size_divisible_by: 64
2 changes: 1 addition & 1 deletion flagscale/runner/runner_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def _generate_run_script_train(
f.write(f"\n")
f.write(f"cd {root_dir}\n")
f.write(f"\n")
f.write(f"export PYTHONPATH={megatron_dir}:{root_dir}\n")
f.write(f"export PYTHONPATH={megatron_dir}:{root_dir}:${{PYTHONPATH}}\n")
f.write(f"\n")
f.write(f'cmd="{cmd}"\n')
f.write(f"\n")
Expand Down
12 changes: 12 additions & 0 deletions flagscale/train/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,18 @@ def pre_validate_args(self):
'pipeline_model_parallel_split_rank not supported with process_meshes set!'
self.args.transformer_pipeline_model_parallel_size = self.args.pipeline_model_parallel_size

# if untie_embeddings_and_output_weights is False, the first and last stage should have the same tp degree
if self.args.untie_embeddings_and_output_weights == False:
assert hetero_process_meshes_tp[0] == hetero_process_meshes_tp[-1], \
f"if untie_embeddings_and_output_weights is False, the first and last stage should have the same tp degree!"
assert self.args.hetero_use_cpu_communication == False, \
f"if untie_embeddings_and_output_weights is False, the hetero_use_cpu_communication should be False currently!"
if hetero_process_meshes_dp[0] != hetero_process_meshes_dp[-1]:
assert self.args.use_partial_reduce_for_shared_embedding == True, \
f"if untie_embeddings_and_output_weights is False and hetero_process_meshes_dp[0] and hetero_process_meshes_dp[-1] are different, "\
"the use_partial_reduce_for_shared_embedding should be True currently!"


# Virtual parallel size.
if self.args.enable_hetero:
assert self.args.num_layers_per_virtual_pipeline_stage == None, \
Expand Down
23 changes: 19 additions & 4 deletions flagscale/train/hetero/p2p_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,11 @@ def recv_backward_hetero(tensor_shape: Shape, config: ModelParallelConfig) -> to
config=config,
group=group,
)
output_tensor_grad.data[sp_start:sp_end, dp_start:dp_end, :] = output_tensor_grad_sliced
if dp_end - dp_start != tensor_shape[1]:
dp_coef = float((dp_end - dp_start)) / float(tensor_shape[1])
output_tensor_grad.data[sp_start:sp_end, dp_start:dp_end, :] = output_tensor_grad_sliced * dp_coef
else:
output_tensor_grad.data[sp_start:sp_end, dp_start:dp_end, :] = output_tensor_grad_sliced
if config.timers is not None:
config.timers('backward-recv').stop()

Expand Down Expand Up @@ -330,6 +334,10 @@ def send_backward_hetero(input_tensor_grad: torch.Tensor, config: ModelParallelC
for tensor_slice in tensor_slices:
dst_rank, (dp_start, dp_end), (sp_start, sp_end), local_hidden_size = tensor_slice
input_tensor_grad_sliced = input_tensor_grad[sp_start:sp_end, dp_start:dp_end, :]
dp_coef = 1.0
if dp_end - dp_start != input_tensor_grad.shape[1]:
dp_coef = float(input_tensor_grad.shape[1]) / float((dp_end - dp_start))

group = None
pp_groups = para_ctx.get_pipeline_model_parallel_group()
for pp_group in pp_groups:
Expand All @@ -339,7 +347,7 @@ def send_backward_hetero(input_tensor_grad: torch.Tensor, config: ModelParallelC
break
_communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad_sliced.contiguous() if "gloo" not in group.name() else input_tensor_grad_sliced.cpu(),
tensor_send_prev=input_tensor_grad_sliced.contiguous() * dp_coef if "gloo" not in group.name() else input_tensor_grad_sliced.cpu() * dp_coef,
recv_prev=False,
recv_next=False,
tensor_shape=None,
Expand Down Expand Up @@ -405,7 +413,11 @@ def send_forward_recv_backward_hetero(
config=config,
group=group,
)
output_tensor_grad.data[sp_start:sp_end, dp_start:dp_end, :] = output_tensor_grad_sliced
if dp_end - dp_start != tensor_shape[1]:
dp_coef = float((dp_end - dp_start)) / float(tensor_shape[1])
output_tensor_grad.data[sp_start:sp_end, dp_start:dp_end, :] = output_tensor_grad_sliced * dp_coef
else:
output_tensor_grad.data[sp_start:sp_end, dp_start:dp_end, :] = output_tensor_grad_sliced
if config.timers is not None:
config.timers('forward-send-backward-recv').stop()
if output_tensor_grad is not None and output_tensor_grad.device == torch.device("cpu"):
Expand Down Expand Up @@ -453,6 +465,9 @@ def send_backward_recv_forward_hetero(
dst_rank, (dp_start, dp_end), (sp_start, sp_end), local_hidden_size = tensor_slice
input_tensor_grad_sliced = input_tensor_grad[sp_start:sp_end, dp_start:dp_end, :]
tensor_shape_sliced = (sp_end - sp_start, dp_end - dp_start, local_hidden_size)
dp_coef = 1.0
if dp_end - dp_start != input_tensor_grad.shape[1]:
dp_coef = float(input_tensor_grad.shape[1]) / float((dp_end - dp_start))
group = None
for pp_group in pp_groups:
pp_group_ranks = torch.distributed.get_process_group_ranks(pp_group)
Expand All @@ -461,7 +476,7 @@ def send_backward_recv_forward_hetero(
break
input_tensor_sliced, _, _ = _communicate(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad_sliced.contiguous() if "gloo" not in group.name() else input_tensor_grad_sliced.cpu(),
tensor_send_prev=input_tensor_grad_sliced.contiguous() * dp_coef if "gloo" not in group.name() else input_tensor_grad_sliced.cpu() * dp_coef,
recv_prev=True,
recv_next=False,
tensor_shape=tensor_shape_sliced,
Expand Down
6 changes: 1 addition & 5 deletions flagscale/train/hetero/parallel_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,11 +1003,7 @@ def get_embedding_group(self):
"""Get the embedding group the caller rank belongs to."""
groups = self._global_process_groups.get("embd", None)
assert groups is not None, 'embedding group is not initialized'
for group in groups:
if self._rank in self._global_process_group_to_ranks[group]:
embd_group = group
break
return embd_group
return groups

def get_position_embedding_group(self):
"""Get the position embedding group the caller rank belongs to."""
Expand Down
42 changes: 37 additions & 5 deletions megatron/megatron/core/distributed/finalize_model_grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,24 +115,56 @@ def _allreduce_word_embedding_grads(model: List[torch.nn.Module], config: Transf
sync.
"""

if (
parallel_state.is_rank_in_embedding_group(ignore_virtual=True)
and torch.distributed.get_world_size(parallel_state.get_embedding_group()) > 1
):
if (parallel_state.is_rank_in_embedding_group(ignore_virtual=True)):
embed_group = parallel_state.get_embedding_group()
if not isinstance(embed_group, list):
embed_group = [embed_group]
else:
return

if (torch.distributed.get_world_size(embed_group[0]) > 1):
if parallel_state.is_pipeline_first_stage(ignore_virtual=True):
model_module = model[0]
elif parallel_state.is_pipeline_last_stage(ignore_virtual=True):
model_module = model[-1]
else: # We do not support an interleaved schedule for models with encoders yet.
model_module = model[0]

use_dist_opt = False
if hasattr(model_module, "ddp_config"):
use_dist_opt = model_module.ddp_config.use_distributed_optimizer
model_module = get_attr_wrapped_model(model_module, 'pre_process', return_model_obj=True)
if model_module.share_embeddings_and_output_weights:
weight = model_module.shared_embedding_or_output_weight()
grad_attr = "main_grad" if hasattr(weight, "main_grad") else "grad"
orig_grad = getattr(weight, grad_attr)
grad = _unshard_if_dtensor(orig_grad)
torch.distributed.all_reduce(grad, group=parallel_state.get_embedding_group())
if use_dist_opt:
if config.use_partial_reduce_for_shared_embedding:
dp_world_size = parallel_state.get_data_parallel_world_size()
dp_rank = parallel_state.get_data_parallel_rank()
assert grad.shape[0] % dp_world_size == 0, f"grad shape: {grad.shape[0]}, dp_world_size: {dp_world_size}"
per_partion_size = grad.shape[0] // dp_world_size
if len(embed_group) == 1:
offset = per_partion_size * dp_rank
torch.distributed.all_reduce(grad[offset:offset+per_partion_size, :], group=embed_group[0])
else:
group_idx = 0
per_partion_size = per_partion_size // len(embed_group)
for group in embed_group:
offset = per_partion_size * (dp_rank * len(embed_group) + group_idx)
torch.distributed.all_reduce(grad[offset : offset + per_partion_size, :], group=group)
group_idx += 1
else: # megartron default method
torch.distributed.all_reduce(grad, group=embed_group[0])
else:
if len(embed_group) == 1: # megartron default method
torch.distributed.all_reduce(grad, group=embed_group[0])
else:
original_grad_data = grad.clone().detach().data
for group in embed_group:
grad.data.copy_(original_grad_data)
torch.distributed.all_reduce(grad, group=group)
setattr(weight, grad_attr, _reshard_if_dtensor(grad, orig_grad))


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,16 @@ def setup_embeddings_and_output_layer(self) -> None:
if parallel_state.is_rank_in_embedding_group():
weight = self.shared_embedding_or_output_weight()
weight.data = weight.data.cuda()
torch.distributed.all_reduce(
weight.data, group=parallel_state.get_embedding_group()
)
embedding_group = parallel_state.get_embedding_group()
if not isinstance(embedding_group, list):
torch.distributed.all_reduce(
weight.data, group=parallel_state.get_embedding_group()
)
else:
original_weight = weight.clone().detach().data
for group in embedding_group:
weight.data.copy_(original_weight)
torch.distributed.all_reduce(weight.data, group=group)

elif not getattr(LanguageModule, "embedding_warning_printed", False):
logging.getLogger(__name__).warning(
Expand Down
10 changes: 5 additions & 5 deletions megatron/megatron/core/optimizer/clip_grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,9 @@ def get_grad_norm_fp32(
# For cpu comminication
tensor_device = get_device_type_for_comm(grad_stats_parallel_group)
if isinstance(grad_stats_parallel_group, list):
original_total_norm = total_norm
for group in grad_stats_parallel_group:
total_norm = original_total_norm
original_total_norm = total_norm.clone().detach()
for mp_group in grad_stats_parallel_group:
total_norm.data = original_total_norm.data.clone()
total_norm = total_norm.to(tensor_device)
torch.distributed.all_reduce(
total_norm, op=torch.distributed.ReduceOp.SUM, group=group
Expand Down Expand Up @@ -236,9 +236,9 @@ def count_zeros_fp32(
)
# Sum across all model-parallel GPUs.
if isinstance(grad_stats_parallel_group, list):
original_total_num_zeros = total_num_zeros
original_total_num_zeros = total_num_zeros.clone().detach()
for group in grad_stats_parallel_group:
total_num_zeros = original_total_num_zeros
total_num_zeros.data = original_total_num_zeros.data.clone()
torch.distributed.all_reduce(
total_num_zeros, op=torch.distributed.ReduceOp.SUM, group=group
)
Expand Down
Loading

0 comments on commit 3d81a6e

Please sign in to comment.