From 5492ca5849a0dc3a07063eed60f988940d60df98 Mon Sep 17 00:00:00 2001 From: vaibhavad Date: Tue, 30 Apr 2024 22:29:07 +0000 Subject: [PATCH] bug fix --- llm2vec/models/bidirectional_llama.py | 34 ++++++++++++++++----------- llm2vec/version.py | 2 +- 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/llm2vec/models/bidirectional_llama.py b/llm2vec/models/bidirectional_llama.py index 78037fb..2e91c9f 100644 --- a/llm2vec/models/bidirectional_llama.py +++ b/llm2vec/models/bidirectional_llama.py @@ -32,6 +32,14 @@ def is_transformers_attn_greater_or_equal_4_38(): "4.38.0" ) +def is_transformers_attn_greater_or_equal_4_40(): + if not _is_package_available("transformers"): + return False + + return version.parse(importlib.metadata.version("transformers")) >= version.parse( + "4.40.0" + ) + class ModifiedLlamaAttention(LlamaAttention): def __init__(self, *args, **kwargs): @@ -99,12 +107,20 @@ def __init__(self, config: LlamaConfig): # Initialize weights and apply final processing self.post_init() - def _update_causal_mask(self, attention_mask, input_tensor, cache_position): + def _update_causal_mask(self, attention_mask, input_tensor, cache_position, past_seen_tokens=None): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None + if is_transformers_attn_greater_or_equal_4_40() and self.config._attn_implementation == "sdpa": + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, + # in order to dispatch on Flash Attention 2. + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens + ): + return None + dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] @@ -116,7 +132,7 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position): target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) - else cache_position[-1] + 1 + else (cache_position[-1] + 1 if not is_transformers_attn_greater_or_equal_4_40() else past_seen_tokens + sequence_length + 1) ) causal_mask = torch.zeros( @@ -164,19 +180,9 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position): and attention_mask is not None and attention_mask.device.type == "cuda" ): - # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400). - is_tracing = ( - torch.jit.is_tracing() - or isinstance(input_tensor, torch.fx.Proxy) - or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) + causal_mask = AttentionMaskConverter._unmask_unattended( + causal_mask, min_dtype ) - if not is_tracing and torch.any(attention_mask != 1): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - causal_mask = AttentionMaskConverter._unmask_unattended( - causal_mask, min_dtype - ) return causal_mask diff --git a/llm2vec/version.py b/llm2vec/version.py index 3dc1f76..1276d02 100644 --- a/llm2vec/version.py +++ b/llm2vec/version.py @@ -1 +1 @@ -__version__ = "0.1.0" +__version__ = "0.1.5"