diff --git a/awq/modules/fused/attn.py b/awq/modules/fused/attn.py index e334dd7f..49b1656a 100644 --- a/awq/modules/fused/attn.py +++ b/awq/modules/fused/attn.py @@ -189,16 +189,19 @@ def forward( self.start_pos = 0 hf_is_generating = False + hf_is_first_forward = "past_key_value" in kwargs and kwargs["past_key_value"] is None + hf_is_new_cache_first_forward = "past_key_value" in kwargs and isinstance(kwargs["past_key_value"], DynamicCache) and kwargs["past_key_value"].get_seq_length() == 0 if self.is_hf_transformers and "use_cache" in kwargs: hf_is_generating = kwargs["use_cache"] + # print(kwargs["past_key_value"].get_seq_length()) # In case we re-generate, we need to refresh the starting position # to 0. We detect it by checking if `past_key_values` is set to None, # which indicates that we are on the first step of `generate()`. # This is only applicable for `transformers` integration - if (self.is_hf_transformers and "past_key_value" in kwargs and kwargs["past_key_value"] is None) or (self.is_hf_transformers and not hf_is_generating): + if (self.is_hf_transformers and (hf_is_first_forward or hf_is_new_cache_first_forward)) or (self.is_hf_transformers and not hf_is_generating): self.start_pos = 0