diff --git a/MaxText/layers/linears.py b/MaxText/layers/linears.py index d06da67e3..e369639cc 100644 --- a/MaxText/layers/linears.py +++ b/MaxText/layers/linears.py @@ -729,14 +729,15 @@ def __call__(self, inputs): name="gate", matmul_precision=self.config.matmul_precision, )(inputs) - max_logging.log("Running MoE megablox implementation.") cfg = self.config w0_kernel, w1_kernel, wo_kernel = self.generate_kernels(cfg.num_experts, cfg.emb_dim, cfg.mlp_dim) if cfg.megablox: + max_logging.log("Running MoE megablox implementation.") if quantizations.in_serve_mode(self.quant): w0_kernel, w1_kernel, wo_kernel = self.retrieve_quantized_weight( inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel ) return self.megablox(inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel) else: + max_logging.log("Running MoE matmul implementation.") return self.dense_matmul(inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel)