Skip to content

Commit

Permalink
[Model][Pixtral] Use memory_efficient_attention for PixtralHFVision (v…
Browse files Browse the repository at this point in the history
  • Loading branch information
mgoin authored Oct 20, 2024
1 parent 5b59fe0 commit 962d2c6
Showing 1 changed file with 21 additions and 41 deletions.
62 changes: 21 additions & 41 deletions vllm/model_executor/models/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
from transformers.models.pixtral.image_processing_pixtral import (
_num_image_tokens)
from transformers.models.pixtral.modeling_pixtral import (
PixtralRotaryEmbedding, apply_rotary_pos_emb,
generate_block_attention_mask, position_ids_in_meshgrid)
PixtralRotaryEmbedding, apply_rotary_pos_emb, position_ids_in_meshgrid)
from xformers.ops.fmha import memory_efficient_attention
from xformers.ops.fmha.attn_bias import BlockDiagonalMask

Expand Down Expand Up @@ -813,48 +812,30 @@ def __init__(self, config: PixtralVisionConfig):
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
attention_mask: BlockDiagonalMask,
position_embeddings: torch.Tensor,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel"""
batch, patches, _ = hidden_states.size()

batch_size, patches, _ = hidden_states.size()

query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(batch_size, patches, self.n_heads,
self.head_dim).transpose(1, 2)
key_states = key_states.view(batch_size, patches, self.n_heads,
self.head_dim).transpose(1, 2)
value_states = value_states.view(batch_size, patches, self.n_heads,
self.head_dim).transpose(1, 2)
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)

# Transpose q and k to apply HF's Rotary Position Embedding
q = q.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2)
k = k.view(batch, patches, self.n_heads, self.head_dim).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states,
key_states,
cos,
sin,
unsqueeze_dim=0)

attn_weights = torch.matmul(query_states, key_states.transpose(
2, 3)) * self.scale

if attention_mask is not None:
attn_weights = attn_weights + attention_mask
q, k = apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=0)

# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights,
dim=-1,
dtype=torch.float32).to(
query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
# Transpose q and k back for attention
q = q.transpose(1, 2).contiguous()
k = k.transpose(1, 2).contiguous()
v = v.reshape(batch, patches, self.n_heads, self.head_dim)

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, patches, -1)
out = memory_efficient_attention(q, k, v, attn_bias=attention_mask)
out = out.reshape(batch, patches, self.n_heads * self.head_dim)

return self.o_proj(attn_output)
return self.o_proj(out)


class PixtralHFTransformerBlock(nn.Module):
Expand All @@ -869,7 +850,7 @@ def __init__(self, config: PixtralVisionConfig):
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
attention_mask: BlockDiagonalMask,
position_embeddings: torch.Tensor,
) -> torch.Tensor:
r = self.attention.forward(self.attention_norm(hidden_states),
Expand All @@ -892,7 +873,7 @@ def __init__(self, config: PixtralVisionConfig):
def forward(
self,
x: torch.Tensor,
attention_mask: torch.Tensor,
attention_mask: BlockDiagonalMask,
position_embeddings: torch.Tensor,
) -> torch.Tensor:
for layer in self.layers:
Expand Down Expand Up @@ -953,9 +934,8 @@ def forward(

position_embedding = self.patch_positional_embedding(
patch_embeds, position_ids)
attention_mask = generate_block_attention_mask(
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list],
patch_embeds)
attention_mask = BlockDiagonalMask.from_seqlens(
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], )
out = self.transformer(patch_embeds, attention_mask,
position_embedding)

Expand Down

0 comments on commit 962d2c6

Please sign in to comment.