-
Notifications
You must be signed in to change notification settings - Fork 52
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
d661ba3
commit b9b4717
Showing
9 changed files
with
1,445 additions
and
643 deletions.
There are no files selected for viewing
91 changes: 91 additions & 0 deletions
91
examples/llava_onevision/conf/train/train_llava_onevision_7b.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
system: | ||
tensor_model_parallel_size: 2 | ||
pipeline_model_parallel_size: 1 | ||
disable_bias_linear: True | ||
use_flash_attn: True | ||
use_distributed_optimizer: True | ||
use_mcore_models: True | ||
transformer_impl: transformer_engine | ||
recompute_method: "uniform" | ||
recompute_granularity: "full" | ||
recompute_num_layers: 1 | ||
use_te: True | ||
precision: | ||
bf16: True | ||
attention_softmax_in_fp32: True | ||
logging: | ||
log_interval: 1 | ||
tensorboard_log_interval: 1 | ||
wandb_project: "train-llava-ov" | ||
wandb_exp_name: "train-llava-ov" | ||
log_params_norm: True | ||
log_num_zeros_in_grad: True | ||
checkpoint: | ||
save_interval: 3000 | ||
pretrained_checkpoint: xxxx | ||
dataloader_save: ${experiment.exp_dir}/checkpoints/dataloader | ||
use_dist_ckpt: False | ||
ckpt_format: torch | ||
async_save: False | ||
|
||
model: | ||
num_layers: 28 | ||
hidden_size: 3584 | ||
ffn_hidden_size: 18944 | ||
num_attention_heads: 28 | ||
num_query_groups: 4 | ||
seq_length: 32768 | ||
max_position_embeddings: 32768 | ||
swiglu: True | ||
normalization: RMSNorm | ||
init_method_std: 0.014 | ||
attention_dropout: 0.0 | ||
hidden_dropout: 0.0 | ||
clip_grad: 1.0 | ||
train_iters: 625 | ||
eval_iters: 0 | ||
micro_batch_size: 1 | ||
global_batch_size: 320 | ||
allow_missing_vision_projection_checkpoint: True | ||
apply_layernorm_1p: True | ||
group_query_attention: True | ||
no_masked_softmax_fusion: True | ||
untie-embeddings-and-output-weights: True | ||
position_embedding_type: rope | ||
rotary_percent: 1.0 | ||
rotary_base: 1000000 | ||
eod_mask_loss: True | ||
freeze_LM: False | ||
freeze_ViT: False | ||
patch_dim: 14 | ||
img_h: 384 | ||
img_w: 384 | ||
language_model_type: qwen2_7b | ||
vision_model_type: siglip | ||
disable_vision_class_token: True | ||
image_grid_pinpoints: '(1x1),...,(6x6)' | ||
image_aspect_ratio: anyres_max_9 | ||
mm_patch_merge_type: spatial_unpad | ||
seed: 42 | ||
|
||
optimizer: | ||
weight_decay: 0.0 | ||
adam_beta1: 0.9 | ||
adam_beta2: 0.95 | ||
lr_scheduler: | ||
lr: 1.0e-5 | ||
lr_warmup_fraction: .03 | ||
lr_decay_style: cosine | ||
|
||
data: | ||
interleaved_dataset: True | ||
training_dataset_only: True | ||
data_path: xxxx | ||
dataloader_type: external | ||
split: 100,0,0 | ||
tokenizer: | ||
tokenizer_type: Qwen2TokenizerFS | ||
tokenizer_path: xxxx | ||
vocab_size: 152064 # 7b | ||
# vocab_size: 151936 # 1.5b | ||
make_vocab_size_divisible_by: 64 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
import torch | ||
|
||
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add | ||
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear | ||
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules | ||
from megatron.core.transformer.dot_product_attention import DotProductAttention | ||
from megatron.core.transformer.enums import AttnMaskType | ||
from megatron.core.transformer.identity_op import IdentityOp | ||
from megatron.core.transformer.mlp import MLP, MLPSubmodules | ||
from megatron.core.transformer.spec_utils import ModuleSpec | ||
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules | ||
|
||
try: | ||
from megatron.core.transformer.custom_layers.transformer_engine import ( | ||
TEColumnParallelLinear, | ||
TEDotProductAttention, | ||
TELayerNormColumnParallelLinear, | ||
TENorm, | ||
TERowParallelLinear, | ||
) | ||
|
||
HAVE_TE = True | ||
except ImportError: | ||
HAVE_TE = False | ||
|
||
try: | ||
import apex | ||
|
||
from megatron.core.fusions.fused_layer_norm import FusedLayerNorm | ||
|
||
HAVE_APEX = True | ||
LNImpl = FusedLayerNorm | ||
except ImportError: | ||
import warnings | ||
|
||
from megatron.core.transformer.torch_layer_norm import WrappedTorchLayerNorm | ||
|
||
warnings.warn(f'Apex is not installed. Falling back to Torch LayerNorm') | ||
LNImpl = WrappedTorchLayerNorm | ||
|
||
|
||
def get_layer_spec(is_vit, normalization) -> ModuleSpec: | ||
attn_mask_type = AttnMaskType.no_mask if is_vit else AttnMaskType.causal | ||
if normalization == "LayerNorm": | ||
norm = LNImpl | ||
elif normalization == "RMSNorm": | ||
norm = TENorm | ||
else: | ||
raise RuntimeError("unknown normalization", normalization) | ||
|
||
mlp = get_mlp_module_spec(use_te=False) # doesn't include norm. | ||
|
||
return ModuleSpec( | ||
module=TransformerLayer, | ||
submodules=TransformerLayerSubmodules( | ||
input_layernorm=norm, | ||
self_attention=ModuleSpec( | ||
module=SelfAttention, | ||
params={"attn_mask_type": attn_mask_type}, | ||
submodules=SelfAttentionSubmodules( | ||
linear_qkv=ColumnParallelLinear, | ||
core_attention=DotProductAttention, | ||
linear_proj=RowParallelLinear, | ||
q_layernorm=IdentityOp, | ||
k_layernorm=IdentityOp, | ||
), | ||
), | ||
self_attn_bda=get_bias_dropout_add, | ||
pre_mlp_layernorm=norm, | ||
mlp=mlp, | ||
mlp_bda=get_bias_dropout_add, | ||
), | ||
) | ||
|
||
|
||
def get_layer_spec_te(is_vit=False) -> ModuleSpec: | ||
attn_mask_type = AttnMaskType.no_mask if is_vit else AttnMaskType.causal | ||
|
||
mlp = get_norm_mlp_module_spec_te() | ||
return ModuleSpec( | ||
module=TransformerLayer, | ||
submodules=TransformerLayerSubmodules( | ||
self_attention=ModuleSpec( | ||
module=SelfAttention, | ||
params={"attn_mask_type": attn_mask_type}, | ||
submodules=SelfAttentionSubmodules( | ||
linear_qkv=TELayerNormColumnParallelLinear, | ||
core_attention=TEDotProductAttention, | ||
linear_proj=TERowParallelLinear, | ||
q_layernorm=IdentityOp, | ||
k_layernorm=IdentityOp, | ||
), | ||
), | ||
self_attn_bda=get_bias_dropout_add, | ||
pre_mlp_layernorm=IdentityOp, | ||
mlp=mlp, | ||
mlp_bda=get_bias_dropout_add, | ||
), | ||
) | ||
|
||
|
||
def get_mlp_module_spec(use_te: bool = True) -> ModuleSpec: | ||
# Dense MLP w/ or w/o TE modules. | ||
return ModuleSpec( | ||
module=MLP, | ||
submodules=MLPSubmodules( | ||
linear_fc1=TEColumnParallelLinear if use_te else ColumnParallelLinear, | ||
linear_fc2=TERowParallelLinear if use_te else RowParallelLinear, | ||
), | ||
) | ||
|
||
|
||
def get_norm_mlp_module_spec_te() -> ModuleSpec: | ||
return ModuleSpec( | ||
module=MLP, | ||
submodules=MLPSubmodules( | ||
linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear | ||
), | ||
) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.