Skip to content

Commit

Permalink
feat(HookedTransformer) accelerate inference with flash attention
Browse files Browse the repository at this point in the history
  • Loading branch information
StarConnor authored Jul 3, 2024
1 parent d844b51 commit 0e3d268
Show file tree
Hide file tree
Showing 11 changed files with 289 additions and 34 deletions.
125 changes: 125 additions & 0 deletions TransformerLens/tests/integration/test_flash_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import einops
import torch

from transformer_lens.components import Attention, GroupedQueryAttention
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig


def test_flash_attention_output_is_correct():
"""
Verify if flash attention output is correct.
"""
d_model = 512
d_head = 32
n_heads = 16
n_ctx = 128
n_key_value_heads = 4
n_layers = 1
dtype = torch.bfloat16
device = torch.device('cuda')

cfg_dict = {
'use_flash_attn': False,
'd_model': d_model,
'd_head': d_head,
'n_heads': n_heads,
'n_ctx': n_ctx,
'n_key_value_heads': n_key_value_heads,
'n_layers': n_layers,
'act_fn': "silu",
'dtype': torch.bfloat16,
}
regular_attention_cfg = HookedTransformerConfig.from_dict(cfg_dict)
cfg_dict['use_flash_attn'] = True
flash_attention_cfg = HookedTransformerConfig.from_dict(cfg_dict)
flash_gqa_attention_cfg = HookedTransformerConfig.from_dict(cfg_dict)

regular_attention = Attention(regular_attention_cfg)

assert not hasattr(regular_attention, 'flash_attn_func'), "AbstractAttention should not have 'flash_attn_func' if set `use_flash_attn=False`"

flash_attention = Attention(flash_attention_cfg)

assert hasattr(flash_attention, 'flash_attn_func'), "AbstractAttention should have 'flash_attn_func' if set `use_flash_attn=True`"

flash_gqa_attention = GroupedQueryAttention(flash_gqa_attention_cfg)

