Skip to content

Commit

Permalink
Fix for the padding in the non-cutlass-fp8 case
Browse files Browse the repository at this point in the history
Signed-off-by: luka <[email protected]>
  • Loading branch information
ProExpertProg committed Dec 4, 2024
1 parent c92acb9 commit e5ded5c
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 42 deletions.
13 changes: 9 additions & 4 deletions tests/compile/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,26 @@ class TestBackend:
This class provides a simple Inductor backend that can be used for testing.
It takes a list of custom passes and runs them after Inductor's passes.
It also saves the graph before and after the custom passes for inspection.
Inductor config can be modified directly by editing the inductor_config
property. This can be helpful for adding passes like the
'pre_grad_custom_pass' and the 'post_grad_custom_pre_pass'.
"""

def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph],
None]]):
self.custom_passes = list(passes)
from torch._inductor import config
self.current_config = config.shallow_copy_dict()
self.current_config['force_disable_caches'] = True
self.current_config['post_grad_custom_post_pass'] = self.post_pass
self.inductor_config = config.shallow_copy_dict()
self.inductor_config['force_disable_caches'] = True
self.inductor_config['post_grad_custom_post_pass'] = self.post_pass

def __call__(self, graph: fx.GraphModule, example_inputs):
self.graph_pre_compile = deepcopy(graph)
from torch._inductor.compile_fx import compile_fx
return compile_fx(graph,
example_inputs,
config_patches=self.current_config)
config_patches=self.inductor_config)

def post_pass(self, graph: fx.Graph):
self.graph_pre_pass = deepcopy(graph)
Expand Down
89 changes: 54 additions & 35 deletions tests/compile/test_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from compressed_tensors.quantization import FP8_DTYPE

import vllm.envs as envs
import vllm.plugins
from vllm.compilation.fusion import (FusionPass, find_auto_fn,
find_auto_fn_maybe)
from vllm.compilation.reshapes import RedundantReshapesPass
from vllm.config import CompilationConfig
from vllm.config import CompilationConfig, VllmConfig, CompilationLevel
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear)
Expand All @@ -16,8 +17,10 @@

class TestModel(torch.nn.Module):

def __init__(self, hidden_size: int, eps: float, *args, **kwargs):
def __init__(self, hidden_size: int, eps: float, cutlass_fp8: bool, *args,
**kwargs):
super().__init__(*args, **kwargs)
self.cutlass_fp8 = cutlass_fp8
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(4)]
self.w = [
Expand All @@ -29,11 +32,19 @@ def forward(self, x):
resid = torch.relu(x)
y = self.norm[0](x)

x2 = apply_fp8_linear(y, self.w[0], self.scale[0], self.scale[1])
x2 = apply_fp8_linear(y,
self.w[0],
self.scale[0],
self.scale[1],
cutlass_fp8_supported=self.cutlass_fp8)
# make sure resid is used for replacement to work
y2, resid = self.norm[1](x2, resid)

x3 = apply_fp8_linear(y2, self.w[1], self.scale[2], self.scale[3])
x3 = apply_fp8_linear(y2,
self.w[1],
self.scale[2],
self.scale[3],
cutlass_fp8_supported=self.cutlass_fp8)
y3, resid = self.norm[2](x3, resid) # use resid here
return y3

Expand All @@ -42,50 +53,58 @@ def forward(self, x):
@pytest.mark.parametrize("hidden_size", [64, 3392, 4096])
@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.parametrize(
"cutlass_fp8",
[True, False] if envs.VLLM_TARGET_DEVICE == "cuda" else [False])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
reason="Only test on CUDA")
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps):
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps,
cutlass_fp8):
torch.set_default_device("cuda")
torch.set_default_dtype(torch.float16)
torch.set_default_dtype(dtype)
torch.manual_seed(1)

if eps != 1e-5:
pytest.skip("Only test eps=1e-5 for now")

# Reshape pass is needed for the fusion pass to work
config = CompilationConfig.PassConfig(enable_fusion=True,
enable_reshape=True)
reshape_pass = RedundantReshapesPass(config)
fusion_pass = FusionPass.instance(config)
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"]))
with vllm.plugins.set_current_vllm_config(vllm_config):
# Reshape pass is needed for the fusion pass to work
config = CompilationConfig.PassConfig(enable_fusion=True,
enable_reshape=True)
reshape_pass = RedundantReshapesPass(config)
fusion_pass = FusionPass.instance(config)

backend = TestBackend(reshape_pass, fusion_pass)
model = TestModel(hidden_size, eps)
backend = TestBackend(reshape_pass, fusion_pass)
model = TestModel(hidden_size, eps, cutlass_fp8)

# First dimension dynamic
x = torch.rand(num_tokens, hidden_size)
torch._dynamo.mark_dynamic(x, 0)
# First dimension dynamic
x = torch.rand(num_tokens, hidden_size)
torch._dynamo.mark_dynamic(x, 0)

result = model(x)
result = model(x)

model2 = torch.compile(model, backend=backend)
result2 = model2(x)
model2 = torch.compile(model, backend=backend)
result2 = model2(x)

# Check that it gives the same answer
torch.testing.assert_close(result, result2, atol=1e-3, rtol=1e-3)
# Check that it gives the same answer
torch.testing.assert_close(result, result2, atol=1e-3, rtol=1e-3)

# Check substitution worked
pre_nodes = backend.graph_pre_pass.nodes
post_nodes = backend.graph_post_pass.nodes
# Check substitution worked
pre_nodes = backend.graph_pre_pass.nodes
post_nodes = backend.graph_post_pass.nodes

rms_quant = torch.ops._C.rms_norm_static_fp8_quant.default
add_rms_quant = torch.ops._C.fused_add_rms_norm_static_fp8_quant.default
fp8_quant = torch.ops._C.static_scaled_fp8_quant.default
rms_quant = torch.ops._C.rms_norm_static_fp8_quant.default
add_rms_quant = torch.ops._C.fused_add_rms_norm_static_fp8_quant.default
fp8_quant = torch.ops._C.static_scaled_fp8_quant.default

# In pre-nodes, fp8 quant should be present and fused kernels should not
assert find_auto_fn_maybe(pre_nodes, rms_quant) is None
assert find_auto_fn_maybe(pre_nodes, add_rms_quant) is None
find_auto_fn(pre_nodes, fp8_quant)
# In pre-nodes, fp8 quant should be there and fused kernels should not
assert find_auto_fn_maybe(pre_nodes, rms_quant) is None
assert find_auto_fn_maybe(pre_nodes, add_rms_quant) is None
find_auto_fn(pre_nodes, fp8_quant)

# In post-nodes, fused kernels should be present and fp8 quant should not
find_auto_fn(post_nodes, rms_quant)
find_auto_fn(post_nodes, add_rms_quant)
assert find_auto_fn_maybe(post_nodes, fp8_quant) is None
# In post-nodes, fused kernels should be there and fp8 quant should not
find_auto_fn(post_nodes, rms_quant)
find_auto_fn(post_nodes, add_rms_quant)
assert find_auto_fn_maybe(post_nodes, fp8_quant) is None
18 changes: 16 additions & 2 deletions vllm/compilation/vllm_inductor_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def __init__(self, config: CompilationConfig.PassConfig):
self.config = config
self.pass_name = self.__class__.__name__

def dump_graph(self, graph: torch.fx.Graph, stage: str):
if stage in self.config.dump_graph_stages:
def dump_graph(self, graph: torch.fx.Graph, stage: str, always=False):
if stage in self.config.dump_graph_stages or always:
# Make sure filename includes rank in the distributed setting
parallel = p_is_init() and get_tp_world_size() > 1
rank = f"-{get_tp_rank()}" if parallel else ""
Expand All @@ -51,3 +51,17 @@ def end_and_log(self):
self._end_time = time.perf_counter_ns()
duration_ms = float(self._end_time - self._start_time) / 1.0e6
logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms)


class PrinterInductorPass(VllmInductorPass):

def __init__(self,
name: str,
config: CompilationConfig.PassConfig,
always=False):
super().__init__(config)
self.name = name
self.always = always

def __call__(self, graph: torch.fx.Graph):
self.dump_graph(graph, self.name, always=self.always)
8 changes: 7 additions & 1 deletion vllm/model_executor/layers/quantization/utils/w8a8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import torch

from vllm import _custom_ops as ops
from vllm.config import CompilationLevel
from vllm.platforms import current_platform
from vllm.plugins import get_current_vllm_config

# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
Expand Down Expand Up @@ -122,10 +124,14 @@ def apply_fp8_linear(
# Note: we pad the input because torch._scaled_mm is more performant
# for matrices with batch dimension > 16.
# This could change in the future.
# We also don't pad when using torch.compile,
# as it breaks with dynamic shapes.
config = get_current_vllm_config().compilation_config
do_pad = config.level < CompilationLevel.PIECEWISE
qinput, x_scale = ops.scaled_fp8_quant(
input_2d,
input_scale,
num_token_padding=17,
num_token_padding=17 if do_pad else None,
use_per_token_if_dynamic=use_per_token_if_dynamic)

per_tensor_weights = (weight_scale.numel() == 1)
Expand Down

0 comments on commit e5ded5c

Please sign in to comment.