Skip to content

Commit

Permalink
Move config into tests, explicitly disable caching in TestBackend
Browse files Browse the repository at this point in the history
Signed-off-by: luka <[email protected]>
  • Loading branch information
ProExpertProg committed Nov 13, 2024
1 parent 2a5f547 commit b051eb2
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 16 deletions.
5 changes: 3 additions & 2 deletions tests/compile/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@ class TestBackend:
It also saves the graph before and after the custom passes for inspection.
"""

def __init__(self, *args: Callable[[torch.fx.Graph], None]):
self.custom_passes = args
def __init__(self, *passes: Callable[[torch.fx.Graph], None]):
self.custom_passes = passes
from torch._inductor import config
self.current_config = config.shallow_copy_dict()
self.current_config['force_disable_caches'] = True
self.current_config['post_grad_custom_post_pass'] = self.post_pass

def __call__(self, graph: torch.fx.GraphModule, example_inputs):
Expand Down
13 changes: 6 additions & 7 deletions tests/compile/test_functionalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import vllm.envs as envs
from vllm import LLM, SamplingParams
from vllm.compilation.config import CompilationConfig
from vllm.compilation.config import PassConfig
from vllm.compilation.functionalization import FixFunctionalizationPass
from vllm.compilation.fusion import (FusionPass, find_auto_fn,
find_auto_fn_maybe)
Expand All @@ -12,11 +12,6 @@

from .backend import TestBackend

# Init does pattern registration, which can only happen once
config = CompilationConfig(enable_fusion=True)
reshape_pass = RedundantReshapesPass(config.pass_config)
fusion_pass = FusionPass.instance(config.pass_config)

OPS_IN_MODEL = [
torch.ops._C.fused_add_rms_norm.default,
torch.ops._C.silu_and_mul.default,
Expand Down Expand Up @@ -47,8 +42,12 @@
def test_fix_functionalization(model: str, do_fusion: bool):
torch.set_default_device("cuda")

config = PassConfig(enable_fusion=do_fusion, enable_reshape=True)
reshape_pass = RedundantReshapesPass(config)
fusion_pass = FusionPass.instance(config)

passes = [reshape_pass, fusion_pass] if do_fusion else [reshape_pass]
func_pass = FixFunctionalizationPass(config.pass_config)
func_pass = FixFunctionalizationPass(config)
backend_func = TestBackend(*passes, func_pass)
backend_no_func = TestBackend(*passes)

Expand Down
12 changes: 5 additions & 7 deletions tests/compile/test_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from compressed_tensors.quantization import FP8_DTYPE

import vllm.envs as envs
from vllm.compilation.config import CompilationConfig
from vllm.compilation.config import PassConfig
from vllm.compilation.fusion import (FusionPass, find_auto_fn,
find_auto_fn_maybe)
from vllm.compilation.reshapes import RedundantReshapesPass
Expand Down Expand Up @@ -38,12 +38,6 @@ def forward(self, x):
return y3


# Init does pattern registration, which can only happen once
config = CompilationConfig(enable_fusion=True)
reshape_pass = RedundantReshapesPass(config.pass_config)
fusion_pass = FusionPass.instance(config.pass_config)


@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [64, 3392, 4096])
@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049])
Expand All @@ -58,6 +52,10 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps):
pytest.skip("Only test eps=1e-5 for now")

# Reshape pass is needed for the fusion pass to work
config = PassConfig(enable_fusion=True, enable_reshape=True)
reshape_pass = RedundantReshapesPass(config)
fusion_pass = FusionPass.instance(config)

backend = TestBackend(reshape_pass, fusion_pass)
model = TestModel(hidden_size, eps)

Expand Down

0 comments on commit b051eb2

Please sign in to comment.