diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index ddc3d3b08fdd0..676d15e2cd37b 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -131,7 +131,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Fused Layernorm + Quant kernels ops.def( "rms_norm_dynamic_per_token_quant(Tensor! result, Tensor input, " - "Tensor weight, Tensor! scales, float epsilon, " + "Tensor weight, Tensor! scale, float epsilon, " "Tensor? scale_ub, Tensor!? residual) -> ()"); ops.impl("rms_norm_dynamic_per_token_quant", torch::kCUDA, &rms_norm_dynamic_per_token_quant); diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index d30ee82117866..fa1765d6ad84a 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -52,11 +52,14 @@ def forward(self, x): reason="Only test on CUDA") def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static): torch.set_default_device("cuda") - torch.set_default_dtype(torch.float16) + torch.set_default_dtype(dtype) + torch.manual_seed(1) # Reshape pass is needed for the fusion pass to work - config = CompilationConfig.PassConfig(enable_fusion=True, - enable_reshape=True) + config = CompilationConfig.PassConfig( + enable_fusion=True, + enable_reshape=True, + dump_graph_stages=["before_fusion", "after_fusion"]) reshape_pass = RedundantReshapesPass(config) fusion_pass = FusionPass.instance(config) @@ -73,8 +76,11 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static): result2 = model2(x) # Check that it gives the same answer, higher tol for dynamic - ATOL, RTOL = (1e-3, 1e-3) if static else (1e-2, 1e-2) - torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL) + ATOL, RTOL = (1e-3, 1e-3) if static else (2e-2, 2e-2) + torch.testing.assert_close(result.to(dtype=torch.float32), + result2.to(dtype=torch.float32), + atol=ATOL, + rtol=RTOL) # Check substitution worked pre_nodes = backend.graph_pre_pass.nodes @@ -85,8 +91,8 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static): add_rms_quant = torch.ops._C.fused_add_rms_norm_static_fp8_quant.default # noqa: E501 fp8_quant = torch.ops._C.static_scaled_fp8_quant.default else: - rms_quant = torch.ops._C.rms_norm_dynamic_fp8_quant.default - add_rms_quant = torch.ops._C.fused_add_rms_norm_dynamic_fp8_quant.default # noqa: E501 + rms_quant = torch.ops._C.rms_norm_dynamic_per_token_quant.default + add_rms_quant = torch.ops._C.rms_norm_dynamic_per_token_quant.default # noqa: E501 fp8_quant = torch.ops._C.dynamic_scaled_fp8_quant.default # In pre-nodes, fp8 quant should be present and fused kernels should not diff --git a/tests/kernels/test_fused_quant_layernorm.py b/tests/kernels/test_fused_quant_layernorm.py index ec3ad5ab5dd8b..15015063658ab 100644 --- a/tests/kernels/test_fused_quant_layernorm.py +++ b/tests/kernels/test_fused_quant_layernorm.py @@ -8,8 +8,8 @@ DTYPES = [torch.bfloat16, torch.float] QUANT_DTYPES = [torch.int8, torch.float8_e4m3fn] -NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing -HIDDEN_SIZES = [1, 2, 3, 4, 16, 67, 768, 2048, 5120, 5137, 8192, +NUM_TOKENS = [1, 7, 83, 2048, 4096] # Arbitrary values for testing +HIDDEN_SIZES = [1, 2, 3, 4, 16, 64, 67, 768, 2048, 5120, 5137, 8192, 8193] # Arbitrary values for testing HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases ADD_RESIDUAL = [False, True] diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index c78059f3eeb68..6b3ba9f704dbf 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -272,19 +272,14 @@ def rms_norm_dynamic_per_token_quant( # TODO is this necessary? @register_fake("_C::rms_norm_dynamic_per_token_quant") def _rms_norm_dynamic_per_token_quant_fake( - input: torch.Tensor, - weight: torch.Tensor, - epsilon: float, - quant_dtype: torch.dtype, - scale_ub: Optional[torch.Tensor] = None, - residual: Optional[torch.Tensor] = None -) -> Tuple[torch.Tensor, torch.Tensor]: - output = torch.empty_like(input, dtype=quant_dtype) - scales = torch.empty((input.numel() // input.shape[-1], 1), - device=input.device, - dtype=torch.float32) - - return output, scales + output: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + scales: torch.Tensor, + epsilon: float, + scale_ub: Optional[torch.Tensor] = None, + residual: Optional[torch.Tensor] = None) -> None: + return None # quantization ops diff --git a/vllm/compilation/fix_functionalization.py b/vllm/compilation/fix_functionalization.py index 1c87c7771b21d..e4661d552931d 100644 --- a/vllm/compilation/fix_functionalization.py +++ b/vllm/compilation/fix_functionalization.py @@ -60,19 +60,15 @@ def __call__(self, graph: torch.fx.Graph): elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: # noqa: E501 mutated_args = {1: 'result', 2: 'residual'} self.defunctionalize(graph, node, mutated_args) - elif at_target == torch.ops._C.fused_add_rms_norm_dynamic_fp8_quant.default: # noqa: E501 - mutated_args = {1: 'result', 2: 'residual', 3: 'scale'} + elif at_target == torch.ops._C.rms_norm_dynamic_per_token_quant.default: # noqa: E501 + mutated_args = {1: 'result', 2: 'scale', 3: 'residual'} self.defunctionalize(graph, node, mutated_args) - elif at_target in [ torch.ops._C.rms_norm.default, torch.ops._C.rms_norm_static_fp8_quant.default ]: mutated_args = {1: 'result'} self.defunctionalize(graph, node, mutated_args) - elif at_target == torch.ops._C.rms_norm_dynamic_fp8_quant.default: - mutated_args = {1: 'result', 2: 'scale'} - self.defunctionalize(graph, node, mutated_args) elif at_target == torch.ops._C.silu_and_mul.default: mutated_args = {1: 'out'} diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index fa9b02ff13ea6..907f9ad2c8a7b 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -16,40 +16,6 @@ logger = init_logger(__name__) -# TODO temp -@torch.library.custom_op("_C::rms_norm_dynamic_fp8_quant", - mutates_args=("result", "scale")) -def rms_norm_dynamic_fp8_quant(result: torch.Tensor, input: torch.Tensor, - weight: torch.Tensor, scale: torch.Tensor, - epsilon: float) -> None: - # Last two are scale_ub, residual - torch.ops._C.rms_norm_dynamic_per_token_quant(result, input, weight, scale, epsilon, None, None) - - -@torch.library.register_fake("_C::rms_norm_dynamic_fp8_quant") -def _(result: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, - scale: torch.Tensor, epsilon: float): - return None - - -@torch.library.custom_op("_C::fused_add_rms_norm_dynamic_fp8_quant", - mutates_args=("result", "residual", "scale")) -def fused_add_rms_norm_dynamic_fp8_quant(result: torch.Tensor, - input: torch.Tensor, - residual: torch.Tensor, - weight: torch.Tensor, - scale: torch.Tensor, - epsilon: float) -> None: - # Last two are scale_ub, residual - torch.ops._C.rms_norm_dynamic_per_token_quant(result, input, weight, scale, epsilon, None, residual) - - -@torch.library.register_fake("_C::rms_norm_dynamic_fp8_quant") -def _(result: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, - scale: torch.Tensor, epsilon: float): - return None - - def empty_bf16(*args, **kwargs): return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda") @@ -372,12 +338,14 @@ def replacement(result: torch.Tensor, result_rms: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): at = auto_functionalized( - torch.ops._C.rms_norm_static_fp8_quant.default, + torch.ops._C.rms_norm_dynamic_per_token_quant.default, result=result, input=input, weight=weight, scale=scale, - epsilon=self.epsilon) + epsilon=self.epsilon, + scale_ub=None, + residual=None) # result, scale return at[1], at[2] @@ -413,7 +381,7 @@ def process(self): # The auto_fn node returns a tuple of (None, result, scale). # # The resulting graph looks like this: - # at = auto_functionalized(torch.ops._C.rms_norm_static_fp8_quant.default, ...) # noqa + # at = auto_functionalized(torch.ops._C.rms_norm_dynamic_per_token_quant.default, ...) # noqa # result_node_new = at[1] # scale_node_new = at[2] with self.inserting_after_match(): @@ -421,10 +389,12 @@ def process(self): # Scalars cannot be inputs to the pattern kwargs["epsilon"] = rms_node.kwargs["epsilon"] + kwargs["scale_ub"] = None # not used but required + kwargs["residual"] = None # not used but required del kwargs["result_rms"] # not used in the fused op fused_node = self.insert_auto_fn( - torch.ops._C.rms_norm_dynamic_fp8_quant.default, + torch.ops._C.rms_norm_dynamic_per_token_quant.default, kwargs=kwargs) getitem_nodes = self.insert_getitems(fused_node, (1, 2)) @@ -466,16 +436,17 @@ def replacement(result: torch.Tensor, input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, scale: torch.Tensor): at = auto_functionalized( - torch.ops._C.fused_add_rms_norm_dynamic_fp8_quant.default, + torch.ops._C.rms_norm_dynamic_per_token_quant.default, result=result, input=input, - residual=residual, weight=weight, scale=scale, - epsilon=self.epsilon) + epsilon=self.epsilon, + scale_ub=None, + residual=residual) # result, residual, scale - return at[1], at[2], at[3] # TODO confirm signature + return at[1], at[3], at[2] inputs = [ empty_fp8(5, 4), # result @@ -508,22 +479,23 @@ def process(self): # The auto_fn node returns a tuple (None, result, scale, residual). # # The resulting graph looks like this: - # at = auto_functionalized(torch.ops._C.fused_add_rms_norm_dynamic_fp8_quant.default, ...) # noqa + # at = auto_functionalized(torch.ops._C.rms_norm_dynamic_per_token_quant.default, ...) # noqa # result_node_new = at[1] - # residual_node_new = at[2] - # scale_node_new = at[3] + # scale_node_new = at[2] + # residual_node_new = at[3] with self.inserting_after_match(): kwargs = self.match.kwargs.copy() # Scalars cannot be inputs to the pattern kwargs["epsilon"] = rms_node.kwargs["epsilon"] + kwargs["scale_ub"] = None # not used but required fused_node = self.insert_auto_fn( - torch.ops._C.fused_add_rms_norm_dynamic_fp8_quant.default, + torch.ops._C.rms_norm_dynamic_per_token_quant.default, kwargs=kwargs) getitem_ns = self.insert_getitems(fused_node, (1, 2, 3)) - result_node_new, residual_node_new, scale_node_new = getitem_ns + result_node_new, scale_node_new, residual_node_new = getitem_ns # Rebind the users of match getitem nodes to use the new nodes. # The old nodes will be removed by DCE at the end of the pass. @@ -588,12 +560,12 @@ def __init__(self, config: CompilationConfig.PassConfig): self.patterns, self.record_match) # Fuse rms_norm + dynamic_scaled_fp8_quant into - # rms_norm_dynamic_fp8_quant + # rms_norm_dynamic_per_token_quant RMSNormDynamicFP8QuantPattern(epsilon).register( self.patterns, self.record_match) # Fuse fused_add_rms_norm + dynamic_scaled_fp8_quant into - # fused_add_rms_norm_dynamic_fp8_quant + # rms_norm_dynamic_per_token_quant FusedAddRMSNormDynamicFP8QuantPattern(epsilon).register( self.patterns, self.record_match)