From 33af761ac8d94a2c5eb7c4f5f766df83b1d3ef16 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Wed, 1 May 2024 12:35:52 +0200 Subject: [PATCH] FIX: Fix multiple generations for new HF cache format (#444) --- awq/modules/fused/attn.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/awq/modules/fused/attn.py b/awq/modules/fused/attn.py index e334dd7f..49b1656a 100644 --- a/awq/modules/fused/attn.py +++ b/awq/modules/fused/attn.py @@ -189,16 +189,19 @@ def forward( self.start_pos = 0 hf_is_generating = False + hf_is_first_forward = "past_key_value" in kwargs and kwargs["past_key_value"] is None + hf_is_new_cache_first_forward = "past_key_value" in kwargs and isinstance(kwargs["past_key_value"], DynamicCache) and kwargs["past_key_value"].get_seq_length() == 0 if self.is_hf_transformers and "use_cache" in kwargs: hf_is_generating = kwargs["use_cache"] + # print(kwargs["past_key_value"].get_seq_length()) # In case we re-generate, we need to refresh the starting position # to 0. We detect it by checking if `past_key_values` is set to None, # which indicates that we are on the first step of `generate()`. # This is only applicable for `transformers` integration - if (self.is_hf_transformers and "past_key_value" in kwargs and kwargs["past_key_value"] is None) or (self.is_hf_transformers and not hf_is_generating): + if (self.is_hf_transformers and (hf_is_first_forward or hf_is_new_cache_first_forward)) or (self.is_hf_transformers and not hf_is_generating): self.start_pos = 0