Skip to content

Commit

Permalink
[MoE] fix typo and add normalization for top_k_weights
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhiyuLi-goog committed Nov 12, 2024
1 parent 5ae00e8 commit 061abd8
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion MaxText/layers/linears.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ def wrapper(x, logits, w0, w1, wo):
layer_w0 = gmm(x, w0, group_sizes)
layer_w0 = checkpoint_name(layer_w0, "mlpwi_0")
layer_w1 = gmm(x, w1, group_sizes)
layer_w1 = checkpoint_name(layer_w0, "mlpwi_1")
layer_w1 = checkpoint_name(layer_w1, "mlpwi_1")
layer_act = _convert_to_activation_function(self.config.mlp_activations[0])(layer_w0)
intermediate_layer = jnp.multiply(layer_act, layer_w1)
intermediate_output = gmm(intermediate_layer, wo, group_sizes)
Expand Down Expand Up @@ -544,6 +544,7 @@ def dense_matmul(self, inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel):
softmax_probs = jax.nn.softmax(gate_logits.astype(jnp.float32), axis=-1).astype(self.dtype)
# shape of top_k_weights & top_k_indices: (batch, sequence, num_experts_per_tok)
top_k_weights, top_k_indices = jax.lax.top_k(softmax_probs, self.num_experts_per_tok)
top_k_weights /= top_k_weights.sum(-1, keepdims=True)
matmul_precision = lax.Precision(self.config.matmul_precision)

if self.config.capacity_factor > 0:
Expand Down

0 comments on commit 061abd8

Please sign in to comment.