Skip to content

Commit

Permalink
support video and multi images
Browse files Browse the repository at this point in the history
  • Loading branch information
Caozhou1995 committed Dec 9, 2024
1 parent d661ba3 commit b9b4717
Show file tree
Hide file tree
Showing 9 changed files with 1,445 additions and 643 deletions.
91 changes: 91 additions & 0 deletions examples/llava_onevision/conf/train/train_llava_onevision_7b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
system:
tensor_model_parallel_size: 2
pipeline_model_parallel_size: 1
disable_bias_linear: True
use_flash_attn: True
use_distributed_optimizer: True
use_mcore_models: True
transformer_impl: transformer_engine
recompute_method: "uniform"
recompute_granularity: "full"
recompute_num_layers: 1
use_te: True
precision:
bf16: True
attention_softmax_in_fp32: True
logging:
log_interval: 1
tensorboard_log_interval: 1
wandb_project: "train-llava-ov"
wandb_exp_name: "train-llava-ov"
log_params_norm: True
log_num_zeros_in_grad: True
checkpoint:
save_interval: 3000
pretrained_checkpoint: xxxx
dataloader_save: ${experiment.exp_dir}/checkpoints/dataloader
use_dist_ckpt: False
ckpt_format: torch
async_save: False

model:
num_layers: 28
hidden_size: 3584
ffn_hidden_size: 18944
num_attention_heads: 28
num_query_groups: 4
seq_length: 32768
max_position_embeddings: 32768
swiglu: True
normalization: RMSNorm
init_method_std: 0.014
attention_dropout: 0.0
hidden_dropout: 0.0
clip_grad: 1.0
train_iters: 625
eval_iters: 0
micro_batch_size: 1
global_batch_size: 320
allow_missing_vision_projection_checkpoint: True
apply_layernorm_1p: True
group_query_attention: True
no_masked_softmax_fusion: True
untie-embeddings-and-output-weights: True
position_embedding_type: rope
rotary_percent: 1.0
rotary_base: 1000000
eod_mask_loss: True
freeze_LM: False
freeze_ViT: False
patch_dim: 14
img_h: 384
img_w: 384
language_model_type: qwen2_7b
vision_model_type: siglip
disable_vision_class_token: True
image_grid_pinpoints: '(1x1),...,(6x6)'
image_aspect_ratio: anyres_max_9
mm_patch_merge_type: spatial_unpad
seed: 42

optimizer:
weight_decay: 0.0
adam_beta1: 0.9
adam_beta2: 0.95
lr_scheduler:
lr: 1.0e-5
lr_warmup_fraction: .03
lr_decay_style: cosine

data:
interleaved_dataset: True
training_dataset_only: True
data_path: xxxx
dataloader_type: external
split: 100,0,0
tokenizer:
tokenizer_type: Qwen2TokenizerFS
tokenizer_path: xxxx
vocab_size: 152064 # 7b
# vocab_size: 151936 # 1.5b
make_vocab_size_divisible_by: 64
52 changes: 22 additions & 30 deletions flagscale/train/models/llava_onevision/dataset_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
import numpy as np
import torch

from megatron.energon import Batch, DefaultTaskEncoder, InterleavedSample
from megatron.energon import (
Batch,
DefaultTaskEncoder,
InterleavedSample
)
from megatron.training import get_args


Expand All @@ -29,10 +33,9 @@ class AnyResTaskSample:
input_ids_shape: torch.Tensor
labels: torch.Tensor
labels_shape: torch.Tensor
images: torch.Tensor
image_sizes: torch.Tensor
modalities: torch.Tensor

images: List[torch.Tensor]
image_sizes: List[torch.Tensor]
modalities: List[torch.Tensor]

# Typing for the resulting batch data after encode_batch()
@dataclass
Expand All @@ -49,16 +52,16 @@ class AnyResTaskBatch(Batch):
modalities: torch.Tensor


class AnyResTaskEncoder(
DefaultTaskEncoder[InterleavedSample, InterleavedSample, AnyResTaskBatch, dict]
):
class AnyResTaskEncoder(DefaultTaskEncoder[InterleavedSample, InterleavedSample, AnyResTaskBatch, dict]):
"""
A task encoder for anyres.
This encoder is just a wrapper around data that has already been made.
Production data can be referenced to LLaVA-NeXT and can be input into vision tower.
"""

def __init__(self):
def __init__(
self
):
# Specify the batch_type for default batching (batching is performed here "manually" by
# overwriting the `batch` method)
super().__init__()
Expand All @@ -67,9 +70,7 @@ def __init__(self):

def encode_sample(self, sample: InterleavedSample):
if not isinstance(sample, InterleavedSample):
raise ValueError(
f"This encoder only supports InterleavedSample, but got {type(sample)}."
)
raise ValueError(f"This encoder only supports InterleavedSample, but got {type(sample)}.")
yield self.encode_interleaved(sample)

def encode_interleaved(self, sample: InterleavedSample):
Expand All @@ -81,9 +82,7 @@ def encode_interleaved(self, sample: InterleavedSample):
elif len(sample.sequence) == 5:
input_ids, labels, images, image_sizes, modalities = sample.sequence
else:
assert ValueError(
"The sequence must have 4 or 5 elements, but got {len(sample.sequence)}."
)
assert ValueError("The sequence must have 4 or 5 elements, but got {len(sample.sequence)}.")

# process modalities to tensor
modalities_list = []
Expand All @@ -97,7 +96,7 @@ def encode_interleaved(self, sample: InterleavedSample):
modalities_list.append(torch.tensor([2]))
else:
raise ValueError(f"Unsupported modality: {modalities}")

