Skip to content

Commit

Permalink
Fused attention: Switch to Flash Decoding (#656)
Browse files Browse the repository at this point in the history
  • Loading branch information
casper-hansen authored Nov 26, 2024
1 parent 167c780 commit dfe396a
Show file tree
Hide file tree
Showing 8 changed files with 154 additions and 383 deletions.
208 changes: 82 additions & 126 deletions awq/modules/fused/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@


try:
import awq_ft_ext
from flash_attn import flash_attn_func, flash_attn_with_kvcache

FT_INSTALLED = True
FA_INSTALLED = True
except:
FT_INSTALLED = False
FA_INSTALLED = False

HF_NEW_CACHE_FORMAT = False

Expand All @@ -28,6 +28,7 @@ class RoPE(nn.Module):
def __init__(self, head_dim, max_seq_len, device, rope_theta):
super(RoPE, self).__init__()

self.head_dim = head_dim
self.freqs_cis = nn.Parameter(
self.precompute_freqs_cis(head_dim, max_seq_len, rope_theta).to(device),
requires_grad=False,
Expand All @@ -49,7 +50,23 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)

def forward(self, xq: torch.Tensor, xk: torch.Tensor, start_pos: int, seqlen: int):
def forward(
self,
xq: torch.Tensor,
xk: torch.Tensor,
start_pos: int,
seqlen: int,
partial: bool = False,
):
if partial:
xq, xq_pass = (
xq[..., : self.head_dim],
xq[..., self.head_dim :],
)
xk, xk_pass = (
xk[..., : self.head_dim],
xk[..., self.head_dim :],
)
xq_ = torch.view_as_complex(
xq.float().reshape(*xq.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
)
Expand All @@ -62,6 +79,10 @@ def forward(self, xq: torch.Tensor, xk: torch.Tensor, start_pos: int, seqlen: in
xq_out = torch.view_as_real(xq_ * freqs_cis).transpose(-2, -1).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3)

if partial:
xq = torch.cat((xq, xq_pass), dim=-1)
xk = torch.cat((xk, xk_pass), dim=-1)

return xq_out.type_as(xq), xk_out.type_as(xk)


Expand Down Expand Up @@ -118,7 +139,7 @@ def __init__(
rope_theta=10000,
partial_rotary_factor=1.0,
head_dim=None,
attn_logit_softcapping=None,
attn_logit_softcapping=0.0,
**kwargs
):
super().__init__()
Expand Down Expand Up @@ -147,18 +168,18 @@ def __init__(
# attention shapes for self attention
self.attention_shapes = get_attention_shapes(
attention_shapes,
max_seq_len,
self.cache_batch_size,
n_heads,
n_kv_heads,
self.head_dim,
)
# cache store that rolls cache
self.cache = WindowedCache(
self.attention_shapes["cache_v"],
self.attention_shapes["cache_k"],
self.max_seq_len,
dev,
cache_batch_size=self.cache_batch_size,
n_heads=n_heads,
n_kv_heads=n_kv_heads,
head_dim=self.head_dim,
max_seq_len=self.max_seq_len,
device=dev,
)

if use_alibi:
Expand All @@ -174,13 +195,10 @@ def __init__(

if kwargs.get("is_neox") is not None:
self.is_neox = kwargs["is_neox"]

self.attn_logit_softcapping = attn_logit_softcapping
self.use_sdpa = kwargs.get("use_sdpa", False)

def forward(
self, hidden_states: torch.Tensor, attention_mask=None, *args, **kwargs
):
def forward(self, hidden_states: torch.Tensor, *args, **kwargs):
bsz, seqlen, _ = hidden_states.shape

# Reallocate cache if batch size changes
Expand All @@ -196,21 +214,27 @@ def forward(
self.start_pos = 0

hf_is_generating = False
hf_is_first_forward = "past_key_value" in kwargs and kwargs["past_key_value"] is None
hf_is_new_cache_first_forward = "past_key_value" in kwargs and isinstance(kwargs["past_key_value"], DynamicCache) and kwargs["past_key_value"].get_seq_length() == 0
hf_is_first_forward = (
"past_key_value" in kwargs and kwargs["past_key_value"] is None
)
hf_is_new_cache_first_forward = (
"past_key_value" in kwargs
and isinstance(kwargs["past_key_value"], DynamicCache)
and kwargs["past_key_value"].get_seq_length() == 0
)

if self.is_hf_transformers and "use_cache" in kwargs:
hf_is_generating = kwargs["use_cache"]

# print(kwargs["past_key_value"].get_seq_length())

# In case we re-generate, we need to refresh the starting position
# to 0. We detect it by checking if `past_key_values` is set to None,
# which indicates that we are on the first step of `generate()`.
# This is only applicable for `transformers` integration
if (self.is_hf_transformers and (hf_is_first_forward or hf_is_new_cache_first_forward)) or (self.is_hf_transformers and not hf_is_generating):
if (
self.is_hf_transformers
and (hf_is_first_forward or hf_is_new_cache_first_forward)
) or (self.is_hf_transformers and not hf_is_generating):
self.start_pos = 0


xqkv = self.qkv_proj(hidden_states)
xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"])
Expand All @@ -219,114 +243,47 @@ def forward(
xk = self.attention_shapes["xk_slice"](xqkv)
xv = self.attention_shapes["xv_slice"](xqkv)

if seqlen > 1 or self.partial_rotary_factor < 1 or not FT_INSTALLED:
xq = xq.view((bsz, seqlen) + self.attention_shapes["xq_view"])
xk = xk.view((bsz, seqlen) + self.attention_shapes["xk_view"])
xv = xv.view((bsz, seqlen) + self.attention_shapes["xv_view"])

if not self.use_alibi:
# Partial rotary embedding
if self.partial_rotary_factor < 1:
xq_rot, xq_pass = (
xq[..., : self.rotary_dim],
xq[..., self.rotary_dim :],
)
xk_rot, xk_pass = (
xk[..., : self.rotary_dim],
xk[..., self.rotary_dim :],
)
xq_rot, xk_rot = self.rope.forward(xq_rot, xk_rot, self.start_pos, seqlen)
xq = torch.cat((xq_rot, xq_pass), dim=-1)
xk = torch.cat((xk_rot, xk_pass), dim=-1)
else:
xq, xk = self.rope.forward(xq, xk, self.start_pos, seqlen)

values_store = xv.transpose(2, 1)
keys_store = (
xk.reshape((bsz, seqlen) + self.attention_shapes["xk_reshape"])
.permute(0, 2, 3, 1, 4)
.contiguous()
if not self.use_alibi:
xq, xk = self.rope.forward(
xq, xk, self.start_pos, seqlen, partial=self.partial_rotary_factor < 1
)

self.cache.to(xq)
self.cache.update_kv(values_store, keys_store, bsz, self.start_pos, seqlen)

# Only necessary to retrieve from cache when we are not processing context
if seqlen == 1:
xv, xk = self.cache.get_kv(bsz, self.start_pos, seqlen, self.head_dim)

keys = xk
values = xv

if self.n_kv_groups != 0:
keys = torch.repeat_interleave(keys, dim=2, repeats=self.n_kv_groups)
values = torch.repeat_interleave(
values, dim=2, repeats=self.n_kv_groups
)

xq = xq.transpose(1, 2)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)

# Used in Gemma2
if self.attn_logit_softcapping is not None:
scores = scores / self.attn_logit_softcapping
scores = torch.tanh(scores)
scores = scores * self.attn_logit_softcapping

if self.use_sdpa:
causal_mask = attention_mask
if attention_mask is not None:
causal_mask = causal_mask[:, :, :, : keys.shape[-2]]
is_causal = True if causal_mask is None and seqlen > 1 else False
output = torch.nn.functional.scaled_dot_product_attention(
xq,
keys,
values,
attn_mask=causal_mask,
dropout_p=0.0,
is_causal=is_causal,
)
else:
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
if self.use_alibi:
scores = self.alibi.forward(scores, seqlen)

# When seqlen is 1, there is nothing else to attend to
if attention_mask is not None and seqlen > 1:
# For llama-arch, the causal mask is preallocated with bsz x 1 x max_seq_len x max_seq_len, thus we
# need to slice it
if attention_mask.shape[-1] != seqlen:
attention_mask = attention_mask[:, :, :seqlen, :seqlen]

scores = (
scores + attention_mask
) # (bs, n_local_heads, slen, cache_len + slen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim)

attention_weight = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
self.cache.to(xq)
self.cache.update_kv(
values_store=xv,
keys_store=xk,
batch_size=bsz,
start_pos=self.start_pos,
seqlen=seqlen,
)

if seqlen > 1:
output = flash_attn_func(
q=xq,
k=xk,
v=xv,
causal=True,
alibi_slopes=self.alibi.slopes if self.alibi is not None else None,
softcap=self.attn_logit_softcapping,
)
else:
xq = xq.view((bsz,) + self.attention_shapes["single_xq_view"])
xk = xk.view((bsz,) + self.attention_shapes["single_xk_view"])
xv = xv.view((bsz,) + self.attention_shapes["single_xv_view"])

alibi_slopes = self.alibi.slopes if self.alibi is not None else None
attention_weight = awq_ft_ext.single_query_attention(
xq, # query
xk, # key
xv, # value
self.cache.k, # key cache
self.cache.v, # value cache
None, # length per sample
alibi_slopes, # alibi slopes
self.start_pos, # timestep
self.rotary_dim, # rotary embedding dimension
self.rope_theta, # rotary embedding base
self.is_neox, # is neox
cache_seqlens = torch.full(
(bsz,), self.start_pos + seqlen, dtype=torch.int32, device=xq.device
)

output = flash_attn_with_kvcache(
q=xq,
k=xk,
k_cache=self.cache.k,
v=xv,
v_cache=self.cache.v,
cache_seqlens=cache_seqlens,
causal=True,
alibi_slopes=self.alibi.slopes if self.alibi is not None else None,
softcap=self.attn_logit_softcapping,
)
attention_weight = attention_weight.reshape(bsz, 1, -1)

attention_weight = output.view(bsz, seqlen, -1)
attn_output = self.o_proj(attention_weight)
self.start_pos += seqlen

Expand All @@ -338,7 +295,6 @@ def forward(
# about past key length
past_key_value = [torch.zeros(1, 1, self.start_pos, 1)]


if HF_NEW_CACHE_FORMAT and self.is_hf_transformers:
new_cache = DynamicCache()
new_cache.update(past_key_value[0], past_key_value[0], layer_idx=0)
Expand Down
Loading

0 comments on commit dfe396a

Please sign in to comment.