Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LLaVA OV] Add video and multi images support #262

Merged
merged 3 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
system:
tensor_model_parallel_size: 2
pipeline_model_parallel_size: 1
disable_bias_linear: True
use_flash_attn: True
use_distributed_optimizer: True
use_mcore_models: True
transformer_impl: transformer_engine
recompute_method: "uniform"
recompute_granularity: "full"
recompute_num_layers: 1
use_te: True
precision:
bf16: True
attention_softmax_in_fp32: True
logging:
log_interval: 1
tensorboard_log_interval: 1
wandb_project: "train-llava-ov"
wandb_exp_name: "train-llava-ov"
log_params_norm: True
log_num_zeros_in_grad: True
checkpoint:
save_interval: 3000
pretrained_checkpoint: xxxx
dataloader_save: ${experiment.exp_dir}/checkpoints/dataloader
use_dist_ckpt: False
ckpt_format: torch
async_save: False

model:
num_layers: 28
hidden_size: 3584
ffn_hidden_size: 18944
num_attention_heads: 28
num_query_groups: 4
seq_length: 32768
max_position_embeddings: 32768
swiglu: True
normalization: RMSNorm
init_method_std: 0.014
attention_dropout: 0.0
hidden_dropout: 0.0
clip_grad: 1.0
train_iters: 625
eval_iters: 0
micro_batch_size: 1
global_batch_size: 320
allow_missing_vision_projection_checkpoint: True
apply_layernorm_1p: True
group_query_attention: True
no_masked_softmax_fusion: True
untie-embeddings-and-output-weights: True
position_embedding_type: rope
rotary_percent: 1.0
rotary_base: 1000000
eod_mask_loss: True
freeze_LM: False
freeze_ViT: False
patch_dim: 14
img_h: 384
img_w: 384
language_model_type: qwen2_7b
vision_model_type: siglip
disable_vision_class_token: True
image_grid_pinpoints: '(1x1),...,(6x6)'
image_aspect_ratio: anyres_max_9
mm_patch_merge_type: spatial_unpad
seed: 42

optimizer:
weight_decay: 0.0
adam_beta1: 0.9
adam_beta2: 0.95
lr_scheduler:
lr: 1.0e-5
lr_warmup_fraction: .03
lr_decay_style: cosine

data:
interleaved_dataset: True
training_dataset_only: True
data_path: xxxx
dataloader_type: external
split: 100,0,0
tokenizer:
tokenizer_type: Qwen2TokenizerFS
tokenizer_path: xxxx
vocab_size: 152064 # 7b
# vocab_size: 151936 # 1.5b
make_vocab_size_divisible_by: 64
42 changes: 22 additions & 20 deletions flagscale/train/models/llava_onevision/dataset_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ class AnyResTaskSample:
input_ids_shape: torch.Tensor
labels: torch.Tensor
labels_shape: torch.Tensor
images: torch.Tensor
image_sizes: torch.Tensor
modalities: torch.Tensor
images: List[torch.Tensor]
image_sizes: List[torch.Tensor]
modalities: List[torch.Tensor]

# Typing for the resulting batch data after encode_batch()
@dataclass
Expand Down Expand Up @@ -85,19 +85,19 @@ def encode_interleaved(self, sample: InterleavedSample):
assert ValueError("The sequence must have 4 or 5 elements, but got {len(sample.sequence)}.")

# process modalities to tensor
if modalities is None:
modalities = "image"
# image, video, text to 0, 1, 2
if modalities == "image":
modalities = torch.tensor([0])
elif modalities == "video":
modalities = torch.tensor([1])
elif modalities == "text":
modalities = torch.tensor([2])
else:
raise ValueError(f"Unsupported modality: {modalities}")
modalities_list = []
for modality in modalities:
# image, video, text to 0, 1, 2
if modality == "image":
modalities_list.append(torch.tensor([0]))
elif modality == "video":
modalities_list.append(torch.tensor([1]))
elif modality == "text":
modalities_list.append(torch.tensor([2]))
else:
raise ValueError(f"Unsupported modality: {modalities}")


