diff --git a/examples/llava_onevision/conf/train/train_llava_onevision_7b.yaml b/examples/llava_onevision/conf/train/train_llava_onevision_7b.yaml new file mode 100644 index 000000000..7fab3f902 --- /dev/null +++ b/examples/llava_onevision/conf/train/train_llava_onevision_7b.yaml @@ -0,0 +1,91 @@ +system: + tensor_model_parallel_size: 2 + pipeline_model_parallel_size: 1 + disable_bias_linear: True + use_flash_attn: True + use_distributed_optimizer: True + use_mcore_models: True + transformer_impl: transformer_engine + recompute_method: "uniform" + recompute_granularity: "full" + recompute_num_layers: 1 + use_te: True + precision: + bf16: True + attention_softmax_in_fp32: True + logging: + log_interval: 1 + tensorboard_log_interval: 1 + wandb_project: "train-llava-ov" + wandb_exp_name: "train-llava-ov" + log_params_norm: True + log_num_zeros_in_grad: True + checkpoint: + save_interval: 3000 + pretrained_checkpoint: xxxx + dataloader_save: ${experiment.exp_dir}/checkpoints/dataloader + use_dist_ckpt: False + ckpt_format: torch + async_save: False + +model: + num_layers: 28 + hidden_size: 3584 + ffn_hidden_size: 18944 + num_attention_heads: 28 + num_query_groups: 4 + seq_length: 32768 + max_position_embeddings: 32768 + swiglu: True + normalization: RMSNorm + init_method_std: 0.014 + attention_dropout: 0.0 + hidden_dropout: 0.0 + clip_grad: 1.0 + train_iters: 625 + eval_iters: 0 + micro_batch_size: 1 + global_batch_size: 320 + allow_missing_vision_projection_checkpoint: True + apply_layernorm_1p: True + group_query_attention: True + no_masked_softmax_fusion: True + untie-embeddings-and-output-weights: True + position_embedding_type: rope + rotary_percent: 1.0 + rotary_base: 1000000 + eod_mask_loss: True + freeze_LM: False + freeze_ViT: False + patch_dim: 14 + img_h: 384 + img_w: 384 + language_model_type: qwen2_7b + vision_model_type: siglip + disable_vision_class_token: True + image_grid_pinpoints: '(1x1),...,(6x6)' + image_aspect_ratio: anyres_max_9 + mm_patch_merge_type: spatial_unpad + seed: 42 + + optimizer: + weight_decay: 0.0 + adam_beta1: 0.9 + adam_beta2: 0.95 + lr_scheduler: + lr: 1.0e-5 + lr_warmup_fraction: .03 + lr_decay_style: cosine + +data: + interleaved_dataset: True + training_dataset_only: True + data_path: xxxx + dataloader_type: external + split: 100,0,0 + tokenizer: + tokenizer_type: Qwen2TokenizerFS + tokenizer_path: xxxx + vocab_size: 152064 # 7b + # vocab_size: 151936 # 1.5b + make_vocab_size_divisible_by: 64 diff --git a/flagscale/train/models/llava_onevision/dataset_helpers.py b/flagscale/train/models/llava_onevision/dataset_helpers.py index 65687db62..0ed334953 100644 --- a/flagscale/train/models/llava_onevision/dataset_helpers.py +++ b/flagscale/train/models/llava_onevision/dataset_helpers.py @@ -33,9 +33,9 @@ class AnyResTaskSample: input_ids_shape: torch.Tensor labels: torch.Tensor labels_shape: torch.Tensor - images: torch.Tensor - image_sizes: torch.Tensor - modalities: torch.Tensor + images: List[torch.Tensor] + image_sizes: List[torch.Tensor] + modalities: List[torch.Tensor] # Typing for the resulting batch data after encode_batch() @dataclass @@ -85,19 +85,19 @@ def encode_interleaved(self, sample: InterleavedSample): assert ValueError("The sequence must have 4 or 5 elements, but got {len(sample.sequence)}.") # process modalities to tensor - if modalities is None: - modalities = "image" - # image, video, text to 0, 1, 2 - if modalities == "image": - modalities = torch.tensor([0]) - elif modalities == "video": - modalities = torch.tensor([1]) - elif modalities == "text": - modalities = torch.tensor([2]) - else: - raise ValueError(f"Unsupported modality: {modalities}") + modalities_list = [] + for modality in modalities: + # image, video, text to 0, 1, 2 + if modality == "image": + modalities_list.append(torch.tensor([0])) + elif modality == "video": + modalities_list.append(torch.tensor([1])) + elif modality == "text": + modalities_list.append(torch.tensor([2])) + else: + raise ValueError(f"Unsupported modality: {modalities}") - + modalities = modalities_list return AnyResTaskSample( __key__=sample.__key__, __subflavors__=sample.__subflavors__, @@ -106,7 +106,7 @@ def encode_interleaved(self, sample: InterleavedSample): labels=labels, labels_shape=torch.tensor(labels.shape), images=images, - image_sizes=torch.tensor(image_sizes), + image_sizes=image_sizes, modalities=modalities ) @@ -115,10 +115,12 @@ def batch(self, samples: List[AnyResTaskSample]) -> AnyResTaskBatch: input_ids_shape = torch.stack([s.input_ids_shape for s in samples], dim=0) labels = torch.cat([s.labels.flatten() for s in samples], dim=0) labels_shape = torch.stack([s.labels_shape for s in samples], dim=0) - images = torch.cat([s.images.flatten() for s in samples], dim=0) - split_image_sizes = torch.stack([torch.tensor(s.images.shape) for s in samples], dim=0) - image_sizes = torch.stack([s.image_sizes for s in samples], dim=0) - modalities = torch.stack([s.modalities for s in samples], dim=0) + # Double loop + images = torch.cat([image.flatten() for s in samples for image in s.images], dim=0) + split_image_sizes = torch.stack([torch.tensor(image.shape) for s in samples for image in s.images], dim=0) + # Adapt video data by decord + image_sizes = torch.stack([image_sizes if len(image_sizes.shape) == 1 else torch.tensor((1, image_sizes.item())) for s in samples for image_sizes in s.image_sizes], dim=0) + modalities = torch.stack([modalities for s in samples for modalities in s.modalities], dim=0) batch = AnyResTaskBatch( __keys__=[s.__key__ for s in samples], diff --git a/flagscale/train/models/llava_onevision/layer_spec.py b/flagscale/train/models/llava_onevision/layer_spec.py new file mode 100644 index 000000000..2355a7eac --- /dev/null +++ b/flagscale/train/models/llava_onevision/layer_spec.py @@ -0,0 +1,121 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import torch + +from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add +from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear +from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules +from megatron.core.transformer.dot_product_attention import DotProductAttention +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.mlp import MLP, MLPSubmodules +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules + +try: + from megatron.core.transformer.custom_layers.transformer_engine import ( + TEColumnParallelLinear, + TEDotProductAttention, + TELayerNormColumnParallelLinear, + TENorm, + TERowParallelLinear, + ) + + HAVE_TE = True +except ImportError: + HAVE_TE = False + +try: + import apex + + from megatron.core.fusions.fused_layer_norm import FusedLayerNorm + + HAVE_APEX = True + LNImpl = FusedLayerNorm +except ImportError: + import warnings + + from megatron.core.transformer.torch_layer_norm import WrappedTorchLayerNorm + + warnings.warn(f'Apex is not installed. Falling back to Torch LayerNorm') + LNImpl = WrappedTorchLayerNorm + + +def get_layer_spec(is_vit, normalization) -> ModuleSpec: + attn_mask_type = AttnMaskType.no_mask if is_vit else AttnMaskType.causal + if normalization == "LayerNorm": + norm = LNImpl + elif normalization == "RMSNorm": + norm = TENorm + else: + raise RuntimeError("unknown normalization", normalization) + + mlp = get_mlp_module_spec(use_te=False) # doesn't include norm. + + return ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + input_layernorm=norm, + self_attention=ModuleSpec( + module=SelfAttention, + params={"attn_mask_type": attn_mask_type}, + submodules=SelfAttentionSubmodules( + linear_qkv=ColumnParallelLinear, + core_attention=DotProductAttention, + linear_proj=RowParallelLinear, + q_layernorm=IdentityOp, + k_layernorm=IdentityOp, + ), + ), + self_attn_bda=get_bias_dropout_add, + pre_mlp_layernorm=norm, + mlp=mlp, + mlp_bda=get_bias_dropout_add, + ), + ) + + +def get_layer_spec_te(is_vit=False) -> ModuleSpec: + attn_mask_type = AttnMaskType.no_mask if is_vit else AttnMaskType.causal + + mlp = get_norm_mlp_module_spec_te() + return ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + self_attention=ModuleSpec( + module=SelfAttention, + params={"attn_mask_type": attn_mask_type}, + submodules=SelfAttentionSubmodules( + linear_qkv=TELayerNormColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + q_layernorm=IdentityOp, + k_layernorm=IdentityOp, + ), + ), + self_attn_bda=get_bias_dropout_add, + pre_mlp_layernorm=IdentityOp, + mlp=mlp, + mlp_bda=get_bias_dropout_add, + ), + ) + + +def get_mlp_module_spec(use_te: bool = True) -> ModuleSpec: + # Dense MLP w/ or w/o TE modules. + return ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TEColumnParallelLinear if use_te else ColumnParallelLinear, + linear_fc2=TERowParallelLinear if use_te else RowParallelLinear, + ), + ) + + +def get_norm_mlp_module_spec_te() -> ModuleSpec: + return ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear + ), + ) + diff --git a/flagscale/train/models/llava_onevision/llava_onevision_model.py b/flagscale/train/models/llava_onevision/llava_onevision_model.py index 1c316fb44..5ef82617f 100644 --- a/flagscale/train/models/llava_onevision/llava_onevision_model.py +++ b/flagscale/train/models/llava_onevision/llava_onevision_model.py @@ -103,8 +103,9 @@ def __init__( self.vision_model = None self.vision_projection = None self.language_model = None - args = get_args() + self.args = get_args() + args = self.args # Init image_newline if "unpad" in args.mm_patch_merge_type: embed_std = 1 / torch.sqrt( @@ -261,11 +262,12 @@ def prepare_inputs_labels_for_multimodal( past_key_values=None, ): """This function is modified from LLaVA-NeXT.""" - args = get_args() + args = self.args vision_tower = self.vision_model - # Micro batch size must be 1 when in the mixture modalities mode. + if vision_tower is None or images is None or input_ids.shape[1] == 1: - input_ids = self.embed_tokens(input_ids) + # [BugFix]: comment out the embed_tokens + # input_ids = self.embed_tokens(input_ids) loss_mask = torch.where( labels == IGNORE_INDEX, torch.tensor(0), torch.tensor(1) ) @@ -274,6 +276,8 @@ def prepare_inputs_labels_for_multimodal( if isinstance(modalities, str): modalities = [modalities] + # Comment out the code because we're starting to support it + """ text_modality = False image_or_video_modality = False for modality in modalities: @@ -286,6 +290,7 @@ def prepare_inputs_labels_for_multimodal( raise ValueError( "Text and image/video modalities cannot be mixed in the same batch." ) + """ if type(images) is list or images.ndim == 5: if type(images) is list: @@ -302,14 +307,14 @@ def prepare_inputs_labels_for_multimodal( images_list.append(image) else: images_list.append(image.unsqueeze(0)) - raise ValueError( - "Video not supported yet. In the future, we will support video." - ) + # Comment out the code because we're starting to support it + # raise ValueError( + # "Video not supported yet. In the future, we will support video." + # ) concat_images = torch.cat([image for image in images_list], dim=0) split_sizes = [image.shape[0] for image in images_list] - split_sizes = [image.shape[0] for image in images_list] encoded_image_features = self.encode_images(concat_images) # Get every sample image features @@ -317,30 +322,72 @@ def prepare_inputs_labels_for_multimodal( image_features = [] for idx, image_feat in enumerate(encoded_image_features): if idx in video_idx_in_batch: - raise ValueError( - "Video not supported yet. In the future, we will support video." - ) + image_features.append(self.get_2dPool(image_feat)) + # Comment out the code because we're starting to support it + # raise ValueError( + # "Video not supported yet. In the future, we will support video." + # ) else: image_features.append(image_feat) mm_patch_merge_type = args.mm_patch_merge_type - assert mm_patch_merge_type in [ - "flat", - "spatial_unpad", - ], f"Unexpected mm_patch_merge_type: {mm_patch_merge_type}" + # mm_patch_merge_type all values are supported + image_aspect_ratio = args.image_aspect_ratio if image_aspect_ratio != "square": assert ( "anyres" in image_aspect_ratio ), f"Unexpected image_aspect_ratio: {image_aspect_ratio}" + mm_newline_position = args.mm_newline_position + assert mm_newline_position in [ + "one_token", + "grid", + "frame", + "no_token", + ], f"Unexpected mm_newline_position: {mm_newline_position}" + if mm_patch_merge_type == "flat": image_features = [x.flatten(0, 1) for x in image_features] elif mm_patch_merge_type.startswith("spatial"): new_image_features = [] for image_idx, image_feature in enumerate(image_features): - if image_feature.shape[0] > 1: + if image_idx in video_idx_in_batch: # video operations + if mm_newline_position == "grid": + # Grid-wise + image_feature = self.add_token_per_grid(image_feature) + assert not self.args.add_faster_video + # No all_faster_video_features variable because encode_multimodals is not called. + # So we didn't add any code about add_faster_video + new_image_features.append(image_feature) + elif mm_newline_position == "frame": + # Frame-wise + image_feature = self.add_token_per_frame(image_feature) + + new_image_features.append(image_feature.flatten(0, 1)) + + elif mm_newline_position == "one_token": + # one-token + image_feature = image_feature.flatten(0, 1) + if "unpad" in mm_patch_merge_type: + image_feature = torch.cat( + ( + image_feature, + self.image_newline[None].to( + image_feature.device + ), + ), + dim=0, + ) + new_image_features.append(image_feature) + elif mm_newline_position == "no_token": + new_image_features.append(image_feature.flatten(0, 1)) + else: + raise ValueError( + f"Unexpected mm_newline_position: {mm_newline_position}" + ) + elif image_feature.shape[0] > 1: # Raw image features base_image_feature = image_feature[0] # Patch iamge features @@ -379,7 +426,14 @@ def prepare_inputs_labels_for_multimodal( num_patch_height, num_patch_width, height, width, -1 ) - if ( + if "maxpool2x2" in mm_patch_merge_type: + image_feature = image_feature.permute( + 4, 0, 2, 1, 3 + ).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = nn.functional.max_pool2d(image_feature, 2) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + elif ( "unpad" in mm_patch_merge_type and "anyres_max" in image_aspect_ratio and matched_anyres_max_num_patches @@ -642,7 +696,6 @@ def prepare_inputs_labels_for_multimodal( if _position_ids is None: position_ids = None - # This branch is not used in the onevision training. if args.use_pos_skipping and args.training: position_ids = ( torch.arange(new_input_embeds.size(1), device=new_input_embeds.device) @@ -705,6 +758,73 @@ def embed_tokens(self, input_ids): ).contiguous() # [b, text_seq_len, h_language] return language_embeddings + def get_2dPool(self, image_feature, stride=2): + args = self.args + height = width = args.img_h // args.patch_dim + num_frames, num_tokens, num_dim = image_feature.shape + image_feature = image_feature.view(num_frames, height, width, -1) + image_feature = image_feature.permute(0, 3, 1, 2).contiguous() + if args.mm_spatial_pool_mode == "average": + image_feature = torch.nn.functional.avg_pool2d(image_feature, stride) + elif args.mm_spatial_pool_mode == "max": + image_feature = torch.nn.functional.max_pool2d(image_feature, stride) + elif args.mm_spatial_pool_mode == "bilinear": + height, width = image_feature.shape[2:] + scaled_shape = [math.ceil(height / stride), math.ceil(width / stride)] + image_feature = torch.nn.functional.interpolate( + image_feature, size=scaled_shape, mode="bilinear" + ) + + else: + raise ValueError( + f"Unexpected mm_spatial_pool_mode: {args.mm_spatial_pool_mode}" + ) + image_feature = image_feature.permute(0, 2, 3, 1) + image_feature = image_feature.view(num_frames, -1, num_dim) + return image_feature + + def add_token_per_grid(self, image_feature): + resize_h = int(math.sqrt(image_feature.shape[1])) + num_frames = image_feature.shape[0] + feature_dim = image_feature.shape[-1] + + image_feature = image_feature.view(num_frames, 1, resize_h, resize_h, -1) + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = torch.cat( + ( + image_feature, + self.image_newline[:, None, None] + .expand(*image_feature.shape[:-1], 1) + .to(image_feature.device), + ), + dim=-1, + ) + if self.args.add_faster_video: + # (3584, 832, 14) -> (3584, 64, 13, 14) + image_feature = image_feature.view(feature_dim, num_frames, resize_h, -1) + # (3584, 64, 13, 14) -> (64, 13, 14, 3584) + image_feature = image_feature.permute(1, 2, 3, 0).contiguous() + # (64, 13, 14, 3584) -> (64, 13*14, 3584) + image_feature = image_feature.flatten(1, 2) + return image_feature + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + return image_feature + + def add_token_per_frame(self, image_feature): + image_feature = image_feature.permute(2, 0, 1).contiguous() + image_feature = torch.cat( + ( + image_feature, + self.image_newline[:, None, None] + .expand(*image_feature.shape[:-1], 1) + .to(image_feature.device), + ), + dim=-1, + ) + image_feature = image_feature.permute(1, 2, 0).contiguous() + return image_feature + def _load_state_dict_hook_ignore_param_names( param_names: List[str], module: torch.nn.Module, incompatible_keys: namedtuple diff --git a/flagscale/train/train_llava_onevision.py b/flagscale/train/train_llava_onevision.py index fb18f805e..36386a278 100644 --- a/flagscale/train/train_llava_onevision.py +++ b/flagscale/train/train_llava_onevision.py @@ -28,7 +28,7 @@ IGNORE_INDEX, IMAGE_TOKEN_INDEX, ) -from examples.multimodal.layer_specs import ( +from flagscale.train.models.llava_onevision.layer_specs import ( get_layer_spec, get_mlp_module_spec, get_layer_spec_te, @@ -438,6 +438,15 @@ def add_multimodal_extra_args(parser): group.add_argument( "--pos-skipping-range", type=int, default=4096, help="Position skipping range" ) + group.add_argument( + "--add-faster-video", default=False, help="Whetehr add fatser video token" + ) + group.add_argument( + "--mm-spatial-pool-mode", type=str, default="bilinear", help="Spatial pool mode" + ) + group.add_argument( + "--mm-newline-position", type=str, default="grid", help="Newline position." + ) return parser diff --git a/tools/checkpoint/llava_onevision/convert_to_fs_qwen2.5_7b.py b/tools/checkpoint/llava_onevision/convert_to_fs_qwen2.5_7b.py new file mode 100644 index 000000000..1cec67b2d --- /dev/null +++ b/tools/checkpoint/llava_onevision/convert_to_fs_qwen2.5_7b.py @@ -0,0 +1,507 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import argparse +import os +import json + +import torch + +from safetensors.torch import load_file + + +def convert(input_path, output_path, tensor_parallel_size, use_te): + device = "cuda" + # index.json + index_path = None + for file in os.listdir(input_path): + if file.endswith("index.json"): + index_path = os.path.join(input_path, file) + break + assert index_path is not None, "index.json not found in input path" + + with open(index_path, "r") as f: + weight_map = json.load(f)["weight_map"] + + caches = {} + for name in weight_map: + file_name = weight_map[name] + if file_name not in caches: + caches[file_name] = load_file( + os.path.join(input_path, file_name), device=device + ) + + new_state_dicts = [{"model": dict()} for _ in range(tensor_parallel_size)] + + # Process language model + # Indices from mapping pytorch multihead attention to megatron. + hidden_dim = 3584 + num_heads = 28 + assert hidden_dim % num_heads == 0 + kv_channels = hidden_dim // num_heads + # GQA Process + num_query_groups = 4 + kv_projection_size = kv_channels * num_query_groups + indices = [] + assert num_heads % num_query_groups == 0 + for i in range(num_query_groups): + lb = i * kv_channels + ub = (i + 1) * kv_channels + indices.append( + torch.arange( + num_heads // num_query_groups * kv_channels * i, + num_heads // num_query_groups * kv_channels * (i + 1), + dtype=torch.int, + ) + ) + indices.append(torch.arange(hidden_dim + lb, hidden_dim + ub, dtype=torch.int)) + indices.append( + torch.arange( + (hidden_dim + kv_projection_size) + lb, + (hidden_dim + kv_projection_size) + ub, + dtype=torch.int, + ) + ) + + indices = torch.cat(indices) + + gate_up_indices = [] + ffn_hidden_size = 18944 + assert ffn_hidden_size % tensor_parallel_size == 0 + interval = ffn_hidden_size // tensor_parallel_size + for i in range(tensor_parallel_size): + lb = i * interval + ub = (i + 1) * interval + gate_up_indices.append(torch.arange(lb, ub, dtype=torch.int)) + gate_up_indices.append( + torch.arange(ffn_hidden_size + lb, ffn_hidden_size + ub, dtype=torch.int) + ) + gate_up_indices = torch.cat(gate_up_indices) + + for name in weight_map: + file_name = weight_map[name] + tensor = caches[file_name][name] + + # Map parameter names to ones used in megatron. + new_name = "" + new_tensor = tensor + if new_tensor.dtype == torch.float16: + new_tensor = new_tensor.to(torch.float32) + + # This is used for chunking some tensors to target tensor parallel size. + chunk_dim = None + + qkv_params = set() + gate_up_params = set() + if "model.embed_tokens.weight" in name: + new_name = "language_model.embedding.word_embeddings.weight" + chunk_dim = 0 + elif "model.image_newline" in name: + new_name = "image_newline" + elif "lm_head.weight" in name: + new_name = "language_model.output_layer.weight" + chunk_dim = 0 + # the norm after last layer + elif "model.norm.weight" in name: + new_name = "language_model.decoder.final_layernorm.weight" + elif "model.layers" not in name: + continue + elif "model.layers" in name: + layer_idx = name.split(".")[2] + base = f"language_model.decoder.layers.{layer_idx}" + if ( + "self_attn.q_proj.weight" in name + or "self_attn.k_proj.weight" in name + or "self_attn.v_proj.weight" in name + ): + new_name = f"{base}.self_attention.linear_qkv.weight" + if new_name not in qkv_params: + # q_proj, k_proj, v_proj + split_name = name.split(".") + split_name[-2] = "q_proj" + q_name = ".".join(split_name) + file_name = weight_map[q_name] + q_tensor = caches[file_name][q_name] + + split_name[-2] = "k_proj" + k_name = ".".join(split_name) + file_name = weight_map[k_name] + k_tensor = caches[file_name][k_name] + + split_name[-2] = "v_proj" + v_name = ".".join(split_name) + file_name = weight_map[v_name] + v_tensor = caches[file_name][v_name] + + # concat and dim = 0 + # q,k,v concat in the first dim + new_tensor = torch.cat([q_tensor, k_tensor, v_tensor], dim=0) + + # reorder + new_tensor = new_tensor[indices] + + chunk_dim = 0 + qkv_params.add(new_name) + else: + continue + + elif ( + "self_attn.q_proj.bias" in name + or "self_attn.k_proj.bias" in name + or "self_attn.v_proj.bias" in name + ): + new_name = f"{base}.self_attention.linear_qkv.bias" + if new_name not in qkv_params: + # q_proj, k_proj, v_proj + split_name = name.split(".") + split_name[-2] = "q_proj" + q_name = ".".join(split_name) + file_name = weight_map[q_name] + q_tensor = caches[file_name][q_name] + + split_name[-2] = "k_proj" + k_name = ".".join(split_name) + file_name = weight_map[k_name] + k_tensor = caches[file_name][k_name] + + split_name[-2] = "v_proj" + v_name = ".".join(split_name) + file_name = weight_map[v_name] + v_tensor = caches[file_name][v_name] + + # concat and dim = 0 + new_tensor = torch.cat([q_tensor, k_tensor, v_tensor], dim=0) + + # reorder + new_tensor = new_tensor[indices] + + chunk_dim = 0 + qkv_params.add(new_name) + else: + continue + elif "self_attn.o_proj.weight" in name: + new_name = f"{base}.self_attention.linear_proj.weight" + chunk_dim = 1 + elif "input_layernorm.weight" in name: + new_name = f"{base}.input_layernorm.weight" + if use_te: + new_name = f"{base}.self_attention.linear_qkv.layer_norm_weight" + elif "mlp.gate_proj.weight" in name or "mlp.up_proj.weight" in name: + new_name = f"{base}.mlp.linear_fc1.weight" + if new_name not in gate_up_params: + # gate, up + split_name = name.split(".") + split_name[-2] = "gate_proj" + gate_name = ".".join(split_name) + file_name = weight_map[gate_name] + gate_tensor = caches[file_name][gate_name] + + split_name = name.split(".") + split_name[-2] = "up_proj" + up_name = ".".join(split_name) + file_name = weight_map[up_name] + up_tensor = caches[file_name][up_name] + + # concat and dim = 0 + new_tensor = torch.cat([gate_tensor, up_tensor], dim=0) + new_tensor = new_tensor[gate_up_indices] + gate_up_params.add(new_name) + chunk_dim = 0 + else: + continue + elif "mlp.down_proj.weight" in name: + new_name = f"{base}.mlp.linear_fc2.weight" + chunk_dim = 1 + elif "post_attention_layernorm.weight" in name: + new_name = f"{base}.pre_mlp_layernorm.weight" + if use_te: + new_name = f"{base}.mlp.linear_fc1.layer_norm_weight" + + assert new_name != "", f"unexpected layer name {name}" + + if chunk_dim is None: + new_tensors = [new_tensor for _ in range(tensor_parallel_size)] + else: + new_tensors = torch.chunk(new_tensor, tensor_parallel_size, dim=chunk_dim) + + for i in range(tensor_parallel_size): + # chunk() creates a view of a bigger tensor. clone() is used here to avoid excessive storage. + new_state_dicts[i]["model"][new_name] = new_tensors[i].clone() + + # TE sets _extra_state (for FP8 purposes), so set an empty one here for compatibility. + extra_state_layers = ( + "linear_qkv", + "linear_proj", + "linear_fc1", + "linear_fc2", + ) + is_extra_state_layer = any([l in new_name for l in extra_state_layers]) + if use_te and is_extra_state_layer: + layer = new_name.split(".")[-2] + if layer in extra_state_layers: + extra_state_name = ( + new_name[: new_name.rfind(".") + 1] + "_extra_state" + ) # Replace the weight name. + new_state_dicts[i]["model"][extra_state_name] = None + + # Process vision tower + # Indices from mapping pytorch multihead attention to megatron. + hidden_dim = 1152 + num_heads = 16 + kv_channels = hidden_dim // num_heads + indices = [] + for i in range(num_heads): + lb = i * kv_channels + ub = (i + 1) * kv_channels + indices.append(torch.arange(lb, ub, dtype=torch.int)) + indices.append(torch.arange(hidden_dim + lb, hidden_dim + ub, dtype=torch.int)) + indices.append( + torch.arange(2 * hidden_dim + lb, 2 * hidden_dim + ub, dtype=torch.int) + ) + + indices = torch.cat(indices) + + for name in weight_map: + file_name = weight_map[name] + tensor = caches[file_name][name] + + if "model.vision_tower.vision_tower.vision_model" not in name: + continue + + # Map parameter names to ones used in megatron. + new_name = "" + new_tensor = tensor + if new_tensor.dtype == torch.float16: + new_tensor = new_tensor.to(torch.float32) + + # This is used for chunking some tensors to target tensor parallel size. + chunk_dim = None + + qkv_params = set() + if "position_embedding" in name: + new_name = "vision_model.position_embeddings.weight" + elif "post_layernorm.weight" in name: + new_name = "vision_model.ln_post.weight" + elif "post_layernorm.bias" in name: + new_name = "vision_model.ln_post.bias" + elif "patch_embedding.weight" in name: + new_name = "vision_model.conv1.weight" + elif "patch_embedding.bias" in name: + new_name = "vision_model.conv1.bias" + elif "encoder.layers" in name: + layer_idx = name.split(".")[6] + base = f"vision_model.decoder.layers.{layer_idx}" + if "encoder.layers.26" in name: + print(f"{name} skipped due to the last layer") + continue + + if ( + "self_attn.q_proj.weight" in name + or "self_attn.k_proj.weight" in name + or "self_attn.v_proj.weight" in name + ): + new_name = f"{base}.self_attention.linear_qkv.weight" + if new_name not in qkv_params: + # q_proj, k_proj, v_proj + split_name = name.split(".") + + split_name[-2] = "q_proj" + q_name = ".".join(split_name) + file_name = weight_map[q_name] + q_tensor = caches[file_name][q_name] + + split_name[-2] = "k_proj" + k_name = ".".join(split_name) + file_name = weight_map[k_name] + k_tensor = caches[file_name][k_name] + + split_name[-2] = "v_proj" + v_name = ".".join(split_name) + file_name = weight_map[v_name] + v_tensor = caches[file_name][v_name] + + # concat and dim = 0 + # q,k,v concat in the first dim + new_tensor = torch.cat([q_tensor, k_tensor, v_tensor], dim=0) + + # reorder + new_tensor = new_tensor[indices] + + chunk_dim = 0 + qkv_params.add(new_name) + else: + continue + + elif ( + "self_attn.q_proj.bias" in name + or "self_attn.k_proj.bias" in name + or "self_attn.v_proj.bias" in name + ): + new_name = f"{base}.self_attention.linear_qkv.bias" + if new_name not in qkv_params: + # q_proj, k_proj, v_proj + split_name = name.split(".") + + split_name[-2] = "q_proj" + q_name = ".".join(split_name) + file_name = weight_map[q_name] + q_tensor = caches[file_name][q_name] + + split_name[-2] = "k_proj" + k_name = ".".join(split_name) + file_name = weight_map[k_name] + k_tensor = caches[file_name][k_name] + + split_name[-2] = "v_proj" + v_name = ".".join(split_name) + file_name = weight_map[v_name] + v_tensor = caches[file_name][v_name] + + # concat and dim = 0 + new_tensor = torch.cat([q_tensor, k_tensor, v_tensor], dim=0) + + # reorder + new_tensor = new_tensor[indices] + + chunk_dim = 0 + qkv_params.add(new_name) + else: + continue + elif "attn.out_proj.weight" in name: + new_name = f"{base}.self_attention.linear_proj.weight" + chunk_dim = 1 + elif "attn.out_proj.bias" in name: + new_name = f"{base}.self_attention.linear_proj.bias" + elif "layer_norm1.weight" in name: + new_name = f"{base}.input_layernorm.weight" + if use_te: + new_name = f"{base}.self_attention.linear_qkv.layer_norm_weight" + elif "layer_norm1.bias" in name: + new_name = f"{base}.input_layernorm.bias" + if use_te: + new_name = f"{base}.self_attention.linear_qkv.layer_norm_bias" + elif "mlp.fc1.weight" in name: + new_name = f"{base}.mlp.linear_fc1.weight" + chunk_dim = 0 + elif "mlp.fc1.bias" in name: + new_name = f"{base}.mlp.linear_fc1.bias" + chunk_dim = 0 + elif "mlp.fc2.weight" in name: + new_name = f"{base}.mlp.linear_fc2.weight" + chunk_dim = 1 + elif "mlp.fc2.bias" in name: + new_name = f"{base}.mlp.linear_fc2.bias" + elif "layer_norm2.weight" in name: + new_name = f"{base}.pre_mlp_layernorm.weight" + if use_te: + new_name = f"{base}.mlp.linear_fc1.layer_norm_weight" + elif "layer_norm2.bias" in name: + new_name = f"{base}.pre_mlp_layernorm.bias" + if use_te: + new_name = f"{base}.mlp.linear_fc1.layer_norm_bias" + + assert new_name != "", f"unexpected layer name {name}" + + if chunk_dim is None: + new_tensors = [new_tensor for _ in range(tensor_parallel_size)] + else: + new_tensors = torch.chunk(new_tensor, tensor_parallel_size, dim=chunk_dim) + + for i in range(tensor_parallel_size): + # chunk() creates a view of a bigger tensor. clone() is used here to avoid excessive storage. + new_state_dicts[i]["model"][new_name] = new_tensors[i].clone() + + # TE sets _extra_state (for FP8 purposes), so set an empty one here for compatibility. + extra_state_layers = ( + "linear_qkv", + "linear_proj", + "linear_fc1", + "linear_fc2", + ) + is_extra_state_layer = any([l in new_name for l in extra_state_layers]) + if use_te and is_extra_state_layer: + layer = new_name.split(".")[-2] + if layer in extra_state_layers: + extra_state_name = ( + new_name[: new_name.rfind(".") + 1] + "_extra_state" + ) # Replace the weight name. + new_state_dicts[i]["model"][extra_state_name] = None + + + # Process projection + for name in weight_map: + file_name = weight_map[name] + tensor = caches[file_name][name] + # Map parameter names to ones used in megatron. + new_name = "" + new_tensor = tensor + chunk_dim = None + if "model.mm_projector" not in name: + continue + # This is used for chunking some tensors to target tensor parallel size. + if name == "model.mm_projector.0.weight": + new_name = "vision_projection.encoder.linear_fc1.weight" + chunk_dim = 0 + elif name == "model.mm_projector.0.bias": + new_name = "vision_projection.encoder.linear_fc1.bias" + chunk_dim = 0 + elif name == "model.mm_projector.2.weight": + new_name = "vision_projection.encoder.linear_fc2.weight" + chunk_dim = 1 + elif name == "model.mm_projector.2.bias": + new_name = "vision_projection.encoder.linear_fc2.bias" + + assert new_name != "", f"unexpected name {name}" + + if chunk_dim is None: + new_tensors = [new_tensor for _ in range(tensor_parallel_size)] + else: + new_tensors = torch.chunk(new_tensor, tensor_parallel_size, dim=chunk_dim) + + for i in range(tensor_parallel_size): + # chunk() creates a view of a bigger tensor. clone() is used here to avoid excessive storage. + new_state_dicts[i]["model"][new_name] = new_tensors[i].clone() + + + for i in range(tensor_parallel_size): + output_dir_tp = os.path.join(output_path, "iter_0000001", f"mp_rank_0{i}") + os.makedirs(output_dir_tp, exist_ok=True) + output_path_tp = os.path.join(output_dir_tp, "model_optim_rng.pt") + torch.save(new_state_dicts[i], output_path_tp) + + latest_checkpointed_iteration = os.path.join( + output_path, "latest_checkpointed_iteration.txt" + ) + + with open(latest_checkpointed_iteration, "w") as f: + f.write("1") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description=""" +Convert Qwen weights to megatron format. + +Example usage: +python convert_to_fs_qwen2.5_7b.py --input /some/input/folder --output /some/output/folder --tensor-parallel-size 4 --use-te +""", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + parser.add_argument("--input", type=str, required=True, help="hf folder") + parser.add_argument( + "--output", + type=str, + required=True, + help="output directory for megatron state dict file(s)", + ) + parser.add_argument( + "--tensor-parallel-size", type=int, default=1, help="model tensor parallel size" + ) + parser.add_argument("--use-te", action="store_true", help="Use Transformer Engine") + + args = parser.parse_args() + + print(args.input, args.output, args.tensor_parallel_size, args.use_te) + + convert(args.input, args.output, args.tensor_parallel_size, args.use_te) + + print("done.") diff --git a/tools/checkpoint/llava_onevision/convert_to_hf_qwen2.5_7b.py b/tools/checkpoint/llava_onevision/convert_to_hf_qwen2.5_7b.py new file mode 100644 index 000000000..2d4627b27 --- /dev/null +++ b/tools/checkpoint/llava_onevision/convert_to_hf_qwen2.5_7b.py @@ -0,0 +1,490 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +import argparse +import os +import json + +import torch + +from safetensors.torch import load_file, save_file + + +def check_model(name, model): + assert name in model, f"unexpected name {name}" + +def check_model_file(name, hf_models): + """ + Check which .safetensors file contains the specified name. + + :param file_dict: dict, where key is the file name and value is the data loaded from the .safetensors file. + :param name: str, the name of the weight to search for. + :return: str, the file name that contains the specified weight. + :raises: Exception if the specified name is not found in any file. + """ + for file_name, data in hf_models.items(): + if name in data.keys(): + return file_name + if "image_newline" not in name: + raise Exception(f"unexpected name {name}") + else: + import warnings + warnings.warn(f"unexpected name {name}") + return None + +def is_chunked(name, chunk_names): + for chunk_name in chunk_names: + if chunk_name in name: + return True + return False + + +def convert(input_path, output_path, use_te, tensor_parallel_size=2): + mc_models = [] + for i in range(tensor_parallel_size): + mc_model = torch.load( + os.path.join(input_path, f"mp_rank_{i:02d}", "model_optim_rng.pt"), + map_location="cpu", + ) + mc_models.append(mc_model) + mc_model = {} + mc_model_0 = mc_models[0] + llm_chunk_dim_0 = [ + "embedding.word_embeddings.weight", + "self_attention.linear_qkv.weight", + "self_attention.linear_qkv.bias", + "mlp.linear_fc1.weight", + "output_layer.weight", + ] + llm_chunk_dim_1 = ["self_attention.linear_proj.weight", "mlp.linear_fc2.weight"] + vision_chunk_dim_0 = [ + "self_attention.linear_qkv.weight", + "self_attention.linear_qkv.bias", + "mlp.linear_fc1.weight", + "mlp.linear_fc1.bias", + ] + vision_chunk_dim_1 = ["self_attention.linear_proj.weight", "mlp.linear_fc2.weight"] + mlp_chunk_dim_0 = ["encoder.linear_fc1.weight", "encoder.linear_fc1.bias"] + mlp_chunk_dim_1 = ["encoder.linear_fc2.weight"] + for name, param in mc_model_0["model"].items(): + # column parallel + if "_extra_state" in name: + continue + params = [model["model"][name] for model in mc_models] + if ( + is_chunked(name, llm_chunk_dim_0) + or is_chunked(name, vision_chunk_dim_0) + or is_chunked(name, mlp_chunk_dim_0) + ): + print(f"{name} concat in dim 0") + + mc_model[name] = torch.cat(params, dim=0) + elif ( + is_chunked(name, llm_chunk_dim_1) + or is_chunked(name, vision_chunk_dim_1) + or is_chunked(name, mlp_chunk_dim_1) + ): + print(f"{name} concat in dim 1") + mc_model[name] = torch.cat(params, dim=1) + else: + print(f"{name} without concat") + mc_model[name] = param + + # index.json + index_path = None + for file in os.listdir(output_path): + if file.endswith("index.json"): + index_path = os.path.join(output_path, file) + break + assert index_path is not None, "index.json not found in output path" + + with open(index_path, "r") as f: + weight_map = json.load(f)["weight_map"] + + hf_models = {} + for name in weight_map: + file_name = weight_map[name] + if file_name not in hf_models: + hf_models[file_name] = load_file( + os.path.join(output_path, file_name), device="cpu" + ) + + mc_args = mc_model_0["args"] + hidden_dim = mc_args.hidden_size + ffn_hidden_size = mc_args.ffn_hidden_size + num_heads = mc_args.num_attention_heads + kv_channels = hidden_dim // num_heads + num_query_groups = mc_args.num_query_groups + kv_projection_size = kv_channels * num_query_groups + + assert hidden_dim % num_heads == 0 + assert kv_channels == mc_args.kv_channels + assert num_heads % num_query_groups == 0 + + indices = [] + # Q + start = 0 + interval = kv_channels * num_heads // num_query_groups + 2 * kv_channels + for i in range(num_query_groups): + offset = i * interval + indices.append( + torch.arange( + start + offset, + start + offset + kv_channels * num_heads // num_query_groups, + dtype=torch.int, + ) + ) + # K + start = kv_channels * num_heads // num_query_groups + for i in range(num_query_groups): + offset = i * interval + indices.append( + torch.arange( + start + offset, + start + offset + kv_channels, + dtype=torch.int, + ) + ) + # V + start = kv_channels * num_heads // num_query_groups + kv_channels + for i in range(num_query_groups): + offset = i * interval + indices.append( + torch.arange( + start + offset, + start + offset + kv_channels, + dtype=torch.int, + ) + ) + indices = torch.cat(indices) + deorder_indices = indices + + gate_up_indices = [] + # Gate + start = 0 + for i in range(tensor_parallel_size): + offset = i * (ffn_hidden_size // tensor_parallel_size * 2) + gate_up_indices.append( + torch.arange( + start + offset, + start + offset + ffn_hidden_size // tensor_parallel_size, + dtype=torch.int, + ) + ) + + # UP + start = ffn_hidden_size // tensor_parallel_size + for i in range(tensor_parallel_size): + offset = i * (ffn_hidden_size // tensor_parallel_size * 2) + gate_up_indices.append( + torch.arange( + start + offset, + start + offset + ffn_hidden_size // tensor_parallel_size, + dtype=torch.int, + ) + ) + gate_up_indices = torch.cat(gate_up_indices) + deorder_gate_up_indices = gate_up_indices + + input_layer_norm_weight = ( + "input_layernorm.weight" + if not use_te + else "self_attention.linear_qkv.layer_norm_weight" + ) + input_layer_norm_bias = ( + "input_layernorm.bias" + if not use_te + else "self_attention.linear_qkv.layer_norm_bias" + ) + + post_attention_layer_norm_weight = ( + "pre_mlp_layernorm.weight" if not use_te else "mlp.linear_fc1.layer_norm_weight" + ) + + layer_norm_2_weight = ( + "pre_mlp_layernorm.weight" if not use_te else "mlp.linear_fc1.layer_norm_weight" + ) + layer_norm_2_bias = ( + "pre_mlp_layernorm.bias" if not use_te else "mlp.linear_fc1.layer_norm_bias" + ) + + for mc_name in mc_model: + print("mc_layer:", mc_name) + mc_tensor = mc_model[mc_name] + + # Language model mappings + if "image_newline" in mc_name: + file_name = check_model_file("model.image_newline", hf_models) + if file_name != None: + hf_models[file_name]["model.image_newline"] = mc_tensor + + if "language_model.embedding.word_embeddings.weight" in mc_name: + file_name = check_model_file("model.embed_tokens.weight", hf_models) + hf_models[file_name]["model.embed_tokens.weight"] = mc_tensor + elif "language_model.output_layer.weight" in mc_name: + file_name = check_model_file("lm_head.weight", hf_models) + hf_models[file_name]["lm_head.weight"] = mc_tensor + elif "language_model.decoder.final_layernorm.weight" in mc_name: + file_name = check_model_file("model.norm.weight", hf_models) + hf_models[file_name]["model.norm.weight"] = mc_tensor + elif "language_model.decoder.layers" in mc_name: + layer_idx = mc_name.split(".")[3] + base = f"model.layers.{layer_idx}" + if "self_attention.linear_qkv.weight" in mc_name: + # deorder_indices + mc_tensor = mc_tensor[deorder_indices] + qkv_weight = torch.split( + mc_tensor, [hidden_dim, kv_projection_size, kv_projection_size] + ) + file_name = check_model_file(f"{base}.self_attn.q_proj.weight", hf_models) + file_name = check_model_file(f"{base}.self_attn.k_proj.weight", hf_models) + file_name = check_model_file(f"{base}.self_attn.v_proj.weight", hf_models) + hf_models[file_name][f"{base}.self_attn.q_proj.weight"] = qkv_weight[0] + hf_models[file_name][f"{base}.self_attn.k_proj.weight"] = qkv_weight[1] + hf_models[file_name][f"{base}.self_attn.v_proj.weight"] = qkv_weight[2] + elif "self_attention.linear_qkv.bias" in mc_name: + # deorder_indices + mc_tensor = mc_tensor[deorder_indices] + qkv_bias = torch.split( + mc_tensor, [hidden_dim, kv_projection_size, kv_projection_size] + ) + file_name = check_model_file(f"{base}.self_attn.q_proj.bias", hf_models) + file_name = check_model_file(f"{base}.self_attn.k_proj.bias", hf_models) + file_name = check_model_file(f"{base}.self_attn.v_proj.bias", hf_models) + hf_models[file_name][f"{base}.self_attn.q_proj.bias"] = qkv_bias[0] + hf_models[file_name][f"{base}.self_attn.k_proj.bias"] = qkv_bias[1] + hf_models[file_name][f"{base}.self_attn.v_proj.bias"] = qkv_bias[2] + elif "self_attention.linear_proj.weight" in mc_name: + file_name = check_model_file(f"{base}.self_attn.o_proj.weight", hf_models) + hf_models[file_name][f"{base}.self_attn.o_proj.weight"] = mc_tensor + elif "self_attention.linear_proj.bias" in mc_name: + file_name = check_model_file(f"{base}.self_attn.o_proj.bias", hf_models) + hf_models[file_name][f"{base}.self_attn.o_proj.bias"] = mc_tensor + elif input_layer_norm_weight in mc_name: + file_name = check_model_file(f"{base}.input_layernorm.weight", hf_models) + hf_models[file_name][f"{base}.input_layernorm.weight"] = mc_tensor + elif "mlp.linear_fc1.weight" in mc_name: + mc_tensor = mc_tensor[deorder_gate_up_indices] + gate_up_weight = torch.split( + mc_tensor, [ffn_hidden_size, ffn_hidden_size] + ) + file_name = check_model_file(f"{base}.mlp.gate_proj.weight", hf_models) + file_name = check_model_file(f"{base}.mlp.up_proj.weight", hf_models) + hf_models[file_name][f"{base}.mlp.gate_proj.weight"] = gate_up_weight[0] + hf_models[file_name][f"{base}.mlp.up_proj.weight"] = gate_up_weight[1] + elif "mlp.linear_fc2.weight" in mc_name: + file_name = check_model_file(f"{base}.mlp.down_proj.weight", hf_models) + hf_models[file_name][f"{base}.mlp.down_proj.weight"] = mc_tensor + + elif post_attention_layer_norm_weight in mc_name: + file_name = check_model_file(f"{base}.post_attention_layernorm.weight", hf_models) + hf_models[file_name][f"{base}.post_attention_layernorm.weight"] = mc_tensor + + else: + raise ValueError(f"{name} is not converted.") + + # Indices from mapping pytorch multihead attention to megatron. + hidden_dim = 1152 + num_heads = 16 + kv_channels = hidden_dim // num_heads + # Because the visual tower does not have GQA, num_query_groups=num_ heads + num_query_groups = num_heads + kv_projection_size = kv_channels * num_query_groups + indices = [] + # Q + start = 0 + interval = kv_channels * 3 + for i in range(num_query_groups): + offset = interval * i + indices.append( + torch.arange( + start + offset, + start + offset + kv_channels, + dtype=torch.int, + ) + ) + # K + start = kv_channels + for i in range(num_query_groups): + offset = interval * i + indices.append( + torch.arange( + start + offset, + start + offset + kv_channels, + dtype=torch.int, + ) + ) + # V + start = kv_channels * 2 + for i in range(num_query_groups): + offset = interval * i + indices.append( + torch.arange( + start + offset, + start + offset + kv_channels, + dtype=torch.int, + ) + ) + indices = torch.cat(indices) + deorder_indices = indices + + for mc_name in mc_model: + print("mc_layer:", mc_name) + mc_tensor = mc_model[mc_name] + + # vision_model + hf_base_name = "model.vision_tower.vision_tower.vision_model" + + if "vision_model.position_embeddings.weight" in mc_name: + file_name = check_model_file( + f"{hf_base_name}.embeddings.position_embedding.weight", hf_models + ) + hf_models[file_name][f"{hf_base_name}.embeddings.position_embedding.weight"] = mc_tensor + + elif "vision_model.ln_post.weight" in mc_name: + file_name = check_model_file(f"{hf_base_name}.post_layernorm.weight", hf_models) + hf_models[file_name][f"{hf_base_name}.post_layernorm.weight"] = mc_tensor + + elif "vision_model.ln_post.bias" in mc_name: + file_name = check_model_file(f"{hf_base_name}.post_layernorm.bias", hf_models) + hf_models[file_name][f"{hf_base_name}.post_layernorm.bias"] = mc_tensor + + elif "vision_model.conv1.weight" in mc_name: + file_name = check_model_file(f"{hf_base_name}.embeddings.patch_embedding.weight", hf_models) + hf_models[file_name][f"{hf_base_name}.embeddings.patch_embedding.weight"] = mc_tensor + + elif "vision_model.conv1.bias" in mc_name: + file_name = check_model_file(f"{hf_base_name}.embeddings.patch_embedding.bias", hf_models) + hf_models[file_name][f"{hf_base_name}.embeddings.patch_embedding.bias"] = mc_tensor + + elif "vision_model.decoder.layers" in mc_name: + layer_idx = mc_name.split(".")[3] + base = f"model.vision_tower.vision_tower.vision_model.encoder.layers.{layer_idx}" + + if "self_attention.linear_qkv.weight" in mc_name: + mc_tensor = mc_tensor[deorder_indices] + qkv_weight = torch.split( + mc_tensor, [hidden_dim, kv_projection_size, kv_projection_size] + ) + file_name = check_model_file(f"{base}.self_attn.q_proj.weight", hf_models) + file_name = check_model_file(f"{base}.self_attn.k_proj.weight", hf_models) + file_name = check_model_file(f"{base}.self_attn.v_proj.weight", hf_models) + hf_models[file_name][f"{base}.self_attn.q_proj.weight"] = qkv_weight[0] + hf_models[file_name][f"{base}.self_attn.k_proj.weight"] = qkv_weight[1] + hf_models[file_name][f"{base}.self_attn.v_proj.weight"] = qkv_weight[2] + + elif "self_attention.linear_qkv.bias" in mc_name: + mc_tensor = mc_tensor[deorder_indices] + qkv_bias = torch.split( + mc_tensor, [hidden_dim, kv_projection_size, kv_projection_size] + ) + file_name = check_model_file(f"{base}.self_attn.q_proj.bias", hf_models) + file_name = check_model_file(f"{base}.self_attn.k_proj.bias", hf_models) + file_name = check_model_file(f"{base}.self_attn.v_proj.bias", hf_models) + hf_models[file_name][f"{base}.self_attn.q_proj.bias"] = qkv_bias[0] + hf_models[file_name][f"{base}.self_attn.k_proj.bias"] = qkv_bias[1] + hf_models[file_name][f"{base}.self_attn.v_proj.bias"] = qkv_bias[2] + + elif "self_attention.linear_proj.weight" in mc_name: + file_name = check_model_file(f"{base}.self_attn.out_proj.weight", hf_models) + hf_models[file_name][f"{base}.self_attn.out_proj.weight"] = mc_tensor + + elif "self_attention.linear_proj.bias" in mc_name: + file_name = check_model_file(f"{base}.self_attn.out_proj.bias", hf_models) + hf_models[file_name][f"{base}.self_attn.out_proj.bias"] = mc_tensor + + elif input_layer_norm_weight in mc_name: + file_name = check_model_file(f"{base}.layer_norm1.weight", hf_models) + hf_models[file_name][f"{base}.layer_norm1.weight"] = mc_tensor + + elif input_layer_norm_bias in mc_name: + file_name = check_model_file(f"{base}.layer_norm1.bias", hf_models) + hf_models[file_name][f"{base}.layer_norm1.bias"] = mc_tensor + + elif "mlp.linear_fc1.weight" in mc_name: + file_name = check_model_file(f"{base}.mlp.fc1.weight", hf_models) + hf_models[file_name][f"{base}.mlp.fc1.weight"] = mc_tensor + + elif "mlp.linear_fc1.bias" in mc_name: + file_name = check_model_file(f"{base}.mlp.fc1.bias", hf_models) + hf_models[file_name][f"{base}.mlp.fc1.bias"] = mc_tensor + + elif "mlp.linear_fc2.weight" in mc_name: + file_name = check_model_file(f"{base}.mlp.fc2.weight", hf_models) + hf_models[file_name][f"{base}.mlp.fc2.weight"] = mc_tensor + + elif "mlp.linear_fc2.bias" in mc_name: + file_name = check_model_file(f"{base}.mlp.fc2.bias", hf_models) + hf_models[file_name][f"{base}.mlp.fc2.bias"] = mc_tensor + + elif layer_norm_2_weight in mc_name: + file_name = check_model_file(f"{base}.layer_norm2.weight", hf_models) + hf_models[file_name][f"{base}.layer_norm2.weight"] = mc_tensor + + elif layer_norm_2_bias in mc_name: + file_name = check_model_file(f"{base}.layer_norm2.bias", hf_models) + hf_models[file_name][f"{base}.layer_norm2.bias"] = mc_tensor + + else: + raise ValueError(f"{name} is not converted.") + + # vision_projection + file_name = check_model_file(f"model.mm_projector.0.weight", hf_models) + hf_models[file_name]["model.mm_projector.0.weight"] = mc_model[ + "vision_projection.encoder.linear_fc1.weight" + ] + file_name = check_model_file(f"model.mm_projector.0.bias", hf_models) + hf_models[file_name]["model.mm_projector.0.bias"] = mc_model[ + "vision_projection.encoder.linear_fc1.bias" + ] + file_name = check_model_file(f"model.mm_projector.2.weight", hf_models) + hf_models[file_name]["model.mm_projector.2.weight"] = mc_model[ + "vision_projection.encoder.linear_fc2.weight" + ] + file_name = check_model_file(f"model.mm_projector.2.bias", hf_models) + hf_models[file_name]["model.mm_projector.2.bias"] = mc_model[ + "vision_projection.encoder.linear_fc2.bias" + ] + + metadata = {"format": "pt"} + # Ensure the output directory exists + os.makedirs(output_path, exist_ok=True) + + # Iterate through hf_models and save each value with metadata + for file_name, data in hf_models.items(): + file_path = os.path.join(output_path, file_name) + save_file(data, file_path, metadata=metadata) # save_file is assumed to accept metadata + print(f"Saved {file_name} to {file_path} with metadata: {metadata}") + + print(f"All files saved successfully with metadata in {output_path}") + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description=""" +Convert Qwen2 7b weights to hugging face format. + + +Example usage: +python convert_to_hf_qwen2.5_7b.py --input /some/input/folder --output /some/output/folder --tensor-parallel-size 4 +""", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + parser.add_argument("--input", type=str, required=True, help="megatron ckpt folder") + parser.add_argument( + "--output", + type=str, + required=True, + help="output directory for hugging face state dict file(s)", + ) + parser.add_argument( + "--tensor-parallel-size", type=int, default=1, help="model tensor parallel size" + ) + parser.add_argument("--use-te", action="store_true", help="Use Transformer Engine") + + args = parser.parse_args() + + convert(args.input, args.output, args.use_te, args.tensor_parallel_size) + + print("done.") diff --git a/tools/datasets/llava_onevision/llava_ov_wds.py b/tools/datasets/llava_onevision/llava_ov_wds.py index 8fa0f5d50..bf2e9cc27 100755 --- a/tools/datasets/llava_onevision/llava_ov_wds.py +++ b/tools/datasets/llava_onevision/llava_ov_wds.py @@ -114,9 +114,10 @@ class ModelArguments: use_pos_skipping: Optional[bool] = field(default=False) pos_skipping_range: Optional[int] = field(default=4096) - mm_newline_position: Optional[str] = field(default="one_token") - + add_faster_video: Optional[bool] = field(default=False) + faster_token_stride: Optional[int] = field(default=10) + delay_load: Optional[bool] = field(default=True) @dataclass class DataArguments: @@ -133,7 +134,8 @@ class DataArguments: video_folder: Optional[str] = field(default=None) video_fps: Optional[int] = field(default=1) frames_upbound: Optional[int] = field(default=0) - + add_time_instruction: Optional[bool] = field(default=False) + force_sample: Optional[bool] = field(default=False) @dataclass class TrainingArguments(transformers.TrainingArguments): @@ -570,7 +572,7 @@ def preprocess_qwen(sources, tokenizer: transformers.PreTrainedTokenizer, has_im tokenizer.add_tokens([""], special_tokens=True) image_token_index = tokenizer.convert_tokens_to_ids("") - im_start, im_end = tokenizer.additional_special_tokens_ids + im_start, im_end = tokenizer.additional_special_tokens_ids[:2] # unmask_tokens = ["<|im_start|>", "<|im_start|>", "\n"] unmask_tokens_idx = [198, im_start, im_end] nl_tokens = tokenizer("\n").input_ids @@ -614,9 +616,7 @@ def preprocess_qwen(sources, tokenizer: transformers.PreTrainedTokenizer, has_im target += [IGNORE_INDEX] * len(encode_id) else: target += encode_id - - assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}" for idx, encode_id in enumerate(input_id): if encode_id in unmask_tokens_idx: @@ -700,8 +700,6 @@ def safe_tokenizer_llama3(text): target += [IGNORE_INDEX] * len(encode_id) else: target += encode_id - - assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}" for idx, encode_id in enumerate(input_id): @@ -1042,7 +1040,9 @@ def _get_item(self, sources) -> Dict[str, torch.Tensor]: print("File {} not exist!".format(video_file)) try: - if "shareVideoGPTV" in video_file: + print("video_file: ", video_file) + if "shareVideoGPTV" in video_file or "M4-Instruct-Videos" in video_file: + print("video_file: ", video_file) frame_files = [os.path.join(video_file, f) for f in os.listdir(video_file) if os.path.isfile(os.path.join(video_file, f))] frame_files.sort() # Ensure the frames are sorted if they are named sequentially @@ -1121,7 +1121,6 @@ def _get_item(self, sources) -> Dict[str, torch.Tensor]: return data_dict - class LazySupervisedDataset(Dataset): def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, data_args: DataArguments): super(LazySupervisedDataset, self).__init__() @@ -1329,7 +1328,7 @@ def _get_item(self, i) -> Dict[str, torch.Tensor]: print("File {} not exist!".format(video_file)) try: - if "shareVideoGPTV" in video_file: + if "shareVideoGPTV" in video_file or "M4-Instruct-Videos" in video_file: frame_files = [os.path.join(video_file, f) for f in os.listdir(video_file) if os.path.isfile(os.path.join(video_file, f))] frame_files.sort() # Ensure the frames are sorted if they are named sequentially @@ -1349,7 +1348,7 @@ def _get_item(self, i) -> Dict[str, torch.Tensor]: except IOError: print(f"Failed to read frame at path: {frame_path}") else: - video = process_video_with_decord(video_file, self.data_args) + video, video_time, frame_time, num_frames_to_sample = process_video_with_decord(video_file, self.data_args) processor = self.data_args.image_processor image = processor.preprocess(video, return_tensors="pt")["pixel_values"] @@ -1850,6 +1849,7 @@ def make_inputs_require_grad(module, input, output): start_time = time.time() with wds.ShardWriter(os.path.join(output, f'llava-ov-{dist.get_rank()}-%d.tar'), maxcount=10000) as shard_writer: dataloader = trainer.get_train_dataloader() + print(f"sample num: {len(dataloader)}") global_id = 0 for entry in tqdm(dataloader): if global_id == 0: @@ -1860,22 +1860,59 @@ def make_inputs_require_grad(module, input, output): sequence = [] sequence.append(entry['input_ids'][0].cpu()) sequence.append(entry['labels'][0].cpu()) - modalities = 'image' - if 'images' in entry and entry['images'] is not None: - modalities = 'image' - images = entry['images'][0].cpu() - images_shape = list(images.shape) - if len(images_shape) == 3: - images = images.reshape([1, *images_shape]) - sequence.append(images) - sequence.append(torch.tensor(entry['image_sizes'][0])) - else: - modalities = 'text' - sequence.append(torch.tensor([0])) - sequence.append(torch.tensor([0])) - # awk(5) - sequence.append(modalities) + assert 'images' in entry + # single image or video + multi_images = False + if len(entry['images']) > 1: + assert len(entry['images']) == len(entry['image_sizes']) + assert len(entry['images']) == len(entry['modalities']) + multi_images = True + + if not multi_images: + images = entry['images'][0].cpu() + images_shape = list(images.shape) + if len(images_shape) == 3: + images = images.reshape([1, *images_shape]) + sequence.append([images]) + sequence.append([torch.tensor(entry['image_sizes'][0])]) + sequence.append([entry['modalities'][0]]) + if entry['modalities'][0] == "video": + print(f"Processing video and image_sizes: {entry['image_sizes'][0]}, {images.shape}") + elif entry['modalities'][0] == "text": + print("Processing text.") + elif entry['modalities'][0] == "image": + print("Processing single image.") + else: + raise ValueError() + else: + # Process images + images = [] + each_image_shape = None + for image in entry['images']: + image_cpu = image.cpu() + image_shape = list(image_cpu.shape) + if not each_image_shape: + each_image_shape = image_shape + # Image shape should be the same when in multi images scene + assert each_image_shape == image_shape + if len(image_shape) == 3: + image_cpu = image_cpu.reshape([1, *image_shape]) + images.append(image_cpu) + + # Process image_sizes + image_sizes = [] + for image_size in entry['image_sizes']: + image_sizes.append(torch.tensor(image_size)) + + # Process modalities + modalities = [] + for modality in entry['modalities']: + modalities.append(modality) + + sequence.append(images) + sequence.append(image_sizes) + sequence.append(modalities) sample = { "__key__": str(global_id), @@ -1885,7 +1922,7 @@ def make_inputs_require_grad(module, input, output): shard_writer.write(sample) global_id += 1 - rank0_print(f"Datasets saved to {training_args.output_dir}") + print(f"rank {dist.get_rank()} datasets saved to {training_args.output_dir}") if __name__ == "__main__": diff --git a/tools/datasets/llava_onevision/make_llava_ov_wds.sh b/tools/datasets/llava_onevision/make_llava_ov_wds.sh index 1f2b977b9..76f7fbe98 100644 --- a/tools/datasets/llava_onevision/make_llava_ov_wds.sh +++ b/tools/datasets/llava_onevision/make_llava_ov_wds.sh @@ -1,3 +1,4 @@ +#!/bin/bash export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 export CUDA_DEVICE_MAX_CONNECTIONS=1 export NCCL_SOCKET_IFNAME=eth0 @@ -14,8 +15,8 @@ LLaVA-NeXT-HOME="Path_Of_LLaVA-NeXT" VISION_MODEL_PATH="Path_Of_VISION_MODEL" PROMPT_VERSION="qwen_1_5" -# stage1 -image_aspect_ratio=square +# # stage1 +# image_aspect_ratio=square # other stages image_aspect_ratio=anyres_max_9 @@ -33,7 +34,7 @@ CKPT_PATH="./checkpoints" mkdir -p $CKPT_PATH mkdir -p $EXPNAME_PATH LOGFILE=$EXPNAME_PATH/exp.log -i=0 +i=1 NNodes=`wc -l ${HOSTFILE} | cut -d " " -f1` MASTER_ADDR=`head -n 1 ${HOSTFILE} | cut -d " " -f1` echo "Master node: ${MASTER_ADDR}" @@ -46,10 +47,12 @@ do echo "Starting node ${i}/${NNodes}: ${ip}" ssh $ip \ "cd ${PWD} && \ + sysctl fs.inotify.max_user_watches=524288 && \ export WANDB_MODE=offline && \ export ACCELERATE_CPU_AFFINITY=1 && \ - export PYTHONPATH=$LLaVA-NeXT-HOME:$PYTHONPATH && \ - torchrun --nproc_per_node=4 --nnodes=${NNodes} --node_rank=${i} --master_addr=${MASTER_ADDR} --master_port=29513 llava_ov_wds.py \ + export PYTHONPATH=$LLaVA_NeXT_HOME:$PYTHONPATH && \ + source /root/miniconda3/bin/activate flagscale && \ + torchrun --nproc_per_node=8 --nnodes=${NNodes} --node_rank=${i} --master_addr=${MASTER_ADDR} --master_port=13888 llava_ov_wds.py \ --model_name_or_path ${CKPT_PATH} \ --version ${PROMPT_VERSION} \ --data_path $DATA_PATH \ @@ -92,6 +95,6 @@ do --dataloader_drop_last True \ --seed 42 \ --do_train False \ - --frames_upbound 32 1>>$LOGFILE.$ip 2>&1" & + --frames_upbound 32 1>$LOGFILE.$ip 2>&1" & i=`expr $i + 1` done