diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index 315ef1ff1..784bd5f28 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -572,18 +572,19 @@ def transpose_tuple(self, items: tuple[Any, Any, Any, Any], axis_order: AxisIdxe def _get_cached_kv_dtype(self, dtype): return self.kv_quant.dtype if self.kv_quant else dtype - def _get_cache_scale_logical_shape(self, batch, heads): + def _get_cache_scale_logical_shape(self, batch, heads, cache_length): assert self.kv_quant if self.kv_quant.axis_cfg == "dkv": - return (batch, self.max_prefill_predict_length, heads, 1) + return (batch, cache_length, heads, 1) if self.kv_quant.axis_cfg == "heads_and_dkv": - return (batch, self.max_prefill_predict_length, 1, 1) + return (batch, cache_length, 1, 1) raise f"Invalid config for kv_quant_axis:{self.kv_quant.axis_cfg}" def _get_prefill_cache_vars(self, batch, heads, kv_head_size, model_mode): + cache_length = self.max_prefill_predict_length dtype = self._get_cached_kv_dtype(self.dtype) - cache_logical_shape = (batch, self.max_prefill_predict_length, heads, kv_head_size) + cache_logical_shape = (batch, cache_length, heads, kv_head_size) if model_mode == common_types.MODEL_MODE_PREFILL: cache_logical_axis_names = self.prefill_cache_logical_axis_names @@ -616,12 +617,12 @@ def _get_prefill_cache_vars(self, batch, heads, kv_head_size, model_mode): "cache", "cache_prefill_segment_id", nn.with_logical_partitioning(jnp.zeros, segment_id_axis_names), - (cache_logical_shape[0], self.max_prefill_predict_length), + (cache_logical_shape[0], cache_length), jnp.int32, ) if self.kv_quant: - cache_scale_logical_shape = self._get_cache_scale_logical_shape(batch, heads) + cache_scale_logical_shape = self._get_cache_scale_logical_shape(batch, heads, cache_length) cache_scale_axis_names = self.transpose_tuple(self.cache_scale_logical_axis_names, self.prefill_cache_axis_order) cache_scale_shape = self.transpose_tuple(cache_scale_logical_shape, self.prefill_cache_axis_order) @@ -707,7 +708,7 @@ def _get_ar_cache_vars(self, batch, heads, kv_head_size, model_mode): ) if self.kv_quant: - cache_scale_logical_shape = self._get_cache_scale_logical_shape(batch, heads) + cache_scale_logical_shape = self._get_cache_scale_logical_shape(batch, heads, cache_length) cache_scale_axis_names = self.transpose_tuple(self.cache_scale_logical_axis_names, self.ar_cache_axis_order) cache_scale_shape = self.transpose_tuple(cache_scale_logical_shape, self.ar_cache_axis_order) diff --git a/MaxText/pyconfig.py b/MaxText/pyconfig.py index 9e48e27e2..ef934429f 100644 --- a/MaxText/pyconfig.py +++ b/MaxText/pyconfig.py @@ -91,6 +91,16 @@ def validate_model_call_mode(s: str) -> None: raise ValueError(f"Invalid model call mode {s}. Valid options are {valid_model_call_modes}") +def validate_prefill_and_target_lengths(max_prefill_length: int, max_target_length: int) -> None: + if max_prefill_length <= 0: + raise ValueError(f"Invalid max_prefill_predict_length {max_prefill_length}, it should be a positive number") + if max_target_length <= max_prefill_length: + raise ValueError( + f"Invalid max_target_length {max_target_length}, this should be sum of " + f"max_prefill_predict_length ({max_prefill_length}) and max output length expected." + ) + + def validate_keys(keys): validate_attention_kernel(keys["attention"]) validate_attention_type(keys["attention_type"]) @@ -98,6 +108,7 @@ def validate_keys(keys): validate_compute_axis_order(keys["compute_axis_order"]) validate_kv_quant_axis(keys["kv_quant_axis"], keys["quantize_kvcache"]) validate_model_call_mode(keys["model_call_mode"]) + validate_prefill_and_target_lengths(keys["max_prefill_predict_length"], keys["max_target_length"]) assert (keys["load_parameters_path"] == "" and keys["load_full_state_path"] == "") or keys[ "enable_checkpointing" diff --git a/MaxText/tests/moe_test.py b/MaxText/tests/moe_test.py index 4e8b522ed..4c132cf69 100644 --- a/MaxText/tests/moe_test.py +++ b/MaxText/tests/moe_test.py @@ -35,7 +35,7 @@ def setUp(self): model_name="mixtral-8x7b", dtype="bfloat16", megablox=False, - max_target_length=4, + max_target_length=80, per_device_batch_size=1, capacity_factor=2, )