From 8d0da9a3089a5f409d545b1848c762a01e2aff2f Mon Sep 17 00:00:00 2001 From: xffxff <1247714429@qq.com> Date: Fri, 22 Nov 2024 04:12:44 +0000 Subject: [PATCH 01/14] add support for aria model Signed-off-by: xffxff <1247714429@qq.com> --- docs/source/models/supported_models.rst | 6 + vllm/model_executor/models/aria.py | 754 ++++++++++++++++++++++++ vllm/model_executor/models/registry.py | 1 + vllm/transformers_utils/configs/aria.py | 47 ++ 4 files changed, 808 insertions(+) create mode 100644 vllm/model_executor/models/aria.py create mode 100644 vllm/transformers_utils/configs/aria.py diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index e902d393f2f70..f02e8c232c7e3 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -424,6 +424,12 @@ Text Generation - Example HF Models - :ref:`LoRA ` - :ref:`PP ` + * - :code:`AriaForConditionalGeneration` + - Aria + - T + I + - :code:`rhymes-ai/Aria-sequential_mlp` + - + - ✅︎ * - :code:`Blip2ForConditionalGeneration` - BLIP-2 - T + I\ :sup:`E` diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py new file mode 100644 index 0000000000000..106d1cad318f9 --- /dev/null +++ b/vllm/model_executor/models/aria.py @@ -0,0 +1,754 @@ +import math +from typing import Iterable, List, Optional, Set, Tuple + +import torch +import torch.nn as nn +from torch.nn.init import trunc_normal_ +from transformers import LlamaConfig + +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig, QuantizationConfig, VllmConfig +from vllm.inputs import INPUT_REGISTRY, token_inputs +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + get_compressed_tensors_cache_scale) +from vllm.model_executor.layers.sampler import (Sampler, SamplerOutput, + SamplingMetadata) +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.models.idefics2_vision_model import ( + Idefics2VisionTransformer) +from vllm.model_executor.models.interfaces import SupportsMultiModal +from vllm.model_executor.models.llama import (LlamaAttention, + LlamaDecoderLayer, LlamaMLP, + LlamaModel) +from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, + is_pp_missing_parameter, + make_layers, maybe_prefix, + merge_multimodal_embeddings) +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.base import MultiModalInputs +from vllm.multimodal.image import cached_get_image_processor +from vllm.multimodal.utils import (cached_get_tokenizer, + repeat_and_pad_placeholder_tokens) +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.aria import (AriaMoELMConfig, + AriaVisionConfig) + + +class AriaVisionTransformer(Idefics2VisionTransformer): + """ + AriaVisionTransformer is a modified version of Idefics2VisionTransformer + that replaces the post-layernorm with an identity layer. + """ + + def __init__( + self, + config: AriaVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(config, quant_config, prefix) + self.post_layernorm = nn.Identity() + + +class AriaVisionModel(nn.Module): + config_class = AriaVisionConfig + + def __init__( + self, + config: AriaVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + *, + prefix: str = "", + ) -> None: + super().__init__() + + self.vision_model = AriaVisionTransformer( + config, + quant_config, + prefix=f"{prefix}.vision_model", + ) + + def forward( + self, + pixel_values: torch.Tensor, + pixel_mask: Optional[torch.BoolTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.BoolTensor]]: + patch_attention_mask = self._create_patch_attention_mask(pixel_mask) + + vit_oup = self.vision_model( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + ) + + image_atts = self._create_image_attention_mask(patch_attention_mask) + + return vit_oup, image_atts + + def _create_patch_attention_mask(self, pixel_mask): + if pixel_mask is None: + return None + + patches_subgrid = pixel_mask.unfold( + dimension=1, + size=self.vision_model.config.patch_size, + step=self.vision_model.config.patch_size, + ).unfold( + dimension=2, + size=self.vision_model.config.patch_size, + step=self.vision_model.config.patch_size, + ) + return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + + def _create_image_attention_mask(self, patch_attention_mask): + if patch_attention_mask is None: + return None + + flattened_mask = patch_attention_mask.flatten(1) + return torch.logical_not(flattened_mask) + + +class FFN(nn.Module): + + def __init__(self, embed_dim, ff_dim, output_dim): + super().__init__() + self.linear_in = ColumnParallelLinear(embed_dim, ff_dim, bias=False) + self.linear_out = RowParallelLinear(ff_dim, output_dim, bias=False) + self.act = get_act_fn("gelu_new") + + def forward(self, hidden_states): + hidden_states, _ = self.linear_in(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.linear_out(hidden_states) + return hidden_states + + +class CrossAttention(nn.Module): + + def __init__(self, kv_dim, embed_dim, num_heads, drop_out_rate=0): + super().__init__() + self.num_heads = num_heads + self.q_proj = ColumnParallelLinear(embed_dim, embed_dim, bias=False) + self.kv_proj = MergedColumnParallelLinear(kv_dim, + [embed_dim, embed_dim], + bias=False) + + self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) + self.linear = RowParallelLinear(embed_dim, embed_dim) + self.dropout = nn.Dropout(drop_out_rate) + + self.layer_norm = nn.LayerNorm(embed_dim) + self.ln_kv = nn.LayerNorm(kv_dim) + + def forward(self, x, hidden_states, attn_mask=None, add_residual=False): + normed_hidden_states = self.layer_norm(hidden_states) + query = self.q_proj(normed_hidden_states)[0].permute(1, 0, 2) + + x = self.ln_kv(x) + key_value = self.kv_proj(x)[0].permute(1, 0, 2) + key, value = key_value.chunk(2, dim=-1) + + attn_output, _ = self.multihead_attn(query, + key, + value, + attn_mask=attn_mask) + + attn_output = attn_output.permute(1, 0, 2) + + if add_residual: + attn_output = hidden_states + self.dropout( + self.linear(attn_output)[0]) + else: + attn_output = self.dropout(self.linear(attn_output)[0]) + + return attn_output + + +class AriaProjector(nn.Module): + """ + A projection module with one cross attention layer and one FFN layer, which + projects ViT's outputs into MoE's inputs. + + Args: + patch_to_query_dict (dict): Maps patch numbers to their corresponding + query numbers, + e.g., {1225: 128, 4900: 256}. This allows for different query sizes + based on image resolution. + embed_dim (int): Embedding dimension. + num_heads (int): Number of attention heads. + kv_dim (int): Dimension of key and value. + ff_dim (int): Hidden dimension of the feed-forward network. + output_dim (int): Output dimension. + norm_layer (nn.Module): Normalization layer. Default is nn.LayerNorm. + + Outputs: + A tensor with the shape of (batch_size, query_number, output_dim) + """ + + def __init__( + self, + patch_to_query_dict, + embed_dim, + num_heads, + kv_dim, + ff_dim, + output_dim, + norm_layer=nn.LayerNorm, + ): + super().__init__() + self.patch_to_query_dict = patch_to_query_dict + self.embed_dim = embed_dim + self.num_heads = num_heads + + self.query = nn.Parameter( + torch.zeros(max(patch_to_query_dict.values()), self.embed_dim)) + + trunc_normal_(self.query, std=0.02) + + self.cross_attn = CrossAttention(kv_dim, embed_dim, num_heads) + + self.ln_ffn = norm_layer(embed_dim) + self.ffn = FFN(embed_dim, ff_dim, output_dim) + + def forward(self, x, attn_mask=None): + bs = x.shape[0] + queries = self.query.unsqueeze(0).repeat(bs, 1, 1) + + query_num = self.patch_to_query_dict.get(x.shape[1], None) + assert (query_num is not None + ), f"Query number for {x.shape[1]} patches is not provided" + + queries = queries[:, :query_num, :] + + if attn_mask is not None: + attn_mask = attn_mask.repeat_interleave(self.num_heads, 0) + attn_mask = attn_mask.unsqueeze(1).expand(-1, queries.size(1), -1) + + attention_out = self.cross_attn(x, queries, attn_mask=attn_mask) + + out = self.ffn(self.ln_ffn(attention_out)) + + return out + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".kv_proj", ".k_proj", 0), + (".kv_proj", ".v_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + loaded_params.add(name) + return loaded_params + + +class MoELayer(nn.Module): + """ + Mixture of Experts (MoE) Layer for the AriaMoE model. + + This layer implements the MoE mechanism, which routes input tokens to + different experts based on a routing algorithm, processes them through the + experts, and then combines the outputs. + """ + + def __init__( + self, + config: AriaMoELMConfig, + quant_config: Optional[QuantizationConfig], + lora_config: Optional[LoRAConfig], + ) -> None: + super().__init__() + self.config = config + + self.router_weight = nn.Parameter( + torch.empty( + (self.config.moe_num_experts, self.config.hidden_size))) + + self.experts = FusedMoE( + num_experts=config.moe_num_experts, + top_k=config.moe_topk, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + quant_config=quant_config, + ) + self.shared_experts = LlamaMLP( + config.hidden_size, + config.moe_intermediate_size * config.moe_num_shared_experts, + "silu", + quant_config=quant_config, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the MoE Layer. + + Args: + hidden_states (torch.Tensor): Input tensor of shape (batch_size, + sequence_length, hidden_size). + + Returns: + torch.Tensor: Output tensor after passing through the MoE layer. + """ + + router_output = torch.nn.functional.linear(hidden_states, + self.router_weight) + + shared_expert_output = self.shared_experts(hidden_states) + sparse_expert_output = self.experts(hidden_states, router_output) + + return sparse_expert_output + shared_expert_output + + +class MoEDecoderLayer(LlamaDecoderLayer): + """ + Custom Decoder Layer for the AriaMoE model which modifies the standard + `LlamaDecoderLayer` by replacing the traditional MLP with a Mixture of + Experts (MoE) Layer. + """ + + def __init__( + self, + config: LlamaConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + prefix: str = "", + ) -> None: + nn.Module.__init__(self) + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + # Support abacusai/Smaug-72B-v0.1 with attention_bias + # Support internlm/internlm-7b with bias + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "bias", False) + self.self_attn = LlamaAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr(config, "num_key_value_heads", + config.num_attention_heads), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=attention_bias, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + ) + self.mlp = MoELayer(config, + quant_config=quant_config, + lora_config=lora_config) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + +class AriaMoELMModel(LlamaModel): + """ + Custom LlamaModel for the AriaMoE model which modifies the standard + LlamaModel by replacing the `LlamaDecoderLayer` with `MoEDecoderLayer`. + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + # FIXME: this is a hack to disable the compilation of the model + self.do_not_compile = True + + self.layers = None + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: MoEDecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=f"{prefix}.layers", + ) + + # Adapted from FusedMoE.make_expert_params_mapping with the modification + # of changing the prefix of the weight names + def _make_expert_params_mapping( + self, ckpt_gate_proj_name: str, ckpt_down_proj_name: str, + ckpt_up_proj_name: str, + num_experts: int) -> List[Tuple[str, str, int, str]]: + + return [ + # (param_name, weight_name, expert_id, shard_id) + ("experts.w13_" if weight_name + in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_", + f"experts.experts.{expert_id}.{weight_name}.", expert_id, shard_id + ) for expert_id in range(num_experts) + for shard_id, weight_name in [ + ("w1", ckpt_gate_proj_name), + ("w2", ckpt_down_proj_name), + ("w3", ckpt_up_proj_name), + ] + ] + + # Adapted from LlamaModel.load_weights with the modification of adding the + # expert_params_mapping + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = self._make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.moe_num_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + if scale_name := get_compressed_tensors_cache_scale(name): + # Loading kv cache scales for compressed-tensors quantization + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = loaded_weight[0] + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + # We have mlp.experts.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, we + # need to skip here BEFORE we update the name, otherwise name + # will be updated to mlp.experts.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping for + # mlp.experts.experts[0].gate_gate_up_proj, which breaks load. + if (("mlp.experts.experts." in name) + and name not in params_dict): + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +def build_mm_projector(config): + return AriaProjector( + patch_to_query_dict=config.projector_patch_to_query_dict, + embed_dim=config.vision_config.hidden_size, + num_heads=config.vision_config.num_attention_heads, + kv_dim=config.vision_config.hidden_size, + ff_dim=config.text_config.hidden_size, + output_dim=config.text_config.hidden_size, + ) + + +def get_max_multimodal_tokens(ctx): + return max(ctx.model_config.hf_config.image_size2tokens.values()) + + +def input_mapper_for_aria(ctx, data): + return MultiModalInputs(data) + + +def repeat_image_tokens(token_ids: list, image_token_id: int, + repeat_times: list) -> list: + """ + Repeats the image token in the token_ids list according to the repeat_times + list. + + Args: + token_ids (list): List of token IDs. + image_token_id (int): The token ID that represents an image. + repeat_times (list): List of integers specifying how many times to + repeat the image token. + + Returns: + list: A new list with the image token repeated as specified. + + Example: + token_ids = [1, 2, 3, 4, 3, 5] + image_token_id = 3 + repeat_times = [2, 3] + result = repeat_image_tokens(token_ids, image_token_id, repeat_times) + # result will be [1, 2, 3, 3, 4, 3, 3, 3, 5] + """ + if len(repeat_times) != token_ids.count(image_token_id): + raise ValueError( + "The length of repeat_times is not equal to the number of images.") + + result = [] + repeat_iter = iter(repeat_times) + + for x in token_ids: + if x == image_token_id: + result.extend([image_token_id] * next(repeat_iter)) + else: + result.append(x) + + return result + + +def input_processor(ctx, llm_inputs): + multi_modal_data = llm_inputs.get("multi_modal_data") + # if it is pure text input, use it as is + if multi_modal_data is None or "image" not in multi_modal_data: + return llm_inputs + + model_config = ctx.model_config + + tokenizer = cached_get_tokenizer(model_config.tokenizer) + image_processor = cached_get_image_processor( + model_config.model, trust_remote_code=model_config.trust_remote_code) + hf_config = model_config.hf_config + + # prepare image tokens, the max_image_size is used to determine the number + # of patch_size for every image + max_image_size = multi_modal_data.pop("max_image_size", 980) + _split_image = multi_modal_data.pop("split_image", False) + + assert isinstance(max_image_size, + (int, float)), "max_image_size should be float or int" + images = (multi_modal_data["image"] if isinstance( + multi_modal_data["image"], list) else [multi_modal_data["image"]]) + + image_inputs = image_processor.preprocess(images, + max_image_size=max_image_size, + split_image=_split_image, + return_tensors="pt").data + num_crops = image_inputs.pop("num_crops") + + prompt_token_ids = llm_inputs["prompt_token_ids"] + prompt_token_ids = repeat_image_tokens(prompt_token_ids, + hf_config.image_token_index, + num_crops) + + repeat_count = [hf_config.image_size2tokens[max_image_size] + ] * sum(num_crops).item() + new_prompt, new_token_ids, _ = repeat_and_pad_placeholder_tokens( + tokenizer, + None, + prompt_token_ids, + placeholder_token_id=hf_config.image_token_index, + repeat_count=repeat_count, + ) + + return token_inputs( + prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data={"image": image_inputs}, + ) + + +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_multimodal_tokens) +@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_aria) +@INPUT_REGISTRY.register_input_processor(input_processor) +class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): + """ + Aria model for conditional generation tasks. + + This model combines a vision tower, a multi-modal projector, and a language + model to perform tasks that involve both image and text inputs. + """ + + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + ): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + + # prepare the image_size to tokens mapping for the image preprocess, see + # input_processor + config.image_size2tokens = { + int(math.sqrt(k) * config.vision_config.patch_size): v + for k, v in config.projector_patch_to_query_dict.items() + } + self.config = config + self.vision_tower = AriaVisionModel(config.vision_config) + self.multi_modal_projector = build_mm_projector(config) + self.vocab_size = config.text_config.vocab_size + self.language_model = AriaMoELMModel( + vllm_config=vllm_config.with_hf_config(config.text_config), + prefix=maybe_prefix(prefix, "language_model.model"), + ) + self.pad_token_id = (self.config.pad_token_id + if self.config.pad_token_id is not None else -1) + self.unpadded_vocab_size = config.text_config.vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.text_config.hidden_size, + org_num_embeddings=self.language_model.org_vocab_size, + quant_config=quant_config, + ) + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + self.vocab_size, logit_scale) + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs: object, + ): + # 1. Extra the input embeddings + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + pixel_values = kwargs.get("pixel_values", None) + pixel_mask = kwargs.get("pixel_mask", None) + + # 2. Merge text and images + if pixel_values is not None: + pixel_values = pixel_values.view(-1, *pixel_values.shape[-3:]).to( + torch.bfloat16) + pixel_mask = pixel_mask.view(-1, *pixel_mask.shape[-2:]) + selected_image_feature, image_attn_mask = self.vision_tower( + pixel_values, + pixel_mask=pixel_mask, + ) + + image_features = self.multi_modal_projector( + selected_image_feature, attn_mask=image_attn_mask) + + inputs_embeds = inputs_embeds.to(image_features.dtype) + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, image_features, + self.config.image_token_index) + + hidden_states = self.language_model( + input_ids, + positions, + kv_caches, + attn_metadata, + None, + inputs_embeds=inputs_embeds, + ) + + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "language_model.model": "language_model", + "language_model.lm_head": "lm_head", + }, + orig_to_new_suffix={ + "router.weight": "router_weight", + }, + ) + + loader = AutoWeightsLoader(self) + loader.load_weights(weights, mapper=hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 22c2e328bfb65..5b1ab7448dcc7 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -123,6 +123,7 @@ _MULTIMODAL_MODELS = { # [Decoder-only] + "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"), "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"), "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501 "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"), diff --git a/vllm/transformers_utils/configs/aria.py b/vllm/transformers_utils/configs/aria.py new file mode 100644 index 0000000000000..d253da0d96a34 --- /dev/null +++ b/vllm/transformers_utils/configs/aria.py @@ -0,0 +1,47 @@ +from transformers.models.idefics2.configuration_idefics2 import ( + Idefics2VisionConfig) +from transformers.models.llama.configuration_llama import LlamaConfig + + +class AriaVisionConfig(Idefics2VisionConfig): + model_type = "aria_vision_model" + + +class AriaMoELMConfig(LlamaConfig): + """ + Configuration class for AriaMoE language model. + + This class extends the LlamaConfig to include additional parameters specific + to the Mixture of Experts (MoE) architecture. + """ + + model_type = "aria_moe_lm" + + def __init__( + self, + moe_intermediate_size: int = 4096, + moe_num_experts: int = 8, + moe_topk: int = 2, + moe_num_shared_experts: int = 2, + **kwargs, + ): + """ + Initialize the AriaMoELMConfig. + + Args: + moe_intermediate_size (int): The intermediate size for MoE layers. + Default is 4096. + moe_num_experts (int): The number of experts in the MoE layer. + Default is 8. + moe_topk (int): The number of top experts to route to for each + token. Default is 2. + moe_num_shared_experts (int): The number of shared experts. Default + is 2. + **kwargs: Additional keyword arguments to be passed to the parent + LlamaConfig. + """ + super().__init__(**kwargs) + self.moe_intermediate_size = moe_intermediate_size + self.moe_num_experts = moe_num_experts + self.moe_topk = moe_topk + self.moe_num_shared_experts = moe_num_shared_experts From 4396ecd68f44098211cf45503dae1ec524fa822d Mon Sep 17 00:00:00 2001 From: xffxff <1247714429@qq.com> Date: Fri, 22 Nov 2024 05:45:56 +0000 Subject: [PATCH 02/14] support prompts with different number of images Signed-off-by: xffxff <1247714429@qq.com> --- vllm/model_executor/models/aria.py | 68 +++++++++++------------------- 1 file changed, 24 insertions(+), 44 deletions(-) diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index 106d1cad318f9..c8ef55da60d11 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -545,44 +545,6 @@ def input_mapper_for_aria(ctx, data): return MultiModalInputs(data) -def repeat_image_tokens(token_ids: list, image_token_id: int, - repeat_times: list) -> list: - """ - Repeats the image token in the token_ids list according to the repeat_times - list. - - Args: - token_ids (list): List of token IDs. - image_token_id (int): The token ID that represents an image. - repeat_times (list): List of integers specifying how many times to - repeat the image token. - - Returns: - list: A new list with the image token repeated as specified. - - Example: - token_ids = [1, 2, 3, 4, 3, 5] - image_token_id = 3 - repeat_times = [2, 3] - result = repeat_image_tokens(token_ids, image_token_id, repeat_times) - # result will be [1, 2, 3, 3, 4, 3, 3, 3, 5] - """ - if len(repeat_times) != token_ids.count(image_token_id): - raise ValueError( - "The length of repeat_times is not equal to the number of images.") - - result = [] - repeat_iter = iter(repeat_times) - - for x in token_ids: - if x == image_token_id: - result.extend([image_token_id] * next(repeat_iter)) - else: - result.append(x) - - return result - - def input_processor(ctx, llm_inputs): multi_modal_data = llm_inputs.get("multi_modal_data") # if it is pure text input, use it as is @@ -613,9 +575,14 @@ def input_processor(ctx, llm_inputs): num_crops = image_inputs.pop("num_crops") prompt_token_ids = llm_inputs["prompt_token_ids"] - prompt_token_ids = repeat_image_tokens(prompt_token_ids, - hf_config.image_token_index, - num_crops) + if num_crops.sum().item() > 0: + _, prompt_token_ids, _ = repeat_and_pad_placeholder_tokens( + tokenizer, + None, + prompt_token_ids, + placeholder_token_id=hf_config.image_token_index, + repeat_count=num_crops, + ) repeat_count = [hf_config.image_size2tokens[max_image_size] ] * sum(num_crops).item() @@ -698,9 +665,22 @@ def forward( # 2. Merge text and images if pixel_values is not None: - pixel_values = pixel_values.view(-1, *pixel_values.shape[-3:]).to( - torch.bfloat16) - pixel_mask = pixel_mask.view(-1, *pixel_mask.shape[-2:]) + if isinstance(pixel_values, torch.Tensor): + pixel_values = pixel_values.view( + -1, *pixel_values.shape[-3:]).to(torch.bfloat16) + pixel_mask = pixel_mask.view(-1, *pixel_mask.shape[-2:]) + elif isinstance(pixel_values, list): + if not all(x.shape[-3:] == pixel_values[0].shape[-3:] + for x in pixel_values): + raise ValueError("All images must be the same size") + + pixel_values = [ + x.view(-1, *x.shape[-3:]).to(torch.bfloat16) + for x in pixel_values + ] + pixel_values = torch.cat(pixel_values, dim=0) + pixel_mask = [x.view(-1, *x.shape[-2:]) for x in pixel_mask] + pixel_mask = torch.cat(pixel_mask, dim=0) selected_image_feature, image_attn_mask = self.vision_tower( pixel_values, pixel_mask=pixel_mask, From 3eaaca478e9a26b48946c60c6899351f6ccfc792 Mon Sep 17 00:00:00 2001 From: zhou fan <1247714429@qq.com> Date: Fri, 22 Nov 2024 15:55:02 +0800 Subject: [PATCH 03/14] Update vllm/model_executor/models/aria.py Co-authored-by: Isotr0py <2037008807@qq.com> --- vllm/model_executor/models/aria.py | 33 +----------------------------- 1 file changed, 1 insertion(+), 32 deletions(-) diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index c8ef55da60d11..e31383882dfd1 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -338,41 +338,10 @@ def __init__( lora_config: Optional[LoRAConfig] = None, prefix: str = "", ) -> None: - nn.Module.__init__(self) - self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) - if rope_scaling is not None and getattr( - config, "original_max_position_embeddings", None): - rope_scaling["original_max_position_embeddings"] = ( - config.original_max_position_embeddings) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) - # Support abacusai/Smaug-72B-v0.1 with attention_bias - # Support internlm/internlm-7b with bias - attention_bias = getattr(config, "attention_bias", False) or getattr( - config, "bias", False) - self.self_attn = LlamaAttention( - config=config, - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - num_kv_heads=getattr(config, "num_key_value_heads", - config.num_attention_heads), - rope_theta=rope_theta, - rope_scaling=rope_scaling, - max_position_embeddings=max_position_embeddings, - quant_config=quant_config, - bias=attention_bias, - cache_config=cache_config, - prefix=f"{prefix}.self_attn", - ) + super().__init__(config, cache_config, quant_config, lora_config, prefix) self.mlp = MoELayer(config, quant_config=quant_config, lora_config=lora_config) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) class AriaMoELMModel(LlamaModel): From dabb3316e865310552ed855fa07b2099d54afcd9 Mon Sep 17 00:00:00 2001 From: xffxff <1247714429@qq.com> Date: Fri, 22 Nov 2024 07:52:22 +0000 Subject: [PATCH 04/14] support rhymes-ai/Aria instead of rhymes-ai/Aria-sequential_mlp Signed-off-by: xffxff <1247714429@qq.com> --- docs/source/models/supported_models.rst | 2 +- vllm/model_executor/models/aria.py | 106 ++++++++++++------------ 2 files changed, 53 insertions(+), 55 deletions(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index f02e8c232c7e3..390c53b2fbc0f 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -427,7 +427,7 @@ Text Generation * - :code:`AriaForConditionalGeneration` - Aria - T + I - - :code:`rhymes-ai/Aria-sequential_mlp` + - :code:`rhymes-ai/Aria` - - ✅︎ * - :code:`Blip2ForConditionalGeneration` diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index e31383882dfd1..5dec536af8a7f 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -41,6 +41,7 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.aria import (AriaMoELMConfig, AriaVisionConfig) +from vllm.distributed import get_tensor_model_parallel_rank class AriaVisionTransformer(Idefics2VisionTransformer): @@ -266,6 +267,35 @@ def load_weights(self, weights: Iterable[Tuple[str, return loaded_params +class AriaFusedMoE(FusedMoE): + + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, shard_id: str) -> Set[str]: + # Override the weight_loader to handle the expert weights in the Aria + # model, which are already packed with experts, and merge the gate and + # up weights for each expert. + # Note: Loading expert weights with quantization is not supported + tp_rank = get_tensor_model_parallel_rank() + if shard_id == 'w13': + # the shape of loaded_weight is (num_experts, hidden_size, 2 * moe_intermediate_size) + if self.tp_size > 1: + up, gate = loaded_weight.chunk(2, dim=-1) + up_current_rank = up.chunk(self.tp_size, dim=-1)[tp_rank] + gate_current_rank = gate.chunk(self.tp_size, dim=-1)[tp_rank] + up_and_gate = torch.cat( + [up_current_rank, gate_current_rank], dim=-1 + ).transpose(1, 2) + param.data.copy_(up_and_gate) + else: + param.data.copy_(loaded_weight.transpose(1, 2)) + elif shard_id == 'w2': + # the shape of loaded_weight is (num_experts, moe_intermediate_size, hidden_size) + if self.tp_size > 1: + down_current_rank = loaded_weight.chunk(self.tp_size, dim=1)[self.tp_rank] + param.data.copy_(down_current_rank.transpose(1, 2)) + else: + param.data.copy_(loaded_weight.transpose(1, 2)) + + class MoELayer(nn.Module): """ Mixture of Experts (MoE) Layer for the AriaMoE model. @@ -288,7 +318,7 @@ def __init__( torch.empty( (self.config.moe_num_experts, self.config.hidden_size))) - self.experts = FusedMoE( + self.experts = AriaFusedMoE( num_experts=config.moe_num_experts, top_k=config.moe_topk, hidden_size=config.hidden_size, @@ -393,8 +423,8 @@ def _make_expert_params_mapping( ] ] - # Adapted from LlamaModel.load_weights with the modification of adding the - # expert_params_mapping + # Adapted from LlamaModel.load_weights with the modification of adding + # the expert weights mapping to `stacked_params_mapping` def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ @@ -404,16 +434,9 @@ def load_weights(self, weights: Iterable[Tuple[str, (".qkv_proj", ".v_proj", "v"), (".gate_up_proj", ".gate_proj", 0), (".gate_up_proj", ".up_proj", 1), + ("experts.w13_weight", "experts.fc1.weight", 'w13'), + ("experts.w2_weight", "experts.fc2.weight", 'w2'), ] - - # Params for weights, fp8 weight scales, fp8 activation scales - # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = self._make_expert_params_mapping( - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=self.config.moe_num_experts) - params_dict = dict(self.named_parameters()) loaded_params: Set[str] = set() for name, loaded_weight in weights: @@ -436,15 +459,6 @@ def load_weights(self, weights: Iterable[Tuple[str, for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue - # We have mlp.experts.experts[0].gate_proj in the checkpoint. - # Since we handle the experts below in expert_params_mapping, we - # need to skip here BEFORE we update the name, otherwise name - # will be updated to mlp.experts.experts[0].gate_up_proj, which - # will then be updated below in expert_params_mapping for - # mlp.experts.experts[0].gate_gate_up_proj, which breaks load. - if (("mlp.experts.experts." in name) - and name not in params_dict): - continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: @@ -458,39 +472,21 @@ def load_weights(self, weights: Iterable[Tuple[str, weight_loader(param, loaded_weight, shard_id) break else: - for mapping in expert_params_mapping: - param_name, weight_name, expert_id, shard_id = mapping - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params @@ -696,6 +692,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): }, orig_to_new_suffix={ "router.weight": "router_weight", + # "experts.fc1.weight": "experts.w13_weight", + # "experts.fc2.weight": "experts.w2_weight", }, ) From a35ba852958b7cc3d0cfde70bbf4062fe6d08fcb Mon Sep 17 00:00:00 2001 From: xffxff <1247714429@qq.com> Date: Fri, 22 Nov 2024 08:03:27 +0000 Subject: [PATCH 05/14] make format happy Signed-off-by: xffxff <1247714429@qq.com> --- vllm/model_executor/models/aria.py | 53 +++++++++--------------------- 1 file changed, 15 insertions(+), 38 deletions(-) diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index 5dec536af8a7f..7a3e3b25324c2 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -7,11 +7,11 @@ from transformers import LlamaConfig from vllm.attention import AttentionMetadata -from vllm.config import CacheConfig, LoRAConfig, QuantizationConfig, VllmConfig +from vllm.config import CacheConfig, QuantizationConfig, VllmConfig +from vllm.distributed import get_tensor_model_parallel_rank from vllm.inputs import INPUT_REGISTRY, token_inputs from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.fused_moe import FusedMoE -from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, RowParallelLinear) @@ -26,8 +26,7 @@ from vllm.model_executor.models.idefics2_vision_model import ( Idefics2VisionTransformer) from vllm.model_executor.models.interfaces import SupportsMultiModal -from vllm.model_executor.models.llama import (LlamaAttention, - LlamaDecoderLayer, LlamaMLP, +from vllm.model_executor.models.llama import (LlamaDecoderLayer, LlamaMLP, LlamaModel) from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, is_pp_missing_parameter, @@ -41,7 +40,6 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.aria import (AriaMoELMConfig, AriaVisionConfig) -from vllm.distributed import get_tensor_model_parallel_rank class AriaVisionTransformer(Idefics2VisionTransformer): @@ -269,28 +267,31 @@ def load_weights(self, weights: Iterable[Tuple[str, class AriaFusedMoE(FusedMoE): - def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, shard_id: str) -> Set[str]: + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, + shard_id: str) -> Set[str]: # Override the weight_loader to handle the expert weights in the Aria # model, which are already packed with experts, and merge the gate and # up weights for each expert. # Note: Loading expert weights with quantization is not supported tp_rank = get_tensor_model_parallel_rank() if shard_id == 'w13': - # the shape of loaded_weight is (num_experts, hidden_size, 2 * moe_intermediate_size) + # the shape of loaded_weight is + # (num_experts, hidden_size, 2 * moe_intermediate_size) if self.tp_size > 1: up, gate = loaded_weight.chunk(2, dim=-1) up_current_rank = up.chunk(self.tp_size, dim=-1)[tp_rank] gate_current_rank = gate.chunk(self.tp_size, dim=-1)[tp_rank] - up_and_gate = torch.cat( - [up_current_rank, gate_current_rank], dim=-1 - ).transpose(1, 2) + up_and_gate = torch.cat([up_current_rank, gate_current_rank], + dim=-1).transpose(1, 2) param.data.copy_(up_and_gate) else: param.data.copy_(loaded_weight.transpose(1, 2)) elif shard_id == 'w2': - # the shape of loaded_weight is (num_experts, moe_intermediate_size, hidden_size) + # the shape of loaded_weight is + # (num_experts, moe_intermediate_size, hidden_size) if self.tp_size > 1: - down_current_rank = loaded_weight.chunk(self.tp_size, dim=1)[self.tp_rank] + down_current_rank = loaded_weight.chunk(self.tp_size, + dim=1)[self.tp_rank] param.data.copy_(down_current_rank.transpose(1, 2)) else: param.data.copy_(loaded_weight.transpose(1, 2)) @@ -309,7 +310,6 @@ def __init__( self, config: AriaMoELMConfig, quant_config: Optional[QuantizationConfig], - lora_config: Optional[LoRAConfig], ) -> None: super().__init__() self.config = config @@ -365,13 +365,10 @@ def __init__( config: LlamaConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, - lora_config: Optional[LoRAConfig] = None, prefix: str = "", ) -> None: - super().__init__(config, cache_config, quant_config, lora_config, prefix) - self.mlp = MoELayer(config, - quant_config=quant_config, - lora_config=lora_config) + super().__init__(config, cache_config, quant_config, prefix) + self.mlp = MoELayer(config, quant_config=quant_config) class AriaMoELMModel(LlamaModel): @@ -403,26 +400,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): prefix=f"{prefix}.layers", ) - # Adapted from FusedMoE.make_expert_params_mapping with the modification - # of changing the prefix of the weight names - def _make_expert_params_mapping( - self, ckpt_gate_proj_name: str, ckpt_down_proj_name: str, - ckpt_up_proj_name: str, - num_experts: int) -> List[Tuple[str, str, int, str]]: - - return [ - # (param_name, weight_name, expert_id, shard_id) - ("experts.w13_" if weight_name - in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_", - f"experts.experts.{expert_id}.{weight_name}.", expert_id, shard_id - ) for expert_id in range(num_experts) - for shard_id, weight_name in [ - ("w1", ckpt_gate_proj_name), - ("w2", ckpt_down_proj_name), - ("w3", ckpt_up_proj_name), - ] - ] - # Adapted from LlamaModel.load_weights with the modification of adding # the expert weights mapping to `stacked_params_mapping` def load_weights(self, weights: Iterable[Tuple[str, From 6413e28a97120b4952f1500d9255110c6cff564a Mon Sep 17 00:00:00 2001 From: xffxff <1247714429@qq.com> Date: Fri, 22 Nov 2024 09:35:01 +0000 Subject: [PATCH 06/14] do not apply tp to cross attention module of projector --- vllm/model_executor/models/aria.py | 56 +++++++----------------------- 1 file changed, 12 insertions(+), 44 deletions(-) diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index 7a3e3b25324c2..c5649f74a8e52 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -8,7 +8,7 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, QuantizationConfig, VllmConfig -from vllm.distributed import get_tensor_model_parallel_rank +from vllm.distributed import get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size from vllm.inputs import INPUT_REGISTRY, token_inputs from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.fused_moe import FusedMoE @@ -135,13 +135,12 @@ class CrossAttention(nn.Module): def __init__(self, kv_dim, embed_dim, num_heads, drop_out_rate=0): super().__init__() self.num_heads = num_heads - self.q_proj = ColumnParallelLinear(embed_dim, embed_dim, bias=False) - self.kv_proj = MergedColumnParallelLinear(kv_dim, - [embed_dim, embed_dim], - bias=False) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False) + self.k_proj = nn.Linear(kv_dim, embed_dim, bias=False) + self.v_proj = nn.Linear(kv_dim, embed_dim, bias=False) self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) - self.linear = RowParallelLinear(embed_dim, embed_dim) + self.linear = nn.Linear(embed_dim, embed_dim) self.dropout = nn.Dropout(drop_out_rate) self.layer_norm = nn.LayerNorm(embed_dim) @@ -149,24 +148,20 @@ def __init__(self, kv_dim, embed_dim, num_heads, drop_out_rate=0): def forward(self, x, hidden_states, attn_mask=None, add_residual=False): normed_hidden_states = self.layer_norm(hidden_states) - query = self.q_proj(normed_hidden_states)[0].permute(1, 0, 2) + query = self.q_proj(normed_hidden_states).permute(1, 0, 2) x = self.ln_kv(x) - key_value = self.kv_proj(x)[0].permute(1, 0, 2) - key, value = key_value.chunk(2, dim=-1) + key = self.k_proj(x).permute(1, 0, 2) + value = self.v_proj(x).permute(1, 0, 2) - attn_output, _ = self.multihead_attn(query, - key, - value, - attn_mask=attn_mask) + attn_output, _ = self.multihead_attn(query, key, value, attn_mask=attn_mask) attn_output = attn_output.permute(1, 0, 2) if add_residual: - attn_output = hidden_states + self.dropout( - self.linear(attn_output)[0]) + attn_output = hidden_states + self.dropout(self.linear(attn_output)) else: - attn_output = self.dropout(self.linear(attn_output)[0]) + attn_output = self.dropout(self.linear(attn_output)) return attn_output @@ -237,33 +232,6 @@ def forward(self, x, attn_mask=None): return out - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - (".kv_proj", ".k_proj", 0), - (".kv_proj", ".v_proj", 1), - ] - params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() - for name, loaded_weight in weights: - for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - - loaded_params.add(name) - return loaded_params - class AriaFusedMoE(FusedMoE): @@ -291,7 +259,7 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, # (num_experts, moe_intermediate_size, hidden_size) if self.tp_size > 1: down_current_rank = loaded_weight.chunk(self.tp_size, - dim=1)[self.tp_rank] + dim=1)[tp_rank] param.data.copy_(down_current_rank.transpose(1, 2)) else: param.data.copy_(loaded_weight.transpose(1, 2)) From d4d62daabdd2b815c0784eb7840ccc982a8c9c85 Mon Sep 17 00:00:00 2001 From: xffxff <1247714429@qq.com> Date: Fri, 22 Nov 2024 09:36:15 +0000 Subject: [PATCH 07/14] fix results when tp is enabled --- vllm/model_executor/models/aria.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index c5649f74a8e52..ea45948d69fa4 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -292,6 +292,7 @@ def __init__( hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, quant_config=quant_config, + reduce_results=True, ) self.shared_experts = LlamaMLP( config.hidden_size, From 292688acfec81176d1852f553ebc4e49964ddfe9 Mon Sep 17 00:00:00 2001 From: xffxff <1247714429@qq.com> Date: Fri, 22 Nov 2024 09:38:15 +0000 Subject: [PATCH 08/14] make format happy Signed-off-by: xffxff <1247714429@qq.com> --- vllm/model_executor/models/aria.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index ea45948d69fa4..745601afed7bf 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -8,12 +8,11 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, QuantizationConfig, VllmConfig -from vllm.distributed import get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size +from vllm.distributed import get_tensor_model_parallel_rank from vllm.inputs import INPUT_REGISTRY, token_inputs from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( @@ -154,12 +153,16 @@ def forward(self, x, hidden_states, attn_mask=None, add_residual=False): key = self.k_proj(x).permute(1, 0, 2) value = self.v_proj(x).permute(1, 0, 2) - attn_output, _ = self.multihead_attn(query, key, value, attn_mask=attn_mask) + attn_output, _ = self.multihead_attn(query, + key, + value, + attn_mask=attn_mask) attn_output = attn_output.permute(1, 0, 2) if add_residual: - attn_output = hidden_states + self.dropout(self.linear(attn_output)) + attn_output = hidden_states + self.dropout( + self.linear(attn_output)) else: attn_output = self.dropout(self.linear(attn_output)) From 1bd6d463d780c2b9295d8ce5d8a80bf8c5bb733c Mon Sep 17 00:00:00 2001 From: xffxff <1247714429@qq.com> Date: Fri, 22 Nov 2024 09:52:30 +0000 Subject: [PATCH 09/14] remove unused code Signed-off-by: xffxff <1247714429@qq.com> --- vllm/model_executor/models/aria.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index 745601afed7bf..878bd3eed6038 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -641,8 +641,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): }, orig_to_new_suffix={ "router.weight": "router_weight", - # "experts.fc1.weight": "experts.w13_weight", - # "experts.fc2.weight": "experts.w2_weight", }, ) From afbabd313a81b80e84661e842f555b8902be98e6 Mon Sep 17 00:00:00 2001 From: xffxff <1247714429@qq.com> Date: Fri, 22 Nov 2024 10:55:20 +0000 Subject: [PATCH 10/14] refactor based on reviewer's feedback Signed-off-by: xffxff <1247714429@qq.com> --- vllm/model_executor/models/aria.py | 122 ++++++++++++++++++++--------- 1 file changed, 84 insertions(+), 38 deletions(-) diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index 878bd3eed6038..3961579f21d05 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -1,5 +1,5 @@ import math -from typing import Iterable, List, Optional, Set, Tuple +from typing import Iterable, List, Optional, Set, Tuple, TypedDict, Union import torch import torch.nn as nn @@ -40,6 +40,18 @@ from vllm.transformers_utils.configs.aria import (AriaMoELMConfig, AriaVisionConfig) +from .utils import flatten_bn + + +class AriaImagePixelInputs(TypedDict): + pixel_values: torch.Tensor + pixel_mask: Optional[torch.Tensor] + """ + Shape: + pixel_values: `(batch_size * num_images, num_channels, height, width)` + pixel_mask: `(batch_size * num_images, height, width)` + """ + class AriaVisionTransformer(Idefics2VisionTransformer): """ @@ -486,6 +498,8 @@ def input_processor(ctx, llm_inputs): max_image_size=max_image_size, split_image=_split_image, return_tensors="pt").data + image_inputs['pixel_values'] = image_inputs['pixel_values'].to( + ctx.model_config.dtype) num_crops = image_inputs.pop("num_crops") prompt_token_ids = llm_inputs["prompt_token_ids"] @@ -563,6 +577,65 @@ def __init__( self.vocab_size, logit_scale) self.sampler = Sampler() + def _validate_image_sizes( + self, images: List[torch.Tensor]) -> List[torch.Tensor]: + if not all(img.shape == images[0].shape for img in images): + raise ValueError("All images must be the same size") + return images + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[AriaImagePixelInputs]: + pixel_values = kwargs.pop("pixel_values", None) + pixel_mask = kwargs.pop("pixel_mask", None) + + if pixel_values is None: + return None + + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + pixel_values = self._validate_image_sizes(pixel_values) + pixel_values = flatten_bn(pixel_values, concat=True) + if pixel_mask is not None: + pixel_mask = flatten_bn(pixel_mask, concat=True) + + return AriaImagePixelInputs( + pixel_values=pixel_values, + pixel_mask=pixel_mask, + ) + + def _process_image_input( + self, image_input: AriaImagePixelInputs + ) -> Tuple[torch.Tensor, torch.Tensor]: + assert self.vision_tower is not None + + pixel_values = image_input['pixel_values'] + pixel_mask = image_input['pixel_mask'] + + image_feature, image_attn_mask = self.vision_tower( + pixel_values, pixel_mask=pixel_mask) + return self.multi_modal_projector(image_feature, image_attn_mask) + + def process_mm_inputs(self, **kwargs): + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + vision_embeddings: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if vision_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, vision_embeddings, + self.config.image_token_index) + return inputs_embeds + def forward( self, input_ids: torch.Tensor, @@ -570,50 +643,23 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, - ): - # 1. Extra the input embeddings - inputs_embeds = self.language_model.get_input_embeddings(input_ids) - pixel_values = kwargs.get("pixel_values", None) - pixel_mask = kwargs.get("pixel_mask", None) - - # 2. Merge text and images - if pixel_values is not None: - if isinstance(pixel_values, torch.Tensor): - pixel_values = pixel_values.view( - -1, *pixel_values.shape[-3:]).to(torch.bfloat16) - pixel_mask = pixel_mask.view(-1, *pixel_mask.shape[-2:]) - elif isinstance(pixel_values, list): - if not all(x.shape[-3:] == pixel_values[0].shape[-3:] - for x in pixel_values): - raise ValueError("All images must be the same size") - - pixel_values = [ - x.view(-1, *x.shape[-3:]).to(torch.bfloat16) - for x in pixel_values - ] - pixel_values = torch.cat(pixel_values, dim=0) - pixel_mask = [x.view(-1, *x.shape[-2:]) for x in pixel_mask] - pixel_mask = torch.cat(pixel_mask, dim=0) - selected_image_feature, image_attn_mask = self.vision_tower( - pixel_values, - pixel_mask=pixel_mask, - ) - - image_features = self.multi_modal_projector( - selected_image_feature, attn_mask=image_attn_mask) - - inputs_embeds = inputs_embeds.to(image_features.dtype) - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, image_features, - self.config.image_token_index) + ) -> Union[torch.Tensor, IntermediateTensors]: + if inputs_embeds is None: + vision_embeddings = self.process_mm_inputs(**kwargs) + # always pass the input via `inputs_embeds` + # to make sure the computation graph is consistent + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None hidden_states = self.language_model( input_ids, positions, kv_caches, attn_metadata, - None, + intermediate_tensors, inputs_embeds=inputs_embeds, ) From a4ffaba19aff2c4646576f7263f56b5f5d5206bd Mon Sep 17 00:00:00 2001 From: xffxff <1247714429@qq.com> Date: Mon, 25 Nov 2024 04:54:19 +0000 Subject: [PATCH 11/14] add examples Signed-off-by: xffxff <1247714429@qq.com> --- examples/offline_inference_vision_language.py | 18 +++++++++++++++++ ...e_inference_vision_language_multi_image.py | 20 +++++++++++++++++++ vllm/entrypoints/chat_utils.py | 2 ++ 3 files changed, 40 insertions(+) diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 11af6880e1b5a..f08f22eec164a 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -402,6 +402,23 @@ def run_idefics3(question: str, modality: str): return llm, prompt, stop_token_ids +# Aria +def run_aria(question: str, modality: str): + assert modality == "image" + model_name = "rhymes-ai/Aria" + + llm = LLM(model=model_name, + tokenizer_mode="slow", + trust_remote_code=True, + dtype="bfloat16") + + prompt = (f"<|im_start|>user\n<|img|>\n{question}" + "<|im_end|>\n<|im_start|>assistant\n") + + stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519] + return llm, prompt, stop_token_ids + + model_example_map = { "llava": run_llava, "llava-next": run_llava_next, @@ -423,6 +440,7 @@ def run_idefics3(question: str, modality: str): "molmo": run_molmo, "glm4v": run_glm4v, "idefics3": run_idefics3, + "aria": run_aria, } diff --git a/examples/offline_inference_vision_language_multi_image.py b/examples/offline_inference_vision_language_multi_image.py index dc12df8d78211..788b604cfd4a0 100644 --- a/examples/offline_inference_vision_language_multi_image.py +++ b/examples/offline_inference_vision_language_multi_image.py @@ -321,6 +321,25 @@ def load_idefics3(question, image_urls: List[str]) -> ModelRequestData: ) +def load_aria(question, image_urls: List[str]) -> ModelRequestData: + model_name = "rhymes-ai/Aria" + llm = LLM(model=model_name, + tokenizer_mode="slow", + trust_remote_code=True, + dtype="bfloat16", + limit_mm_per_prompt={"image": len(image_urls)}) + placeholders = "<|img|>\n" * len(image_urls) + prompt = (f"<|im_start|>user\n{placeholders}{question}<|im_end|>\n" + "<|im_start|>assistant\n") + stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519] + return ModelRequestData( + llm=llm, + prompt=prompt, + stop_token_ids=stop_token_ids, + image_data=[fetch_image(url) for url in image_urls], + chat_template=None) + + model_example_map = { "phi3_v": load_phi3v, "h2ovl_chat": load_h2onvl, @@ -330,6 +349,7 @@ def load_idefics3(question, image_urls: List[str]) -> ModelRequestData: "qwen_vl_chat": load_qwenvl_chat, "mllama": load_mllama, "idefics3": load_idefics3, + "aria": load_aria, } diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index abee5ac46391c..c2054dcbfce0e 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -412,6 +412,8 @@ def _placeholder_str(self, modality: ModalityStr, return "" if model_type == "idefics3": return "" + if model_type == "aria": + return "<|fim_prefix|><|img|><|fim_suffix|>" raise TypeError(f"Unknown {modality} model type: {model_type}") elif modality == "audio": From 9525e798459cf469c1886745fca24deeeb142742 Mon Sep 17 00:00:00 2001 From: xffxff <1247714429@qq.com> Date: Mon, 25 Nov 2024 05:04:13 +0000 Subject: [PATCH 12/14] refactor: follow the new interface Signed-off-by: xffxff <1247714429@qq.com> --- vllm/model_executor/models/aria.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index 3961579f21d05..0356435e9c257 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -34,6 +34,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.image import cached_get_image_processor +from vllm.multimodal.inputs import NestedTensors from vllm.multimodal.utils import (cached_get_tokenizer, repeat_and_pad_placeholder_tokens) from vllm.sequence import IntermediateTensors @@ -617,22 +618,22 @@ def _process_image_input( pixel_values, pixel_mask=pixel_mask) return self.multi_modal_projector(image_feature, image_attn_mask) - def process_mm_inputs(self, **kwargs): + def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return None - vision_embeddings = self._process_image_input(image_input) - return vision_embeddings + multimodal_embeddings = self._process_image_input(image_input) + return multimodal_embeddings def get_input_embeddings( self, input_ids: torch.Tensor, - vision_embeddings: Optional[torch.Tensor] = None, + multimodal_embeddings: Optional[NestedTensors] = None, ) -> torch.Tensor: inputs_embeds = self.language_model.get_input_embeddings(input_ids) - if vision_embeddings is not None: + if multimodal_embeddings is not None: inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, vision_embeddings, + input_ids, inputs_embeds, multimodal_embeddings, self.config.image_token_index) return inputs_embeds @@ -647,11 +648,11 @@ def forward( **kwargs: object, ) -> Union[torch.Tensor, IntermediateTensors]: if inputs_embeds is None: - vision_embeddings = self.process_mm_inputs(**kwargs) + multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) # always pass the input via `inputs_embeds` # to make sure the computation graph is consistent inputs_embeds = self.get_input_embeddings(input_ids, - vision_embeddings) + multimodal_embeddings) input_ids = None hidden_states = self.language_model( From 8908a662bc2e23406187be3ecd4e46f47d1f102b Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Mon, 25 Nov 2024 23:30:12 +0800 Subject: [PATCH 13/14] add aria registry test Signed-off-by: Isotr0py <2037008807@qq.com> --- tests/models/registry.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/models/registry.py b/tests/models/registry.py index 3848367b6126c..456e3ae1803f8 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -43,6 +43,8 @@ class _HfExamplesInfo: trust_remote_code=True), "ArcticForCausalLM": _HfExamplesInfo("Snowflake/snowflake-arctic-instruct", trust_remote_code=True), + "AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria", + trust_remote_code=True), "BaiChuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan-7B", trust_remote_code=True), "BaichuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan2-7B-chat", From 7e8e37f736e690a4a71ce60c469a8f57acf250fc Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Mon, 25 Nov 2024 23:48:21 +0800 Subject: [PATCH 14/14] Fix indent --- tests/models/registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/registry.py b/tests/models/registry.py index 456e3ae1803f8..ab3ed0b523a23 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -44,7 +44,7 @@ class _HfExamplesInfo: "ArcticForCausalLM": _HfExamplesInfo("Snowflake/snowflake-arctic-instruct", trust_remote_code=True), "AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria", - trust_remote_code=True), + trust_remote_code=True), "BaiChuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan-7B", trust_remote_code=True), "BaichuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan2-7B-chat",