diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 464bc2af8fd6d..b0189725d427d 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -15,14 +15,15 @@ from .counter import compilation_counter from .inductor_pass import InductorPass from .pass_manager import PostGradPassManager +from .utils import dump_graph logger = init_logger(__name__) -def wrap_inductor(graph, - example_inputs, - additional_inductor_config, - do_logging=False, +def wrap_inductor(graph: fx.GraphModule, + example_inputs: Sequence[Any], + additional_inductor_config: Optional[Dict] = None, + do_logging: bool = False, runtime_shape: Optional[int] = None, use_inductor: bool = True): if not use_inductor: @@ -37,6 +38,10 @@ def wrap_inductor(graph, logger.info("Compiling a graph for shape %s", runtime_shape) from torch._inductor import config + + # Enable support for symmetric memory ops in the inductor. + torch._inductor.config._micro_pipeline_tp = True + current_config = config.shallow_copy_dict() from torch._inductor.compile_fx import compile_fx @@ -248,9 +253,19 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: self.compilation_configs.init_during_runtime() self.configure_post_pass() + if ("before_split_graph" + in self.compilation_configs.pass_config.dump_graph_stages): + dump_graph(self.compilation_configs.pass_config, graph.graph, + "before_split_graph") + self.split_gm, self.piecewise_graphs = split_graph( graph, self.compilation_configs.splitting_ops) + if ("after_split_graph" + in self.compilation_configs.pass_config.dump_graph_stages): + dump_graph(self.compilation_configs.pass_config, + self.split_gm.graph, "after_split_graph") + from torch._dynamo.utils import lazy_format_graph_code logger.debug("%s", lazy_format_graph_code("before split", self.graph)) logger.debug("%s", lazy_format_graph_code("after split", diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py new file mode 100644 index 0000000000000..167b619135de5 --- /dev/null +++ b/vllm/compilation/collective_fusion.py @@ -0,0 +1,450 @@ +import operator +from typing import Callable, List, Optional, Tuple + +import torch +import torch.fx as fx +from torch._inductor.pattern_matcher import (Match, PatternMatcherPass, + fwd_only, register_replacement) + +import vllm.envs as envs +from vllm.compilation.utils import (find_auto_fn, find_fn, find_getitem, + last_node_in_match) +from vllm.config import CompilationConfig +from vllm.distributed import (tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce) +from vllm.distributed.parallel_state import ( + get_group_from_group_name, get_tensor_model_parallel_world_size) +from vllm.logger import init_logger +from vllm.utils import direct_register_custom_op + +from .vllm_inductor_pass import VllmInductorPass + +logger = init_logger(__name__) + +use_flux = False +if envs.VLLM_USE_FLUX: + try: + import flux + use_flux = True + logger.info("Using flux kernels for collective communication fusion.") + except ImportError: + logger.info("Attempting to use flux but flux not installed.") + use_flux = False + +# Depends on arch, see auto_tile_shape in include/flux/gemm_hparams.h +# Can be 256 on sm80. +FLUX_TILE_SIZE: int = 128 + + +def use_cc_kernels(m_shape: int) -> bool: + n_slices = get_tensor_model_parallel_world_size() + return (m_shape % (FLUX_TILE_SIZE * n_slices) == 0 + and m_shape >= FLUX_TILE_SIZE * n_slices) + + +def get_world_name() -> str: + return torch.distributed.group.WORLD.group_name + + +def residual_slice_shape(residual: torch.Tensor, rank: int) -> int: + n_slices = get_tensor_model_parallel_world_size() + chunk, rem = divmod(residual.shape[0], n_slices) + return chunk if rank < n_slices - 1 or rem == 0 else rem + + +def residual_slice_shape_fake(residual: torch.Tensor, rank: int) -> int: + n_slices = get_tensor_model_parallel_world_size() + slices = torch.chunk(residual, n_slices, dim=0) + return slices[rank].shape[0] + + +def match_gemm_rs_ag_gemm( + residual: torch.Tensor, + gemm_1_weights: torch.Tensor, + gemm_1_activations: torch.Tensor, + rms_norm_weights: torch.Tensor, + gemm_2_weights: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + gemm_1_w_perm = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) + mm_1 = torch.ops.aten.mm.default(gemm_1_activations, gemm_1_w_perm) + all_reduce = tensor_model_parallel_all_reduce(mm_1) + + norm_res = torch.ops.higher_order.auto_functionalized( + torch.ops._C.fused_add_rms_norm.default, + input=all_reduce, + residual=residual, + weight=rms_norm_weights, + epsilon=1e-05) + normalized = norm_res[1] + new_residual = norm_res[2] + + gemm_2_w_perm = torch.ops.aten.permute.default(gemm_2_weights, [1, 0]) + mm_2 = torch.ops.aten.mm.default(normalized, gemm_2_w_perm) + + return mm_2, new_residual + + +def get_gemm_rs_ag_gemm(use_flux: bool, max_m: int, gemm_1_type: torch.dtype, + gemm_1_weights: torch.Size, gemm_2_type: torch.dtype, + gemm_2_weights: torch.Size, + tp_group_name: str) -> Callable: + + group = get_group_from_group_name(tp_group_name) + device_group = group.device_group + rank = group.rank_in_group + + if use_flux: + gemm_rs_op = flux.GemmRS( + device_group, + 1, # One node + max_m, # max M + gemm_1_weights[0], # N + # TODO: It would be nicer to modify flux to dispatch based on dtype + # at run time, but I don't know what the downside would be. + # Similar comment for max m. + gemm_1_type, + # Note: transpose_weight=False means that B is transposed + transpose_weight=False, + # Note: bfloat16 requires fuse_reduction=False. + fuse_reduction=False, + ) + + ag_gemm_op = flux.AGKernel( + device_group, + 1, # One node + max_m, # max M + gemm_2_weights[0], # N + gemm_2_weights[1], # K + # TODO: It would be nicer to modify flux to dispatch based on dtype + # at run time, but I don't know what the downside would be. + # Similar comment for max m. + gemm_2_type, + gemm_2_type, + # Note: transpose_weight=False means that B is transposed + transpose_weight=False, + # Note: if local_copy=True, I hit the following runtime error: + # /flux/src/all_gather/ths_op/all_gather_gemm_kernel.cc:648 + # Check failed: 33554432((input.numel() * input.element_size())) + # == 139836453421056((this->chunk_size)) + local_copy=False, + ) + + def gemm_rs(act, wt): + return gemm_rs_op.forward(act, wt).squeeze(0) + + def ag_gemm(act, wt): + return ag_gemm_op.forward(act, wt) + + gemm_1_str = str(gemm_1_type).removeprefix("torch.") + gemm_2_str = str(gemm_2_type).removeprefix("torch.") + group_str = tp_group_name.replace(":", "_") + name = (f"gemm_rs_ag_gemm_{max_m}_{gemm_1_str}_{gemm_1_weights[0]}_" + f"{gemm_2_str}_{gemm_2_weights[0]}_{gemm_2_weights[1]}_" + f"{group_str}") + else: + world_group_name = get_world_name() + + def gemm_rs(act, wt): + return torch.ops.symm_mem.fused_matmul_reduce_scatter.default( + act, wt.transpose(1, 0), 'avg', 0, world_group_name) + + def ag_gemm(act, wt): + return torch.ops.symm_mem.fused_all_gather_matmul.default( + act, [wt.transpose(1, 0)], 0, world_group_name)[1] + + group_str = tp_group_name.replace(":", "_") + name = f"gemm_rs_ag_gemm_{group_str}" + + def gemm_rs_ag_gemm( + residual: torch.Tensor, old_my_residual: torch.Tensor, + gemm_1_weights: torch.Tensor, gemm_1_activations: torch.Tensor, + rms_norm_weights: torch.Tensor, gemm_2_weights: torch.Tensor, + first_layer: bool + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + if first_layer and use_cc_kernels(residual.shape[0]): + slice_shape = residual_slice_shape(residual, rank) + residual_chunk = torch.ops.aten.split.Tensor(residual, slice_shape) + my_residual = residual_chunk[0] + else: + my_residual = residual + slice_shape = residual.shape[0] + + if not use_cc_kernels(residual.shape[0]): + output = torch.ops.aten.mm.default(gemm_1_activations, + gemm_1_weights.transpose(1, 0)) + reduced_output = tensor_model_parallel_all_reduce(output) + + torch.ops._C.fused_add_rms_norm.default(input=reduced_output, + residual=my_residual, + weight=rms_norm_weights, + epsilon=1e-05) + + mm_2 = torch.ops.aten.mm.default(reduced_output, + gemm_2_weights.transpose(1, 0)) + + return mm_2, my_residual, my_residual.clone() + else: + output = gemm_rs(gemm_1_activations, gemm_1_weights) + + torch.ops._C.fused_add_rms_norm.default(input=output, + residual=my_residual, + weight=rms_norm_weights, + epsilon=1e-05) + + residual_1 = residual if first_layer else old_my_residual + slice_scatter = torch.ops.aten.slice_scatter.default( + residual_1, my_residual, 0, 0, slice_shape) + split_2 = torch.ops.aten.split.Tensor(slice_scatter, slice_shape) + new_residual = split_2[0] + + mm_2 = ag_gemm(output, gemm_2_weights) + + return mm_2[0], new_residual, slice_scatter + + def gemm_rs_ag_gemm_fake( + residual: torch.Tensor, + my_residual: torch.Tensor, + gemm_1_weights: torch.Tensor, + gemm_1_activations: torch.Tensor, + rms_norm_weights: torch.Tensor, + gemm_2_weights: torch.Tensor, + first_layer: bool, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if first_layer and use_cc_kernels(gemm_1_activations.shape[0]): + slice_shape = residual_slice_shape_fake(residual, rank) + split_1 = torch.ops.aten.split.Tensor(residual, slice_shape) + my_residual = split_1[0] + else: + my_residual = residual + + # TODO: verify the type is always correct + mm_res = torch.empty( + (gemm_1_activations.shape[0], gemm_2_weights.shape[0]), + device=gemm_1_activations.device, + dtype=gemm_1_activations.dtype) + + return (mm_res, my_residual, residual) + + if not hasattr(torch.ops.vllm, name): + logger.info("registering torch.ops.vllm.%s", name) + direct_register_custom_op(name, + gemm_rs_ag_gemm, + mutates_args=[], + fake_impl=gemm_rs_ag_gemm_fake) + assert getattr(torch.ops.vllm, name) + + return getattr(torch.ops.vllm, name).default + + +def match_final( + my_residual: torch.Tensor, + gemm_1_weights: torch.Tensor, + gemm_1_activations: torch.Tensor, + rms_norm_weights: torch.Tensor, +) -> torch.Tensor: + gemm_1_w_perm = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) + mm_1 = torch.ops.aten.mm.default(gemm_1_activations, gemm_1_w_perm) + + all_reduce = tensor_model_parallel_all_reduce(mm_1) + + norm_res = torch.ops.higher_order.auto_functionalized( + torch.ops._C.fused_add_rms_norm.default, + input=all_reduce, + residual=my_residual, + weight=rms_norm_weights, + epsilon=1e-05) + normalized = norm_res[1] + + return normalized + + +# Register this as a custom op since all gather cannot be torch.compiled yet. +def gemm_ag_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, + gemm_1_activations: torch.Tensor, + rms_norm_weights: torch.Tensor) -> torch.Tensor: + mm_1 = torch.ops.aten.mm.default(gemm_1_activations, + gemm_1_weights.transpose(1, 0)) + + reduced = tensor_model_parallel_all_reduce(mm_1) + + if use_cc_kernels(gemm_1_activations.shape[0]): + wait_tensor = tensor_model_parallel_all_gather(my_residual) + else: + wait_tensor = my_residual + + torch.ops._C.fused_add_rms_norm.default(input=reduced, + residual=wait_tensor, + weight=rms_norm_weights, + epsilon=1e-05) + + return reduced + + +def gemm_ag_final_fake(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, + gemm_1_activations: torch.Tensor, + rms_norm_weights: torch.Tensor) -> torch.Tensor: + return torch.empty([gemm_1_activations.shape[0], my_residual.shape[1]], + dtype=my_residual.dtype, + device=my_residual.device) + + +direct_register_custom_op("gemm_ag_final", + gemm_ag_final, + mutates_args=[], + fake_impl=gemm_ag_final_fake) + + +class CollectiveFusionPass(VllmInductorPass): + + _instance: 'Optional[CollectiveFusionPass]' = None + + @classmethod + def instance(cls, config: CompilationConfig) -> "CollectiveFusionPass": + """ + Get the singleton instance of the CollectiveFusionPass. + If the instance exists, the config is updated but + initialization is not repeated. + """ + if cls._instance is None: + cls._instance = CollectiveFusionPass(config) + else: + cls._instance.config = config + return cls._instance + + def __init__(self, config: CompilationConfig): + assert self.__class__._instance is None, \ + "FusionPass singleton instance already exists" + super().__init__(config) + + self.gemm_rs_ag_gemm_pattern = PatternMatcherPass() + self.final_pattern = PatternMatcherPass() + self.matches: List[Match] = [] + + # Run in fake mode so that we don't call real functions + # when tracing the patterns. + with torch._dynamo.utils.detect_fake_mode(): + x = torch.empty([4, 4], device='cuda', dtype=torch.float16) + w = torch.empty([4, 4], device='cuda', dtype=torch.float16) + resid = torch.empty([4, 4], device='cuda', dtype=torch.float16) + resid_w = torch.empty([4, 4], device='cuda', dtype=torch.float16) + x2 = torch.empty([4, 4], device='cuda', dtype=torch.float16) + inputs = [resid, x, w, resid_w, x2] + final_inputs = [x, w, resid, resid_w] + + register_replacement(match_gemm_rs_ag_gemm, + match_gemm_rs_ag_gemm, + inputs, + fwd_only, [self.gemm_rs_ag_gemm_pattern], + extra_check=lambda m: self.record_match(m)) + + register_replacement(match_final, torch.ops.vllm.gemm_ag_final, + final_inputs, fwd_only, [self.final_pattern]) + + def record_match(self, match: Match) -> bool: + # Hijack the extra_check to record the match and + # save it for post-processing. + self.matches.append(match) + + # Return False to prevent automatic replacement. + return False + + def find_max_m(self, matches: List[Match]) -> int: + max_m = 0 + for m in matches: + residual = m.kwargs["residual"].meta["val"] + max_m = max(max_m, residual.shape[1]) + assert max_m > 0 + return max_m + + def process_matches(self, graph: fx.Graph) -> None: + nodes = list(graph.nodes) + + def find_min_index(match: Match) -> int: + return min(match.nodes, key=lambda x: nodes.index(x)) + + # "sort" matches in topo order. + matches = sorted(self.matches, key=lambda x: find_min_index(x)) + + res_replacements: List[fx.Node] = [] + my_res_replacements: List[fx.Node] = [] + + max_m = self.find_max_m(matches) + logger.info("max m = %d", max_m) + + for match in matches: + last_node = last_node_in_match(match) + + with graph.inserting_after(last_node): + kwargs = match.kwargs + kwargs["first_layer"] = match == matches[0] + kwargs["residual"] = res_replacements[-1] if len( + res_replacements) > 0 else match.kwargs["residual"] + kwargs["old_my_residual"] = my_res_replacements[-1] if len( + my_res_replacements) > 0 else match.kwargs["residual"] + + gemm_1 = kwargs["gemm_1_weights"].meta.get("val") + gemm_2 = kwargs["gemm_2_weights"].meta.get("val") + if gemm_1 is None or gemm_2 is None: + raise ValueError("Missing 'val' in gemm weights meta data") + + # Extract group_name from matched code. Use to + # generate proper replacement code. + ar_node = find_fn(match.nodes, + torch.ops.vllm.all_reduce.default) + assert ar_node is not None + tp_group_name = ar_node.args[1] + + fused_gemm_func = get_gemm_rs_ag_gemm( + use_flux, max_m, gemm_1.dtype, gemm_1.shape, gemm_2.dtype, + gemm_2.shape, tp_group_name) + + fused_node = graph.call_function(fused_gemm_func, + kwargs=kwargs) + + graph.inserting_after(fused_node) + result_node_new = graph.call_function(operator.getitem, + (fused_node, 0)) + residual_node_new = graph.call_function( + operator.getitem, (fused_node, 1)) + my_residual_node_new = graph.call_function( + operator.getitem, (fused_node, 2)) + res_replacements.append(residual_node_new) + my_res_replacements.append(my_residual_node_new) + + rms_node = find_auto_fn(reversed(match.nodes), + torch.ops._C.fused_add_rms_norm.default) + gemm_node = find_fn(reversed(match.nodes), + torch.ops.aten.mm.default) + assert rms_node is not None + assert gemm_node is not None + + assert len(rms_node.users) == 2 + assert len(gemm_node.users) == 1 or len(gemm_node.users) == 2 + + residual_getter_node = find_getitem(rms_node, 2) + assert residual_getter_node is not None + residual_getter_node.replace_all_uses_with(residual_node_new) + gemm_node.replace_all_uses_with(result_node_new) + + # Finally, remove matched nodes + graph.eliminate_dead_code() + assert all(node not in graph.nodes for match in matches + for node in match.nodes) + + def __call__(self, graph: fx.Graph): + # TODO: disable if chunk prefill size is too small + # or when doing decode. + self.dump_graph(graph, "before_collective_fusion") + self.gemm_rs_ag_gemm_pattern.apply(graph) + logger.info("fused gemm match count = %d", len(self.matches)) + + # Don't apply final pattern unless we've matched and replaced the + # gemm+collective ops. + if len(self.matches) > 0: + count = self.final_pattern.apply(graph) + logger.info("final match count = %d", count) + self.process_matches(graph) + + self.dump_graph(graph, "after_collective_fusion") + self.matches.clear() diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index fb522ae053e97..833d177624af0 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -5,6 +5,7 @@ from vllm.config import CompilationConfig from vllm.logger import init_logger +from .collective_fusion import CollectiveFusionPass from .fix_functionalization import FixFunctionalizationPass from .fusion import FusionPass from .inductor_pass import InductorPass @@ -47,6 +48,9 @@ def configure(self, pass_config: CompilationConfig.PassConfig): if pass_config.enable_fusion: self.passes += [FusionPass.instance(pass_config)] + if pass_config.enable_collective_fusion: + self.passes += [CollectiveFusionPass.instance(pass_config)] + self.fix_functionalization = FixFunctionalizationPass(pass_config) def add(self, pass_: InductorPass): diff --git a/vllm/compilation/utils.py b/vllm/compilation/utils.py new file mode 100644 index 0000000000000..195e7c8e812fe --- /dev/null +++ b/vllm/compilation/utils.py @@ -0,0 +1,83 @@ +import operator +import os +from pathlib import Path +from typing import Dict, Iterable, Optional + +import torch.fx as fx +from torch._higher_order_ops.auto_functionalize import auto_functionalized +from torch._inductor.pattern_matcher import Match + +# yapf: enable +# yapf: disable +from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank +from vllm.distributed import ( + get_tensor_model_parallel_world_size as get_tp_world_size) +from vllm.distributed import model_parallel_is_initialized as p_is_init +from vllm.logger import init_logger + +logger = init_logger(__name__) + +COUNTS: Dict[str, int] = {} + +# Depends on arch, see auto_tile_shape in include/flux/gemm_hparams.h +# Can be 256 on sm80. +FLUX_TILE_SIZE: int = 128 + + +def find_fn(nodes: Iterable[fx.Node], op) -> Optional[fx.Node]: + for node in nodes: + if node.op == "call_function" and node.target == op: + return node + return None + + +def find_auto_fn(nodes: Iterable[fx.Node], op) -> Optional[fx.Node]: + for node in nodes: + if (node.op == "call_function" and node.target == auto_functionalized + and node.args[0] == op): + return node + return None + + +def find_getitem(node: fx.Node, idx: int) -> Optional[fx.Node]: + for user in node.users: + if (user.op == "call_function" and user.target == operator.getitem + and user.args[1] == idx): + return user + return None + + +def last_node_in_match(match: Match) -> fx.Node: + if len(match.nodes) > 0: + graph = match.nodes[0].graph + for n in reversed(graph.nodes): + if n in reversed(match.nodes): + return n + raise ValueError("No nodes in graph") + + +def dump_graph(pass_config, graph: fx.Graph, name: str) -> None: + global COUNTS + count = COUNTS.get(name, 0) + + # Make sure filename includes rank in the distributed setting + parallel = p_is_init() and get_tp_world_size() > 1 + rank = f"-{get_tp_rank()}" if parallel else "" + filepath = Path(pass_config.dump_graph_dir) / f"{name}{rank}-{count}.py" + COUNTS[name] = count + 1 + + os.makedirs(pass_config.dump_graph_dir, exist_ok=True) + logger.info("%s printing graph to %s", name, filepath) + with open(filepath, "w") as f: + src = graph.owning_module.print_readable(print_output=False) + # Add imports so it's not full of errors + print("import torch; from torch import device", file=f) + print(src, file=f) + + +# Note: this heuristic is unique to flux +def use_cc_kernels(m_shape: int, n_slices: Optional[int] = None) -> bool: + if n_slices is None: + n_slices = get_tp_world_size() + return (m_shape % (FLUX_TILE_SIZE * n_slices) == 0 + and m_shape >= FLUX_TILE_SIZE * n_slices) diff --git a/vllm/compilation/vllm_inductor_pass.py b/vllm/compilation/vllm_inductor_pass.py index dbf6b8f7789e1..e43a95863eee2 100644 --- a/vllm/compilation/vllm_inductor_pass.py +++ b/vllm/compilation/vllm_inductor_pass.py @@ -3,15 +3,10 @@ import torch from vllm.config import CompilationConfig -# yapf: disable -from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank -from vllm.distributed import ( - get_tensor_model_parallel_world_size as get_tp_world_size) -from vllm.distributed import model_parallel_is_initialized as p_is_init -# yapf: enable from vllm.logger import init_logger from .inductor_pass import InductorPass +from .utils import dump_graph as utils_dump_graph logger = init_logger(__name__) @@ -32,17 +27,7 @@ def __init__(self, config: CompilationConfig.PassConfig): def dump_graph(self, graph: torch.fx.Graph, stage: str): if stage in self.config.dump_graph_stages: - # Make sure filename includes rank in the distributed setting - parallel = p_is_init() and get_tp_world_size() > 1 - rank = f"-{get_tp_rank()}" if parallel else "" - filepath = self.config.dump_graph_dir / f"{stage}{rank}.py" - - logger.info("%s printing graph to %s", self.pass_name, filepath) - with open(filepath, "w") as f: - src = graph.python_code(root_module="self", verbose=True).src - # Add imports so it's not full of errors - print("import torch; from torch import device", file=f) - print(src, file=f) + utils_dump_graph(self.config, graph, stage) def begin(self): self._start_time = time.perf_counter_ns() diff --git a/vllm/config.py b/vllm/config.py index eae6f909e3933..c59d22f0799ed 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -16,6 +16,7 @@ import vllm.envs as envs from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass +from vllm.compilation.utils import use_cc_kernels from vllm.logger import init_logger from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS, get_quantization_config) @@ -2168,12 +2169,15 @@ class PassConfig(BaseModel): - dump_graph_stages: list of stages for which we want to dump the graph. Each pass defines its own stages (before, after, maybe in-between). - dump_graph_dir: directory to dump the graphs. Default is . + - enable_collective_fusion: whether to enable the custom collective + communication fusion pass. - enable_fusion: whether to enable the custom fusion pass. - enable_reshape: whether to enable the custom reshape elimination pass. TODO better pass enabling system. """ dump_graph_stages: List[str] = Field(default_factory=list) dump_graph_dir: Path = Field(default=Path(".")) + enable_collective_fusion: bool = True enable_fusion: bool = True enable_reshape: bool = True @@ -2184,8 +2188,9 @@ def uuid(self): Do not include dump_graph_* in the hash - they don't affect compilation. """ - dict_ = self.model_dump( - include={"enable_fusion", "enable_reshape"}) + dict_ = self.model_dump(include={ + "enable_collective_fusion", "enable_fusion", "enable_reshape" + }) encoded = json.dumps(dict_, sort_keys=True).encode("utf-8") return hashlib.sha256(encoded).digest() @@ -2398,6 +2403,7 @@ def __post_init__(self): self.compilation_config.custom_ops = ["none"] self.compilation_config.use_cudagraph = True self.compilation_config.use_inductor = True + self.compilation_config.pass_config.enable_collective_fusion = False self.compilation_config.pass_config.enable_fusion = False self.compilation_config.pass_config.enable_reshape = False self.compilation_config.level = CompilationLevel.PIECEWISE @@ -2416,6 +2422,21 @@ def __post_init__(self): "Disabling `torch.compile`.") self.compilation_config.level = CompilationLevel.NO_COMPILATION + if self.compilation_config.pass_config.enable_collective_fusion: + n_slices = self.parallel_config.world_size + max_tokens = self.scheduler_config.max_num_batched_tokens + if not use_cc_kernels(int(max_tokens / n_slices), n_slices): + logger.info( + ("Disabling collective fusion pass since chunked prefill " + "size %d is too small."), max_tokens) + self.compilation_config.pass_config.enable_collective_fusion = \ + False + if n_slices == 1: + logger.info("Disabling collective fusion pass since tensor " + "parallelism is not enabled.") + self.compilation_config.pass_config.enable_collective_fusion = \ + False + current_platform.check_and_update_config(self) def __str__(self): diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index ccbe00386c5da..c4d2b42c18e97 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -32,13 +32,21 @@ import torch import torch.distributed -from torch.distributed import Backend, ProcessGroup +from torch.distributed import Backend, ProcessGroup, _symmetric_memory import vllm.envs as envs from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op, supports_custom_op +has_flux = False +if envs.VLLM_USE_FLUX: + try: + import flux + has_flux = True + except ImportError: + has_flux = False + @dataclass class GraphCaptureContext: @@ -96,11 +104,16 @@ def _register_group(group: "GroupCoordinator") -> None: _groups[group.unique_name] = weakref.ref(group) -def all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor: +def get_group_from_group_name(group_name: str) -> "GroupCoordinator": assert group_name in _groups, f"Group {group_name} is not found." group = _groups[group_name]() if group is None: raise ValueError(f"Group {group_name} is destroyed.") + return group + + +def all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor: + group = get_group_from_group_name(group_name) return group._all_reduce_out_place(tensor) @@ -199,6 +212,10 @@ def __init__( self.use_hpu_communicator = use_hpu_communicator self.use_xpu_communicator = use_xpu_communicator + if has_flux and torch.distributed.get_world_size( + self.device_group) > 1: + flux.init_flux_shm(self.device_group) + # lazy import to avoid documentation build error from vllm.distributed.device_communicators.custom_all_reduce import ( CustomAllreduce) @@ -974,6 +991,9 @@ def init_distributed_environment( assert _WORLD.world_size == torch.distributed.get_world_size(), ( "world group already initialized with a different world size") + group_name = torch.distributed.group.WORLD.group_name + _symmetric_memory.enable_symm_mem_for_group(group_name) + def initialize_model_parallel( tensor_model_parallel_size: int = 1, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 1551a9a998160..d860242d4674e 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -171,7 +171,8 @@ def __init__( # After positional args are removed, move this right below `model` task: TaskOption = "auto", override_pooler_config: Optional[PoolerConfig] = None, - compilation_config: Optional[Union[int, Dict[str, Any]]] = None, + compilation_config: Optional[Union[int, Dict[str, Any], + CompilationConfig]] = None, **kwargs, ) -> None: ''' diff --git a/vllm/envs.py b/vllm/envs.py index c896770e5f6bc..852f101be2758 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -70,6 +70,7 @@ VLLM_DISABLED_KERNELS: List[str] = [] VLLM_USE_V1: bool = False VLLM_ENABLE_V1_MULTIPROCESSING: bool = False + VLLM_USE_FLUX: bool = False def get_default_cache_root(): @@ -457,6 +458,10 @@ def get_default_config_root(): # If set, enable multiprocessing in LLM for the V1 code path. "VLLM_ENABLE_V1_MULTIPROCESSING": lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0"))), + + # If set, try to use the flux fused collective communication gemm kernels. + "VLLM_USE_FLUX": + lambda: bool(int(os.getenv("VLLM_USE_FLUX", "1"))), } # end-env-vars-definition