Skip to content

Commit

Permalink
bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
vaibhavad committed Apr 30, 2024
1 parent 4d5771a commit 5492ca5
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 15 deletions.
34 changes: 20 additions & 14 deletions llm2vec/models/bidirectional_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ def is_transformers_attn_greater_or_equal_4_38():
"4.38.0"
)

def is_transformers_attn_greater_or_equal_4_40():
if not _is_package_available("transformers"):
return False

return version.parse(importlib.metadata.version("transformers")) >= version.parse(
"4.40.0"
)


class ModifiedLlamaAttention(LlamaAttention):
def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -99,12 +107,20 @@ def __init__(self, config: LlamaConfig):
# Initialize weights and apply final processing
self.post_init()

def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
def _update_causal_mask(self, attention_mask, input_tensor, cache_position, past_seen_tokens=None):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None

if is_transformers_attn_greater_or_equal_4_40() and self.config._attn_implementation == "sdpa":
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument,
# in order to dispatch on Flash Attention 2.
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens
):
return None

dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
Expand All @@ -116,7 +132,7 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else cache_position[-1] + 1
else (cache_position[-1] + 1 if not is_transformers_attn_greater_or_equal_4_40() else past_seen_tokens + sequence_length + 1)
)

causal_mask = torch.zeros(
Expand Down Expand Up @@ -164,19 +180,9 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
and attention_mask is not None
and attention_mask.device.type == "cuda"
):
# TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400).
is_tracing = (
torch.jit.is_tracing()
or isinstance(input_tensor, torch.fx.Proxy)
or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling())
causal_mask = AttentionMaskConverter._unmask_unattended(
causal_mask, min_dtype
)
if not is_tracing and torch.any(attention_mask != 1):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(
causal_mask, min_dtype
)

return causal_mask

Expand Down
2 changes: 1 addition & 1 deletion llm2vec/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.0"
__version__ = "0.1.5"

0 comments on commit 5492ca5

Please sign in to comment.