diff --git a/TransformerLens/transformer_lens/components/abstract_attention.py b/TransformerLens/transformer_lens/components/abstract_attention.py index 0bb5ac81..3c9d1503 100644 --- a/TransformerLens/transformer_lens/components/abstract_attention.py +++ b/TransformerLens/transformer_lens/components/abstract_attention.py @@ -103,27 +103,13 @@ def __init__( if self.layer_id is None: # keep mypy happy raise ValueError("Layer ID must be provided to scale attention scores") self.attn_scale *= self.layer_id + 1 - + self.hook_k = HookPoint() # [batch, pos, head_index, d_head] self.hook_q = HookPoint() # [batch, pos, head_index, d_head] self.hook_v = HookPoint() # [batch, pos, head_index, d_head] - - if self.cfg.use_flash_attn: - # If using FlashAttention, import flash-attn and create related class method. - from flash_attn import flash_attn_func, flash_attn_varlen_func - from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input - self.flash_attn_func = flash_attn_func - self.flash_attn_varlen_func = flash_attn_varlen_func - self.fa_index_first_axis = index_first_axis - self.fa_pad_input = pad_input - self.fa_unpad_input = unpad_input - # Because of FlashAttention's characteristic, intermediate results (attention scores, pattern, z) are not supported to be hooked. - else: - self.hook_z = HookPoint() # [batch, pos, head_index, d_head] - self.hook_attn_scores = HookPoint() # [batch, head_index, query_pos, key_pos] - self.hook_pattern = HookPoint() # [batch, head_index, query_pos, key_pos] - - + self.hook_z = HookPoint() # [batch, pos, head_index, d_head] + self.hook_attn_scores = HookPoint() # [batch, head_index, query_pos, key_pos] + self.hook_pattern = HookPoint() # [batch, head_index, query_pos, key_pos] self.hook_result = HookPoint() # [batch, pos, head_index, d_model] # See HookedTransformerConfig for more details. @@ -222,23 +208,24 @@ def forward( self.apply_rotary(k, 0, attention_mask) ) # keys are cached so no offset - if self.cfg.dtype not in [torch.float32, torch.float64, torch.bfloat16]: + if self.cfg.dtype not in [torch.float32, torch.float64]: # If using 16 bits, increase the precision to avoid numerical instabilities q = q.to(torch.float32) k = k.to(torch.float32) - # use FlashAttentionV2 to accelerate inference. self.hook_attn_scores, self.hook_pattern, self.hook_z are not supported in this case. - if self.cfg.use_flash_attn: - # FlashAttention could only accept the dtype of bfp16 and fp16 - q = q.to(torch.bfloat16) - k = k.to(torch.bfloat16) - - # Contains at least one padding token in the sequence - causal = True if self.cfg.attention_dir == "causal" else False - if attention_mask is not None: - batch_size, query_length, _ = q.shape - query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( - q, k, v, attention_mask, q.shape[1] + attn_scores = self.calculate_attention_scores( + q, k + ) # [batch, head_index, query_pos, key_pos] + + if self.cfg.positional_embedding_type == "alibi": + query_ctx = attn_scores.size(-2) + # The key context length is the number of positions in the past - this includes all positions in the cache + key_ctx = attn_scores.size(-1) + + # only recompute when necessary to increase efficiency. + if self.alibi is None or key_ctx > self.alibi.size(-1): + self.alibi = AbstractAttention.create_alibi_bias( + self.cfg.n_heads, key_ctx, self.cfg.device ) attn_scores += self.alibi[ @@ -263,37 +250,16 @@ def forward( attn_scores = self.apply_causal_mask( attn_scores, kv_cache_pos_offset, attention_mask ) # [batch, head_index, query_pos, key_pos] - - if self.cfg.positional_embedding_type == "alibi": - query_ctx = attn_scores.size(-2) - # The key context length is the number of positions in the past - this includes all positions in the cache - key_ctx = attn_scores.size(-1) - - # only recompute when necessary to increase efficiency. - if self.alibi is None or key_ctx > self.alibi.size(-1): - self.alibi = AbstractAttention.create_alibi_bias( - self.cfg.n_heads, key_ctx, self.cfg.device - ) - - attn_scores += self.alibi[ - :, :query_ctx, :key_ctx - ] # [batch, head_index, query_pos, key_pos] - - if self.cfg.attention_dir == "causal": - # If causal attention, we mask it to only attend backwards. If bidirectional, we don't mask. - attn_scores = self.apply_causal_mask( - attn_scores, kv_cache_pos_offset, attention_mask - ) # [batch, head_index, query_pos, key_pos] - if additive_attention_mask is not None: - attn_scores += additive_attention_mask - - attn_scores = self.hook_attn_scores(attn_scores) - pattern = F.softmax(attn_scores, dim=-1) - pattern = torch.where(torch.isnan(pattern), torch.zeros_like(pattern), pattern) - pattern = self.hook_pattern(pattern) # [batch, head_index, query_pos, key_pos] - pattern = pattern.to(self.cfg.dtype) - pattern = pattern.to(v.device) - z = self.calculate_z_scores(v, pattern) # [batch, pos, head_index, d_head] + if additive_attention_mask is not None: + attn_scores += additive_attention_mask + + attn_scores = self.hook_attn_scores(attn_scores) + pattern = F.softmax(attn_scores, dim=-1) + pattern = torch.where(torch.isnan(pattern), torch.zeros_like(pattern), pattern) + pattern = self.hook_pattern(pattern) # [batch, head_index, query_pos, key_pos] + pattern = pattern.to(self.cfg.dtype) + pattern = pattern.to(v.device) + z = self.calculate_z_scores(v, pattern) # [batch, pos, head_index, d_head] if not self.cfg.use_attn_result: if self.cfg.load_in_4bit: # call bitsandbytes method to dequantize and multiply @@ -700,67 +666,4 @@ def create_alibi_bias( # The ALiBi bias is then m * slope_matrix alibi_bias = torch.einsum("ij,k->kij", slope, multipliers) - return alibi_bias - - def _upad_input( - self, - query_layer: Float[torch.Tensor, "batch key_pos head_index d_head"], - key_layer: Float[torch.Tensor, "batch key_pos head_index d_head"], - value_layer: Float[torch.Tensor, "batch key_pos head_index d_head"], - attention_mask: Optional[Float[torch.Tensor, "batch 1 1 pos"]], - query_length: int, - ): - """ - Refer to the implementation of flash attention of llama3 in package transformers: LlamaFlashAttention2. - The function is used when attention mask is not None and query length is not equal to key length. - """ - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) - batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape - - key_layer = self.fa_index_first_axis( - key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k - ) - value_layer = self.fa_index_first_axis( - value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k - ) - if query_length == kv_seq_len: - query_layer = self.fa_index_first_axis( - query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k - ) - cu_seqlens_q = cu_seqlens_k - max_seqlen_in_batch_q = max_seqlen_in_batch_k - indices_q = indices_k - elif query_length == 1: - max_seqlen_in_batch_q = 1 - cu_seqlens_q = torch.arange( - batch_size + 1, dtype=torch.int32, device=query_layer.device - ) # There is a memcpy here, that is very bad. - indices_q = cu_seqlens_q[:-1] - query_layer = query_layer.squeeze(1) - else: - # The -q_len: slice assumes left padding. - attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = self.fa_unpad_input(query_layer, attention_mask) - - return ( - query_layer, - key_layer, - value_layer, - indices_q, - (cu_seqlens_q, cu_seqlens_k), - (max_seqlen_in_batch_q, max_seqlen_in_batch_k), - ) - -def _get_unpad_data(attention_mask): - """ - From transformers.models.llama.modeling_llama - """ - seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) - return ( - indices, - cu_seqlens, - max_seqlen_in_batch, - ) \ No newline at end of file + return alibi_bias \ No newline at end of file diff --git a/TransformerLens/transformer_lens/components/transformer_block.py b/TransformerLens/transformer_lens/components/transformer_block.py index 6db16a19..83e7b250 100644 --- a/TransformerLens/transformer_lens/components/transformer_block.py +++ b/TransformerLens/transformer_lens/components/transformer_block.py @@ -82,7 +82,7 @@ def __init__(self, cfg: Union[Dict, HookedTransformerConfig], block_index): attn_type = self.cfg.attn_types[block_index] self.attn = attention(self.cfg, attn_type, block_index) if not self.cfg.attn_only: - self.mlp = MLPFactory.create_mlp(self.cfg) + self.mlp = MLPFactory.create_mlp(self.cfg).to(self.cfg.device).to(self.cfg.dtype) self.hook_attn_in = HookPoint() # [batch, pos, n_heads, d_model] self.hook_q_input = HookPoint() # [batch, pos, n_heads, d_model]