Skip to content

Commit

Permalink
remove cruft
Browse files Browse the repository at this point in the history
Signed-off-by: Bill Nell <[email protected]>
  • Loading branch information
bnellnm committed Nov 25, 2024
1 parent 9872079 commit d713a7d
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 14 deletions.
9 changes: 0 additions & 9 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@ def wrap_inductor(graph: fx.GraphModule,
do_logging: bool = False,
runtime_shape: Optional[int] = None,
use_inductor: bool = True):

print(f"WRAP_INDUCTOR {graph}")

if not use_inductor:
return graph

Expand Down Expand Up @@ -153,15 +150,12 @@ def call_module(self, target: torch.fx.node.Target,
assert isinstance(target, str)
output = super().call_module(target, args, kwargs)

print(f"TARGET {target}")

if target in self.compile_submod_names:
index = self.compile_submod_names.index(target)
submod = self.fetch_attr(target)
sym_shape_indices = [
i for i, x in enumerate(args) if isinstance(x, torch.SymInt)
]
print(f"COMPILE {target}")
compiled_graph_for_general_shape = wrap_inductor(
submod,
args,
Expand Down Expand Up @@ -280,8 +274,6 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
if not item.is_splitting_graph
]

print(f"submod_names_to_compile = {submod_names_to_compile}")

# propagate the split graph to the piecewise backend,
# compile submodules with symbolic shapes
PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile,
Expand Down Expand Up @@ -421,7 +413,6 @@ def __call__(self, *args) -> Any:
if entry.need_to_compile and not entry.compiled:
entry.compiled = True
# args are real arguments
print(f"COMPILE ENTRY")
entry.runnable = wrap_inductor(
self.graph,
args,
Expand Down
7 changes: 4 additions & 3 deletions vllm/compilation/collective_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from vllm.compilation.utils import (find_auto_fn, find_fn, find_getitem,
last_node_in_match, use_cc_kernels)
from vllm.config import CompilationConfig
from vllm.distributed import (tensor_model_parallel_all_gather,
from vllm.distributed import (model_parallel_is_initialized,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import (
get_group_from_group_name, get_tensor_model_parallel_world_size)
Expand Down Expand Up @@ -470,8 +471,8 @@ def find_min_index(match: Match) -> int:
for node in match.nodes)

def __call__(self, graph: fx.Graph):
if not (model_parallel_is_initialized() and
get_tensor_model_parallel_world_size() > 1):
if not (model_parallel_is_initialized()
and get_tensor_model_parallel_world_size() > 1):
return

# TODO: disable if chunk prefill size is too small
Expand Down
4 changes: 2 additions & 2 deletions vllm/compilation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import Match

from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank
# yapf: enable
# yapf: disable
from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank
from vllm.distributed import (
get_tensor_model_parallel_world_size as get_tp_world_size)
# yapf: enable
from vllm.distributed import model_parallel_is_initialized as p_is_init
from vllm.logger import init_logger

Expand Down

0 comments on commit d713a7d

Please sign in to comment.