modalities = modalities_list
return AnyResTaskSample(
__key__=sample.__key__,
__subflavors__=sample.__subflavors__,
Expand All @@ -106,7 +106,7 @@ def encode_interleaved(self, sample: InterleavedSample):
labels=labels,
labels_shape=torch.tensor(labels.shape),
images=images,
image_sizes=torch.tensor(image_sizes),
image_sizes=image_sizes,
modalities=modalities
)

Expand All @@ -115,10 +115,12 @@ def batch(self, samples: List[AnyResTaskSample]) -> AnyResTaskBatch:
input_ids_shape = torch.stack([s.input_ids_shape for s in samples], dim=0)
labels = torch.cat([s.labels.flatten() for s in samples], dim=0)
labels_shape = torch.stack([s.labels_shape for s in samples], dim=0)
images = torch.cat([s.images.flatten() for s in samples], dim=0)
split_image_sizes = torch.stack([torch.tensor(s.images.shape) for s in samples], dim=0)
image_sizes = torch.stack([s.image_sizes for s in samples], dim=0)
modalities = torch.stack([s.modalities for s in samples], dim=0)
# Double loop
images = torch.cat([image.flatten() for s in samples for image in s.images], dim=0)
split_image_sizes = torch.stack([torch.tensor(image.shape) for s in samples for image in s.images], dim=0)
# Adapt video data by decord
image_sizes = torch.stack([image_sizes if len(image_sizes.shape) == 1 else torch.tensor((1, image_sizes.item())) for s in samples for image_sizes in s.image_sizes], dim=0)
modalities = torch.stack([modalities for s in samples for modalities in s.modalities], dim=0)

batch = AnyResTaskBatch(
__keys__=[s.__key__ for s in samples],
Expand Down
121 changes: 121 additions & 0 deletions flagscale/train/models/llava_onevision/layer_spec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import torch

from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules

try:
from megatron.core.transformer.custom_layers.transformer_engine import (
TEColumnParallelLinear,
TEDotProductAttention,
TELayerNormColumnParallelLinear,
TENorm,
TERowParallelLinear,
)

HAVE_TE = True
except ImportError:
HAVE_TE = False

try:
import apex

from megatron.core.fusions.fused_layer_norm import FusedLayerNorm

HAVE_APEX = True
LNImpl = FusedLayerNorm
except ImportError:
import warnings

from megatron.core.transformer.torch_layer_norm import WrappedTorchLayerNorm

warnings.warn(f'Apex is not installed. Falling back to Torch LayerNorm')
LNImpl = WrappedTorchLayerNorm


def get_layer_spec(is_vit, normalization) -> ModuleSpec:
attn_mask_type = AttnMaskType.no_mask if is_vit else AttnMaskType.causal
if normalization == "LayerNorm":
norm = LNImpl
elif normalization == "RMSNorm":
norm = TENorm
else:
raise RuntimeError("unknown normalization", normalization)

mlp = get_mlp_module_spec(use_te=False) # doesn't include norm.

return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=norm,
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": attn_mask_type},
submodules=SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear,
core_attention=DotProductAttention,
linear_proj=RowParallelLinear,
q_layernorm=IdentityOp,
k_layernorm=IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=norm,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
),
)


def get_layer_spec_te(is_vit=False) -> ModuleSpec:
attn_mask_type = AttnMaskType.no_mask if is_vit else AttnMaskType.causal

mlp = get_norm_mlp_module_spec_te()
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": attn_mask_type},
submodules=SelfAttentionSubmodules(
linear_qkv=TELayerNormColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
q_layernorm=IdentityOp,
k_layernorm=IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=IdentityOp,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
),
)


def get_mlp_module_spec(use_te: bool = True) -> ModuleSpec:
# Dense MLP w/ or w/o TE modules.
return ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=TEColumnParallelLinear if use_te else ColumnParallelLinear,
linear_fc2=TERowParallelLinear if use_te else RowParallelLinear,
),
)


def get_norm_mlp_module_spec_te() -> ModuleSpec:
return ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear
),
)

Loading
Loading