diff --git a/MaxText/layers/linears.py b/MaxText/layers/linears.py index 82adcd0f3..2fa6f42bf 100644 --- a/MaxText/layers/linears.py +++ b/MaxText/layers/linears.py @@ -385,18 +385,14 @@ def permute(self, inputs, gate_logits): group_size = jnp.bincount(flatten_selected_experts, length=self.num_experts) return sorted_inputs, sorted_selected_experts, weights, group_size - def unpermute(self, intermediate, sorted_selected_experts, weights): + def unpermute(self, intermediate, sorted_selected_experts, weights, batch_size: int, sequence_length: int): """Unpermute tokens to original order and combine weights.""" unsort_intermediate = jnp.take(intermediate, indices=jnp.argsort(sorted_selected_experts), axis=0) reshaped_weights = jnp.reshape(weights, (-1, self.num_experts_per_tok)) - tensor_parallelism = self.config.ici_tensor_parallelism * self.config.dcn_tensor_parallelism - data_parallelism = self.config.ici_data_parallelism * self.config.dcn_data_parallelism - fsdp_parallelism = self.config.ici_fsdp_parallelism * self.config.dcn_fsdp_parallelism - batch_sharding = data_parallelism * fsdp_parallelism reshaped_intermediate = jnp.reshape( unsort_intermediate, - (-1, self.num_experts_per_tok, self.config.emb_dim // tensor_parallelism), + (reshaped_weights.shape[0], self.num_experts_per_tok, -1), ) with jax.named_scope("weight_sum"): matmul_precision = lax.Precision(self.config.matmul_precision) @@ -406,14 +402,7 @@ def unpermute(self, intermediate, sorted_selected_experts, weights): reshaped_weights.astype(jnp.float32), precision=matmul_precision, ) - updated_batch = int(self.config.per_device_batch_size * jax.device_count() // batch_sharding) - # inferencing hack - # prefill has BS =1 sequence length = max_prefill_length - # decode has BS = B, sequence_length= 1 - if output.shape[0] % updated_batch != 0: - updated_batch = 1 - - return output.reshape(updated_batch, -1, self.config.emb_dim // tensor_parallelism).astype(self.dtype) + return output.reshape(batch_size, sequence_length, -1).astype(self.dtype) def megablox(self, inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel): tile_size = (512, 1024, 1024) @@ -472,6 +461,7 @@ def gmm(inputs, kernel, group_sizes): check_rep=False, ) def wrapper(x, logits, w0, w1, wo): + batch_size, sequence_length, _ = x.shape x, sorted_selected_experts, weights, group_sizes = self.permute(x, logits) layer_w0 = gmm(x, w0, group_sizes) layer_w0 = checkpoint_name(layer_w0, "mlpwi_0") @@ -484,7 +474,9 @@ def wrapper(x, logits, w0, w1, wo): tensor_parallelism = self.config.ici_tensor_parallelism * self.config.dcn_tensor_parallelism if tensor_parallelism > 1: intermediate_output = jax.lax.psum_scatter(intermediate_output, "tensor", scatter_dimension=1, tiled=True) - output = self.unpermute(intermediate_output, sorted_selected_experts, weights) + output = self.unpermute( + intermediate_output, sorted_selected_experts, weights, batch_size=batch_size, sequence_length=sequence_length + ) return output, None return wrapper(inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel)