diff --git a/nexa/transformers/README.md b/nexa/transformers/README.md new file mode 100644 index 00000000..c539b454 --- /dev/null +++ b/nexa/transformers/README.md @@ -0,0 +1,8 @@ +# transformers support for Nexa AI models + +``` +python run_omnivision.py +``` + +## Acknowledgements +We thank the [Hugging Face Transformers](https://github.com/huggingface/transformers) for their amazing work on the Transformers library. diff --git a/nexa/transformers/__init__.py b/nexa/transformers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/nexa/transformers/omnivision/__init__.py b/nexa/transformers/omnivision/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/nexa/transformers/omnivision/configuration.py b/nexa/transformers/omnivision/configuration.py new file mode 100644 index 00000000..d356a315 --- /dev/null +++ b/nexa/transformers/omnivision/configuration.py @@ -0,0 +1,130 @@ +# Copyright (c) 2024 Nexa AI Inc., Alibaba Group (Qwen team), and HuggingFace Inc. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Qwen2 model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging +from typing import Union +from transformers import PretrainedConfig +import os +from transformers.models.auto import CONFIG_MAPPING + +logger = logging.get_logger(__name__) + + +class SigLipVisionConfig(PretrainedConfig): + model_type = "siglip_vision_model" + def __init__( + self, + hidden_size=1152, + image_mean=(0.5, 0.5, 0.5), + intermediate_size=4304, + num_hidden_layers=27, + num_attention_heads=16, + num_channels=3, + image_size=384, + patch_size=14, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.image_mean = image_mean + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the vision config dict if we are loading from SigLipConfig + if config_dict.get("model_type") == "siglip": + config_dict = config_dict["vision_config"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + return cls.from_dict(config_dict, **kwargs) + + +""" Nexa AI model configuration""" +class OminiVLMConfig(PretrainedConfig): + model_type = "nano-omini-vlm" + + model_type = "omini_vlm" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vision_config=None, + text_config=None, + hidden_size=4096, + mm_hidden_size=1152, + mm_projector_lr=None, + mm_projector_type="mlp2x_gelu", + image_token_index=151655, + initializer_range=0.02, + **kwargs, + ): + self.hidden_size = hidden_size + self.mm_hidden_size = mm_hidden_size + self.mm_projector_lr = mm_projector_lr + self.mm_projector_type = mm_projector_type + self.image_token_index = image_token_index + self.initializer_range = initializer_range + if isinstance(vision_config, dict): + vision_config = SigLipVisionConfig(**vision_config) + elif vision_config is None: + vision_config = SigLipVisionConfig( + hidden_size=1152, + image_mean=(0.5, 0.5, 0.5), + intermediate_size=4304, + num_hidden_layers=27, + num_attention_heads=16, + num_channels=3, + image_size=384, + patch_size=14, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + ) + self.vision_config = vision_config + + if isinstance(text_config, dict): + text_config["model_type"] = ( + text_config["model_type"] if "model_type" in text_config else "qwen2" + ) + text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config) + elif text_config is None: + text_config = CONFIG_MAPPING["qwen2"]() + + self.text_config = text_config + + super().__init__(**kwargs) + \ No newline at end of file diff --git a/nexa/transformers/omnivision/modeling.py b/nexa/transformers/omnivision/modeling.py new file mode 100644 index 00000000..94bd67b1 --- /dev/null +++ b/nexa/transformers/omnivision/modeling.py @@ -0,0 +1,709 @@ +# Copyright (c) 2024 Nexa AI Inc., Alibaba Group (Qwen team), and HuggingFace Inc. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union +from dataclasses import dataclass +from transformers.activations import ACT2FN +import torch.utils.checkpoint +from torch import nn +from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ModelOutput +from transformers import Qwen2ForCausalLM +from .configuration import SigLipVisionConfig, OminiVLMConfig + +# ======================================================================================== # +# vision tower # +# ======================================================================================== # +@dataclass +class SigLipVisionModelOutput(ModelOutput): + """ + Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. + + Args: + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class OminiVLMOutputWithPast(ModelOutput): + """ + Base class for Gemma2Audio causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + attention_mask (`torch.FloatTensor`, *optional*): + Attentions mask, used to update attention mask and position_ids. + """ + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + attention_mask: Optional[torch.FloatTensor] = None + + +class SigLipVisionEmbeddings(nn.Module): + def __init__(self, config: SigLipVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] + embeddings = patch_embeds.flatten(2).transpose(1, 2) + + embeddings = embeddings + self.position_embedding(self.position_ids) + return embeddings + + +class SigLipAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim ** -0.5 + self.dropout = config.attention_dropout + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + batch_size, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + k_v_seq_len = key_states.shape[-2] + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale + + if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len): + raise ValueError( + f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): + raise ValueError( + f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +class SigLipMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class SigLipEncoderLayer(nn.Module): + def __init__(self, config: SigLipVisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = SigLipAttention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = SigLipMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + attention_mask (`torch.FloatTensor`): + Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class SigLipPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SigLipVisionConfig + base_model_prefix = "siglip" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + pass + + +class SigLipEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`SigLipEncoderLayer`]. + + Args: + config: SigLipVisionConfig + """ + + def __init__(self, config: SigLipVisionConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([SigLipEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class SigLipMultiheadAttentionPoolingHead(nn.Module): + """Multihead Attention Pooling.""" + + def __init__(self, config: SigLipVisionConfig): + super().__init__() + + self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = SigLipMLP(config) + + def forward(self, hidden_state): + batch_size = hidden_state.shape[0] + probe = self.probe.repeat(batch_size, 1, 1) + + hidden_state = self.attention(probe, hidden_state, hidden_state)[0] + + residual = hidden_state + hidden_state = self.layernorm(hidden_state) + hidden_state = residual + self.mlp(hidden_state) + + return hidden_state[:, 0] + + +class SigLipVisionTransformer(nn.Module): + def __init__(self, config: SigLipVisionConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = SigLipVisionEmbeddings(config) + self.encoder = SigLipEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.head = SigLipMultiheadAttentionPoolingHead(config) + + def get_dtype(self) -> torch.dtype: + return self.encoder.layers[0].mlp.fc2.weight.dtype + + def forward( + self, + pixel_values, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + hidden_states = self.embeddings(pixel_values) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.post_layernorm(last_hidden_state) + + pooled_output = self.head(last_hidden_state) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class SigLipVisionModel(SigLipPreTrainedModel): + config_class = SigLipVisionConfig + main_input_name = "pixel_values" + _no_split_modules = ["SigLipEncoderLayer"] + + def __init__(self, config: SigLipVisionConfig): + super().__init__(config) + self.vision_model = SigLipVisionTransformer(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + def forward( + self, + pixel_values, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + return self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +# ======================================================================================== # +# Projector # +# ======================================================================================== # + +import re +def build_vision_projector(config, delay_load=False, **kwargs): + projector_type = getattr(config, 'mm_projector_type', 'mlp2x_gelu') + mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) + if mlp_gelu_match: + mlp_depth = int(mlp_gelu_match.group(1)) + modules = [nn.Linear(config.mm_hidden_size*9, config.hidden_size)] + for _ in range(1, mlp_depth): + modules.append(nn.GELU()) + modules.append(nn.Linear(config.hidden_size, config.text_config.hidden_size)) + return nn.Sequential(*modules) + + +# ======================================================================================== # +# LLM # +# ======================================================================================== # +class OminiVLMPreTrainedModel(PreTrainedModel): + config_class = OminiVLMConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Qwen2DecoderLayer", "SigLipEncoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv3d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class OminiVLMForConditionalGeneration(OminiVLMPreTrainedModel): + def __init__(self, config: OminiVLMConfig): + super().__init__(config) + if isinstance(config.vision_config, dict): + vision_config = SigLipVisionConfig(**config.vision_config) + else: + vision_config = config.vision_config + self.vision_tower = SigLipVisionModel(vision_config) + self.multi_modal_projector = build_vision_projector(config) + self.vocab_size = config.text_config.vocab_size + self.language_model = Qwen2ForCausalLM( + config.text_config, + ) + self.pad_token_id = ( + self.config.pad_token_id if self.config.pad_token_id is not None else -1 + ) + self._padding_side = "right" # set it to left by default, user can use setter to change padding_sides + self.post_init() + + @property + def padding_side(self): + return self._padding_side + + @padding_side.setter + def padding_side(self, padding_side: str): + if padding_side not in ["left", "right"]: + raise ValueError(f"{padding_side} is not `left` or `right`.") + self._padding_side = padding_side + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.language_model.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + self.language_model.set_output_embeddings(new_embeddings) + + def set_decoder(self, decoder): + self.language_model.set_decoder(decoder) + + def get_decoder(self): + return self.language_model.get_decoder() + + def tie_weights(self): + return self.language_model.tie_weights() + + def resize_token_embeddings( + self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None + ) -> nn.Embedding: + model_embeds = self.language_model.resize_token_embeddings( + new_num_tokens, pad_to_multiple_of + ) + # update vocab size + self.config.text_config.vocab_size = model_embeds.num_embeddings + self.vocab_size = model_embeds.num_embeddings + return model_embeds + + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + num_new_tokens: int = 1, + ) -> Dict[str, Any]: + model_kwargs = super()._update_model_kwargs_for_generation( + outputs=outputs, + model_kwargs=model_kwargs, + is_encoder_decoder=is_encoder_decoder, + num_new_tokens=num_new_tokens, + ) + return model_kwargs + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + ) -> Union[Tuple, OminiVLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + ```""" + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + target_device = self.vision_tower.device + + if pixel_values is not None: + pixel_values = pixel_values.to(target_device) + + if inputs_embeds is None: + # 1. Extract the input embeddings + inputs_embeds = self.get_input_embeddings()(input_ids) + + # 2. Merge text and vision features + if pixel_values is not None: + pixel_values = pixel_values.type(self.vision_tower.vision_model.get_dtype()) + image_embeds = self.vision_tower(pixel_values).last_hidden_state.to(pixel_values.dtype) + image_embeds = image_embeds.view(image_embeds.shape[0], 81, -1) + image_embeds = self.multi_modal_projector(image_embeds) + image_mask = ( + (input_ids == self.config.image_token_index) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + + ## This is to intelligently replace the image tokens with the image features + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits = outputs[0] + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + if attention_mask is not None: + shift_attention_mask = attention_mask[..., 1:] + shift_logits = logits[..., :-1, :][ + shift_attention_mask.to(logits.device) != 0 + ].contiguous() + shift_labels = labels[..., 1:][ + shift_attention_mask.to(labels.device) != 0 + ].contiguous() + else: + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), + shift_labels.view(-1).to(shift_logits.device), + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return OminiVLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + attention_mask=attention_mask, + ) \ No newline at end of file diff --git a/nexa/transformers/omnivision/processing.py b/nexa/transformers/omnivision/processing.py new file mode 100644 index 00000000..2bc3f008 --- /dev/null +++ b/nexa/transformers/omnivision/processing.py @@ -0,0 +1,201 @@ +# Copyright (c) 2024 Nexa AI Inc., Alibaba Group (Qwen team), and HuggingFace Inc. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Union + +try: + from typing import Unpack +except ImportError: + from typing_extensions import Unpack + +from transformers.feature_extraction_utils import BatchFeature +from transformers.image_utils import ImageInput, VideoInput +from transformers.processing_utils import ( + ProcessingKwargs, + ProcessorMixin, +) +from transformers.tokenization_utils_base import PreTokenizedInput, TextInput +from transformers.utils import logging + + +logger = logging.get_logger(__name__) +NUM_IMAGE_TOKENS = 81 + +class NanoVLMProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "padding": False, + }, + } + + +class NanoVLMProcessor(ProcessorMixin): + attributes = ["image_processor", "tokenizer"] + valid_kwargs = ["chat_template"] + image_processor_class = "SiglipImageProcessor" + tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") + + def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs): + if chat_template is None: + chat_template = self.default_chat_template + super().__init__(image_processor, tokenizer, chat_template=chat_template) + + def __call__( + self, + images: ImageInput = None, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + **kwargs: Unpack[NanoVLMProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to Gemma2TokenizerFast's [`~Gemma2TokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to + Gemma2VLImageProcessor's [`~Gemma2VLImageProcessor.__call__`] if `vision_infos` is not `None`. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + """ + output_kwargs = self._merge_kwargs( + NanoVLMProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + # check the number of images is equal to the number of all image_pad tokens + assert len(images) == sum([t.count("<|image_pad|>") for t in text]), "The number of images must be equal to the number of all image_pad tokens in the text." + + if images is not None: + image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"]) + else: + image_inputs = {} + + if not isinstance(text, list): + text = [text] + + if image_inputs is not None: + index = 0 + for i in range(len(text)): + while "<|image_pad|>" in text[i]: + text[i] = text[i].replace( + "<|image_pad|>", "<|placeholder|>" * NUM_IMAGE_TOKENS, 1 + ) + index += 1 + text[i] = text[i].replace("<|placeholder|>", "<|image_pad|>") + + _ = output_kwargs["text_kwargs"].pop("padding_side", None) + text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) + + return BatchFeature(data={**text_inputs, **image_inputs}) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Gemma2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Gemma2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + + + @property + def default_chat_template(self): + return ( + "{%- if tools %}" + "{{- '<|im_start|>system\n' }}" + "{%- if messages[0]['role'] == 'system' %}" + "{{- messages[0]['content'] }}" + "{%- else %}" + "{{- 'You are Nano-Omni-VLM, created by Nexa AI. You are a helpful assistant.' }}" + "{%- endif %}" + "{{- \"\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n\" }}" + "{%- for tool in tools %}" + "{{- \"\n\" }}" + "{{- tool | tojson }}" + "{%- endfor %}" + "{{- \"\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\\\"name\\\": , \\\"arguments\\\": }\n<|im_end|>\n\" }}" + "{%- else %}" + "{%- if messages[0]['role'] == 'system' %}" + "{{- '<|im_start|>system\n' + messages[0]['content'] + '<|im_end|>\n' }}" + "{%- else %}" + "{{- '<|im_start|>system\nYou are Nano-Omni-VLM, created by Nexa AI. You are a helpful assistant.<|im_end|>\n' }}" + "{%- endif %}" + "{%- endif %}" + "{%- for message in messages %}" + "{%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}" + "{{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}" + "{%- elif message.role == \"assistant\" %}" + "{{- '<|im_start|>' + message.role }}" + "{%- if message.content %}" + "{{- '\n' + message.content }}" + "{%- endif %}" + "{%- for tool_call in message.tool_calls %}" + "{%- if tool_call.function is defined %}" + "{%- set tool_call = tool_call.function %}" + "{%- endif %}" + "{{- '\n\n{\"name\": \"' }}" + "{{- tool_call.name }}" + "{{- '\", \"arguments\": ' }}" + "{{- tool_call.arguments | tojson }}" + "{{- '}\n' }}" + "{%- endfor %}" + "{{- '<|im_end|>\n' }}" + "{%- elif message.role == \"tool\" %}" + "{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}" + "{{- '<|im_start|>user' }}" + "{%- endif %}" + "{{- '\n\n' }}" + "{{- message.content }}" + "{{- '\n' }}" + "{%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}" + "{{- '<|im_end|>\n' }}" + "{%- endif %}" + "{%- endif %}" + "{%- endfor %}" + "{%- if add_generation_prompt %}" + "{{- '<|im_start|>assistant\n' }}" + "{%- endif %}" + ) \ No newline at end of file diff --git a/nexa/transformers/run_omnivision.py b/nexa/transformers/run_omnivision.py new file mode 100644 index 00000000..f81d1efe --- /dev/null +++ b/nexa/transformers/run_omnivision.py @@ -0,0 +1,92 @@ +from nexa.transformers.omnivision.processing import NanoVLMProcessor +from nexa.transformers.omnivision.modeling import OminiVLMForConditionalGeneration +import argparse +import torch + + +model_name = "NexaAIDev/omnivlm-dpo" +image_url = "https://public-storage.nexa4ai.com/public-images/cat.png" + + +def get_device(): + if torch.cuda.is_available(): + return "cuda" + elif torch.backends.mps.is_available(): + return "mps" + return "cpu" + + +def load_model_and_processor(model_path): + device = get_device() + proc_path = "nexa-collaboration/nano-vlm-processor" + processor = NanoVLMProcessor.from_pretrained(proc_path) + processor.tokenizer.pad_token = processor.tokenizer.eos_token + processor.tokenizer.padding_side = "right" + + model_kwargs = {} + # Adjust dtype based on device + dtype = torch.bfloat16 if device == "cuda" else torch.float32 + local_model = OminiVLMForConditionalGeneration.from_pretrained( + model_path, + torch_dtype=dtype, + **model_kwargs + ) + local_model = local_model.to(device) + return local_model, processor + + +def process_single_image(processor, image_path, input_prompt=None): + text = f"<|im_start|>system\nYou are Nano-Omni-VLM, created by Nexa AI. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n{input_prompt}\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>" + # Changed from Image.open() to handle URLs + if image_path.startswith('http'): + from PIL import Image + import requests + from io import BytesIO + response = requests.get(image_path) + image = Image.open(BytesIO(response.content)).convert('RGB') + else: + image = Image.open(image_path).convert('RGB') + inputs = processor( + text=[text], + images=[image], + padding=True, + return_tensors="pt", + ) + return inputs.to(get_device()) + + +def generate_output(model, processor, inputs, max_tokens): + cur_ids = inputs['input_ids'] + cur_attention_mask = inputs['attention_mask'] + input_token_length = cur_ids.shape[-1] + for _ in range(max_tokens): + out = model( + cur_ids, + attention_mask=cur_attention_mask, + pixel_values=inputs['pixel_values'], + use_cache=False + ) + next_token = out.logits[:, -1].argmax() + next_word = processor.decode(next_token) + cur_ids = torch.cat([cur_ids, next_token.unsqueeze(0).unsqueeze(0)], dim=-1) + cur_attention_mask = torch.cat([cur_attention_mask, torch.ones_like(next_token).unsqueeze(0).unsqueeze(0)], dim=-1) + if next_word in ("<|im_end|>"): + break + return processor.batch_decode(cur_ids[:, input_token_length:])[0] + +def main(args): + model, processor = load_model_and_processor(args.model_path) + inputs = process_single_image(processor, args.image_path, args.input_prompt) + output = generate_output(model, processor, inputs, args.max_tokens) + print("=== Inference Result ===\n", output) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Inference script for Nano-Omni-VLM") + parser.add_argument("--model_path", default=model_name, help="Path to the model checkpoint") + # Add image_path argument + parser.add_argument("--image_path", default=image_url, help="Path to input image or image URL") + parser.add_argument("--input_prompt", type=str, default="Describe this image for me", help="Input prompt for instruct task") + parser.add_argument("--max_tokens", type=int, default=512, help="Maximum number of tokens to generate") + + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index e80e093e..7e9c6478 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,6 +80,12 @@ convert = [ "nexa-gguf", ] +transformers = [ + "transformers", + "torch", + "pillow" +] + [project.urls] Homepage = "https://github.com/NexaAI/nexa-sdk" Issues = "https://github.com/NexaAI/nexa-sdk/issues" @@ -105,6 +111,7 @@ wheel.packages = [ "nexa.onnx.streamlit", "nexa.onnx.server", "nexa.eval", + "nexa.transformers", ] sdist.include = [ "CMakeLists.txt",