diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 315ef1ff1..573b32802 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -651,13 +651,18 @@ def _get_ar_cache_vars(self, batch, heads, kv_head_size, model_mode): dtype = self._get_cached_kv_dtype(self.dtype) cache_length = self.max_target_length - self.max_prefill_predict_length - cache_logical_shape = (batch, cache_length, heads, kv_head_size) if model_mode == common_types.MODEL_MODE_PREFILL: cache_logical_axis_names = self.prefill_cache_logical_axis_names + # TODO: find a better way to not initialize the ar cache during prefill. + # The current Engine insert API implementation requires the prefill result + # cache has the same pytree def as the decode state cache, so we still + # initialize the ar cache in prefill but with length as 1. + cache_length = 1 else: cache_logical_axis_names = self.cache_logical_axis_names + cache_logical_shape = (batch, cache_length, heads, kv_head_size) cache_axis_names = self.transpose_tuple(cache_logical_axis_names, self.ar_cache_axis_order) cache_shape = self.transpose_tuple(cache_logical_shape, self.ar_cache_axis_order)