Skip to content

Commit

Permalink
Optimize for proper flash attn causal handling (#2503)
Browse files Browse the repository at this point in the history
  • Loading branch information
siddartha-RE authored Oct 2, 2023
1 parent 6b4fc64 commit 46e5207
Showing 1 changed file with 16 additions and 37 deletions.
53 changes: 16 additions & 37 deletions fastchat/train/llama2_flash_attn_monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Optional, Tuple

import torch
from flash_attn import __version__ as flash_attn_version
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import (
flash_attn_func,
Expand Down Expand Up @@ -36,6 +37,7 @@ def forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
warnings.warn(
Expand All @@ -58,50 +60,31 @@ def forward(
kv_seq_len = k.shape[1]
past_kv_len = 0
if past_key_value is not None:
past_kv_len = past_key_value[0].shape[1]
past_kv_len = past_key_value[0].shape[2]
kv_seq_len += past_kv_len

cos_sin = self.rotary_emb(v, seq_len=kv_seq_len)
q, k = apply_rotary_pos_emb(q, k, cos_sin, position_ids)

if past_key_value is not None:
assert (
flash_attn_version >= "2.1.0"
), "past_key_value support requires flash-attn >= 2.1.0"
# reuse k, v
k = torch.cat([past_key_value[0], k], dim=1)
v = torch.cat([past_key_value[1], v], dim=1)

past_key_value = (k, v) if use_cache else None

key_padding_mask = attention_mask
# Ideally we could just do this:
# q, indices, cu_q_lens, max_s = unpad_input(q, key_padding_mask[:, -q_len:])
# but this does not work as Flash attention treats the q seq and kv seq as starting at index 0
# which then breaks the causality logic. Probably if q_len >> past_kv_len we should
# just skip flash attention. Leaving this in for now to demonstrate correctness of
# flash attention information even when q needs padding.
# TODO(siddartha): delegate back to original implementation on this condition.
if past_kv_len > 0:
q = torch.cat(
(
torch.full(
(bsz, past_kv_len, self.num_heads, self.head_dim),
0.0,
dtype=q.dtype,
device=q.device,
),
q,
),
dim=1,
)
k = torch.cat([past_key_value[0].transpose(1, 2), k], dim=1)
v = torch.cat([past_key_value[1].transpose(1, 2), v], dim=1)

past_key_value = (k.transpose(1, 2), v.transpose(1, 2)) if use_cache else None

if key_padding_mask is None:
if attention_mask is None:
output = flash_attn_func(q, k, v, 0.0, softmax_scale=None, causal=True).view(
bsz, q_len + past_kv_len, -1
bsz, q_len, -1
)
else:
q, indices, cu_q_lens, max_s = unpad_input(q, key_padding_mask)
q, indices, cu_q_lens, max_s = unpad_input(q, attention_mask[:, -q_len:])
# We can skip concat and call unpad twice but seems better to call unpad only once.
kv, _, cu_k_lens, max_k = unpad_input(
torch.stack((k, v), dim=2), key_padding_mask
torch.stack((k, v), dim=2), attention_mask
)
output_unpad = flash_attn_varlen_kvpacked_func(
q,
Expand All @@ -115,11 +98,7 @@ def forward(
causal=True,
)
output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
output = pad_input(output_unpad, indices, bsz, q_len + past_kv_len)

# Need to strip off the zero query outputs.
if past_kv_len > 0:
output = output[:, past_kv_len:, ...]
output = pad_input(output_unpad, indices, bsz, q_len)

return self.o_proj(output), None, past_key_value

Expand Down Expand Up @@ -245,7 +224,7 @@ def test():
use_cache=True,
)
parts.append(part)
past_kv_len = past_kv[0].shape[1]
past_kv_len = past_kv[0].shape[2]

print(
f"allclose(oneshot[:, 0], parts[0]) = {torch.allclose(oneshot[:, :part_len], parts[0])}"
Expand Down

0 comments on commit 46e5207

Please sign in to comment.