diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index d951dd190e139..a936c73810721 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -446,8 +446,8 @@ def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, @register_fake("_C::machete_mm") def machete_mm_fake( a: torch.Tensor, - b_q: torch. - Tensor, # Should be the tensor returned by machete_prepack_B + # Should be the tensor returned by machete_prepack_B + b_q: torch.Tensor, b_type: ScalarType, out_type: Optional[torch.dtype] = None, b_group_scales: Optional[torch.Tensor] = None,