Skip to content

Commit

Permalink
Fix moe logging to differentiate dense and megablox runs
Browse files Browse the repository at this point in the history
  • Loading branch information
lenscloth committed Dec 14, 2024
1 parent 4546d68 commit 945ee3d
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion MaxText/layers/linears.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,14 +729,15 @@ def __call__(self, inputs):
name="gate",
matmul_precision=self.config.matmul_precision,
)(inputs)
max_logging.log("Running MoE megablox implementation.")
cfg = self.config
w0_kernel, w1_kernel, wo_kernel = self.generate_kernels(cfg.num_experts, cfg.emb_dim, cfg.mlp_dim)
if cfg.megablox:
max_logging.log("Running MoE megablox implementation.")
if quantizations.in_serve_mode(self.quant):
w0_kernel, w1_kernel, wo_kernel = self.retrieve_quantized_weight(
inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel
)
return self.megablox(inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel)
else:
max_logging.log("Running MoE matmul implementation.")
return self.dense_matmul(inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel)

0 comments on commit 945ee3d

Please sign in to comment.