diff --git a/benchmarks/fused_kernels/layernorm_rms_benchmarks.py b/benchmarks/fused_kernels/layernorm_rms_benchmarks.py index 4dbdd3638aad3..ef91f9f8eb529 100644 --- a/benchmarks/fused_kernels/layernorm_rms_benchmarks.py +++ b/benchmarks/fused_kernels/layernorm_rms_benchmarks.py @@ -52,7 +52,7 @@ def unfused_int8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor, torch_out, _ = rms_norm_layer.forward_cuda(x, residual) # Quant - torch_out, _ = ops.scaled_int8_quant(torch_out) + torch_out, _, _ = ops.scaled_int8_quant(torch_out) def unfused_fp8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 1532901bd089f..ddc3d3b08fdd0 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -130,7 +130,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Fused Layernorm + Quant kernels ops.def( - "rms_norm_dynamic_per_token_quant(Tensor! out, Tensor input, " + "rms_norm_dynamic_per_token_quant(Tensor! result, Tensor input, " "Tensor weight, Tensor! scales, float epsilon, " "Tensor? scale_ub, Tensor!? residual) -> ()"); ops.impl("rms_norm_dynamic_per_token_quant", torch::kCUDA, diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index b75249666d771..d30ee82117866 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -72,8 +72,9 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static): model2 = torch.compile(model, backend=backend) result2 = model2(x) - # Check that it gives the same answer - torch.testing.assert_close(result, result2, atol=1e-3, rtol=1e-3) + # Check that it gives the same answer, higher tol for dynamic + ATOL, RTOL = (1e-3, 1e-3) if static else (1e-2, 1e-2) + torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL) # Check substitution worked pre_nodes = backend.graph_pre_pass.nodes diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index fbe5ea38e4dde..c78059f3eeb68 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -22,6 +22,7 @@ supports_moe_ops = False with contextlib.suppress(ImportError): import vllm._moe_C # noqa: F401 + supports_moe_ops = True # neuron has torch version that doesn't even have impl_abstract @@ -241,7 +242,6 @@ def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int, paged_kv_indptr: torch.Tensor, paged_kv_last_page_len: torch.Tensor, block_table_bound: torch.Tensor) -> None: - return torch.ops._C.advance_step_flashinfer( num_seqs, num_queries, block_size, input_tokens, sampled_token_ids, input_positions, seq_lens, slot_mapping, block_tables, @@ -258,7 +258,6 @@ def rms_norm_dynamic_per_token_quant( scale_ub: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: - output = torch.empty_like(input, dtype=quant_dtype) scales = torch.empty((input.numel() // input.shape[-1], 1), device=input.device, @@ -270,6 +269,24 @@ def rms_norm_dynamic_per_token_quant( return output, scales +# TODO is this necessary? +@register_fake("_C::rms_norm_dynamic_per_token_quant") +def _rms_norm_dynamic_per_token_quant_fake( + input: torch.Tensor, + weight: torch.Tensor, + epsilon: float, + quant_dtype: torch.dtype, + scale_ub: Optional[torch.Tensor] = None, + residual: Optional[torch.Tensor] = None +) -> Tuple[torch.Tensor, torch.Tensor]: + output = torch.empty_like(input, dtype=quant_dtype) + scales = torch.empty((input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.float32) + + return output, scales + + # quantization ops # awq def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, @@ -723,7 +740,7 @@ def scaled_fp8_quant( shape: Union[Tuple[int, int], torch.Size] = input.shape # For rocm, the output fp8 dtype is torch.float_e3m3fnuz out_dtype: torch.dtype = torch.float8_e4m3fnuz \ - if current_platform.is_rocm() else torch.float8_e4m3fn + if current_platform.is_rocm() else torch.float8_e4m3fn if num_token_padding: shape = (max(num_token_padding, input.shape[0]), shape[1]) output = torch.empty(shape, device=input.device, dtype=out_dtype) @@ -1006,9 +1023,9 @@ def register_graph_buffers(fa: int, handles: List[List[int]], # the case when users use `import __annotations__` to turn type # hints into strings. if isinstance(v, fn_type) \ - and v.__code__.co_filename == __file__ \ - and any(arg is torch.Tensor or arg == "torch.Tensor" - for arg in v.__annotations__.values()): + and v.__code__.co_filename == __file__ \ + and any(arg is torch.Tensor or arg == "torch.Tensor" + for arg in v.__annotations__.values()): names_and_values_to_update[k] = hint_on_error(v) names_and_values.update(names_and_values_to_update) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index e4c8c974623f7..fa9b02ff13ea6 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -22,9 +22,8 @@ def rms_norm_dynamic_fp8_quant(result: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, epsilon: float) -> None: - result_rms = torch.empty_like(input) - torch.ops._C.rms_norm(result_rms, input, weight, epsilon) - torch.ops._C.dynamic_scaled_fp8_quant(result, result_rms, scale) + # Last two are scale_ub, residual + torch.ops._C.rms_norm_dynamic_per_token_quant(result, input, weight, scale, epsilon, None, None) @torch.library.register_fake("_C::rms_norm_dynamic_fp8_quant") @@ -41,8 +40,8 @@ def fused_add_rms_norm_dynamic_fp8_quant(result: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor, epsilon: float) -> None: - torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon) - torch.ops._C.dynamic_scaled_fp8_quant(result, input, scale) + # Last two are scale_ub, residual + torch.ops._C.rms_norm_dynamic_per_token_quant(result, input, weight, scale, epsilon, None, residual) @torch.library.register_fake("_C::rms_norm_dynamic_fp8_quant") diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py index d5056b18fe968..217ece551925e 100644 --- a/vllm/plugins/__init__.py +++ b/vllm/plugins/__init__.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Optional import vllm.envs as envs +from vllm.utils import print_warning_once if TYPE_CHECKING: from vllm.config import VllmConfig @@ -93,7 +94,7 @@ def get_current_vllm_config() -> "VllmConfig": # in ci, usually when we test custom ops/modules directly, # we don't set the vllm config. In that case, we set a default # config. - logger.warning("Current VLLM config is not set.") + print_warning_once("Current VLLM config is not set.") from vllm.config import VllmConfig return VllmConfig() return _current_vllm_config