Skip to content

Commit

Permalink
update sharding
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhiyuLi-goog authored and lenscloth committed Dec 14, 2024
1 parent 945ee3d commit 2b62fe8
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion MaxText/layers/linears.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,9 @@ def unpermute(self, intermediate, sorted_selected_experts, weights):
unsort_intermediate = jnp.take(intermediate, indices=jnp.argsort(sorted_selected_experts), axis=0)
reshaped_weights = jnp.reshape(weights, (-1, self.num_experts_per_tok))
tensor_parallelism = self.config.ici_tensor_parallelism * self.config.dcn_tensor_parallelism
data_parallelism = self.config.ici_data_parallelism * self.config.dcn_data_parallelism
fsdp_parallelism = self.config.ici_fsdp_parallelism * self.config.dcn_fsdp_parallelism
batch_sharding = data_parallelism * fsdp_parallelism
reshaped_intermediate = jnp.reshape(
unsort_intermediate,
(-1, self.num_experts_per_tok, self.config.emb_dim // tensor_parallelism),
Expand All @@ -403,7 +406,7 @@ def unpermute(self, intermediate, sorted_selected_experts, weights):
reshaped_weights.astype(jnp.float32),
precision=matmul_precision,
)
updated_batch = int(self.config.per_device_batch_size * jax.device_count() // self.config.ici_fsdp_parallelism)
updated_batch = int(self.config.per_device_batch_size * jax.device_count() // batch_sharding)
# inferencing hack
# prefill has BS =1 sequence length = max_prefill_length
# decode has BS = B, sequence_length= 1
Expand Down

0 comments on commit 2b62fe8

Please sign in to comment.