Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize for proper flash attn causal handling #2503

Merged
merged 1 commit into from
Oct 2, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 16 additions & 37 deletions fastchat/train/llama2_flash_attn_monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Optional, Tuple

import torch
from flash_attn import __version__ as flash_attn_version
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import (
flash_attn_func,
Expand Down Expand Up @@ -36,6 +37,7 @@ def forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
warnings.warn(
Expand All @@ -58,50 +60,31 @@ def forward(
kv_seq_len = k.shape[1]
past_kv_len = 0
if past_key_value is not None:
past_kv_len = past_key_value[0].shape[1]
past_kv_len = past_key_value[0].shape[2]
kv_seq_len += past_kv_len

cos_sin = self.rotary_emb(v, seq_len=kv_seq_len)
q, k = apply_rotary_pos_emb(q, k, cos_sin, position_ids)

if past_key_value is not None:
assert (
flash_attn_version >= "2.1.0"
), "past_key_value support requires flash-attn >= 2.1.0"
# reuse k, v
k = torch.cat([past_key_value[0], k], dim=1)
v = torch.cat([past_key_value[1], v], dim=1)

past_key_value = (k, v) if use_cache else None

key_padding_mask = attention_mask
# Ideally we could just do this:
# q, indices, cu_q_lens, max_s = unpad_input(q, key_padding_mask[:, -q_len:])
# but this does not work as Flash attention treats the q seq and kv seq as starting at index 0
# which then breaks the causality logic. Probably if q_len >> past_kv_len we should
# just skip flash attention. Leaving this in for now to demonstrate correctness of
# flash attention information even when q needs padding.
# TODO(siddartha): delegate back to original implementation on this condition.
if past_kv_len > 0:
q = torch.cat(
(
torch.full(
(bsz, past_kv_len, self.num_heads, self.head_dim),
0.0,
dtype=q.dtype,
device=q.device,
),
q,
),
dim=1,
)
k = torch.cat([past_key_value[0].transpose(1, 2), k], dim=1)
v = torch.cat([past_key_value[1].transpose(1, 2), v], dim=1)

past_key_value = (k.transpose(1, 2), v.transpose(1, 2)) if use_cache else None

if key_padding_mask is None:
if attention_mask is None:
output = flash_attn_func(q, k, v, 0.0, softmax_scale=None, causal=True).view(
bsz, q_len + past_kv_len, -1
bsz, q_len, -1
)
else:
q, indices, cu_q_lens, max_s = unpad_input(q, key_padding_mask)
q, indices, cu_q_lens, max_s = unpad_input(q, attention_mask[:, -q_len:])
# We can skip concat and call unpad twice but seems better to call unpad only once.
kv, _, cu_k_lens, max_k = unpad_input(
torch.stack((k, v), dim=2), key_padding_mask
torch.stack((k, v), dim=2), attention_mask
)
output_unpad = flash_attn_varlen_kvpacked_func(
q,
Expand All @@ -115,11 +98,7 @@ def forward(
causal=True,
)
output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
output = pad_input(output_unpad, indices, bsz, q_len + past_kv_len)

# Need to strip off the zero query outputs.
if past_kv_len > 0:
output = output[:, past_kv_len:, ...]
output = pad_input(output_unpad, indices, bsz, q_len)

return self.o_proj(output), None, past_key_value

Expand Down Expand Up @@ -245,7 +224,7 @@ def test():
use_cache=True,
)
parts.append(part)
past_kv_len = past_kv[0].shape[1]
past_kv_len = past_kv[0].shape[2]

print(
f"allclose(oneshot[:, 0], parts[0]) = {torch.allclose(oneshot[:, :part_len], parts[0])}"
Expand Down