Skip to content

Commit

Permalink
Merge branch 'main' into quantize_megablox_squashed
Browse files Browse the repository at this point in the history
  • Loading branch information
lenscloth authored Dec 17, 2024
2 parents 41367a3 + 92c27a6 commit e9c3d46
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 8 deletions.
15 changes: 8 additions & 7 deletions MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,18 +572,19 @@ def transpose_tuple(self, items: tuple[Any, Any, Any, Any], axis_order: AxisIdxe
def _get_cached_kv_dtype(self, dtype):
return self.kv_quant.dtype if self.kv_quant else dtype

def _get_cache_scale_logical_shape(self, batch, heads):
def _get_cache_scale_logical_shape(self, batch, heads, cache_length):
assert self.kv_quant
if self.kv_quant.axis_cfg == "dkv":
return (batch, self.max_prefill_predict_length, heads, 1)
return (batch, cache_length, heads, 1)
if self.kv_quant.axis_cfg == "heads_and_dkv":
return (batch, self.max_prefill_predict_length, 1, 1)
return (batch, cache_length, 1, 1)
raise f"Invalid config for kv_quant_axis:{self.kv_quant.axis_cfg}"

def _get_prefill_cache_vars(self, batch, heads, kv_head_size, model_mode):

cache_length = self.max_prefill_predict_length
dtype = self._get_cached_kv_dtype(self.dtype)
cache_logical_shape = (batch, self.max_prefill_predict_length, heads, kv_head_size)
cache_logical_shape = (batch, cache_length, heads, kv_head_size)

if model_mode == common_types.MODEL_MODE_PREFILL:
cache_logical_axis_names = self.prefill_cache_logical_axis_names
Expand Down Expand Up @@ -616,12 +617,12 @@ def _get_prefill_cache_vars(self, batch, heads, kv_head_size, model_mode):
"cache",
"cache_prefill_segment_id",
nn.with_logical_partitioning(jnp.zeros, segment_id_axis_names),
(cache_logical_shape[0], self.max_prefill_predict_length),
(cache_logical_shape[0], cache_length),
jnp.int32,
)

if self.kv_quant:
cache_scale_logical_shape = self._get_cache_scale_logical_shape(batch, heads)
cache_scale_logical_shape = self._get_cache_scale_logical_shape(batch, heads, cache_length)
cache_scale_axis_names = self.transpose_tuple(self.cache_scale_logical_axis_names, self.prefill_cache_axis_order)
cache_scale_shape = self.transpose_tuple(cache_scale_logical_shape, self.prefill_cache_axis_order)

Expand Down Expand Up @@ -707,7 +708,7 @@ def _get_ar_cache_vars(self, batch, heads, kv_head_size, model_mode):
)

if self.kv_quant:
cache_scale_logical_shape = self._get_cache_scale_logical_shape(batch, heads)
cache_scale_logical_shape = self._get_cache_scale_logical_shape(batch, heads, cache_length)
cache_scale_axis_names = self.transpose_tuple(self.cache_scale_logical_axis_names, self.ar_cache_axis_order)
cache_scale_shape = self.transpose_tuple(cache_scale_logical_shape, self.ar_cache_axis_order)

Expand Down
11 changes: 11 additions & 0 deletions MaxText/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,24 @@ def validate_model_call_mode(s: str) -> None:
raise ValueError(f"Invalid model call mode {s}. Valid options are {valid_model_call_modes}")


def validate_prefill_and_target_lengths(max_prefill_length: int, max_target_length: int) -> None:
if max_prefill_length <= 0:
raise ValueError(f"Invalid max_prefill_predict_length {max_prefill_length}, it should be a positive number")
if max_target_length <= max_prefill_length:
raise ValueError(
f"Invalid max_target_length {max_target_length}, this should be sum of "
f"max_prefill_predict_length ({max_prefill_length}) and max output length expected."
)


def validate_keys(keys):
validate_attention_kernel(keys["attention"])
validate_attention_type(keys["attention_type"])
validate_profiler_type(keys["profiler"])
validate_compute_axis_order(keys["compute_axis_order"])
validate_kv_quant_axis(keys["kv_quant_axis"], keys["quantize_kvcache"])
validate_model_call_mode(keys["model_call_mode"])
validate_prefill_and_target_lengths(keys["max_prefill_predict_length"], keys["max_target_length"])

assert (keys["load_parameters_path"] == "" and keys["load_full_state_path"] == "") or keys[
"enable_checkpointing"
Expand Down
2 changes: 1 addition & 1 deletion MaxText/tests/moe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def setUp(self):
model_name="mixtral-8x7b",
dtype="bfloat16",
megablox=False,
max_target_length=4,
max_target_length=80,
per_device_batch_size=1,
capacity_factor=2,
)
Expand Down

0 comments on commit e9c3d46

Please sign in to comment.