Skip to content

Commit

Permalink
remove duplicated code
Browse files Browse the repository at this point in the history
Signed-off-by: xffxff <[email protected]>
  • Loading branch information
xffxff committed Nov 21, 2024
1 parent 64449a1 commit 1a68b48
Showing 1 changed file with 0 additions and 168 deletions.
168 changes: 0 additions & 168 deletions vllm/model_executor/models/aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,174 +316,6 @@ def forward(self, x, attn_mask=None):
return out


class FFN(nn.Module):
"""
Feed-Forward Network module.
Args:
embed_dim (int): Input embedding dimension.
ff_dim (int): Hidden dimension of the feed-forward network.
output_dim (int): Output dimension.
"""

def __init__(self, embed_dim, ff_dim, output_dim):
super().__init__()
self.linear_in = nn.Linear(embed_dim, ff_dim, bias=False)
self.linear_out = nn.Linear(ff_dim, output_dim, bias=False)
self.act = ACT2FN["gelu_new"]

def forward(self, hidden_states):
hidden_states = self.act(self.linear_in(hidden_states))
hidden_states = self.linear_out(hidden_states)
return hidden_states


class CrossAttention(nn.Module):
"""
Cross-Attention module.
Args:
kv_dim (int): Dimension of key and value.
embed_dim (int): Embedding dimension.
num_heads (int): Number of attention heads.
drop_out_rate (float): Dropout rate. Default is 0.
"""

def __init__(self, kv_dim, embed_dim, num_heads, drop_out_rate=0):
super().__init__()
self.num_heads = num_heads
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.k_proj = nn.Linear(kv_dim, embed_dim, bias=False)
self.v_proj = nn.Linear(kv_dim, embed_dim, bias=False)

self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
self.linear = nn.Linear(embed_dim, embed_dim)
self.dropout = nn.Dropout(drop_out_rate)

self.layer_norm = nn.LayerNorm(embed_dim)
self.ln_kv = nn.LayerNorm(kv_dim)

def forward(self, x, hidden_states, attn_mask=None, add_residual=False):
"""
Forward pass of the CrossAttention module.
Args:
x (torch.Tensor): Input tensor for key and value.
hidden_states (torch.Tensor): Input tensor for query.
attn_mask (torch.Tensor, optional): Attention mask. Default is None.
add_residual (bool): Whether to add residual connection. Default is False.
Returns:
torch.Tensor: Output tensor after cross-attention.
"""
normed_hidden_states = self.layer_norm(hidden_states)
query = self.q_proj(normed_hidden_states).permute(1, 0, 2)

x = self.ln_kv(x)
key = self.k_proj(x).permute(1, 0, 2)
value = self.v_proj(x).permute(1, 0, 2)

attn_output, _ = self.multihead_attn(query,
key,
value,
attn_mask=attn_mask)

attn_output = attn_output.permute(1, 0, 2)

if add_residual:
attn_output = hidden_states + self.dropout(
self.linear(attn_output))
else:
attn_output = self.dropout(self.linear(attn_output))

return attn_output


class AriaProjector(nn.Module):
"""
A projection module with one cross attention layer and one FFN layer, which projects ViT's outputs into MoE's inputs.
Args:
patch_to_query_dict (dict): Maps patch numbers to their corresponding query numbers,
e.g., {1225: 128, 4900: 256}. This allows for different query sizes based on image resolution.
embed_dim (int): Embedding dimension.
num_heads (int): Number of attention heads.
kv_dim (int): Dimension of key and value.
ff_dim (int): Hidden dimension of the feed-forward network.
output_dim (int): Output dimension.
norm_layer (nn.Module): Normalization layer. Default is nn.LayerNorm.
Outputs:
A tensor with the shape of (batch_size, query_number, output_dim)
"""

def __init__(
self,
patch_to_query_dict,
embed_dim,
num_heads,
kv_dim,
ff_dim,
output_dim,
norm_layer=nn.LayerNorm,
):
super().__init__()
self.patch_to_query_dict = patch_to_query_dict
self.embed_dim = embed_dim
self.num_heads = num_heads

self.query = nn.Parameter(
torch.zeros(max(patch_to_query_dict.values()), self.embed_dim))

trunc_normal_(self.query, std=0.02)

self.cross_attn = CrossAttention(kv_dim, embed_dim, num_heads)

self.ln_ffn = norm_layer(embed_dim)
self.ffn = FFN(embed_dim, ff_dim, output_dim)

self.apply(self._init_weights)

def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)

def forward(self, x, attn_mask=None):
"""
Forward pass of the Projector module.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, num_patches, kv_dim).
attn_mask (torch.Tensor, optional): Attention mask. Default is None.
Returns:
torch.Tensor: Output tensor of shape (batch_size, query_number, output_dim).
"""
bs = x.shape[0]
queries = self.query.unsqueeze(0).repeat(bs, 1, 1)

query_num = self.patch_to_query_dict.get(x.shape[1], None)
assert (query_num is not None
), f"Query number for {x.shape[1]} patches is not provided"

queries = queries[:, :query_num, :]

if attn_mask is not None:
attn_mask = attn_mask.repeat_interleave(self.num_heads, 0)
attn_mask = attn_mask.unsqueeze(1).expand(-1, queries.size(1), -1)

attention_out = self.cross_attn(x, queries, attn_mask=attn_mask)

out = self.ffn(self.ln_ffn(attention_out))

return out


class AriaMoELMConfig(LlamaConfig):
"""
Configuration class for AriaMoE language model.
Expand Down

0 comments on commit 1a68b48

Please sign in to comment.