diff --git a/nexa/transformers/README.md b/nexa/transformers/README.md
new file mode 100644
index 00000000..c539b454
--- /dev/null
+++ b/nexa/transformers/README.md
@@ -0,0 +1,8 @@
+# transformers support for Nexa AI models
+
+```
+python run_omnivision.py
+```
+
+## Acknowledgements
+We thank the [Hugging Face Transformers](https://github.com/huggingface/transformers) for their amazing work on the Transformers library.
diff --git a/nexa/transformers/__init__.py b/nexa/transformers/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/nexa/transformers/omnivision/__init__.py b/nexa/transformers/omnivision/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/nexa/transformers/omnivision/configuration.py b/nexa/transformers/omnivision/configuration.py
new file mode 100644
index 00000000..d356a315
--- /dev/null
+++ b/nexa/transformers/omnivision/configuration.py
@@ -0,0 +1,130 @@
+# Copyright (c) 2024 Nexa AI Inc., Alibaba Group (Qwen team), and HuggingFace Inc.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""" Qwen2 model configuration"""
+
+from transformers.configuration_utils import PretrainedConfig
+from transformers.utils import logging
+from typing import Union
+from transformers import PretrainedConfig
+import os
+from transformers.models.auto import CONFIG_MAPPING
+
+logger = logging.get_logger(__name__)
+
+
+class SigLipVisionConfig(PretrainedConfig):
+ model_type = "siglip_vision_model"
+ def __init__(
+ self,
+ hidden_size=1152,
+ image_mean=(0.5, 0.5, 0.5),
+ intermediate_size=4304,
+ num_hidden_layers=27,
+ num_attention_heads=16,
+ num_channels=3,
+ image_size=384,
+ patch_size=14,
+ hidden_act="gelu_pytorch_tanh",
+ layer_norm_eps=1e-6,
+ attention_dropout=0.0,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_channels = num_channels
+ self.patch_size = patch_size
+ self.image_size = image_size
+ self.attention_dropout = attention_dropout
+ self.layer_norm_eps = layer_norm_eps
+ self.hidden_act = hidden_act
+ self.image_mean = image_mean
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
+ cls._set_token_in_kwargs(kwargs)
+
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
+
+ # get the vision config dict if we are loading from SigLipConfig
+ if config_dict.get("model_type") == "siglip":
+ config_dict = config_dict["vision_config"]
+
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
+ logger.warning(
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
+ )
+ return cls.from_dict(config_dict, **kwargs)
+
+
+""" Nexa AI model configuration"""
+class OminiVLMConfig(PretrainedConfig):
+ model_type = "nano-omini-vlm"
+
+ model_type = "omini_vlm"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ vision_config=None,
+ text_config=None,
+ hidden_size=4096,
+ mm_hidden_size=1152,
+ mm_projector_lr=None,
+ mm_projector_type="mlp2x_gelu",
+ image_token_index=151655,
+ initializer_range=0.02,
+ **kwargs,
+ ):
+ self.hidden_size = hidden_size
+ self.mm_hidden_size = mm_hidden_size
+ self.mm_projector_lr = mm_projector_lr
+ self.mm_projector_type = mm_projector_type
+ self.image_token_index = image_token_index
+ self.initializer_range = initializer_range
+ if isinstance(vision_config, dict):
+ vision_config = SigLipVisionConfig(**vision_config)
+ elif vision_config is None:
+ vision_config = SigLipVisionConfig(
+ hidden_size=1152,
+ image_mean=(0.5, 0.5, 0.5),
+ intermediate_size=4304,
+ num_hidden_layers=27,
+ num_attention_heads=16,
+ num_channels=3,
+ image_size=384,
+ patch_size=14,
+ hidden_act="gelu_pytorch_tanh",
+ layer_norm_eps=1e-6,
+ attention_dropout=0.0,
+ )
+ self.vision_config = vision_config
+
+ if isinstance(text_config, dict):
+ text_config["model_type"] = (
+ text_config["model_type"] if "model_type" in text_config else "qwen2"
+ )
+ text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
+ elif text_config is None:
+ text_config = CONFIG_MAPPING["qwen2"]()
+
+ self.text_config = text_config
+
+ super().__init__(**kwargs)
+
\ No newline at end of file
diff --git a/nexa/transformers/omnivision/modeling.py b/nexa/transformers/omnivision/modeling.py
new file mode 100644
index 00000000..94bd67b1
--- /dev/null
+++ b/nexa/transformers/omnivision/modeling.py
@@ -0,0 +1,709 @@
+# Copyright (c) 2024 Nexa AI Inc., Alibaba Group (Qwen team), and HuggingFace Inc.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+from dataclasses import dataclass
+from transformers.activations import ACT2FN
+import torch.utils.checkpoint
+from torch import nn
+from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
+from transformers.modeling_utils import PreTrainedModel
+from transformers.utils import ModelOutput
+from transformers import Qwen2ForCausalLM
+from .configuration import SigLipVisionConfig, OminiVLMConfig
+
+# ======================================================================================== #
+# vision tower #
+# ======================================================================================== #
+@dataclass
+class SigLipVisionModelOutput(ModelOutput):
+ """
+ Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
+
+ Args:
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
+ The image embeddings obtained by applying the projection layer to the pooler_output.
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Sequence of hidden-states at the output of the last layer of the model.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ """
+ image_embeds: Optional[torch.FloatTensor] = None
+ last_hidden_state: torch.FloatTensor = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+
+
+@dataclass
+class OminiVLMOutputWithPast(ModelOutput):
+ """
+ Base class for Gemma2Audio causal language model (or autoregressive) outputs.
+
+ Args:
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
+ Language modeling loss (for next-token prediction).
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
+
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
+ sequence_length)`.
+
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
+ heads.
+ attention_mask (`torch.FloatTensor`, *optional*):
+ Attentions mask, used to update attention mask and position_ids.
+ """
+ loss: Optional[torch.FloatTensor] = None
+ logits: torch.FloatTensor = None
+ past_key_values: Optional[List[torch.FloatTensor]] = None
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
+ attention_mask: Optional[torch.FloatTensor] = None
+
+
+class SigLipVisionEmbeddings(nn.Module):
+ def __init__(self, config: SigLipVisionConfig):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.image_size = config.image_size
+ self.patch_size = config.patch_size
+
+ self.patch_embedding = nn.Conv2d(
+ in_channels=config.num_channels,
+ out_channels=self.embed_dim,
+ kernel_size=self.patch_size,
+ stride=self.patch_size,
+ padding="valid",
+ )
+
+ self.num_patches = (self.image_size // self.patch_size) ** 2
+ self.num_positions = self.num_patches
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
+ self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
+
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
+
+ embeddings = embeddings + self.position_embedding(self.position_ids)
+ return embeddings
+
+
+class SigLipAttention(nn.Module):
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.embed_dim = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_heads
+ if self.head_dim * self.num_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
+ f" {self.num_heads})."
+ )
+ self.scale = self.head_dim ** -0.5
+ self.dropout = config.attention_dropout
+
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ batch_size, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+
+ k_v_seq_len = key_states.shape[-2]
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
+
+ if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
+ raise ValueError(
+ f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
+ raise ValueError(
+ f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights + attention_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights
+
+
+class SigLipMLP(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.activation_fn = ACT2FN[config.hidden_act]
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = self.fc2(hidden_states)
+ return hidden_states
+
+
+class SigLipEncoderLayer(nn.Module):
+ def __init__(self, config: SigLipVisionConfig):
+ super().__init__()
+ self.embed_dim = config.hidden_size
+ self.self_attn = SigLipAttention(config)
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+ self.mlp = SigLipMLP(config)
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.FloatTensor]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`):
+ Input to the layer of shape `(batch, seq_len, embed_dim)`.
+ attention_mask (`torch.FloatTensor`):
+ Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
+ output_attentions (`bool`, *optional*, defaults to `False`):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ """
+ residual = hidden_states
+
+ hidden_states = self.layer_norm1(hidden_states)
+ hidden_states, attn_weights = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ )
+ hidden_states = residual + hidden_states
+
+ residual = hidden_states
+ hidden_states = self.layer_norm2(hidden_states)
+ hidden_states = self.mlp(hidden_states)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
+class SigLipPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = SigLipVisionConfig
+ base_model_prefix = "siglip"
+ supports_gradient_checkpointing = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ pass
+
+
+class SigLipEncoder(nn.Module):
+ """
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
+ [`SigLipEncoderLayer`].
+
+ Args:
+ config: SigLipVisionConfig
+ """
+
+ def __init__(self, config: SigLipVisionConfig):
+ super().__init__()
+ self.config = config
+ self.layers = nn.ModuleList([SigLipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ inputs_embeds,
+ attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutput]:
+ r"""
+ Args:
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ hidden_states = inputs_embeds
+ for encoder_layer in self.layers:
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ encoder_layer.__call__,
+ hidden_states,
+ attention_mask,
+ output_attentions,
+ )
+ else:
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask,
+ output_attentions=output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+ )
+
+
+class SigLipMultiheadAttentionPoolingHead(nn.Module):
+ """Multihead Attention Pooling."""
+
+ def __init__(self, config: SigLipVisionConfig):
+ super().__init__()
+
+ self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
+ self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.mlp = SigLipMLP(config)
+
+ def forward(self, hidden_state):
+ batch_size = hidden_state.shape[0]
+ probe = self.probe.repeat(batch_size, 1, 1)
+
+ hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
+
+ residual = hidden_state
+ hidden_state = self.layernorm(hidden_state)
+ hidden_state = residual + self.mlp(hidden_state)
+
+ return hidden_state[:, 0]
+
+
+class SigLipVisionTransformer(nn.Module):
+ def __init__(self, config: SigLipVisionConfig):
+ super().__init__()
+ self.config = config
+ embed_dim = config.hidden_size
+
+ self.embeddings = SigLipVisionEmbeddings(config)
+ self.encoder = SigLipEncoder(config)
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
+ self.head = SigLipMultiheadAttentionPoolingHead(config)
+
+ def get_dtype(self) -> torch.dtype:
+ return self.encoder.layers[0].mlp.fc2.weight.dtype
+
+ def forward(
+ self,
+ pixel_values,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
+ r"""
+ Returns:
+
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ hidden_states = self.embeddings(pixel_values)
+
+ encoder_outputs = self.encoder(
+ inputs_embeds=hidden_states,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ last_hidden_state = encoder_outputs[0]
+ last_hidden_state = self.post_layernorm(last_hidden_state)
+
+ pooled_output = self.head(last_hidden_state)
+
+ if not return_dict:
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPooling(
+ last_hidden_state=last_hidden_state,
+ pooler_output=pooled_output,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+class SigLipVisionModel(SigLipPreTrainedModel):
+ config_class = SigLipVisionConfig
+ main_input_name = "pixel_values"
+ _no_split_modules = ["SigLipEncoderLayer"]
+
+ def __init__(self, config: SigLipVisionConfig):
+ super().__init__(config)
+ self.vision_model = SigLipVisionTransformer(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self) -> nn.Module:
+ return self.vision_model.embeddings.patch_embedding
+
+ def forward(
+ self,
+ pixel_values,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ return self.vision_model(
+ pixel_values=pixel_values,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+
+# ======================================================================================== #
+# Projector #
+# ======================================================================================== #
+
+import re
+def build_vision_projector(config, delay_load=False, **kwargs):
+ projector_type = getattr(config, 'mm_projector_type', 'mlp2x_gelu')
+ mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
+ if mlp_gelu_match:
+ mlp_depth = int(mlp_gelu_match.group(1))
+ modules = [nn.Linear(config.mm_hidden_size*9, config.hidden_size)]
+ for _ in range(1, mlp_depth):
+ modules.append(nn.GELU())
+ modules.append(nn.Linear(config.hidden_size, config.text_config.hidden_size))
+ return nn.Sequential(*modules)
+
+
+# ======================================================================================== #
+# LLM #
+# ======================================================================================== #
+class OminiVLMPreTrainedModel(PreTrainedModel):
+ config_class = OminiVLMConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["Qwen2DecoderLayer", "SigLipEncoderLayer"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ _supports_cache_class = True
+ _supports_static_cache = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, (nn.Linear, nn.Conv3d)):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+
+class OminiVLMForConditionalGeneration(OminiVLMPreTrainedModel):
+ def __init__(self, config: OminiVLMConfig):
+ super().__init__(config)
+ if isinstance(config.vision_config, dict):
+ vision_config = SigLipVisionConfig(**config.vision_config)
+ else:
+ vision_config = config.vision_config
+ self.vision_tower = SigLipVisionModel(vision_config)
+ self.multi_modal_projector = build_vision_projector(config)
+ self.vocab_size = config.text_config.vocab_size
+ self.language_model = Qwen2ForCausalLM(
+ config.text_config,
+ )
+ self.pad_token_id = (
+ self.config.pad_token_id if self.config.pad_token_id is not None else -1
+ )
+ self._padding_side = "right" # set it to left by default, user can use setter to change padding_sides
+ self.post_init()
+
+ @property
+ def padding_side(self):
+ return self._padding_side
+
+ @padding_side.setter
+ def padding_side(self, padding_side: str):
+ if padding_side not in ["left", "right"]:
+ raise ValueError(f"{padding_side} is not `left` or `right`.")
+ self._padding_side = padding_side
+
+ def get_input_embeddings(self):
+ return self.language_model.get_input_embeddings()
+
+ def set_input_embeddings(self, value):
+ self.language_model.set_input_embeddings(value)
+
+ def get_output_embeddings(self):
+ return self.language_model.get_output_embeddings()
+
+ def set_output_embeddings(self, new_embeddings):
+ self.language_model.set_output_embeddings(new_embeddings)
+
+ def set_decoder(self, decoder):
+ self.language_model.set_decoder(decoder)
+
+ def get_decoder(self):
+ return self.language_model.get_decoder()
+
+ def tie_weights(self):
+ return self.language_model.tie_weights()
+
+ def resize_token_embeddings(
+ self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None
+ ) -> nn.Embedding:
+ model_embeds = self.language_model.resize_token_embeddings(
+ new_num_tokens, pad_to_multiple_of
+ )
+ # update vocab size
+ self.config.text_config.vocab_size = model_embeds.num_embeddings
+ self.vocab_size = model_embeds.num_embeddings
+ return model_embeds
+
+ def _update_model_kwargs_for_generation(
+ self,
+ outputs: ModelOutput,
+ model_kwargs: Dict[str, Any],
+ is_encoder_decoder: bool = False,
+ num_new_tokens: int = 1,
+ ) -> Dict[str, Any]:
+ model_kwargs = super()._update_model_kwargs_for_generation(
+ outputs=outputs,
+ model_kwargs=model_kwargs,
+ is_encoder_decoder=is_encoder_decoder,
+ num_new_tokens=num_new_tokens,
+ )
+ return model_kwargs
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ pixel_values: Optional[torch.Tensor] = None,
+ ) -> Union[Tuple, OminiVLMOutputWithPast]:
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+ ```"""
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+
+ target_device = self.vision_tower.device
+
+ if pixel_values is not None:
+ pixel_values = pixel_values.to(target_device)
+
+ if inputs_embeds is None:
+ # 1. Extract the input embeddings
+ inputs_embeds = self.get_input_embeddings()(input_ids)
+
+ # 2. Merge text and vision features
+ if pixel_values is not None:
+ pixel_values = pixel_values.type(self.vision_tower.vision_model.get_dtype())
+ image_embeds = self.vision_tower(pixel_values).last_hidden_state.to(pixel_values.dtype)
+ image_embeds = image_embeds.view(image_embeds.shape[0], 81, -1)
+ image_embeds = self.multi_modal_projector(image_embeds)
+ image_mask = (
+ (input_ids == self.config.image_token_index)
+ .unsqueeze(-1)
+ .expand_as(inputs_embeds)
+ .to(inputs_embeds.device)
+ )
+ image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
+
+ ## This is to intelligently replace the image tokens with the image features
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
+
+ if attention_mask is not None:
+ attention_mask = attention_mask.to(inputs_embeds.device)
+
+ outputs = self.language_model(
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ logits = outputs[0]
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ if attention_mask is not None:
+ shift_attention_mask = attention_mask[..., 1:]
+ shift_logits = logits[..., :-1, :][
+ shift_attention_mask.to(logits.device) != 0
+ ].contiguous()
+ shift_labels = labels[..., 1:][
+ shift_attention_mask.to(labels.device) != 0
+ ].contiguous()
+ else:
+ shift_logits = logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss_fct = nn.CrossEntropyLoss()
+ loss = loss_fct(
+ shift_logits.view(-1, shift_logits.size(-1)),
+ shift_labels.view(-1).to(shift_logits.device),
+ )
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return OminiVLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ attention_mask=attention_mask,
+ )
\ No newline at end of file
diff --git a/nexa/transformers/omnivision/processing.py b/nexa/transformers/omnivision/processing.py
new file mode 100644
index 00000000..2bc3f008
--- /dev/null
+++ b/nexa/transformers/omnivision/processing.py
@@ -0,0 +1,201 @@
+# Copyright (c) 2024 Nexa AI Inc., Alibaba Group (Qwen team), and HuggingFace Inc.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Union
+
+try:
+ from typing import Unpack
+except ImportError:
+ from typing_extensions import Unpack
+
+from transformers.feature_extraction_utils import BatchFeature
+from transformers.image_utils import ImageInput, VideoInput
+from transformers.processing_utils import (
+ ProcessingKwargs,
+ ProcessorMixin,
+)
+from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
+from transformers.utils import logging
+
+
+logger = logging.get_logger(__name__)
+NUM_IMAGE_TOKENS = 81
+
+class NanoVLMProcessorKwargs(ProcessingKwargs, total=False):
+ _defaults = {
+ "text_kwargs": {
+ "padding": False,
+ },
+ }
+
+
+class NanoVLMProcessor(ProcessorMixin):
+ attributes = ["image_processor", "tokenizer"]
+ valid_kwargs = ["chat_template"]
+ image_processor_class = "SiglipImageProcessor"
+ tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
+
+ def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
+ if chat_template is None:
+ chat_template = self.default_chat_template
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
+
+ def __call__(
+ self,
+ images: ImageInput = None,
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
+ **kwargs: Unpack[NanoVLMProcessorKwargs],
+ ) -> BatchFeature:
+ """
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
+ and `kwargs` arguments to Gemma2TokenizerFast's [`~Gemma2TokenizerFast.__call__`] if `text` is not `None` to encode
+ the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
+ Gemma2VLImageProcessor's [`~Gemma2VLImageProcessor.__call__`] if `vision_infos` is not `None`.
+
+ Args:
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
+ tensor. Both channels-first and channels-last formats are supported.
+ text (`str`, `List[str]`, `List[List[str]]`):
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
+ If set, will return tensors of a particular framework. Acceptable values are:
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
+ - `'np'`: Return NumPy `np.ndarray` objects.
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
+
+ Returns:
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
+
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
+ `None`).
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
+ """
+ output_kwargs = self._merge_kwargs(
+ NanoVLMProcessorKwargs,
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
+ **kwargs,
+ )
+
+ # check the number of images is equal to the number of all image_pad tokens
+ assert len(images) == sum([t.count("<|image_pad|>") for t in text]), "The number of images must be equal to the number of all image_pad tokens in the text."
+
+ if images is not None:
+ image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
+ else:
+ image_inputs = {}
+
+ if not isinstance(text, list):
+ text = [text]
+
+ if image_inputs is not None:
+ index = 0
+ for i in range(len(text)):
+ while "<|image_pad|>" in text[i]:
+ text[i] = text[i].replace(
+ "<|image_pad|>", "<|placeholder|>" * NUM_IMAGE_TOKENS, 1
+ )
+ index += 1
+ text[i] = text[i].replace("<|placeholder|>", "<|image_pad|>")
+
+ _ = output_kwargs["text_kwargs"].pop("padding_side", None)
+ text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
+
+ return BatchFeature(data={**text_inputs, **image_inputs})
+
+ def batch_decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to Gemma2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
+ refer to the docstring of this method for more information.
+ """
+ return self.tokenizer.batch_decode(*args, **kwargs)
+
+ def decode(self, *args, **kwargs):
+ """
+ This method forwards all its arguments to Gemma2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
+ the docstring of this method for more information.
+ """
+ return self.tokenizer.decode(*args, **kwargs)
+
+ @property
+ def model_input_names(self):
+ tokenizer_input_names = self.tokenizer.model_input_names
+ image_processor_input_names = self.image_processor.model_input_names
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
+
+
+ @property
+ def default_chat_template(self):
+ return (
+ "{%- if tools %}"
+ "{{- '<|im_start|>system\n' }}"
+ "{%- if messages[0]['role'] == 'system' %}"
+ "{{- messages[0]['content'] }}"
+ "{%- else %}"
+ "{{- 'You are Nano-Omni-VLM, created by Nexa AI. You are a helpful assistant.' }}"
+ "{%- endif %}"
+ "{{- \"\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n\" }}"
+ "{%- for tool in tools %}"
+ "{{- \"\n\" }}"
+ "{{- tool | tojson }}"
+ "{%- endfor %}"
+ "{{- \"\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\\\"name\\\": , \\\"arguments\\\": }\n<|im_end|>\n\" }}"
+ "{%- else %}"
+ "{%- if messages[0]['role'] == 'system' %}"
+ "{{- '<|im_start|>system\n' + messages[0]['content'] + '<|im_end|>\n' }}"
+ "{%- else %}"
+ "{{- '<|im_start|>system\nYou are Nano-Omni-VLM, created by Nexa AI. You are a helpful assistant.<|im_end|>\n' }}"
+ "{%- endif %}"
+ "{%- endif %}"
+ "{%- for message in messages %}"
+ "{%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}"
+ "{{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}"
+ "{%- elif message.role == \"assistant\" %}"
+ "{{- '<|im_start|>' + message.role }}"
+ "{%- if message.content %}"
+ "{{- '\n' + message.content }}"
+ "{%- endif %}"
+ "{%- for tool_call in message.tool_calls %}"
+ "{%- if tool_call.function is defined %}"
+ "{%- set tool_call = tool_call.function %}"
+ "{%- endif %}"
+ "{{- '\n\n{\"name\": \"' }}"
+ "{{- tool_call.name }}"
+ "{{- '\", \"arguments\": ' }}"
+ "{{- tool_call.arguments | tojson }}"
+ "{{- '}\n' }}"
+ "{%- endfor %}"
+ "{{- '<|im_end|>\n' }}"
+ "{%- elif message.role == \"tool\" %}"
+ "{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}"
+ "{{- '<|im_start|>user' }}"
+ "{%- endif %}"
+ "{{- '\n\n' }}"
+ "{{- message.content }}"
+ "{{- '\n' }}"
+ "{%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}"
+ "{{- '<|im_end|>\n' }}"
+ "{%- endif %}"
+ "{%- endif %}"
+ "{%- endfor %}"
+ "{%- if add_generation_prompt %}"
+ "{{- '<|im_start|>assistant\n' }}"
+ "{%- endif %}"
+ )
\ No newline at end of file
diff --git a/nexa/transformers/run_omnivision.py b/nexa/transformers/run_omnivision.py
new file mode 100644
index 00000000..f81d1efe
--- /dev/null
+++ b/nexa/transformers/run_omnivision.py
@@ -0,0 +1,92 @@
+from nexa.transformers.omnivision.processing import NanoVLMProcessor
+from nexa.transformers.omnivision.modeling import OminiVLMForConditionalGeneration
+import argparse
+import torch
+
+
+model_name = "NexaAIDev/omnivlm-dpo"
+image_url = "https://public-storage.nexa4ai.com/public-images/cat.png"
+
+
+def get_device():
+ if torch.cuda.is_available():
+ return "cuda"
+ elif torch.backends.mps.is_available():
+ return "mps"
+ return "cpu"
+
+
+def load_model_and_processor(model_path):
+ device = get_device()
+ proc_path = "nexa-collaboration/nano-vlm-processor"
+ processor = NanoVLMProcessor.from_pretrained(proc_path)
+ processor.tokenizer.pad_token = processor.tokenizer.eos_token
+ processor.tokenizer.padding_side = "right"
+
+ model_kwargs = {}
+ # Adjust dtype based on device
+ dtype = torch.bfloat16 if device == "cuda" else torch.float32
+ local_model = OminiVLMForConditionalGeneration.from_pretrained(
+ model_path,
+ torch_dtype=dtype,
+ **model_kwargs
+ )
+ local_model = local_model.to(device)
+ return local_model, processor
+
+
+def process_single_image(processor, image_path, input_prompt=None):
+ text = f"<|im_start|>system\nYou are Nano-Omni-VLM, created by Nexa AI. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n{input_prompt}\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>"
+ # Changed from Image.open() to handle URLs
+ if image_path.startswith('http'):
+ from PIL import Image
+ import requests
+ from io import BytesIO
+ response = requests.get(image_path)
+ image = Image.open(BytesIO(response.content)).convert('RGB')
+ else:
+ image = Image.open(image_path).convert('RGB')
+ inputs = processor(
+ text=[text],
+ images=[image],
+ padding=True,
+ return_tensors="pt",
+ )
+ return inputs.to(get_device())
+
+
+def generate_output(model, processor, inputs, max_tokens):
+ cur_ids = inputs['input_ids']
+ cur_attention_mask = inputs['attention_mask']
+ input_token_length = cur_ids.shape[-1]
+ for _ in range(max_tokens):
+ out = model(
+ cur_ids,
+ attention_mask=cur_attention_mask,
+ pixel_values=inputs['pixel_values'],
+ use_cache=False
+ )
+ next_token = out.logits[:, -1].argmax()
+ next_word = processor.decode(next_token)
+ cur_ids = torch.cat([cur_ids, next_token.unsqueeze(0).unsqueeze(0)], dim=-1)
+ cur_attention_mask = torch.cat([cur_attention_mask, torch.ones_like(next_token).unsqueeze(0).unsqueeze(0)], dim=-1)
+ if next_word in ("<|im_end|>"):
+ break
+ return processor.batch_decode(cur_ids[:, input_token_length:])[0]
+
+def main(args):
+ model, processor = load_model_and_processor(args.model_path)
+ inputs = process_single_image(processor, args.image_path, args.input_prompt)
+ output = generate_output(model, processor, inputs, args.max_tokens)
+ print("=== Inference Result ===\n", output)
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Inference script for Nano-Omni-VLM")
+ parser.add_argument("--model_path", default=model_name, help="Path to the model checkpoint")
+ # Add image_path argument
+ parser.add_argument("--image_path", default=image_url, help="Path to input image or image URL")
+ parser.add_argument("--input_prompt", type=str, default="Describe this image for me", help="Input prompt for instruct task")
+ parser.add_argument("--max_tokens", type=int, default=512, help="Maximum number of tokens to generate")
+
+ args = parser.parse_args()
+ main(args)
\ No newline at end of file
diff --git a/pyproject.toml b/pyproject.toml
index e80e093e..7e9c6478 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -80,6 +80,12 @@ convert = [
"nexa-gguf",
]
+transformers = [
+ "transformers",
+ "torch",
+ "pillow"
+]
+
[project.urls]
Homepage = "https://github.com/NexaAI/nexa-sdk"
Issues = "https://github.com/NexaAI/nexa-sdk/issues"
@@ -105,6 +111,7 @@ wheel.packages = [
"nexa.onnx.streamlit",
"nexa.onnx.server",
"nexa.eval",
+ "nexa.transformers",
]
sdist.include = [
"CMakeLists.txt",