Skip to content

Commit

Permalink
fix and clean the code after first review
Browse files Browse the repository at this point in the history
  • Loading branch information
StarConnor committed Jul 2, 2024
1 parent 5447661 commit c8c86bd
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 201 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import pytest

from transformer_lens import HookedTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2Model
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

MODEL_NAMES = {
Expand All @@ -10,9 +8,9 @@
'llama3-instruct':'meta-llama/Meta-Llama-3-8B-Instruct',
}
MODEL_PATHS = {
'gpt2':'/remote-home/fkzhu/models/gpt2',
'llama3':'/remote-home/share/models/llama3_hf/Meta-Llama-3-8B',
'llama3-instruct':'/remote-home/share/models/llama3_hf/Meta-Llama-3-8B-Instruct',
'gpt2':'path/to/gpt2',
'llama3':'path/to/llama3-base',
'llama3-instruct':'path/to/llama3-instruct',
}


Expand Down Expand Up @@ -41,3 +39,5 @@ def test_hooked_transformer():
tokenizer=hf_tokenizer,
dtype=dtype,
)

assert not hasattr(model.blocks[0].attn, 'flash_attn_func'), "AbstractAttention should not have 'flash_attn_func' if set `use_flash_attn=False`"
54 changes: 37 additions & 17 deletions TransformerLens/transformer_lens/components/abstract_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,6 @@
import bitsandbytes as bnb
from bitsandbytes.nn.modules import Params4bit

# From transformers/models/llama/modeling_llama.py
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)

class AbstractAttention(ABC, nn.Module):
alibi: Union[torch.Tensor, None]
Expand Down Expand Up @@ -112,15 +101,16 @@ def __init__(
self.hook_q = HookPoint() # [batch, pos, head_index, d_head]
self.hook_v = HookPoint() # [batch, pos, head_index, d_head]

# Because of FlashAttention's characteristic, intermediate results (attention scores, pattern, z) are not supported to be hooked.
if self.cfg.use_flash_attn:
# If using FlashAttention, import flash-attn and create related class method.
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
self.flash_attn_func = flash_attn_func
self.flash_attn_varlen_func = flash_attn_varlen_func
self.fa_index_first_axis = index_first_axis
self.fa_pad_input = pad_input
self.fa_unpad_input = unpad_input
# Because of FlashAttention's characteristic, intermediate results (attention scores, pattern, z) are not supported to be hooked.
else:
self.hook_z = HookPoint() # [batch, pos, head_index, d_head]
self.hook_attn_scores = HookPoint() # [batch, head_index, query_pos, key_pos]
Expand Down Expand Up @@ -219,12 +209,17 @@ def forward(
self.apply_rotary(k, 0, attention_mask)
) # keys are cached so no offset

if self.cfg.dtype not in [torch.float32, torch.float64] and self.cfg.dtype != torch.bfloat16:
if self.cfg.dtype not in [torch.float32, torch.float64]:
# If using 16 bits, increase the precision to avoid numerical instabilities
q = q.to(torch.float32)
k = k.to(torch.float32)

# use FlashAttentionV2 to accelerate inference. self.hook_attn_scores, self.hook_pattern, self.hook_z are not supported in this case.
if self.cfg.use_flash_attn:
# use FlashAttentionV2 to accelerate inference. self.hook_attn_scores, self.hook_pattern, self.hook_z are not supported in this case.
# FlashAttention could only accept the dtype of bfp16 and fp16
q = q.to(torch.bfloat16)
k = k.to(torch.bfloat16)

# Contains at least one padding token in the sequence
causal = True if self.cfg.attention_dir == "causal" else False
if attention_mask is not None:
Expand Down Expand Up @@ -708,7 +703,18 @@ def create_alibi_bias(

return alibi_bias

def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
def _upad_input(
self,
query_layer: Float[torch.Tensor, "batch key_pos head_index d_head"],
key_layer: Float[torch.Tensor, "batch key_pos head_index d_head"],
value_layer: Float[torch.Tensor, "batch key_pos head_index d_head"],
attention_mask: Optional[Float[torch.Tensor, "batch 1 1 pos"]],
query_length: int,
):
"""
Refer to the implementation of flash attention of llama3 in package transformers: LlamaFlashAttention2.
The function is used when attention mask is not None and query length is not equal to key length.
"""
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape

Expand Down Expand Up @@ -744,4 +750,18 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
)

def _get_unpad_data(attention_mask):
"""
From transformers.models.llama.modeling_llama
"""
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
7 changes: 1 addition & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,4 @@ check_untyped_defs=true
exclude=[".venv/", "examples", "TransformerLens", "tests", "exp"]
ignore_missing_imports=true
allow_redefinition=true
implicit_optional=true

[build-system]
requires = ["pdm-pep517"]
build-backend = "pdm.pep517.api"

implicit_optional=true
31 changes: 0 additions & 31 deletions tests/conftest.py

This file was deleted.

Loading

0 comments on commit c8c86bd

Please sign in to comment.