diff --git a/fastchat/train/llama2_flash_attn_monkey_patch.py b/fastchat/train/llama2_flash_attn_monkey_patch.py index 948355cdb..c1fe51c91 100644 --- a/fastchat/train/llama2_flash_attn_monkey_patch.py +++ b/fastchat/train/llama2_flash_attn_monkey_patch.py @@ -2,6 +2,7 @@ from typing import Optional, Tuple import torch +from flash_attn import __version__ as flash_attn_version from flash_attn.bert_padding import pad_input, unpad_input from flash_attn.flash_attn_interface import ( flash_attn_func, @@ -36,6 +37,7 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, + padding_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: warnings.warn( @@ -58,50 +60,31 @@ def forward( kv_seq_len = k.shape[1] past_kv_len = 0 if past_key_value is not None: - past_kv_len = past_key_value[0].shape[1] + past_kv_len = past_key_value[0].shape[2] kv_seq_len += past_kv_len cos_sin = self.rotary_emb(v, seq_len=kv_seq_len) q, k = apply_rotary_pos_emb(q, k, cos_sin, position_ids) if past_key_value is not None: + assert ( + flash_attn_version >= "2.1.0" + ), "past_key_value support requires flash-attn >= 2.1.0" # reuse k, v - k = torch.cat([past_key_value[0], k], dim=1) - v = torch.cat([past_key_value[1], v], dim=1) - - past_key_value = (k, v) if use_cache else None - - key_padding_mask = attention_mask - # Ideally we could just do this: - # q, indices, cu_q_lens, max_s = unpad_input(q, key_padding_mask[:, -q_len:]) - # but this does not work as Flash attention treats the q seq and kv seq as starting at index 0 - # which then breaks the causality logic. Probably if q_len >> past_kv_len we should - # just skip flash attention. Leaving this in for now to demonstrate correctness of - # flash attention information even when q needs padding. - # TODO(siddartha): delegate back to original implementation on this condition. - if past_kv_len > 0: - q = torch.cat( - ( - torch.full( - (bsz, past_kv_len, self.num_heads, self.head_dim), - 0.0, - dtype=q.dtype, - device=q.device, - ), - q, - ), - dim=1, - ) + k = torch.cat([past_key_value[0].transpose(1, 2), k], dim=1) + v = torch.cat([past_key_value[1].transpose(1, 2), v], dim=1) + + past_key_value = (k.transpose(1, 2), v.transpose(1, 2)) if use_cache else None - if key_padding_mask is None: + if attention_mask is None: output = flash_attn_func(q, k, v, 0.0, softmax_scale=None, causal=True).view( - bsz, q_len + past_kv_len, -1 + bsz, q_len, -1 ) else: - q, indices, cu_q_lens, max_s = unpad_input(q, key_padding_mask) + q, indices, cu_q_lens, max_s = unpad_input(q, attention_mask[:, -q_len:]) # We can skip concat and call unpad twice but seems better to call unpad only once. kv, _, cu_k_lens, max_k = unpad_input( - torch.stack((k, v), dim=2), key_padding_mask + torch.stack((k, v), dim=2), attention_mask ) output_unpad = flash_attn_varlen_kvpacked_func( q, @@ -115,11 +98,7 @@ def forward( causal=True, ) output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim) - output = pad_input(output_unpad, indices, bsz, q_len + past_kv_len) - - # Need to strip off the zero query outputs. - if past_kv_len > 0: - output = output[:, past_kv_len:, ...] + output = pad_input(output_unpad, indices, bsz, q_len) return self.o_proj(output), None, past_key_value @@ -245,7 +224,7 @@ def test(): use_cache=True, ) parts.append(part) - past_kv_len = past_kv[0].shape[1] + past_kv_len = past_kv[0].shape[2] print( f"allclose(oneshot[:, 0], parts[0]) = {torch.allclose(oneshot[:, :part_len], parts[0])}"