# Variables started with `_` mean that the GQA key/value parameters
W_Q = torch.rand((n_heads, d_model, d_head), dtype=dtype)
b_Q = torch.rand((n_heads, d_head), dtype=dtype)
_W_K = torch.rand((n_key_value_heads, d_model, d_head), dtype=dtype)
W_K = torch.repeat_interleave(_W_K, dim=0, repeats=n_heads // n_key_value_heads)
_b_K = torch.rand((n_key_value_heads, d_head), dtype=dtype)
b_K = torch.repeat_interleave(_b_K, dim=0, repeats=n_heads // n_key_value_heads)
_W_V = torch.rand((n_key_value_heads, d_model, d_head), dtype=dtype)
W_V = torch.repeat_interleave(_W_V, dim=0, repeats=n_heads // n_key_value_heads)
_b_V = torch.rand((n_key_value_heads, d_head), dtype=dtype)
b_V = torch.repeat_interleave(_b_V, dim=0, repeats=n_heads // n_key_value_heads)
W_O = torch.rand((n_heads, d_head, d_model), dtype=dtype)
b_O = torch.rand(d_model, dtype=dtype)

regular_attention_state_dict = {
"W_Q": W_Q,
"b_Q": b_Q,
"W_O": W_O,
"b_O": b_O,
"W_K": W_K,
"b_K": b_K,
"W_V": W_V,
"b_V": b_V,
"mask": regular_attention.state_dict()["mask"],
"IGNORE": regular_attention.state_dict()["IGNORE"],
}
flash_attention_state_dict = {
"W_Q": W_Q,
"b_Q": b_Q,
"W_O": W_O,
"b_O": b_O,
"W_K": W_K,
"b_K": b_K,
"W_V": W_V,
"b_V": b_V,
"mask": flash_attention.state_dict()["mask"],
"IGNORE": flash_attention.state_dict()["IGNORE"],
}
flash_gqa_attention_state_dict = {
"W_Q": W_Q,
"b_Q": b_Q,
"W_O": W_O,
"b_O": b_O,
"_W_K": _W_K,
"_b_K": _b_K,
"_W_V": _W_V,
"_b_V": _b_V,
"mask": flash_attention.state_dict()["mask"],
"IGNORE": flash_attention.state_dict()["IGNORE"],
}

regular_attention.load_state_dict(regular_attention_state_dict)
regular_attention.to(device)
flash_attention.load_state_dict(flash_attention_state_dict)
flash_attention.to(device)
flash_gqa_attention.load_state_dict(flash_gqa_attention_state_dict)
flash_gqa_attention.to(device)

query_input = torch.rand((1, 5, d_model), dtype=dtype).to(device)
key_input = torch.rand((1, 5, d_model), dtype=dtype).to(device)
value_input = torch.rand((1, 5, d_model), dtype=dtype).to(device)

# Test regular attention and attention with FlashAttentionV2
regular_attn_output = regular_attention(query_input, key_input, value_input)
flash_attn_output = flash_attention(query_input, key_input, value_input)

assert torch.allclose(regular_attn_output, flash_attn_output, rtol=1e-2)

# Test FlashAttention behaves correctly when use_split_qkv_input is True
flash_gqa_attention.cfg.use_split_qkv_input = True
split_query_input = einops.repeat(query_input, "b n d -> b n h d", h=n_heads).clone()
split_key_input = einops.repeat(key_input, "b n d -> b n h d", h=n_key_value_heads).clone()
split_value_input = einops.repeat(value_input, "b n d -> b n h d", h=n_key_value_heads).clone()

split_flash_attn_output = flash_gqa_attention(
split_query_input, split_key_input, split_value_input
)

assert torch.allclose(regular_attn_output, split_flash_attn_output, rtol=1e-2)
2 changes: 2 additions & 0 deletions TransformerLens/transformer_lens/HookedTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,6 +1039,7 @@ def from_pretrained(
cls,
model_name: str,
fold_ln: bool = True,
use_flash_attn: bool = False,
center_writing_weights: bool = True,
center_unembed: bool = True,
refactor_factored_attn_matrices: bool = False,
Expand Down Expand Up @@ -1240,6 +1241,7 @@ def from_pretrained(
checkpoint_index=checkpoint_index,
checkpoint_value=checkpoint_value,
fold_ln=fold_ln,
use_flash_attn=use_flash_attn,
device=device,
n_devices=n_devices,
default_prepend_bos=default_prepend_bos,
Expand Down
3 changes: 3 additions & 0 deletions TransformerLens/transformer_lens/HookedTransformerConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ class HookedTransformerConfig:
custom config, if loading from pretrained then this is not needed.
use_local_attn (bool): whether to use local attention - ie each
destination token can only attend to source tokens a certain distance back.
use_flash_attn (bool): whether to use FlashAttention-2. Please refer to
https://github.com/Dao-AILab/flash-attention.
window_size (int, *optional*): the size of the window for local
attention
attn_types (List[str], *optional*): the types of attention to use for
Expand Down Expand Up @@ -177,6 +179,7 @@ class HookedTransformerConfig:
use_hook_mlp_in: bool = False
use_attn_in: bool = False
use_local_attn: bool = False
use_flash_attn: bool = False
original_architecture: Optional[str] = None
from_checkpoint: bool = False
checkpoint_index: Optional[int] = None
Expand Down
177 changes: 143 additions & 34 deletions TransformerLens/transformer_lens/components/abstract_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,27 @@ def __init__(
if self.cfg.scale_attn_by_inverse_layer_idx:
assert self.layer_id is not None # keep mypy happy
self.attn_scale *= self.layer_id + 1

self.hook_k = HookPoint() # [batch, pos, head_index, d_head]
self.hook_q = HookPoint() # [batch, pos, head_index, d_head]
self.hook_v = HookPoint() # [batch, pos, head_index, d_head]
self.hook_z = HookPoint() # [batch, pos, head_index, d_head]
self.hook_attn_scores = HookPoint() # [batch, head_index, query_pos, key_pos]
self.hook_pattern = HookPoint() # [batch, head_index, query_pos, key_pos]

if self.cfg.use_flash_attn:
# If using FlashAttention, import flash-attn and create related class method.
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
self.flash_attn_func = flash_attn_func
self.flash_attn_varlen_func = flash_attn_varlen_func
self.fa_index_first_axis = index_first_axis
self.fa_pad_input = pad_input
self.fa_unpad_input = unpad_input
# Because of FlashAttention's characteristic, intermediate results (attention scores, pattern, z) are not supported to be hooked.
else:
self.hook_z = HookPoint() # [batch, pos, head_index, d_head]
self.hook_attn_scores = HookPoint() # [batch, head_index, query_pos, key_pos]
self.hook_pattern = HookPoint() # [batch, head_index, query_pos, key_pos]


self.hook_result = HookPoint() # [batch, pos, head_index, d_model]

# See HookedTransformerConfig for more details.
Expand Down Expand Up @@ -200,40 +214,72 @@ def forward(
q = q.to(torch.float32)
k = k.to(torch.float32)

attn_scores = self.calculate_attention_scores(
q, k
) # [batch, head_index, query_pos, key_pos]

if self.cfg.positional_embedding_type == "alibi":
query_ctx = attn_scores.size(-2)
# The key context length is the number of positions in the past - this includes all positions in the cache
key_ctx = attn_scores.size(-1)

# only recompute when necessary to increase efficiency.
if self.alibi is None or key_ctx > self.alibi.size(-1):
self.alibi = AbstractAttention.create_alibi_bias(
self.cfg.n_heads, key_ctx, self.cfg.device
# use FlashAttentionV2 to accelerate inference. self.hook_attn_scores, self.hook_pattern, self.hook_z are not supported in this case.
if self.cfg.use_flash_attn:
# FlashAttention could only accept the dtype of bfp16 and fp16
q = q.to(torch.bfloat16)
k = k.to(torch.bfloat16)

# Contains at least one padding token in the sequence
causal = True if self.cfg.attention_dir == "causal" else False
if attention_mask is not None:
batch_size, query_length, _ = q.shape
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
q, k, v, attention_mask, q.shape[1]
)

attn_scores += self.alibi[
:, :query_ctx, :key_ctx
] # [batch, head_index, query_pos, key_pos]
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens

attn_output_unpad = self.flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
causal=causal,
)

if self.cfg.attention_dir == "causal":
# If causal attention, we mask it to only attend backwards. If bidirectional, we don't mask.
attn_scores = self.apply_causal_mask(
attn_scores, kv_cache_pos_offset, attention_mask
z = self.fa_pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
z = self.flash_attn_func(q, k, v, causal=causal)
else:
attn_scores = self.calculate_attention_scores(
q, k
) # [batch, head_index, query_pos, key_pos]
if additive_attention_mask is not None:
attn_scores += additive_attention_mask

attn_scores = self.hook_attn_scores(attn_scores)
pattern = F.softmax(attn_scores, dim=-1)
pattern = torch.where(torch.isnan(pattern), torch.zeros_like(pattern), pattern)
pattern = self.hook_pattern(pattern) # [batch, head_index, query_pos, key_pos]
pattern = pattern.to(self.cfg.dtype)
pattern = pattern.to(v.device)
z = self.calculate_z_scores(v, pattern) # [batch, pos, head_index, d_head]

if self.cfg.positional_embedding_type == "alibi":
query_ctx = attn_scores.size(-2)
# The key context length is the number of positions in the past - this includes all positions in the cache
key_ctx = attn_scores.size(-1)

# only recompute when necessary to increase efficiency.
if self.alibi is None or key_ctx > self.alibi.size(-1):
self.alibi = AbstractAttention.create_alibi_bias(
self.cfg.n_heads, key_ctx, self.cfg.device
)

attn_scores += self.alibi[
:, :query_ctx, :key_ctx
] # [batch, head_index, query_pos, key_pos]

if self.cfg.attention_dir == "causal":
# If causal attention, we mask it to only attend backwards. If bidirectional, we don't mask.
attn_scores = self.apply_causal_mask(
attn_scores, kv_cache_pos_offset, attention_mask
) # [batch, head_index, query_pos, key_pos]
if additive_attention_mask is not None:
attn_scores += additive_attention_mask

attn_scores = self.hook_attn_scores(attn_scores)
pattern = F.softmax(attn_scores, dim=-1)
pattern = torch.where(torch.isnan(pattern), torch.zeros_like(pattern), pattern)
pattern = self.hook_pattern(pattern) # [batch, head_index, query_pos, key_pos]
pattern = pattern.to(self.cfg.dtype)
pattern = pattern.to(v.device)
z = self.calculate_z_scores(v, pattern) # [batch, pos, head_index, d_head]
if not self.cfg.use_attn_result:
if self.cfg.load_in_4bit:
# call bitsandbytes method to dequantize and multiply
Expand Down Expand Up @@ -656,3 +702,66 @@ def create_alibi_bias(
alibi_bias = torch.einsum("ij,k->kij", slope, multipliers)

return alibi_bias

def _upad_input(
self,
query_layer: Float[torch.Tensor, "batch key_pos head_index d_head"],
key_layer: Float[torch.Tensor, "batch key_pos head_index d_head"],
value_layer: Float[torch.Tensor, "batch key_pos head_index d_head"],
attention_mask: Optional[Float[torch.Tensor, "batch 1 1 pos"]],
query_length: int,
):
"""
Refer to the implementation of flash attention of llama3 in package transformers: LlamaFlashAttention2.
The function is used when attention mask is not None and query length is not equal to key length.
"""
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape

key_layer = self.fa_index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
value_layer = self.fa_index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
if query_length == kv_seq_len:
query_layer = self.fa_index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = self.fa_unpad_input(query_layer, attention_mask)

return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)

def _get_unpad_data(attention_mask):
"""
From transformers.models.llama.modeling_llama
"""
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
5 changes: 5 additions & 0 deletions TransformerLens/transformer_lens/loading_from_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -1224,6 +1224,7 @@ def get_pretrained_model_config(
checkpoint_index: Optional[int] = None,
checkpoint_value: Optional[int] = None,
fold_ln: bool = False,
use_flash_attn: bool = False,
device: Optional[Union[str, torch.device]] = None,
n_devices: int = 1,
default_prepend_bos: bool = True,
Expand Down Expand Up @@ -1251,6 +1252,8 @@ def get_pretrained_model_config(
fold_ln (bool, optional): Whether to fold the layer norm into the
subsequent linear layers (see HookedTransformer.fold_layer_norm for
details). Defaults to False.
use_flash_attn (bool): whether to use FlashAttention-2. Please refer to
https://github.com/Dao-AILab/flash-attention. Defaults to False.
device (str, optional): The device to load the model onto. By
default will load to CUDA if available, else CPU.
n_devices (int, optional): The number of devices to split the model across. Defaults to 1.
Expand Down Expand Up @@ -1310,6 +1313,8 @@ def get_pretrained_model_config(
cfg_dict["normalization_type"] = "RMSPre"
else:
logging.warning("Cannot fold in layer norm, normalization_type is not LN.")
if use_flash_attn:
cfg_dict["use_flash_attn"] = True

if checkpoint_index is not None or checkpoint_value is not None:
checkpoint_labels, checkpoint_label_type = get_checkpoint_labels(
Expand Down
1 change: 1 addition & 0 deletions examples/configuration/analyze.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ exp_result_dir = "results"
[lm]
model_name = "gpt2"
d_model = 768
use_flash_attn = false

[dataset]
dataset_path = "openwebtext"
Expand Down
Loading

0 comments on commit 0e3d268

Please sign in to comment.