diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index ff02efd5c0830..b2e41ea47b240 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -316,174 +316,6 @@ def forward(self, x, attn_mask=None): return out -class FFN(nn.Module): - """ - Feed-Forward Network module. - - Args: - embed_dim (int): Input embedding dimension. - ff_dim (int): Hidden dimension of the feed-forward network. - output_dim (int): Output dimension. - """ - - def __init__(self, embed_dim, ff_dim, output_dim): - super().__init__() - self.linear_in = nn.Linear(embed_dim, ff_dim, bias=False) - self.linear_out = nn.Linear(ff_dim, output_dim, bias=False) - self.act = ACT2FN["gelu_new"] - - def forward(self, hidden_states): - hidden_states = self.act(self.linear_in(hidden_states)) - hidden_states = self.linear_out(hidden_states) - return hidden_states - - -class CrossAttention(nn.Module): - """ - Cross-Attention module. - - Args: - kv_dim (int): Dimension of key and value. - embed_dim (int): Embedding dimension. - num_heads (int): Number of attention heads. - drop_out_rate (float): Dropout rate. Default is 0. - """ - - def __init__(self, kv_dim, embed_dim, num_heads, drop_out_rate=0): - super().__init__() - self.num_heads = num_heads - self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False) - self.k_proj = nn.Linear(kv_dim, embed_dim, bias=False) - self.v_proj = nn.Linear(kv_dim, embed_dim, bias=False) - - self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) - self.linear = nn.Linear(embed_dim, embed_dim) - self.dropout = nn.Dropout(drop_out_rate) - - self.layer_norm = nn.LayerNorm(embed_dim) - self.ln_kv = nn.LayerNorm(kv_dim) - - def forward(self, x, hidden_states, attn_mask=None, add_residual=False): - """ - Forward pass of the CrossAttention module. - - Args: - x (torch.Tensor): Input tensor for key and value. - hidden_states (torch.Tensor): Input tensor for query. - attn_mask (torch.Tensor, optional): Attention mask. Default is None. - add_residual (bool): Whether to add residual connection. Default is False. - - Returns: - torch.Tensor: Output tensor after cross-attention. - """ - normed_hidden_states = self.layer_norm(hidden_states) - query = self.q_proj(normed_hidden_states).permute(1, 0, 2) - - x = self.ln_kv(x) - key = self.k_proj(x).permute(1, 0, 2) - value = self.v_proj(x).permute(1, 0, 2) - - attn_output, _ = self.multihead_attn(query, - key, - value, - attn_mask=attn_mask) - - attn_output = attn_output.permute(1, 0, 2) - - if add_residual: - attn_output = hidden_states + self.dropout( - self.linear(attn_output)) - else: - attn_output = self.dropout(self.linear(attn_output)) - - return attn_output - - -class AriaProjector(nn.Module): - """ - A projection module with one cross attention layer and one FFN layer, which projects ViT's outputs into MoE's inputs. - - Args: - patch_to_query_dict (dict): Maps patch numbers to their corresponding query numbers, - e.g., {1225: 128, 4900: 256}. This allows for different query sizes based on image resolution. - embed_dim (int): Embedding dimension. - num_heads (int): Number of attention heads. - kv_dim (int): Dimension of key and value. - ff_dim (int): Hidden dimension of the feed-forward network. - output_dim (int): Output dimension. - norm_layer (nn.Module): Normalization layer. Default is nn.LayerNorm. - - Outputs: - A tensor with the shape of (batch_size, query_number, output_dim) - """ - - def __init__( - self, - patch_to_query_dict, - embed_dim, - num_heads, - kv_dim, - ff_dim, - output_dim, - norm_layer=nn.LayerNorm, - ): - super().__init__() - self.patch_to_query_dict = patch_to_query_dict - self.embed_dim = embed_dim - self.num_heads = num_heads - - self.query = nn.Parameter( - torch.zeros(max(patch_to_query_dict.values()), self.embed_dim)) - - trunc_normal_(self.query, std=0.02) - - self.cross_attn = CrossAttention(kv_dim, embed_dim, num_heads) - - self.ln_ffn = norm_layer(embed_dim) - self.ffn = FFN(embed_dim, ff_dim, output_dim) - - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=0.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - def forward(self, x, attn_mask=None): - """ - Forward pass of the Projector module. - - Args: - x (torch.Tensor): Input tensor of shape (batch_size, num_patches, kv_dim). - attn_mask (torch.Tensor, optional): Attention mask. Default is None. - - Returns: - torch.Tensor: Output tensor of shape (batch_size, query_number, output_dim). - """ - bs = x.shape[0] - queries = self.query.unsqueeze(0).repeat(bs, 1, 1) - - query_num = self.patch_to_query_dict.get(x.shape[1], None) - assert (query_num is not None - ), f"Query number for {x.shape[1]} patches is not provided" - - queries = queries[:, :query_num, :] - - if attn_mask is not None: - attn_mask = attn_mask.repeat_interleave(self.num_heads, 0) - attn_mask = attn_mask.unsqueeze(1).expand(-1, queries.size(1), -1) - - attention_out = self.cross_attn(x, queries, attn_mask=attn_mask) - - out = self.ffn(self.ln_ffn(attention_out)) - - return out - - class AriaMoELMConfig(LlamaConfig): """ Configuration class for AriaMoE language model.