Skip to content

Commit

Permalink
Reshape based on the original input shape
Browse files Browse the repository at this point in the history
  • Loading branch information
lenscloth committed Dec 14, 2024
1 parent 2b62fe8 commit 742019f
Showing 1 changed file with 7 additions and 15 deletions.
22 changes: 7 additions & 15 deletions MaxText/layers/linears.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,18 +385,14 @@ def permute(self, inputs, gate_logits):
group_size = jnp.bincount(flatten_selected_experts, length=self.num_experts)
return sorted_inputs, sorted_selected_experts, weights, group_size

def unpermute(self, intermediate, sorted_selected_experts, weights):
def unpermute(self, intermediate, sorted_selected_experts, weights, batch_size: int, sequence_length: int):
"""Unpermute tokens to original order and combine weights."""

unsort_intermediate = jnp.take(intermediate, indices=jnp.argsort(sorted_selected_experts), axis=0)
reshaped_weights = jnp.reshape(weights, (-1, self.num_experts_per_tok))
tensor_parallelism = self.config.ici_tensor_parallelism * self.config.dcn_tensor_parallelism
data_parallelism = self.config.ici_data_parallelism * self.config.dcn_data_parallelism
fsdp_parallelism = self.config.ici_fsdp_parallelism * self.config.dcn_fsdp_parallelism
batch_sharding = data_parallelism * fsdp_parallelism
reshaped_intermediate = jnp.reshape(
unsort_intermediate,
(-1, self.num_experts_per_tok, self.config.emb_dim // tensor_parallelism),
(reshaped_weights.shape[0], self.num_experts_per_tok, -1),
)
with jax.named_scope("weight_sum"):
matmul_precision = lax.Precision(self.config.matmul_precision)
Expand All @@ -406,14 +402,7 @@ def unpermute(self, intermediate, sorted_selected_experts, weights):
reshaped_weights.astype(jnp.float32),
precision=matmul_precision,
)
updated_batch = int(self.config.per_device_batch_size * jax.device_count() // batch_sharding)
# inferencing hack
# prefill has BS =1 sequence length = max_prefill_length
# decode has BS = B, sequence_length= 1
if output.shape[0] % updated_batch != 0:
updated_batch = 1

return output.reshape(updated_batch, -1, self.config.emb_dim // tensor_parallelism).astype(self.dtype)
return output.reshape(batch_size, sequence_length, -1).astype(self.dtype)

def megablox(self, inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel):
tile_size = (512, 1024, 1024)
Expand Down Expand Up @@ -472,6 +461,7 @@ def gmm(inputs, kernel, group_sizes):
check_rep=False,
)
def wrapper(x, logits, w0, w1, wo):
batch_size, sequence_length, _ = x.shape
x, sorted_selected_experts, weights, group_sizes = self.permute(x, logits)
layer_w0 = gmm(x, w0, group_sizes)
layer_w0 = checkpoint_name(layer_w0, "mlpwi_0")
Expand All @@ -484,7 +474,9 @@ def wrapper(x, logits, w0, w1, wo):
tensor_parallelism = self.config.ici_tensor_parallelism * self.config.dcn_tensor_parallelism
if tensor_parallelism > 1:
intermediate_output = jax.lax.psum_scatter(intermediate_output, "tensor", scatter_dimension=1, tiled=True)
output = self.unpermute(intermediate_output, sorted_selected_experts, weights)
output = self.unpermute(
intermediate_output, sorted_selected_experts, weights, batch_size=batch_size, sequence_length=sequence_length
)
return output, None

return wrapper(inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel)
Expand Down

0 comments on commit 742019f

Please sign in to comment.