diff --git a/MaxText/pyconfig.py b/MaxText/pyconfig.py index 7947b3ef1..c0bbe6440 100644 --- a/MaxText/pyconfig.py +++ b/MaxText/pyconfig.py @@ -94,7 +94,8 @@ def validate_model_call_mode(s: str) -> None: def validate_prefill_and_target_lengths(max_prefill_length: int, max_target_length: int) -> None: if max_prefill_length <= 0: raise ValueError(f"Invalid max_prefill_predict_length {max_prefill_length}, it should be a positive number") - if max_target_length <= max_prefill_length: + if max_target_length < max_prefill_length: + # valid max_target_length = max_prefill_length for existing logit checks raise ValueError( f"Invalid max_target_length {max_target_length}, this should be sum of " f"max_prefill_predict_length ({max_prefill_length}) and max output length expected."