Skip to content

Commit

Permalink
changed to install flash-attn by users
Browse files Browse the repository at this point in the history
  • Loading branch information
StarConnor committed Jul 1, 2024
1 parent a12cc22 commit 5447661
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 15 deletions.
28 changes: 18 additions & 10 deletions TransformerLens/transformer_lens/components/abstract_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from better_abc import abstract_attribute
from fancy_einsum import einsum
from jaxtyping import Float, Int
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
from transformers.utils import is_bitsandbytes_available

from transformer_lens.FactoredMatrix import FactoredMatrix
Expand Down Expand Up @@ -115,10 +113,20 @@ def __init__(
self.hook_v = HookPoint() # [batch, pos, head_index, d_head]

# Because of FlashAttention's characteristic, intermediate results (attention scores, pattern, z) are not supported to be hooked.
if not self.cfg.use_flash_attn:
if self.cfg.use_flash_attn:
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
self.flash_attn_func = flash_attn_func
self.flash_attn_varlen_func = flash_attn_varlen_func
self.fa_index_first_axis = index_first_axis
self.fa_pad_input = pad_input
self.fa_unpad_input = unpad_input
else:
self.hook_z = HookPoint() # [batch, pos, head_index, d_head]
self.hook_attn_scores = HookPoint() # [batch, head_index, query_pos, key_pos]
self.hook_pattern = HookPoint() # [batch, head_index, query_pos, key_pos]


self.hook_result = HookPoint() # [batch, pos, head_index, d_model]

# See HookedTransformerConfig for more details.
Expand Down Expand Up @@ -228,7 +236,7 @@ def forward(
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens

attn_output_unpad = flash_attn_varlen_func(
attn_output_unpad = self.flash_attn_varlen_func(
query_states,
key_states,
value_states,
Expand All @@ -239,9 +247,9 @@ def forward(
causal=causal,
)

z = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
z = self.fa_pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
z = flash_attn_func(q, k, v, causal=causal)
z = self.flash_attn_func(q, k, v, causal=causal)
else:
attn_scores = self.calculate_attention_scores(
q, k
Expand Down Expand Up @@ -704,14 +712,14 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape

key_layer = index_first_axis(
key_layer = self.fa_index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
value_layer = index_first_axis(
value_layer = self.fa_index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer = self.fa_index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
Expand All @@ -727,7 +735,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = self.fa_unpad_input(query_layer, attention_mask)

return (
query_layer,
Expand Down
3 changes: 0 additions & 3 deletions install_flash_attn.sh

This file was deleted.

2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,5 +65,3 @@ implicit_optional=true
requires = ["pdm-pep517"]
build-backend = "pdm.pep517.api"

[tool.pdm.scripts]
post_install = ["./install_flash_attn.sh"]
43 changes: 43 additions & 0 deletions tests/test_HookedTransformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import pytest

from transformer_lens import HookedTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2Model
import torch

MODEL_NAMES = {
'gpt2':'gpt2',
'llama3-base':'meta-llama/Meta-Llama-3-8B',
'llama3-instruct':'meta-llama/Meta-Llama-3-8B-Instruct',
}
MODEL_PATHS = {
'gpt2':'/remote-home/fkzhu/models/gpt2',
'llama3':'/remote-home/share/models/llama3_hf/Meta-Llama-3-8B',
'llama3-instruct':'/remote-home/share/models/llama3_hf/Meta-Llama-3-8B-Instruct',
}


def test_hooked_transformer():
model_name = 'gpt2'
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
dtype = torch.bfloat16
hf_model = AutoModelForCausalLM.from_pretrained(
MODEL_PATHS[model_name],
trust_remote_code=True,
local_files_only=True,
torch_dtype=dtype,
)

hf_tokenizer:AutoTokenizer = AutoTokenizer.from_pretrained(
MODEL_PATHS[model_name],
trust_remote_code=True,
use_fast=True,
add_bos_token=True,
)
model = HookedTransformer.from_pretrained(
MODEL_NAMES[model_name],
use_flash_attn=False,
device=device,
hf_model=hf_model,
tokenizer=hf_tokenizer,
dtype=dtype,
)

0 comments on commit 5447661

Please sign in to comment.