modalities = modalities_list
return AnyResTaskSample(
__key__=sample.__key__,
Expand All @@ -108,7 +107,7 @@ def encode_interleaved(self, sample: InterleavedSample):
labels_shape=torch.tensor(labels.shape),
images=images,
image_sizes=image_sizes,
modalities=modalities,
modalities=modalities
)

def batch(self, samples: List[AnyResTaskSample]) -> AnyResTaskBatch:
Expand All @@ -117,18 +116,11 @@ def batch(self, samples: List[AnyResTaskSample]) -> AnyResTaskBatch:
labels = torch.cat([s.labels.flatten() for s in samples], dim=0)
labels_shape = torch.stack([s.labels_shape for s in samples], dim=0)
# Double loop
images = torch.cat(
[image.flatten() for s in samples for image in s.images], dim=0
)
split_image_sizes = torch.stack(
[torch.tensor(image.shape) for s in samples for image in s.images], dim=0
)
image_sizes = torch.stack(
[image_sizes for s in samples for image_sizes in s.image_sizes], dim=0
)
modalities = torch.stack(
[s.modalities for s in samples for modalities in s.modalities], dim=0
)
images = torch.cat([image.flatten() for s in samples for image in s.images], dim=0)
split_image_sizes = torch.stack([torch.tensor(image.shape) for s in samples for image in s.images], dim=0)
# Adapt video data by decord
image_sizes = torch.stack([image_sizes if len(image_sizes.shape) == 1 else torch.tensor((1, image_sizes.item())) for s in samples for image_sizes in s.image_sizes], dim=0)
modalities = torch.stack([modalities for s in samples for modalities in s.modalities], dim=0)

batch = AnyResTaskBatch(
__keys__=[s.__key__ for s in samples],
Expand All @@ -140,7 +132,7 @@ def batch(self, samples: List[AnyResTaskSample]) -> AnyResTaskBatch:
images=images,
image_sizes=image_sizes,
split_image_sizes=split_image_sizes,
modalities=modalities,
modalities=modalities
)

return batch
Expand Down
121 changes: 121 additions & 0 deletions flagscale/train/models/llava_onevision/layer_spec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import torch

from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules

try:
from megatron.core.transformer.custom_layers.transformer_engine import (
TEColumnParallelLinear,
TEDotProductAttention,
TELayerNormColumnParallelLinear,
TENorm,
TERowParallelLinear,
)

HAVE_TE = True
except ImportError:
HAVE_TE = False

try:
import apex

from megatron.core.fusions.fused_layer_norm import FusedLayerNorm

HAVE_APEX = True
LNImpl = FusedLayerNorm
except ImportError:
import warnings

from megatron.core.transformer.torch_layer_norm import WrappedTorchLayerNorm

warnings.warn(f'Apex is not installed. Falling back to Torch LayerNorm')
LNImpl = WrappedTorchLayerNorm


def get_layer_spec(is_vit, normalization) -> ModuleSpec:
attn_mask_type = AttnMaskType.no_mask if is_vit else AttnMaskType.causal
if normalization == "LayerNorm":
norm = LNImpl
elif normalization == "RMSNorm":
norm = TENorm
else:
raise RuntimeError("unknown normalization", normalization)

mlp = get_mlp_module_spec(use_te=False) # doesn't include norm.

return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=norm,
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": attn_mask_type},
submodules=SelfAttentionSubmodules(
linear_qkv=ColumnParallelLinear,
core_attention=DotProductAttention,
linear_proj=RowParallelLinear,
q_layernorm=IdentityOp,
k_layernorm=IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=norm,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
),
)


def get_layer_spec_te(is_vit=False) -> ModuleSpec:
attn_mask_type = AttnMaskType.no_mask if is_vit else AttnMaskType.causal

mlp = get_norm_mlp_module_spec_te()
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": attn_mask_type},
submodules=SelfAttentionSubmodules(
linear_qkv=TELayerNormColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
q_layernorm=IdentityOp,
k_layernorm=IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=IdentityOp,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
),
)


def get_mlp_module_spec(use_te: bool = True) -> ModuleSpec:
# Dense MLP w/ or w/o TE modules.
return ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=TEColumnParallelLinear if use_te else ColumnParallelLinear,
linear_fc2=TERowParallelLinear if use_te else RowParallelLinear,
),
)


def get_norm_mlp_module_spec_te() -> ModuleSpec:
return ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear
),
)

Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def __init__(
self.language_model = None
self.args = get_args()

args = self.args
# Init image_newline
if "unpad" in args.mm_patch_merge_type:
embed_std = 1 / torch.sqrt(
Expand Down
7 changes: 5 additions & 2 deletions flagscale/train/train_llava_onevision.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
IGNORE_INDEX,
IMAGE_TOKEN_INDEX,
)
from examples.multimodal.layer_specs import (
from flagscale.train.models.llava_onevision.layer_specs import (
get_layer_spec,
get_mlp_module_spec,
get_layer_spec_te,
Expand Down Expand Up @@ -442,7 +442,10 @@ def add_multimodal_extra_args(parser):
"--add-faster-video", default=False, help="Whetehr add fatser video token"
)
group.add_argument(
"--mm-spatial-pool-mode", type=str, default=bilinear, help="Spatial pool mode"
"--mm-spatial-pool-mode", type=str, default="bilinear", help="Spatial pool mode"
)
group.add_argument(
"--mm-newline-position", type=str, default="grid", help="Newline position."
)
return parser

Expand Down
Loading

0 comments on commit b9b4717

Please sign in to comment.