From 61c79b364b6408fa10d3a78305fc1b645ed5c735 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Wed, 12 Jun 2024 21:07:08 +0000 Subject: [PATCH 01/72] Prototype integration of bytedance Flux kernels Signed-off-by: Bill Nell --- flux_env.sh | 17 +++ vllm/distributed/parallel_state.py | 5 + vllm/model_executor/layers/linear.py | 152 ++++++++++++++++++++++++--- vllm/model_executor/models/llama.py | 55 ++++++++-- vllm/model_executor/models/utils.py | 20 +++- 5 files changed, 225 insertions(+), 24 deletions(-) create mode 100644 flux_env.sh diff --git a/flux_env.sh b/flux_env.sh new file mode 100644 index 0000000000000..8979ce0858d0c --- /dev/null +++ b/flux_env.sh @@ -0,0 +1,17 @@ +#Point to the directory containing the flux .so files: +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/nm-vllm/flux_experiment/lib + +export NVSHMEM_BOOTSTRAP_MPI_PLUGIN=nvshmem_bootstrap_torch.so + +# Env variables for symmetric heap allocation. +# These are needed for supporting CUDA_VISIBLE DEVICES +# This is big enough for llama3 8b, but should be set correctly +export NVSHMEM_SYMMETRIC_SIZE=$((8*1024**3)) +export NVSHMEM_DISABLE_CUDA_VMM=1 # moving from cpp to shell + +# Not sure if these are needed +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export BYTED_TORCH_BYTECCL=O0 +export NCCL_IB_TIMEOUT=${NCCL_IB_TIMEOUT:=23} +export NCCL_IB_GID_INDEX=${NCCL_IB_GID_INDEX:=3} +export NVSHMEM_IB_GID_INDEX=3 diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index ccbe00386c5da..bbec1f57f665a 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -30,6 +30,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union from unittest.mock import patch +import flux import torch import torch.distributed from torch.distributed import Backend, ProcessGroup @@ -199,6 +200,10 @@ def __init__( self.use_hpu_communicator = use_hpu_communicator self.use_xpu_communicator = use_xpu_communicator + # Initialize pynvshmem + if torch.distributed.get_world_size(self.device_group) > 1: + flux.init_flux_shm(self.device_group) + # lazy import to avoid documentation build error from vllm.distributed.device_communicators.custom_all_reduce import ( CustomAllreduce) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 46ef11e7d02c6..19f0de158089a 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -2,6 +2,7 @@ from abc import abstractmethod from typing import Dict, List, Optional, Tuple +import flux import torch import torch.nn.functional as F from torch.nn.parameter import Parameter, UninitializedParameter @@ -11,6 +12,7 @@ split_tensor_along_last_dim, tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) +from vllm.distributed.parallel_state import get_tp_group from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) @@ -137,6 +139,104 @@ def apply(self, return F.linear(x, layer.weight, bias) +class GemmRS(LinearMethodBase): + #Fused Gemm-ReduceScatter without quantization. + + def __init__(self, separate_bias_add: bool = False): + self.separate_bias_add = separate_bias_add + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + weight = Parameter(torch.empty(sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype), + requires_grad=False) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, extra_weight_attrs) + + self.gemm_rs_op = flux.GemmRS( + get_tp_group().device_group, + 1, # One node + 8192, # Max M. TODO: Pass in correctly. + output_size, # N + # TODO: Pass in input dtype correctly. + # TODO: It would be nicer to modify flux to dispatch based on dtype + # at run time, but I don't know what the downside would be. + # Similar comment for max m. + torch.float16, + # Note: transpose_weight=False means that B is transposed + transpose_weight=False, + # Note: bfloat16 requires fuse_reduction=False. + fuse_reduction=False, + ) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + assert bias is None + + output = self.gemm_rs_op.forward(x, layer.weight) + output = output.squeeze(0) + + return output + + +class AGCook(LinearMethodBase): + #Fused AllGather-Gemm without quantization. + + def __init__(self, separate_bias_add: bool = False): + self.separate_bias_add = separate_bias_add + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + weight = Parameter(torch.empty(sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype), + requires_grad=False) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, extra_weight_attrs) + + self.ag_gemm_op = flux.AGKernel( + get_tp_group().device_group, + 1, # One node + 8192, # Max M. TODO: Pass in correctly. + weight.shape[0], # N + weight.shape[1], # K + # TODO: Pass in input dtype correctly. + # TODO: It would be nicer to modify flux to dispatch based on dtype + # at run time, but I don't know what the downside would be. + # Similar comment for max m. + torch.float16, + torch.float16, + # Note: transpose_weight=False means that B is transposed + transpose_weight=False, + # Note: if local_copy=True, I hit the following runtime error: + # /flux/src/all_gather/ths_op/all_gather_gemm_kernel.cc:648 + # Check failed: 33554432((input.numel() * input.element_size())) + # == 139836453421056((this->chunk_size)) + local_copy=False, + ) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + assert bias is None + + output = self.ag_gemm_op.forward(x, layer.weight) + + return output + + class LinearBase(torch.nn.Module): """Base linear layer. @@ -157,6 +257,8 @@ def __init__( params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + fuse_gemm_rs: bool = False, + fuse_ag_gemm: bool = False, ): super().__init__() @@ -167,9 +269,15 @@ def __init__( if params_dtype is None: params_dtype = torch.get_default_dtype() self.params_dtype = params_dtype - if quant_config is None: - self.quant_method: Optional[ - QuantizeMethodBase] = UnquantizedLinearMethod() + + if fuse_gemm_rs: + assert (quant_config is None) + self.quant_method: Optional[QuantizeMethodBase] = GemmRS() + elif fuse_ag_gemm: + assert (quant_config is None) + self.quant_method = AGCook() + elif quant_config is None: + self.quant_method = UnquantizedLinearMethod() else: self.quant_method = quant_config.get_quant_method(self, prefix=prefix) @@ -282,9 +390,15 @@ def __init__(self, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, output_sizes: Optional[List[int]] = None, - prefix: str = ""): - super().__init__(input_size, output_size, skip_bias_add, params_dtype, - quant_config, prefix) + prefix: str = "", + fuse_ag_gemm: bool = False): + super().__init__(input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix, + fuse_ag_gemm=fuse_ag_gemm) self.gather_output = gather_output @@ -419,7 +533,8 @@ def __init__(self, skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + prefix: str = "", + fuse_ag_gemm: bool = False): self.output_sizes = output_sizes tp_size = get_tensor_model_parallel_world_size() assert all(output_size % tp_size == 0 for output_size in output_sizes) @@ -430,7 +545,8 @@ def __init__(self, skip_bias_add=skip_bias_add, params_dtype=params_dtype, quant_config=quant_config, - prefix=prefix) + prefix=prefix, + fuse_ag_gemm=fuse_ag_gemm) def weight_loader(self, param: Parameter, @@ -667,7 +783,8 @@ def __init__(self, skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + prefix: str = "", + fuse_ag_gemm: bool = False): self.hidden_size = hidden_size self.head_size = head_size self.total_num_heads = total_num_heads @@ -700,7 +817,8 @@ def __init__(self, skip_bias_add=skip_bias_add, params_dtype=params_dtype, quant_config=quant_config, - prefix=prefix) + prefix=prefix, + fuse_ag_gemm=fuse_ag_gemm) def _get_shard_offset_mapping(self, loaded_shard_id: str): shard_offset_mapping = { @@ -995,12 +1113,20 @@ def __init__(self, params_dtype: Optional[torch.dtype] = None, reduce_results: bool = True, quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): - super().__init__(input_size, output_size, skip_bias_add, params_dtype, - quant_config, prefix) + prefix: str = "", + fuse_gemm_rs: bool = False): + super().__init__(input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix, + fuse_gemm_rs=fuse_gemm_rs) self.input_is_parallel = input_is_parallel self.reduce_results = reduce_results + if fuse_gemm_rs: + self.reduce_results = False # Divide the weight matrix along the last dimension. self.tp_rank = get_tensor_model_parallel_rank() diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 355b2f3ef8b28..d19a2966bd84a 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -30,7 +30,8 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -69,6 +70,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, bias: bool = False, prefix: str = "", + last_layer: bool = False, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( @@ -77,13 +79,15 @@ def __init__( bias=bias, quant_config=quant_config, prefix=f"{prefix}.gate_up_proj", - ) + fuse_ag_gemm=True) + self.down_proj = RowParallelLinear( input_size=intermediate_size, output_size=hidden_size, bias=bias, quant_config=quant_config, prefix=f"{prefix}.down_proj", + fuse_gemm_rs=(not last_layer), ) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " @@ -105,6 +109,7 @@ def __init__( hidden_size: int, num_heads: int, num_kv_heads: int, + first_layer: bool, rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, @@ -146,14 +151,14 @@ def __init__( bias=bias, quant_config=quant_config, prefix=f"{prefix}.qkv_proj", - ) - + fuse_ag_gemm=(not first_layer)) self.o_proj = RowParallelLinear( input_size=self.total_num_heads * self.head_dim, output_size=hidden_size, bias=bias, quant_config=quant_config, prefix=f"{prefix}.o_proj", + fuse_gemm_rs=True, ) is_neox_style = True @@ -198,6 +203,11 @@ class LlamaDecoderLayer(nn.Module): def __init__( self, config: LlamaConfig, + # Hack: pass in whether this is the first/last layer + # so we know if we can rewrite AllReduce -> ReduceScatter + AllGather, + # and then propagate the AllGather to the next layer. + first_layer: bool, + last_layer: bool, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -222,6 +232,7 @@ def __init__( num_heads=config.num_attention_heads, num_kv_heads=getattr(config, "num_key_value_heads", config.num_attention_heads), + first_layer=first_layer, rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, @@ -237,12 +248,16 @@ def __init__( quant_config=quant_config, bias=getattr(config, "mlp_bias", False), prefix=f"{prefix}.mlp", + last_layer=last_layer, ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.first_layer = first_layer + self.last_layer = last_layer + def forward( self, positions: torch.Tensor, @@ -256,17 +271,37 @@ def forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: + assert (hidden_states.shape == residual.shape) hidden_states, residual = self.input_layernorm( hidden_states, residual) - hidden_states = self.self_attn(positions=positions, - hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata) + + # Partition residual + if self.first_layer: + n_slices = get_tensor_model_parallel_world_size() + residual_slices = torch.chunk(residual, n_slices, dim=0) + my_residual = residual_slices[get_tensor_model_parallel_rank()] + else: + my_residual = residual + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + assert (hidden_states.shape == my_residual.shape) + hidden_states, my_residual = self.post_attention_layernorm( + hidden_states, my_residual) hidden_states = self.mlp(hidden_states) + + if self.last_layer: + residual = tensor_model_parallel_all_gather(my_residual, 0) + else: + residual = my_residual + + assert (hidden_states.shape == residual.shape) return hidden_states, residual diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index dcfd2cb7d2622..b3ee6bc5d99c7 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -501,14 +501,32 @@ def make_layers( """Make a list of layers with the given layer function, taking pipeline parallelism into account. """ + import inspect + from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.utils import get_pp_indices start_layer, end_layer = get_pp_indices(num_hidden_layers, get_pp_group().rank_in_group, get_pp_group().world_size) + + # Determine if layer_fn accepts first/last args by inspecting its signature + sig = inspect.signature(layer_fn) + has_firstlast_args = ('first_layer' + in sig.parameters) and ('last_layer' + in sig.parameters) + + def make_one_layer(idx, start_layer, end_layer): + if has_firstlast_args: + return maybe_offload_to_cpu( + layer_fn(prefix=f"{prefix}.{idx}", + first_layer=(idx == start_layer), + last_layer=(idx == end_layer - 1))) + else: + return maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}")) + modules = torch.nn.ModuleList( [PPMissingLayer() for _ in range(start_layer)] + [ - maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}")) + make_one_layer(idx, start_layer, end_layer) for idx in range(start_layer, end_layer) ] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)]) return start_layer, end_layer, modules From 62b5ab6d81ec994cf52c4d5a7704b1a4fca8fa89 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 7 Oct 2024 18:35:43 +0000 Subject: [PATCH 02/72] wip Signed-off-by: Bill Nell --- vllm/distributed/parallel_state.py | 28 +++++-- vllm/model_executor/layers/linear.py | 105 +++++++++++++++++++++++++-- vllm/model_executor/models/llama.py | 1 + 3 files changed, 121 insertions(+), 13 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index bbec1f57f665a..4ab2285bd8fff 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -30,10 +30,15 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union from unittest.mock import patch -import flux +try: + import flux + has_flux = True +except ImportError: + has_flux = False + import torch import torch.distributed -from torch.distributed import Backend, ProcessGroup +from torch.distributed import Backend, ProcessGroup, _symmetric_memory import vllm.envs as envs from vllm.logger import init_logger @@ -41,6 +46,9 @@ from vllm.utils import direct_register_custom_op, supports_custom_op +torch._inductor.config._micro_pipeline_tp = True + + @dataclass class GraphCaptureContext: stream: torch.cuda.Stream @@ -144,6 +152,7 @@ class GroupCoordinator: rank_in_group: int # rank inside the group cpu_group: ProcessGroup # group for CPU communication device_group: ProcessGroup # group for device communication + #device_mesh: DeviceMesh use_pynccl: bool # a hint of whether to use PyNccl use_custom_allreduce: bool # a hint of whether to use CustomAllreduce # communicators are only created for world size > 1 @@ -201,7 +210,7 @@ def __init__( self.use_xpu_communicator = use_xpu_communicator # Initialize pynvshmem - if torch.distributed.get_world_size(self.device_group) > 1: + if has_flux and torch.distributed.get_world_size(self.device_group) > 1: flux.init_flux_shm(self.device_group) # lazy import to avoid documentation build error @@ -932,12 +941,12 @@ def graph_capture(): logger = init_logger(__name__) -_ENABLE_CUSTOM_ALL_REDUCE = True +_ENABLE_CUSTOM_ALL_REDUCE = False #True def set_custom_all_reduce(enable: bool): global _ENABLE_CUSTOM_ALL_REDUCE - _ENABLE_CUSTOM_ALL_REDUCE = enable + _ENABLE_CUSTOM_ALL_REDUCE = False #enable def init_distributed_environment( @@ -961,6 +970,7 @@ def init_distributed_environment( init_method=distributed_init_method, world_size=world_size, rank=rank) + print(f"INIT {backend}, {distributed_init_method}, {world_size}, {rank}") # set the local rank # local_rank is not available in torch ProcessGroup, # see https://github.com/pytorch/pytorch/issues/122816 @@ -979,6 +989,10 @@ def init_distributed_environment( assert _WORLD.world_size == torch.distributed.get_world_size(), ( "world group already initialized with a different world size") + group_name = torch.distributed.group.WORLD.group_name + print(f"WORLD! {group_name}") + _symmetric_memory.enable_symm_mem_for_group(group_name) + def initialize_model_parallel( tensor_model_parallel_size: int = 1, @@ -1039,6 +1053,10 @@ def initialize_model_parallel( use_message_queue_broadcaster=True, group_name="tp") + print(f"ENABLE! {_TP.device_group.group_name}, {backend}") + _symmetric_memory.enable_symm_mem_for_group(_TP.device_group.group_name) + + # Build the pipeline model-parallel groups. num_pipeline_model_parallel_groups: int = (world_size // pipeline_model_parallel_size) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 19f0de158089a..1d3d65e03a561 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -2,9 +2,15 @@ from abc import abstractmethod from typing import Dict, List, Optional, Tuple -import flux +try: + import flux + has_flux = True +except ImportError: + has_flux = False + import torch import torch.nn.functional as F +import torch.distributed as D from torch.nn.parameter import Parameter, UninitializedParameter from vllm.distributed import (divide, get_tensor_model_parallel_rank, @@ -12,7 +18,7 @@ split_tensor_along_last_dim, tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) -from vllm.distributed.parallel_state import get_tp_group +from vllm.distributed.parallel_state import get_tp_group, get_world_group from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) @@ -139,7 +145,7 @@ def apply(self, return F.linear(x, layer.weight, bias) -class GemmRS(LinearMethodBase): +class FluxGemmRS(LinearMethodBase): #Fused Gemm-ReduceScatter without quantization. def __init__(self, separate_bias_add: bool = False): @@ -186,7 +192,7 @@ def apply(self, return output -class AGCook(LinearMethodBase): +class FluxAGCook(LinearMethodBase): #Fused AllGather-Gemm without quantization. def __init__(self, separate_bias_add: bool = False): @@ -237,6 +243,86 @@ def apply(self, return output +class GemmRS(LinearMethodBase): + #Fused Gemm-ReduceScatter without quantization. + + def __init__(self, separate_bias_add: bool = False): + self.separate_bias_add = separate_bias_add + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + weight = Parameter(torch.empty(sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype), + requires_grad=False) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, extra_weight_attrs) + print(f"out_partitions={output_partition_sizes}, input_size={input_size}, output_size={output_size}") + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + assert bias is None + + group_name = torch.distributed.group.WORLD.group_name # TODO: factor out to setup + + print(f"MATMUL_RS {group_name}") + + output = torch.ops.symm_mem.fused_matmul_reduce_scatter( + x, + layer.weight, + "avg", # ? + scatter_dim=0, # rows + group_name=group_name + ) + + return output + + +class AGCook(LinearMethodBase): + #Fused AllGather-Gemm without quantization. + + def __init__(self, separate_bias_add: bool = False): + self.separate_bias_add = separate_bias_add + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + weight = Parameter(torch.empty(sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype), + requires_grad=False) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, extra_weight_attrs) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + assert bias is None + + group_name = torch.distributed.group.WORLD.group_name + + print(f"AG_MATMUL {group_name}") + + ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_matmul( + x, + [layer.weight], #? + gather_dim=1, # cols + group_name=group_name, + ) + + return mm_outputs + + class LinearBase(torch.nn.Module): """Base linear layer. @@ -269,13 +355,16 @@ def __init__( if params_dtype is None: params_dtype = torch.get_default_dtype() self.params_dtype = params_dtype + self.quant_method: Optional[QuantizeMethodBase] = None - if fuse_gemm_rs: + tp_size = get_tensor_model_parallel_world_size() + + if fuse_gemm_rs and tp_size > 1: assert (quant_config is None) - self.quant_method: Optional[QuantizeMethodBase] = GemmRS() - elif fuse_ag_gemm: + self.quant_method = FluxGemmRS() if has_flux else GemmRS() + elif fuse_ag_gemm and tp_size > 1: assert (quant_config is None) - self.quant_method = AGCook() + self.quant_method = FluxAGCook() if has_flux else AGCook() elif quant_config is None: self.quant_method = UnquantizedLinearMethod() else: diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index d19a2966bd84a..5c352a54126ed 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -297,6 +297,7 @@ def forward( hidden_states = self.mlp(hidden_states) if self.last_layer: + print("GOT HERE") residual = tensor_model_parallel_all_gather(my_residual, 0) else: residual = my_residual From 93fe660e1cfd7aad40a1995f70011ce5e6c6dbe2 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 8 Oct 2024 20:37:34 +0000 Subject: [PATCH 03/72] fix Signed-off-by: Bill Nell --- vllm/model_executor/layers/linear.py | 88 ++++++++++++++++++++-------- vllm/model_executor/models/llama.py | 54 ++++++++++++----- 2 files changed, 104 insertions(+), 38 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 1d3d65e03a561..417c6952c64d0 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -243,7 +243,7 @@ def apply(self, return output -class GemmRS(LinearMethodBase): +class MatmulRS(LinearMethodBase): #Fused Gemm-ReduceScatter without quantization. def __init__(self, separate_bias_add: bool = False): @@ -261,7 +261,38 @@ def create_weights(self, layer: torch.nn.Module, set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) layer.register_parameter("weight", weight) set_weight_attrs(weight, extra_weight_attrs) - print(f"out_partitions={output_partition_sizes}, input_size={input_size}, output_size={output_size}") + print(f"inpp={input_size_per_partition}, output_part_siz={output_partition_sizes}, input_size={input_size}, output_size={output_size}") + + def apply_old(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + assert bias is None + + group_name = torch.distributed.group.WORLD.group_name # TODO: factor out to setup + + print(f"MATMUL_RS {group_name} {x.shape}, {layer.weight.transpose(1,0).shape}") + + if x.shape[0] % 2 != 0: + res = torch.matmul(x, layer.weight.transpose(1,0)) + output = D._symmetric_memory._SymmetricMemory.empty_strided_p2p(res.shape, + res.stride(), + res.dtype, + res.device, + group_name).copy_(res) + else: + output = torch.ops.symm_mem.fused_matmul_reduce_scatter( + x, + layer.weight.transpose(1, 0), + "avg", + scatter_dim=0, # ? + group_name=group_name + ) + + print(f"MATMUL_RS DONE {output.shape}") + + return output + def apply(self, layer: torch.nn.Module, @@ -271,20 +302,25 @@ def apply(self, group_name = torch.distributed.group.WORLD.group_name # TODO: factor out to setup - print(f"MATMUL_RS {group_name}") + print(f"MATMUL_RS {get_tp_group().rank} {x.shape}, {layer.weight.transpose(1,0).shape}") - output = torch.ops.symm_mem.fused_matmul_reduce_scatter( - x, - layer.weight, - "avg", # ? - scatter_dim=0, # rows - group_name=group_name - ) + if True or x.shape[0] % 2 != 0 or x.shape[0] < 128: + output = torch.matmul(x, layer.weight.transpose(1, 0)) + else: + output = torch.ops.symm_mem.fused_matmul_reduce_scatter( + x, + layer.weight.transpose(1, 0), + "avg", + scatter_dim=0, # ? + group_name=group_name + ) + + print(f"MATMUL_RS DONE {get_tp_group().rank} {output.shape}") return output -class AGCook(LinearMethodBase): +class AGMatmul(LinearMethodBase): #Fused AllGather-Gemm without quantization. def __init__(self, separate_bias_add: bool = False): @@ -311,16 +347,22 @@ def apply(self, group_name = torch.distributed.group.WORLD.group_name - print(f"AG_MATMUL {group_name}") + print(f"AG_MATMUL {get_tp_group().rank}, {x.shape}, {layer.weight.transpose(1,0).shape}") - ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_matmul( - x, - [layer.weight], #? - gather_dim=1, # cols - group_name=group_name, - ) + if True or x.shape[0] % 2 != 0 or x.shape[0] < 128: + output = torch.matmul(x, layer.weight.transpose(1,0)) + else: + ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_matmul( + x, + [layer.weight.transpose(1,0)], + gather_dim=0, + group_name=group_name, + ) + output = mm_outputs[0] - return mm_outputs + print(f"AG_MATMUL DONE {get_tp_group().rank}, {output.shape}") + + return output class LinearBase(torch.nn.Module): @@ -359,12 +401,12 @@ def __init__( tp_size = get_tensor_model_parallel_world_size() - if fuse_gemm_rs and tp_size > 1: + if False and fuse_gemm_rs and tp_size > 1: assert (quant_config is None) - self.quant_method = FluxGemmRS() if has_flux else GemmRS() - elif fuse_ag_gemm and tp_size > 1: + self.quant_method = FluxGemmRS() if has_flux else MatmulRS() + elif False and fuse_ag_gemm and tp_size > 1: assert (quant_config is None) - self.quant_method = FluxAGCook() if has_flux else AGCook() + self.quant_method = FluxAGCook() if has_flux else AGMatmul() elif quant_config is None: self.quant_method = UnquantizedLinearMethod() else: diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 5c352a54126ed..4537f7d2d7cba 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -31,7 +31,7 @@ from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, - tensor_model_parallel_all_gather) + tensor_model_parallel_all_gather, get_tp_group) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -275,34 +275,58 @@ def forward( hidden_states, residual = self.input_layernorm( hidden_states, residual) - # Partition residual - if self.first_layer: + def slices(residual) -> bool: + return [residual] + if residual.shape[0] < 128: + print(f"SLICES TOO SMALL {[residual.shape]}") + return [residual] + n_slices = get_tensor_model_parallel_world_size() residual_slices = torch.chunk(residual, n_slices, dim=0) - my_residual = residual_slices[get_tensor_model_parallel_rank()] + if all(r.shape == residual_slices[0].shape for r in residual_slices): + print(f"SLICES SAME {[r.shape for r in residual_slices]}") + return residual_slices + else: + print(f"SLICES TAIL {[residual.shape]}") + return [residual] + + # Partition residual + if self.first_layer: + residual_slices = slices(residual) + if len(residual_slices) > 1: + my_residual = residual_slices[get_tensor_model_parallel_rank()] + else: + my_residual = residual else: my_residual = residual - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, - ) + hidden_states = self.self_attn(positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata) # Fully Connected - assert (hidden_states.shape == my_residual.shape) + assert (hidden_states.shape == my_residual.shape), f"{hidden_states.shape} != {my_residual.shape}" hidden_states, my_residual = self.post_attention_layernorm( hidden_states, my_residual) hidden_states = self.mlp(hidden_states) - if self.last_layer: - print("GOT HERE") - residual = tensor_model_parallel_all_gather(my_residual, 0) + if self.last_layer and len(slices(residual)) > 1: + print(f"GOT HERE {my_residual.shape}") + if True: + residual = tensor_model_parallel_all_gather(my_residual, 0) + else: + residual = torch.ops._c10d_functional.all_gather_into_tensor( + my_residual.contiguous(), + get_tp_group().world_size, + torch.distributed.group.WORLD.group_name + ) + + print(f"GOT HERE2 {my_residual.shape}, {residual.shape}") else: residual = my_residual - assert (hidden_states.shape == residual.shape) + assert (hidden_states.shape == residual.shape), f"{hidden_states.shape} != {residual.shape}" return hidden_states, residual From c3176100008441c008eeb12ab78881b91c3644fd Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 14 Oct 2024 18:01:39 +0000 Subject: [PATCH 04/72] working naive Signed-off-by: Bill Nell --- vllm/distributed/parallel_state.py | 7 ++-- vllm/model_executor/layers/linear.py | 58 ++++++++-------------------- vllm/model_executor/models/llama.py | 31 +++++++-------- 3 files changed, 33 insertions(+), 63 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 4ab2285bd8fff..45df60fb91a02 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -152,7 +152,6 @@ class GroupCoordinator: rank_in_group: int # rank inside the group cpu_group: ProcessGroup # group for CPU communication device_group: ProcessGroup # group for device communication - #device_mesh: DeviceMesh use_pynccl: bool # a hint of whether to use PyNccl use_custom_allreduce: bool # a hint of whether to use CustomAllreduce # communicators are only created for world size > 1 @@ -941,7 +940,7 @@ def graph_capture(): logger = init_logger(__name__) -_ENABLE_CUSTOM_ALL_REDUCE = False #True +_ENABLE_CUSTOM_ALL_REDUCE = False # True def set_custom_all_reduce(enable: bool): @@ -1053,8 +1052,8 @@ def initialize_model_parallel( use_message_queue_broadcaster=True, group_name="tp") - print(f"ENABLE! {_TP.device_group.group_name}, {backend}") - _symmetric_memory.enable_symm_mem_for_group(_TP.device_group.group_name) + #print(f"ENABLE! {_TP.device_group.group_name}, {backend}") + #_symmetric_memory.enable_symm_mem_for_group(_TP.device_group.group_name) # Build the pipeline model-parallel groups. diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 417c6952c64d0..e29651483b3dd 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -243,6 +243,12 @@ def apply(self, return output +# This check is a hack +def should_slice(shape) -> bool: + n_slices = get_tensor_model_parallel_world_size() + return False and (shape[0] % n_slices == 0 and shape[0] >= 128) + + class MatmulRS(LinearMethodBase): #Fused Gemm-ReduceScatter without quantization. @@ -263,53 +269,24 @@ def create_weights(self, layer: torch.nn.Module, set_weight_attrs(weight, extra_weight_attrs) print(f"inpp={input_size_per_partition}, output_part_siz={output_partition_sizes}, input_size={input_size}, output_size={output_size}") - def apply_old(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - assert bias is None - - group_name = torch.distributed.group.WORLD.group_name # TODO: factor out to setup - - print(f"MATMUL_RS {group_name} {x.shape}, {layer.weight.transpose(1,0).shape}") - - if x.shape[0] % 2 != 0: - res = torch.matmul(x, layer.weight.transpose(1,0)) - output = D._symmetric_memory._SymmetricMemory.empty_strided_p2p(res.shape, - res.stride(), - res.dtype, - res.device, - group_name).copy_(res) - else: - output = torch.ops.symm_mem.fused_matmul_reduce_scatter( - x, - layer.weight.transpose(1, 0), - "avg", - scatter_dim=0, # ? - group_name=group_name - ) - - print(f"MATMUL_RS DONE {output.shape}") - - return output - - def apply(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: assert bias is None - group_name = torch.distributed.group.WORLD.group_name # TODO: factor out to setup - print(f"MATMUL_RS {get_tp_group().rank} {x.shape}, {layer.weight.transpose(1,0).shape}") - if True or x.shape[0] % 2 != 0 or x.shape[0] < 128: + if not should_slice(x.shape): + print("MATMUL_RS naive") output = torch.matmul(x, layer.weight.transpose(1, 0)) + # total hack + output = tensor_model_parallel_all_reduce(output) else: + group_name = torch.distributed.group.WORLD.group_name # TODO: factor out to setup output = torch.ops.symm_mem.fused_matmul_reduce_scatter( x, - layer.weight.transpose(1, 0), + layer.weight.transpose(1, 0).contiguous(), "avg", scatter_dim=0, # ? group_name=group_name @@ -345,16 +322,15 @@ def apply(self, bias: Optional[torch.Tensor] = None) -> torch.Tensor: assert bias is None - group_name = torch.distributed.group.WORLD.group_name - print(f"AG_MATMUL {get_tp_group().rank}, {x.shape}, {layer.weight.transpose(1,0).shape}") - if True or x.shape[0] % 2 != 0 or x.shape[0] < 128: + if not should_slice(x.shape): output = torch.matmul(x, layer.weight.transpose(1,0)) else: + group_name = torch.distributed.group.WORLD.group_name ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_matmul( x, - [layer.weight.transpose(1,0)], + [layer.weight.transpose(1,0).contiguous()], gather_dim=0, group_name=group_name, ) @@ -401,10 +377,10 @@ def __init__( tp_size = get_tensor_model_parallel_world_size() - if False and fuse_gemm_rs and tp_size > 1: + if fuse_gemm_rs and tp_size > 1: assert (quant_config is None) self.quant_method = FluxGemmRS() if has_flux else MatmulRS() - elif False and fuse_ag_gemm and tp_size > 1: + elif fuse_ag_gemm and tp_size > 1: assert (quant_config is None) self.quant_method = FluxAGCook() if has_flux else AGMatmul() elif quant_config is None: diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 4537f7d2d7cba..2e7fcf95ec6a6 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -36,7 +36,8 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, - RowParallelLinear) + RowParallelLinear, + should_slice) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig @@ -275,28 +276,22 @@ def forward( hidden_states, residual = self.input_layernorm( hidden_states, residual) + print(f"RESIDUAL SHAPE = {residual.shape}") + def slices(residual) -> bool: - return [residual] - if residual.shape[0] < 128: + if not should_slice(residual.shape): print(f"SLICES TOO SMALL {[residual.shape]}") - return [residual] + return [] n_slices = get_tensor_model_parallel_world_size() residual_slices = torch.chunk(residual, n_slices, dim=0) - if all(r.shape == residual_slices[0].shape for r in residual_slices): - print(f"SLICES SAME {[r.shape for r in residual_slices]}") - return residual_slices - else: - print(f"SLICES TAIL {[residual.shape]}") - return [residual] + print(f"SLICES {[r.shape for r in residual_slices]}") + return residual_slices # Partition residual - if self.first_layer: - residual_slices = slices(residual) - if len(residual_slices) > 1: - my_residual = residual_slices[get_tensor_model_parallel_rank()] - else: - my_residual = residual + residual_slices = slices(residual) if self.first_layer else [] + if len(residual_slices) > 0: + my_residual = residual_slices[get_tensor_model_parallel_rank()] else: my_residual = residual @@ -311,8 +306,8 @@ def slices(residual) -> bool: hidden_states, my_residual) hidden_states = self.mlp(hidden_states) - if self.last_layer and len(slices(residual)) > 1: - print(f"GOT HERE {my_residual.shape}") + if self.last_layer and len(residual_slices) > 0: + print(f"FINAL REDUCE {my_residual.shape}") if True: residual = tensor_model_parallel_all_gather(my_residual, 0) else: From 57b3e748c544f0861cec314ac0b36c858947667a Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 14 Oct 2024 18:28:52 +0000 Subject: [PATCH 05/72] working real Signed-off-by: Bill Nell --- vllm/model_executor/layers/linear.py | 2 +- vllm/model_executor/models/llama.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index e29651483b3dd..9478308254436 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -246,7 +246,7 @@ def apply(self, # This check is a hack def should_slice(shape) -> bool: n_slices = get_tensor_model_parallel_world_size() - return False and (shape[0] % n_slices == 0 and shape[0] >= 128) + return (shape[0] % n_slices == 0 and shape[0] >= 128) class MatmulRS(LinearMethodBase): diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 2e7fcf95ec6a6..afe9b4a22cde9 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -288,6 +288,8 @@ def slices(residual) -> bool: print(f"SLICES {[r.shape for r in residual_slices]}") return residual_slices + orig_residual = residual + # Partition residual residual_slices = slices(residual) if self.first_layer else [] if len(residual_slices) > 0: @@ -306,7 +308,8 @@ def slices(residual) -> bool: hidden_states, my_residual) hidden_states = self.mlp(hidden_states) - if self.last_layer and len(residual_slices) > 0: + print(f"LAST_LAYER = {self.last_layer}, #slices = {len(residual_slices)}") + if self.last_layer and len(slices(orig_residual)) > 0: print(f"FINAL REDUCE {my_residual.shape}") if True: residual = tensor_model_parallel_all_gather(my_residual, 0) From 296d65d74524e51b8338ec349a223427b0810e32 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 14 Oct 2024 18:29:42 +0000 Subject: [PATCH 06/72] working real Signed-off-by: Bill Nell --- vllm/model_executor/models/llama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index afe9b4a22cde9..6734890a9d4ed 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -288,7 +288,7 @@ def slices(residual) -> bool: print(f"SLICES {[r.shape for r in residual_slices]}") return residual_slices - orig_residual = residual + orig_residual_shape = residual.shape # Partition residual residual_slices = slices(residual) if self.first_layer else [] @@ -309,7 +309,7 @@ def slices(residual) -> bool: hidden_states = self.mlp(hidden_states) print(f"LAST_LAYER = {self.last_layer}, #slices = {len(residual_slices)}") - if self.last_layer and len(slices(orig_residual)) > 0: + if self.last_layer and should_slice(orig_residual_shape) > 0: print(f"FINAL REDUCE {my_residual.shape}") if True: residual = tensor_model_parallel_all_gather(my_residual, 0) From 7c430683c3249e0f72310f773bb71aab2adc47d9 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 14 Oct 2024 20:05:37 +0000 Subject: [PATCH 07/72] work w/torch.compile Signed-off-by: Bill Nell --- vllm/compilation/backends.py | 1 + vllm/model_executor/layers/linear.py | 17 +++++++++++------ vllm/model_executor/models/llama.py | 22 ++++++++++++++-------- 3 files changed, 26 insertions(+), 14 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 464bc2af8fd6d..a994a69b761ce 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -37,6 +37,7 @@ def wrap_inductor(graph, logger.info("Compiling a graph for shape %s", runtime_shape) from torch._inductor import config + torch._inductor.config._micro_pipeline_tp = True current_config = config.shallow_copy_dict() from torch._inductor.compile_fx import compile_fx diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 9478308254436..858358911dcc1 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -41,6 +41,11 @@ ] +def pprint(x): + #print(x) + pass + + def adjust_marlin_shard(param, shard_size, shard_offset): marlin_tile_size = getattr(param, "marlin_tile_size", None) if marlin_tile_size is None: @@ -267,7 +272,7 @@ def create_weights(self, layer: torch.nn.Module, set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) layer.register_parameter("weight", weight) set_weight_attrs(weight, extra_weight_attrs) - print(f"inpp={input_size_per_partition}, output_part_siz={output_partition_sizes}, input_size={input_size}, output_size={output_size}") + pprint(f"inpp={input_size_per_partition}, output_part_siz={output_partition_sizes}, input_size={input_size}, output_size={output_size}") def apply(self, layer: torch.nn.Module, @@ -275,10 +280,10 @@ def apply(self, bias: Optional[torch.Tensor] = None) -> torch.Tensor: assert bias is None - print(f"MATMUL_RS {get_tp_group().rank} {x.shape}, {layer.weight.transpose(1,0).shape}") + pprint(f"MATMUL_RS {get_tp_group().rank} {x.shape}, {layer.weight.transpose(1,0).shape}") if not should_slice(x.shape): - print("MATMUL_RS naive") + pprint("MATMUL_RS naive") output = torch.matmul(x, layer.weight.transpose(1, 0)) # total hack output = tensor_model_parallel_all_reduce(output) @@ -292,7 +297,7 @@ def apply(self, group_name=group_name ) - print(f"MATMUL_RS DONE {get_tp_group().rank} {output.shape}") + pprint(f"MATMUL_RS DONE {get_tp_group().rank} {output.shape}") return output @@ -322,7 +327,7 @@ def apply(self, bias: Optional[torch.Tensor] = None) -> torch.Tensor: assert bias is None - print(f"AG_MATMUL {get_tp_group().rank}, {x.shape}, {layer.weight.transpose(1,0).shape}") + pprint(f"AG_MATMUL {get_tp_group().rank}, {x.shape}, {layer.weight.transpose(1,0).shape}") if not should_slice(x.shape): output = torch.matmul(x, layer.weight.transpose(1,0)) @@ -336,7 +341,7 @@ def apply(self, ) output = mm_outputs[0] - print(f"AG_MATMUL DONE {get_tp_group().rank}, {output.shape}") + pprint(f"AG_MATMUL DONE {get_tp_group().rank}, {output.shape}") return output diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 6734890a9d4ed..0a8da41f02a35 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -61,6 +61,11 @@ maybe_prefix) +def pprint(x): + #print(x) + pass + + class LlamaMLP(nn.Module): def __init__( @@ -276,16 +281,16 @@ def forward( hidden_states, residual = self.input_layernorm( hidden_states, residual) - print(f"RESIDUAL SHAPE = {residual.shape}") + pprint(f"RESIDUAL SHAPE = {residual.shape}") def slices(residual) -> bool: if not should_slice(residual.shape): - print(f"SLICES TOO SMALL {[residual.shape]}") + pprint(f"SLICES TOO SMALL {[residual.shape]}") return [] n_slices = get_tensor_model_parallel_world_size() residual_slices = torch.chunk(residual, n_slices, dim=0) - print(f"SLICES {[r.shape for r in residual_slices]}") + pprint(f"SLICES {[r.shape for r in residual_slices]}") return residual_slices orig_residual_shape = residual.shape @@ -308,10 +313,10 @@ def slices(residual) -> bool: hidden_states, my_residual) hidden_states = self.mlp(hidden_states) - print(f"LAST_LAYER = {self.last_layer}, #slices = {len(residual_slices)}") - if self.last_layer and should_slice(orig_residual_shape) > 0: - print(f"FINAL REDUCE {my_residual.shape}") - if True: + pprint(f"LAST_LAYER = {self.last_layer}, #slices = {len(residual_slices)}") + if self.last_layer and should_slice(orig_residual_shape): + pprint(f"FINAL REDUCE {my_residual.shape}") + if False: residual = tensor_model_parallel_all_gather(my_residual, 0) else: residual = torch.ops._c10d_functional.all_gather_into_tensor( @@ -319,8 +324,9 @@ def slices(residual) -> bool: get_tp_group().world_size, torch.distributed.group.WORLD.group_name ) + residual = torch.ops._c10d_functional.wait_tensor(residual) - print(f"GOT HERE2 {my_residual.shape}, {residual.shape}") + pprint(f"GOT HERE2 {my_residual.shape}, {residual.shape}") else: residual = my_residual From aa61f876dd71b8cd4f71ba837da34dfa27f8dcc7 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 14 Oct 2024 20:08:28 +0000 Subject: [PATCH 08/72] work w/torch.compile Signed-off-by: Bill Nell --- vllm/distributed/parallel_state.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 45df60fb91a02..e23a05457eaae 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -46,7 +46,7 @@ from vllm.utils import direct_register_custom_op, supports_custom_op -torch._inductor.config._micro_pipeline_tp = True +#torch._inductor.config._micro_pipeline_tp = True @dataclass @@ -940,12 +940,12 @@ def graph_capture(): logger = init_logger(__name__) -_ENABLE_CUSTOM_ALL_REDUCE = False # True +_ENABLE_CUSTOM_ALL_REDUCE = True def set_custom_all_reduce(enable: bool): global _ENABLE_CUSTOM_ALL_REDUCE - _ENABLE_CUSTOM_ALL_REDUCE = False #enable + _ENABLE_CUSTOM_ALL_REDUCE = enable def init_distributed_environment( From 17020db04f00ea9b4c769d9187f6629eeb0f7dd0 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 15 Oct 2024 17:24:05 +0000 Subject: [PATCH 09/72] add fuse_gemms flag to turn it on/off Signed-off-by: Bill Nell --- vllm/distributed/parallel_state.py | 7 ------- vllm/model_executor/models/llama.py | 19 +++++++++++++------ 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index e23a05457eaae..c2fbe70d3e7d9 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -46,9 +46,6 @@ from vllm.utils import direct_register_custom_op, supports_custom_op -#torch._inductor.config._micro_pipeline_tp = True - - @dataclass class GraphCaptureContext: stream: torch.cuda.Stream @@ -1052,10 +1049,6 @@ def initialize_model_parallel( use_message_queue_broadcaster=True, group_name="tp") - #print(f"ENABLE! {_TP.device_group.group_name}, {backend}") - #_symmetric_memory.enable_symm_mem_for_group(_TP.device_group.group_name) - - # Build the pipeline model-parallel groups. num_pipeline_model_parallel_groups: int = (world_size // pipeline_model_parallel_size) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 0a8da41f02a35..e493efde0e6a4 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -77,6 +77,7 @@ def __init__( bias: bool = False, prefix: str = "", last_layer: bool = False, + fuse_gemms = True, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( @@ -85,7 +86,7 @@ def __init__( bias=bias, quant_config=quant_config, prefix=f"{prefix}.gate_up_proj", - fuse_ag_gemm=True) + fuse_ag_gemm=fuse_gemms) self.down_proj = RowParallelLinear( input_size=intermediate_size, @@ -93,7 +94,7 @@ def __init__( bias=bias, quant_config=quant_config, prefix=f"{prefix}.down_proj", - fuse_gemm_rs=(not last_layer), + fuse_gemm_rs=(not last_layer) and fuse_gemms, ) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " @@ -123,6 +124,7 @@ def __init__( bias: bool = False, cache_config: Optional[CacheConfig] = None, prefix: str = "", + fuse_gemms=True, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -157,14 +159,14 @@ def __init__( bias=bias, quant_config=quant_config, prefix=f"{prefix}.qkv_proj", - fuse_ag_gemm=(not first_layer)) + fuse_ag_gemm=(not first_layer) and fuse_gemms) self.o_proj = RowParallelLinear( input_size=self.total_num_heads * self.head_dim, output_size=hidden_size, bias=bias, quant_config=quant_config, prefix=f"{prefix}.o_proj", - fuse_gemm_rs=True, + fuse_gemm_rs=fuse_gemms, ) is_neox_style = True @@ -217,8 +219,10 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + fuse_gemms = True, ) -> None: super().__init__() + self.fuse_gemms = fuse_gemms self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) @@ -246,6 +250,7 @@ def __init__( bias=attention_bias, cache_config=cache_config, prefix=f"{prefix}.self_attn", + fuse_gemms=self.fuse_gemms, ) self.mlp = LlamaMLP( hidden_size=self.hidden_size, @@ -255,6 +260,7 @@ def __init__( bias=getattr(config, "mlp_bias", False), prefix=f"{prefix}.mlp", last_layer=last_layer, + fuse_gemms=self.fuse_gemms, ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -284,7 +290,7 @@ def forward( pprint(f"RESIDUAL SHAPE = {residual.shape}") def slices(residual) -> bool: - if not should_slice(residual.shape): + if not self.fuse_gemms or not should_slice(residual.shape): pprint(f"SLICES TOO SMALL {[residual.shape]}") return [] @@ -314,7 +320,7 @@ def slices(residual) -> bool: hidden_states = self.mlp(hidden_states) pprint(f"LAST_LAYER = {self.last_layer}, #slices = {len(residual_slices)}") - if self.last_layer and should_slice(orig_residual_shape): + if self.fuse_gemms and self.last_layer and should_slice(orig_residual_shape): pprint(f"FINAL REDUCE {my_residual.shape}") if False: residual = tensor_model_parallel_all_gather(my_residual, 0) @@ -343,6 +349,7 @@ def __init__(self, prefix: str = "", layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer): super().__init__() + fuse_gemms = False #True config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config From 1f5fe34a5e58e19c1770a740c7ee7d8d1e18954e Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 15 Oct 2024 22:07:50 +0000 Subject: [PATCH 10/72] pattern wip Signed-off-by: Bill Nell --- vllm/compilation/backends.py | 90 ++++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index a994a69b761ce..10648dbee95ca 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -18,6 +18,96 @@ logger = init_logger(__name__) +aten = torch.ops.aten + +FILENO=0 + +def match_gemm_rs_ag_gemm_orig(): + permute_2 = torch.ops.aten.permute(arg7_1, [1, 0]) + mm_1 = torch.ops.aten.mm(getitem_22, permute_2) + auto_functionalized_4 = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.inplace_all_reduce, tensor=mm_1, group_name='tp:0') + getitem_25 = auto_functionalized_4[1] + auto_functionalized_5 = torch.ops.higher_order.auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input=getitem_25, residual=getitem_1, weight=arg8_1, epsilon=1e-05) + getitem_27 = auto_functionalized_5[1] + getitem_28 = auto_functionalized_5[2] + permute_3 = torch.ops.aten.permute(arg9_1, [1, 0]) + mm_2 = torch.ops.aten.mm(getitem_27, permute_3) + return mm_2 + + +def match_gemm_rs_ag_gemm_small(arg7_1, getitem_22): + permute_2 = torch.ops.aten.permute.default(arg7_1, [1, 0]) + mm_1 = torch.ops.aten.mm.default(getitem_22, permute_2) + auto_functionalized_4 = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.inplace_all_reduce.default, tensor = mm_1, group_name = 'tp:0') # how to deal with groupname? + getitem_25 = auto_functionalized_4[1] + return getitem_25 + + +def match_gemm_rs_ag_gemm_med(arg7_1, getitem_22, getitem_1, arg8_1): + permute_2 = torch.ops.aten.permute.default(arg7_1, [1, 0]) + mm_1 = torch.ops.aten.mm.default(getitem_22, permute_2) + auto_functionalized_4 = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.inplace_all_reduce.default, tensor = mm_1, group_name = 'tp:0') # how to deal with groupname? + getitem_25 = auto_functionalized_4[1] + auto_functionalized_5 = torch.ops.higher_order.auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input = getitem_25, residual = getitem_1, weight = arg8_1, epsilon = 1e-05) + getitem_27 = auto_functionalized_5[1] + getitem_28 = auto_functionalized_5[2] + return getitem_27, getitem_28 + #permute_3 = torch.ops.aten.permute.default(arg9_1, [1, 0]) + #mm_2 = torch.ops.aten.mm.default(getitem_27, permute_3) + #return mm_2 + + +def match_gemm_rs_ag_gemm(arg7_1, getitem_22, arg8_1, getitem_1, arg9_1): + permute_2 = torch.ops.aten.permute.default(arg7_1, [1, 0]) + mm_1 = torch.ops.aten.mm.default(getitem_22, permute_2) + auto_functionalized_4 = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.inplace_all_reduce.default, tensor = mm_1, group_name = 'tp:0') # how to deal with groupname? + getitem_25 = auto_functionalized_4[1] + auto_functionalized_5 = torch.ops.higher_order.auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input = getitem_25, residual = getitem_1, weight = arg8_1, epsilon = 1e-05) + getitem_27 = auto_functionalized_5[1] + getitem_28 = auto_functionalized_5[2] + permute_3 = torch.ops.aten.permute.default(arg9_1, [1, 0]) + mm_2 = torch.ops.aten.mm.default(getitem_27, permute_3) + return mm_2, getitem_28 + + +def replace_gemm_rs_ag_gemm(arg7_1, getitem_24, arg8_1, getitem_26, arg9_1): + permute_3 = torch.ops.aten.permute.default(arg7_1, [1, 0]) + clone = torch.ops.aten.clone.default(permute_3, memory_format = torch.contiguous_format) + fused_matmul_reduce_scatter = torch.ops.symm_mem.fused_matmul_reduce_scatter.default(getitem_24, clone, 'avg', 0, '0') + auto_functionalized_4 = torch.ops.higher_order.auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input = fused_matmul_reduce_scatter, residual = getitem_26, weight = arg8_1, epsilon = 1e-05) + getitem_29 = auto_functionalized_4[1] + getitem_30 = auto_functionalized_4[2] + slice_scatter_2 = torch.ops.aten.slice_scatter.default(getitem_1, getitem_30, 0, 0, 2048) + split_2 = torch.ops.aten.split.Tensor(slice_scatter_2, 2048) + getitem_31 = split_2[0] + permute_5 = torch.ops.aten.permute.default(arg9_1, [1, 0]) + clone_1 = torch.ops.aten.clone.default(permute_5, memory_format = torch.contiguous_format) + fused_all_gather_matmul = torch.ops.symm_mem.fused_all_gather_matmul.default(getitem_29, [clone_1], 0, '0') + getitem_34 = fused_all_gather_matmul[1] + getitem_35 = getitem_34[0] + return getitem_34, getitem_35 + + +my_patterns = PatternMatcherPass() +x = torch.empty([4,4], device='cuda') +w = torch.empty([4,4], device='cuda') +resid = torch.empty([4,4], device='cuda') +resid_w = torch.empty([4,4], device='cuda') +x2 = torch.empty([4,4], device='cuda') +inputs = [x, w, resid, resid_w, x2] +inputs_small = inputs[0:2] +inputs_med = inputs[0:4] +register_replacement(match_gemm_rs_ag_gemm, + replace_gemm_rs_ag_gemm, + inputs, + fwd_only, + [my_patterns]) + +def async_rewrite(graph: fx.Graph): + count = my_patterns.apply(graph) + print(f"match count = {count}") + return graph + def wrap_inductor(graph, example_inputs, From 25400bf6d696fe12be886841d0b780567e0f0761 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 15 Oct 2024 22:16:51 +0000 Subject: [PATCH 11/72] wip Signed-off-by: Bill Nell --- vllm/compilation/backends.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 10648dbee95ca..2a98045e0a4a7 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -70,7 +70,10 @@ def match_gemm_rs_ag_gemm(arg7_1, getitem_22, arg8_1, getitem_1, arg9_1): return mm_2, getitem_28 -def replace_gemm_rs_ag_gemm(arg7_1, getitem_24, arg8_1, getitem_26, arg9_1): +# getitem_1 full residual +def replace_gemm_rs_ag_gemm(arg7_1, getitem_24, arg8_1, getitem_1, arg9_1): + split_1 = torch.ops.aten.split.Tensor(getitem_1, 2048) + getitem_26 = split_1[0]; split_1 = None permute_3 = torch.ops.aten.permute.default(arg7_1, [1, 0]) clone = torch.ops.aten.clone.default(permute_3, memory_format = torch.contiguous_format) fused_matmul_reduce_scatter = torch.ops.symm_mem.fused_matmul_reduce_scatter.default(getitem_24, clone, 'avg', 0, '0') @@ -79,13 +82,13 @@ def replace_gemm_rs_ag_gemm(arg7_1, getitem_24, arg8_1, getitem_26, arg9_1): getitem_30 = auto_functionalized_4[2] slice_scatter_2 = torch.ops.aten.slice_scatter.default(getitem_1, getitem_30, 0, 0, 2048) split_2 = torch.ops.aten.split.Tensor(slice_scatter_2, 2048) - getitem_31 = split_2[0] + getitem_31 = split_2[0] # local residual permute_5 = torch.ops.aten.permute.default(arg9_1, [1, 0]) clone_1 = torch.ops.aten.clone.default(permute_5, memory_format = torch.contiguous_format) fused_all_gather_matmul = torch.ops.symm_mem.fused_all_gather_matmul.default(getitem_29, [clone_1], 0, '0') getitem_34 = fused_all_gather_matmul[1] getitem_35 = getitem_34[0] - return getitem_34, getitem_35 + return getitem_35, getitem_31 my_patterns = PatternMatcherPass() From 1d3b3aa1d5abfa9cae1caaf4e7c99d72cf8a418b Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 16 Oct 2024 01:33:29 +0000 Subject: [PATCH 12/72] final pattern Signed-off-by: Bill Nell --- vllm/compilation/backends.py | 35 ++++++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 2a98045e0a4a7..4e785b6e0bc5b 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -91,7 +91,30 @@ def replace_gemm_rs_ag_gemm(arg7_1, getitem_24, arg8_1, getitem_1, arg9_1): return getitem_35, getitem_31 +def match_final(arg227_1, getitem_1022, getitem_1020, arg228_1): + permute_128 = torch.ops.aten.permute.default(arg227_1, [1, 0]) + mm_127 = torch.ops.aten.mm.default(getitem_1022, permute_128) + auto_functionalized_224 = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.inplace_all_reduce.default, tensor = mm_127, group_name = 'tp:0') + getitem_1024 = auto_functionalized_224[1] + auto_functionalized_225 = torch.ops.higher_order.auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input = getitem_1024, residual = getitem_1020, weight = arg228_1, epsilon = 1e-05) + getitem_1026 = auto_functionalized_225[1] + return getitem_1026 + + +def replace_final(arg227_1, getitem_1215, getitem_1209, arg228_1): + permute_254 = torch.ops.aten.permute.default(arg227_1, [1, 0]) + mm_1 = torch.ops.aten.mm.default(getitem_1215, permute_254) + auto_functionalized_161 = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.inplace_all_reduce.default, tensor = mm_1, group_name = 'tp:0') + getitem_1217 = auto_functionalized_161[1] + all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_1209, 2, '0') + wait_tensor = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor) + auto_functionalized_162 = torch.ops.higher_order.auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input = getitem_1217, residual = wait_tensor, weight = arg228_1, epsilon = 1e-05) + getitem_1219 = auto_functionalized_162[1] + return getitem_1219 + + my_patterns = PatternMatcherPass() +my_patterns2 = PatternMatcherPass() x = torch.empty([4,4], device='cuda') w = torch.empty([4,4], device='cuda') resid = torch.empty([4,4], device='cuda') @@ -100,15 +123,25 @@ def replace_gemm_rs_ag_gemm(arg7_1, getitem_24, arg8_1, getitem_1, arg9_1): inputs = [x, w, resid, resid_w, x2] inputs_small = inputs[0:2] inputs_med = inputs[0:4] + register_replacement(match_gemm_rs_ag_gemm, replace_gemm_rs_ag_gemm, inputs, fwd_only, [my_patterns]) +final_inputs = [x, w, resid, resid_w] +register_replacement(match_final, + replace_final, + final_inputs, + fwd_only, + [my_patterns2]) + def async_rewrite(graph: fx.Graph): count = my_patterns.apply(graph) - print(f"match count = {count}") + print(f"fused gemm match count = {count}") + count = my_patterns2.apply(graph) + print(f"final match count = {count}") return graph From 91462a86f55f84c28f615ce338643bad744efb89 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 18 Oct 2024 21:17:55 +0000 Subject: [PATCH 13/72] progress Signed-off-by: Bill Nell --- vllm/compilation/backends.py | 195 ++++++++++++++++++++++++---- vllm/model_executor/models/llama.py | 2 +- 2 files changed, 169 insertions(+), 28 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 4e785b6e0bc5b..9220d3a9c5e9e 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -6,6 +6,7 @@ import torch import torch.fx as fx +from typing import Tuple, List, Optional import vllm.envs as envs from vllm.config import CompilationConfig @@ -22,6 +23,12 @@ FILENO=0 + +def pprint(x): + #print(x) + pass + + def match_gemm_rs_ag_gemm_orig(): permute_2 = torch.ops.aten.permute(arg7_1, [1, 0]) mm_1 = torch.ops.aten.mm(getitem_22, permute_2) @@ -71,7 +78,7 @@ def match_gemm_rs_ag_gemm(arg7_1, getitem_22, arg8_1, getitem_1, arg9_1): # getitem_1 full residual -def replace_gemm_rs_ag_gemm(arg7_1, getitem_24, arg8_1, getitem_1, arg9_1): +def replace_gemm_rs_ag_gemm_orig(arg7_1, getitem_24, arg8_1, getitem_1, arg9_1): split_1 = torch.ops.aten.split.Tensor(getitem_1, 2048) getitem_26 = split_1[0]; split_1 = None permute_3 = torch.ops.aten.permute.default(arg7_1, [1, 0]) @@ -90,11 +97,60 @@ def replace_gemm_rs_ag_gemm(arg7_1, getitem_24, arg8_1, getitem_1, arg9_1): getitem_35 = getitem_34[0] return getitem_35, getitem_31 +def slices(residual) -> List[torch.Tensor]: + n_slices = get_tensor_model_parallel_world_size() + residual_slices = torch.chunk(residual, n_slices, dim=0) + #pprint(f"SLICES {[r.shape for r in residual_slices]}") + return residual_slices + +@torch.library.custom_op("vllm::gemm_rs_ag_gemm", mutates_args=()) +def gemm_rs_ag_gemm(arg7_1: torch.Tensor, getitem_22: torch.Tensor, arg8_1: torch.Tensor, getitem_1: torch.Tensor, arg9_1: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + group_name = torch.distributed.group.WORLD.group_name # TODO: factor out to setup + + # this is terrible + if True: + res_slices = slices(getitem_1) + slice_size = res_slices[get_tensor_model_parallel_rank()].shape[0] + print(f"SLICE_SIZE = {slice_size}, orig_shape={getitem_1.shape}, slice_shapes=[{[x.shape for x in res_slices]}]") + else: + slice_size = 2048 + + split_1 = torch.ops.aten.split.Tensor(getitem_1, slice_size) # XXXXXXXXXXX + getitem_26 = split_1[0]; split_1 = None + permute_3 = torch.ops.aten.permute.default(arg7_1, [1, 0]) + clone = torch.ops.aten.clone.default(permute_3, memory_format = torch.contiguous_format) + fused_matmul_reduce_scatter = torch.ops.symm_mem.fused_matmul_reduce_scatter.default(getitem_22, clone, 'avg', 0, group_name) # XXXXXXXXXXXXX + auto_functionalized_4 = torch.ops.higher_order.auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input = fused_matmul_reduce_scatter, residual = getitem_26, weight = arg8_1, epsilon = 1e-05) + getitem_29 = auto_functionalized_4[1] + getitem_30 = auto_functionalized_4[2] + slice_scatter_2 = torch.ops.aten.slice_scatter.default(getitem_1, getitem_30, 0, 0, slice_size) # XXXXXXXXXXXXXXX + split_2 = torch.ops.aten.split.Tensor(slice_scatter_2, slice_size) # XXXXXXXXXXX + getitem_31 = split_2[0] # local residual + permute_5 = torch.ops.aten.permute.default(arg9_1, [1, 0]) + clone_1 = torch.ops.aten.clone.default(permute_5, memory_format = torch.contiguous_format) + fused_all_gather_matmul = torch.ops.symm_mem.fused_all_gather_matmul.default(getitem_29, [clone_1], 0, group_name) # XXXXXXXXXXX + getitem_34 = fused_all_gather_matmul[1] + getitem_35 = getitem_34[0] + return getitem_35, getitem_31 # matmul, residual + +# this is wrong +@torch.library.register_fake("vllm::gemm_rs_ag_gemm") +def gemm_rs_ag_gemm_fake(arg7_1: torch.Tensor, getitem_22: torch.Tensor, arg8_1: torch.Tensor, getitem_1: torch.Tensor, arg9_1: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + mm_res = torch.empty((getitem_22.shape[0], arg9_1.shape[0]), device=getitem_22.device, dtype=getitem_22.dtype) #??? + resid = torch.empty_like(getitem_1) + return (mm_res, resid) + +def replace_gemm_rs_ag_gemm(arg7_1, getitem_22, arg8_1, getitem_1, arg9_1): + results = torch.ops.vllm.gemm_rs_ag_gemm(arg7_1, getitem_22, arg8_1, getitem_1, arg9_1) + getitem_34 = results[0] + getitem_35 = results[1] + return getitem_34, getitem_35 + def match_final(arg227_1, getitem_1022, getitem_1020, arg228_1): permute_128 = torch.ops.aten.permute.default(arg227_1, [1, 0]) mm_127 = torch.ops.aten.mm.default(getitem_1022, permute_128) - auto_functionalized_224 = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.inplace_all_reduce.default, tensor = mm_127, group_name = 'tp:0') + auto_functionalized_224 = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.inplace_all_reduce.default, tensor = mm_127, group_name = 'tp:0') # TODO: not same as group name getitem_1024 = auto_functionalized_224[1] auto_functionalized_225 = torch.ops.higher_order.auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input = getitem_1024, residual = getitem_1020, weight = arg228_1, epsilon = 1e-05) getitem_1026 = auto_functionalized_225[1] @@ -102,46 +158,131 @@ def match_final(arg227_1, getitem_1022, getitem_1020, arg228_1): def replace_final(arg227_1, getitem_1215, getitem_1209, arg228_1): + group_name = torch.distributed.group.WORLD.group_name # factor out? permute_254 = torch.ops.aten.permute.default(arg227_1, [1, 0]) mm_1 = torch.ops.aten.mm.default(getitem_1215, permute_254) - auto_functionalized_161 = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.inplace_all_reduce.default, tensor = mm_1, group_name = 'tp:0') + auto_functionalized_161 = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.inplace_all_reduce.default, tensor = mm_1, group_name = 'tp:0') # TODO: not same as group name getitem_1217 = auto_functionalized_161[1] - all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_1209, 2, '0') + all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_1209, 2, group_name) wait_tensor = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor) auto_functionalized_162 = torch.ops.higher_order.auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input = getitem_1217, residual = wait_tensor, weight = arg228_1, epsilon = 1e-05) getitem_1219 = auto_functionalized_162[1] return getitem_1219 -my_patterns = PatternMatcherPass() -my_patterns2 = PatternMatcherPass() -x = torch.empty([4,4], device='cuda') -w = torch.empty([4,4], device='cuda') -resid = torch.empty([4,4], device='cuda') -resid_w = torch.empty([4,4], device='cuda') -x2 = torch.empty([4,4], device='cuda') -inputs = [x, w, resid, resid_w, x2] -inputs_small = inputs[0:2] -inputs_med = inputs[0:4] - -register_replacement(match_gemm_rs_ag_gemm, - replace_gemm_rs_ag_gemm, - inputs, - fwd_only, - [my_patterns]) - -final_inputs = [x, w, resid, resid_w] -register_replacement(match_final, - replace_final, - final_inputs, - fwd_only, - [my_patterns2]) +my_patterns: Optional[PatternMatcherPass] = None +my_patterns2: Optional[PatternMatcherPass] = None + +def get_matches(): + global my_patterns + global my_patterns2 + matches = [] + matches2 = [] + + def record_match_fn(match: Match): + matches.append(match) + return False + + def record_match_fn2(match: Match): + matches2.append(match) + return False + + if not my_patterns: + my_patterns = PatternMatcherPass() + my_patterns2 = PatternMatcherPass() + + x = torch.empty([4,4], device='cuda') + w = torch.empty([4,4], device='cuda') + resid = torch.empty([4,4], device='cuda') + resid_w = torch.empty([4,4], device='cuda') + x2 = torch.empty([4,4], device='cuda') + inputs = [x, w, resid, resid_w, x2] + inputs_small = inputs[0:2] + inputs_med = inputs[0:4] + + register_replacement(match_gemm_rs_ag_gemm, + replace_gemm_rs_ag_gemm, + inputs, + fwd_only, + [my_patterns], + extra_check=record_match_fn) + + final_inputs = [x, w, resid, resid_w] + register_replacement(match_final, + replace_final, + final_inputs, + fwd_only, + [my_patterns2]) + + return matches, matches2 + +def process_matches(graph: fx.Graph, matches): + print(f"len = {len(matches)}") + for match in matches: + last_node_in_match = match.nodes[-1] #max(match.nodes, key=lambda x: nodes.index(x)) + + with graph.inserting_after(last_node_in_match): + fused_node = graph.call_function(torch.ops.vllm.gemm_rs_ag_gemm.default, kwargs=match.kwargs) + + graph.inserting_after(fused_node) + result_node_new = graph.call_function(operator.getitem, (fused_node, 0)) + residual_node_new = graph.call_function(operator.getitem, (fused_node, 1)) + + # find the output and the residual + def find_fn(op): + for node in reversed(match.nodes): + if node.op == "call_function" and node.target == op: + return node + return None + + def find_auto_fn(op): + for node in reversed(match.nodes): + if node.op == "call_function" and node.target == auto_functionalized and node.args[0] == op: + return node + return None + + def find_getitem(node, idx): + for user in reversed(node.users): + if user.op == "call_function" and user.target == operator.getitem and user.args[1] == idx: + return user + return None + + rms_node = find_auto_fn(torch.ops._C.fused_add_rms_norm.default) + gemm_node = find_fn(torch.ops.aten.mm.default) + assert rms_node is not None + assert gemm_node is not None + + #assert len(rms_node.users) == 2 + #assert len(gemm_node.users) == 1 + + # meta["val"] is used by de-functionalization + rms_val = rms_node.meta["val"] + gemm_val = gemm_node.meta["val"] + fused_node.meta["val"] = (gemm_val, rms_val[2]) + + find_getitem(rms_node, 2).replace_all_uses_with(residual_node_new) + gemm_node.replace_all_uses_with(result_node_new) + + # Finally, remove matched nodes + graph.eliminate_dead_code() + assert all(node not in graph.nodes for match in matches for node in match.nodes) + +def process_matches2(graph: fx.Graph, matches2): + print(f"len2 = {len(matches2)}") def async_rewrite(graph: fx.Graph): + matches, matches2 = get_matches() + matches.clear() + matches2.clear() + count = my_patterns.apply(graph) print(f"fused gemm match count = {count}") count = my_patterns2.apply(graph) print(f"final match count = {count}") + + process_matches(graph, matches) + process_matches2(graph, matches2) + return graph diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index e493efde0e6a4..f7644768f2842 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -289,7 +289,7 @@ def forward( pprint(f"RESIDUAL SHAPE = {residual.shape}") - def slices(residual) -> bool: + def slices(residual) -> List[torch.Tensor]: if not self.fuse_gemms or not should_slice(residual.shape): pprint(f"SLICES TOO SMALL {[residual.shape]}") return [] From ab68b65a8eb2b1bfe91b18204accc992b3a7649a Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 24 Oct 2024 02:55:44 +0000 Subject: [PATCH 14/72] wip Signed-off-by: Bill Nell --- vllm/compilation/backends.py | 290 +++++++++++++++++----------- vllm/model_executor/models/llama.py | 3 +- 2 files changed, 174 insertions(+), 119 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 9220d3a9c5e9e..d93cff0d7f4c9 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -29,119 +29,148 @@ def pprint(x): pass -def match_gemm_rs_ag_gemm_orig(): - permute_2 = torch.ops.aten.permute(arg7_1, [1, 0]) - mm_1 = torch.ops.aten.mm(getitem_22, permute_2) - auto_functionalized_4 = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.inplace_all_reduce, tensor=mm_1, group_name='tp:0') - getitem_25 = auto_functionalized_4[1] - auto_functionalized_5 = torch.ops.higher_order.auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input=getitem_25, residual=getitem_1, weight=arg8_1, epsilon=1e-05) - getitem_27 = auto_functionalized_5[1] - getitem_28 = auto_functionalized_5[2] - permute_3 = torch.ops.aten.permute(arg9_1, [1, 0]) - mm_2 = torch.ops.aten.mm(getitem_27, permute_3) - return mm_2 - - -def match_gemm_rs_ag_gemm_small(arg7_1, getitem_22): - permute_2 = torch.ops.aten.permute.default(arg7_1, [1, 0]) - mm_1 = torch.ops.aten.mm.default(getitem_22, permute_2) - auto_functionalized_4 = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.inplace_all_reduce.default, tensor = mm_1, group_name = 'tp:0') # how to deal with groupname? - getitem_25 = auto_functionalized_4[1] - return getitem_25 - - -def match_gemm_rs_ag_gemm_med(arg7_1, getitem_22, getitem_1, arg8_1): - permute_2 = torch.ops.aten.permute.default(arg7_1, [1, 0]) - mm_1 = torch.ops.aten.mm.default(getitem_22, permute_2) - auto_functionalized_4 = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.inplace_all_reduce.default, tensor = mm_1, group_name = 'tp:0') # how to deal with groupname? - getitem_25 = auto_functionalized_4[1] - auto_functionalized_5 = torch.ops.higher_order.auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input = getitem_25, residual = getitem_1, weight = arg8_1, epsilon = 1e-05) - getitem_27 = auto_functionalized_5[1] - getitem_28 = auto_functionalized_5[2] - return getitem_27, getitem_28 - #permute_3 = torch.ops.aten.permute.default(arg9_1, [1, 0]) - #mm_2 = torch.ops.aten.mm.default(getitem_27, permute_3) - #return mm_2 +# This check is a hack, copied from linear.py +def should_slice(shape) -> bool: + n_slices = get_tensor_model_parallel_world_size() + return (shape[0] % n_slices == 0 and shape[0] >= 128) -def match_gemm_rs_ag_gemm(arg7_1, getitem_22, arg8_1, getitem_1, arg9_1): +# getitem_1 = residual (mutation) +# arg7_1 = first gemm weights +# getitem_24 = first gemm activation +# arg8_1 = rms norm weights +# getitem_1a = residual +# arg9_1 = second gemm weights +def match_gemm_rs_ag_gemm(getitem_1, arg7_1, getitem_22, arg8_1, arg9_1): permute_2 = torch.ops.aten.permute.default(arg7_1, [1, 0]) mm_1 = torch.ops.aten.mm.default(getitem_22, permute_2) auto_functionalized_4 = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.inplace_all_reduce.default, tensor = mm_1, group_name = 'tp:0') # how to deal with groupname? getitem_25 = auto_functionalized_4[1] auto_functionalized_5 = torch.ops.higher_order.auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input = getitem_25, residual = getitem_1, weight = arg8_1, epsilon = 1e-05) getitem_27 = auto_functionalized_5[1] - getitem_28 = auto_functionalized_5[2] + getitem_28 = auto_functionalized_5[2] # new residual permute_3 = torch.ops.aten.permute.default(arg9_1, [1, 0]) mm_2 = torch.ops.aten.mm.default(getitem_27, permute_3) return mm_2, getitem_28 -# getitem_1 full residual -def replace_gemm_rs_ag_gemm_orig(arg7_1, getitem_24, arg8_1, getitem_1, arg9_1): - split_1 = torch.ops.aten.split.Tensor(getitem_1, 2048) +def slices(residual) -> List[torch.Tensor]: + n_slices = get_tensor_model_parallel_world_size() + residual_slices = torch.chunk(residual, n_slices, dim=0) + #pprint(f"SLICES {[r.shape for r in residual_slices]}") + return residual_slices + + +@torch.library.custom_op("vllm::gemm_rs_ag_gemm_orig", mutates_args=()) +def gemm_rs_ag_gemm(getitem_1: torch.Tensor, # residual + getitem_28: torch.Tensor, # my residual + arg7_1: torch.Tensor, # first gemm weights + getitem_22: torch.Tensor, # first gemm activation + arg8_1: torch.Tensor, # rms norm weights + arg9_1: torch.Tensor, # second gemm weights + ) -> Tuple[torch.Tensor, torch.Tensor]: + group_name = torch.distributed.group.WORLD.group_name # TODO: factor out to setup + slice_size = 2048 + split_1 = torch.ops.aten.split.Tensor(getitem_1, slice_size) getitem_26 = split_1[0]; split_1 = None permute_3 = torch.ops.aten.permute.default(arg7_1, [1, 0]) clone = torch.ops.aten.clone.default(permute_3, memory_format = torch.contiguous_format) - fused_matmul_reduce_scatter = torch.ops.symm_mem.fused_matmul_reduce_scatter.default(getitem_24, clone, 'avg', 0, '0') + fused_matmul_reduce_scatter = torch.ops.symm_mem.fused_matmul_reduce_scatter.default(getitem_22, clone, 'avg', 0, group_name) auto_functionalized_4 = torch.ops.higher_order.auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input = fused_matmul_reduce_scatter, residual = getitem_26, weight = arg8_1, epsilon = 1e-05) getitem_29 = auto_functionalized_4[1] getitem_30 = auto_functionalized_4[2] - slice_scatter_2 = torch.ops.aten.slice_scatter.default(getitem_1, getitem_30, 0, 0, 2048) - split_2 = torch.ops.aten.split.Tensor(slice_scatter_2, 2048) - getitem_31 = split_2[0] # local residual + slice_scatter_2 = torch.ops.aten.slice_scatter.default(getitem_1, getitem_30, 0, 0, slice_size) + split_2 = torch.ops.aten.split.Tensor(slice_scatter_2, slice_size) + getitem_31 = split_2[0] permute_5 = torch.ops.aten.permute.default(arg9_1, [1, 0]) clone_1 = torch.ops.aten.clone.default(permute_5, memory_format = torch.contiguous_format) - fused_all_gather_matmul = torch.ops.symm_mem.fused_all_gather_matmul.default(getitem_29, [clone_1], 0, '0') + fused_all_gather_matmul = torch.ops.symm_mem.fused_all_gather_matmul.default(getitem_29, [clone_1], 0, group_name) # XXXXXXXXXXX getitem_34 = fused_all_gather_matmul[1] getitem_35 = getitem_34[0] return getitem_35, getitem_31 -def slices(residual) -> List[torch.Tensor]: - n_slices = get_tensor_model_parallel_world_size() - residual_slices = torch.chunk(residual, n_slices, dim=0) - #pprint(f"SLICES {[r.shape for r in residual_slices]}") - return residual_slices + +# First split only on first occurrence! +# need to introduce splits since original graph does not have them. @torch.library.custom_op("vllm::gemm_rs_ag_gemm", mutates_args=()) -def gemm_rs_ag_gemm(arg7_1: torch.Tensor, getitem_22: torch.Tensor, arg8_1: torch.Tensor, getitem_1: torch.Tensor, arg9_1: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - group_name = torch.distributed.group.WORLD.group_name # TODO: factor out to setup +def gemm_rs_ag_gemm(getitem_1: torch.Tensor, + getitem_28: torch.Tensor, + arg7_1: torch.Tensor, + getitem_22: torch.Tensor, + arg8_1: torch.Tensor, + arg9_1: torch.Tensor, + first_layer: bool) -> Tuple[torch.Tensor, torch.Tensor]: + print(f"CUSTOM {getitem_1.shape}, should_slice={should_slice(getitem_1.shape)}, first={first_layer}") # this is terrible if True: res_slices = slices(getitem_1) slice_size = res_slices[get_tensor_model_parallel_rank()].shape[0] - print(f"SLICE_SIZE = {slice_size}, orig_shape={getitem_1.shape}, slice_shapes=[{[x.shape for x in res_slices]}]") else: slice_size = 2048 + print(f"SLICE_SIZE = {slice_size}, orig_shape={getitem_1.shape}, slice_shapes=[{[x.shape for x in res_slices]}]") - split_1 = torch.ops.aten.split.Tensor(getitem_1, slice_size) # XXXXXXXXXXX - getitem_26 = split_1[0]; split_1 = None - permute_3 = torch.ops.aten.permute.default(arg7_1, [1, 0]) - clone = torch.ops.aten.clone.default(permute_3, memory_format = torch.contiguous_format) - fused_matmul_reduce_scatter = torch.ops.symm_mem.fused_matmul_reduce_scatter.default(getitem_22, clone, 'avg', 0, group_name) # XXXXXXXXXXXXX - auto_functionalized_4 = torch.ops.higher_order.auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input = fused_matmul_reduce_scatter, residual = getitem_26, weight = arg8_1, epsilon = 1e-05) - getitem_29 = auto_functionalized_4[1] - getitem_30 = auto_functionalized_4[2] - slice_scatter_2 = torch.ops.aten.slice_scatter.default(getitem_1, getitem_30, 0, 0, slice_size) # XXXXXXXXXXXXXXX - split_2 = torch.ops.aten.split.Tensor(slice_scatter_2, slice_size) # XXXXXXXXXXX - getitem_31 = split_2[0] # local residual - permute_5 = torch.ops.aten.permute.default(arg9_1, [1, 0]) - clone_1 = torch.ops.aten.clone.default(permute_5, memory_format = torch.contiguous_format) - fused_all_gather_matmul = torch.ops.symm_mem.fused_all_gather_matmul.default(getitem_29, [clone_1], 0, group_name) # XXXXXXXXXXX - getitem_34 = fused_all_gather_matmul[1] - getitem_35 = getitem_34[0] - return getitem_35, getitem_31 # matmul, residual - -# this is wrong + if should_slice(getitem_1.shape) and first_layer: + print(f"FIRST! rank={get_tensor_model_parallel_rank}") + split_1 = torch.ops.aten.split.Tensor(getitem_1, slice_size) + getitem_26 = split_1[0]; split_1 = None + else: + getitem_26 = getitem_1 + + if not should_slice(getitem_1.shape): + print("NAIVE") + permute_3 = torch.ops.aten.permute.default(arg7_1, [1, 0]) + output = torch.matmul(getitem_22, permute_3) + # all reduce? + auto_functionalized_4 = torch.ops.higher_order.auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input=output, residual=getitem_26, weight=arg8_1, epsilon=1e-05) + getitem_29 = auto_functionalized_4[1] + getitem_30 = auto_functionalized_4[2] + permute_5 = torch.ops.aten.permute.default(arg9_1, [1, 0]) + getitem_35 = torch.matmul(getitem_29, permute_5) + getitem_1 = getitem_26 + + print(f"DONE CUSTOM NAIVE {getitem_30.shape}") + return getitem_35, getitem_30 + else: + group_name = torch.distributed.group.WORLD.group_name # TODO: factor out to setup + permute_3 = torch.ops.aten.permute.default(arg7_1, [1, 0]) + clone = torch.ops.aten.clone.default(permute_3, memory_format = torch.contiguous_format) + output = torch.ops.symm_mem.fused_matmul_reduce_scatter.default(getitem_22, clone, 'avg', 0, group_name) + auto_functionalized_4 = torch.ops.higher_order.auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input=output, residual=getitem_26, weight=arg8_1, epsilon=1e-05) + getitem_29 = auto_functionalized_4[1] + getitem_30 = auto_functionalized_4[2] + slice_scatter_2 = torch.ops.aten.slice_scatter.default(getitem_28, getitem_30, 0, 0, slice_size) + split_2 = torch.ops.aten.split.Tensor(slice_scatter_2, slice_size) + getitem_31 = split_2[0] + getitem_1 = getitem_31 + permute_5 = torch.ops.aten.permute.default(arg9_1, [1, 0]) + clone_1 = torch.ops.aten.clone.default(permute_5, memory_format = torch.contiguous_format) + fused_all_gather_matmul = torch.ops.symm_mem.fused_all_gather_matmul.default(getitem_29, [clone_1], 0, group_name) + getitem_34 = fused_all_gather_matmul[1] + getitem_35 = getitem_34[0] + + print(f"DONE CUSTOM {getitem_31.shape}") + return getitem_35, getitem_31 # matmul, residual ##### + + +# this is wrong? do we need it? @torch.library.register_fake("vllm::gemm_rs_ag_gemm") -def gemm_rs_ag_gemm_fake(arg7_1: torch.Tensor, getitem_22: torch.Tensor, arg8_1: torch.Tensor, getitem_1: torch.Tensor, arg9_1: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: +def gemm_rs_ag_gemm_fake(getitem_1: torch.Tensor, + getitem_28: torch.Tensor, + arg7_1: torch.Tensor, + getitem_22: torch.Tensor, + arg8_1: torch.Tensor, + arg9_1: torch.Tensor, + first_layer: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: mm_res = torch.empty((getitem_22.shape[0], arg9_1.shape[0]), device=getitem_22.device, dtype=getitem_22.dtype) #??? resid = torch.empty_like(getitem_1) return (mm_res, resid) -def replace_gemm_rs_ag_gemm(arg7_1, getitem_22, arg8_1, getitem_1, arg9_1): - results = torch.ops.vllm.gemm_rs_ag_gemm(arg7_1, getitem_22, arg8_1, getitem_1, arg9_1) + +def replace_gemm_rs_ag_gemm(getitem_1, getitem_28, arg7_1, getitem_22, arg8_1, arg9_1): + results = torch.ops.vllm.gemm_rs_ag_gemm(getitem_1, getitem_28, arg7_1, getitem_22, arg8_1, arg9_1) getitem_34 = results[0] getitem_35 = results[1] return getitem_34, getitem_35 @@ -158,13 +187,20 @@ def match_final(arg227_1, getitem_1022, getitem_1020, arg228_1): def replace_final(arg227_1, getitem_1215, getitem_1209, arg228_1): - group_name = torch.distributed.group.WORLD.group_name # factor out? + tp_group_name = "tp:0" # f"tp:{group_name}" # TODO: not same as group name + permute_254 = torch.ops.aten.permute.default(arg227_1, [1, 0]) mm_1 = torch.ops.aten.mm.default(getitem_1215, permute_254) - auto_functionalized_161 = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.inplace_all_reduce.default, tensor = mm_1, group_name = 'tp:0') # TODO: not same as group name + auto_functionalized_161 = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.inplace_all_reduce.default, tensor = mm_1, group_name = tp_group_name) getitem_1217 = auto_functionalized_161[1] - all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_1209, 2, group_name) - wait_tensor = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor) + + if should_slice(getitem_1209.shape): + group_name = torch.distributed.group.WORLD.group_name # factor out? + all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_1209, 2, group_name) + wait_tensor = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor) + else: + wait_tensor = getitem_1209 + auto_functionalized_162 = torch.ops.higher_order.auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input = getitem_1217, residual = wait_tensor, weight = arg228_1, epsilon = 1e-05) getitem_1219 = auto_functionalized_162[1] return getitem_1219 @@ -174,19 +210,13 @@ def replace_final(arg227_1, getitem_1215, getitem_1209, arg228_1): my_patterns2: Optional[PatternMatcherPass] = None def get_matches(): - global my_patterns - global my_patterns2 + global my_patterns, my_patterns2 matches = [] - matches2 = [] def record_match_fn(match: Match): matches.append(match) return False - def record_match_fn2(match: Match): - matches2.append(match) - return False - if not my_patterns: my_patterns = PatternMatcherPass() my_patterns2 = PatternMatcherPass() @@ -196,7 +226,7 @@ def record_match_fn2(match: Match): resid = torch.empty([4,4], device='cuda') resid_w = torch.empty([4,4], device='cuda') x2 = torch.empty([4,4], device='cuda') - inputs = [x, w, resid, resid_w, x2] + inputs = [resid, resid, x, w, resid_w, x2] inputs_small = inputs[0:2] inputs_med = inputs[0:4] @@ -214,41 +244,59 @@ def record_match_fn2(match: Match): fwd_only, [my_patterns2]) - return matches, matches2 + return matches + + +# find the output and the residual +def find_fn(nodes, op): + for node in reversed(nodes): + if node.op == "call_function" and node.target == op: + return node + return None + +def find_auto_fn(nodes, op): + for node in reversed(nodes): + if node.op == "call_function" and node.target == auto_functionalized and node.args[0] == op: + return node + return None + +def find_getitem(node, idx): + for user in reversed(node.users): + if user.op == "call_function" and user.target == operator.getitem and user.args[1] == idx: + return user + return None def process_matches(graph: fx.Graph, matches): print(f"len = {len(matches)}") + first_layer = True # hacky + + nodes = list(graph.nodes) + first_match = None + min_node = None + for match in matches: + first_node_in_match = min(match.nodes, key=lambda x: nodes.index(x)) + if not min_node or nodes.index(first_node_in_match) < min_node: + min_node = nodes.index(first_node_in_match) + first_match = match + for match in matches: last_node_in_match = match.nodes[-1] #max(match.nodes, key=lambda x: nodes.index(x)) with graph.inserting_after(last_node_in_match): - fused_node = graph.call_function(torch.ops.vllm.gemm_rs_ag_gemm.default, kwargs=match.kwargs) + kwargs = match.kwargs + kwargs["first_layer"] = match == first_match + kwargs["getitem_28"] = find_auto_fn(match.nodes, torch.ops._C.fused_add_rms_norm.default) + first_layer = False + fused_node = graph.call_function(torch.ops.vllm.gemm_rs_ag_gemm.default, kwargs=kwargs) graph.inserting_after(fused_node) result_node_new = graph.call_function(operator.getitem, (fused_node, 0)) residual_node_new = graph.call_function(operator.getitem, (fused_node, 1)) - # find the output and the residual - def find_fn(op): - for node in reversed(match.nodes): - if node.op == "call_function" and node.target == op: - return node - return None - - def find_auto_fn(op): - for node in reversed(match.nodes): - if node.op == "call_function" and node.target == auto_functionalized and node.args[0] == op: - return node - return None - - def find_getitem(node, idx): - for user in reversed(node.users): - if user.op == "call_function" and user.target == operator.getitem and user.args[1] == idx: - return user - return None - - rms_node = find_auto_fn(torch.ops._C.fused_add_rms_norm.default) - gemm_node = find_fn(torch.ops.aten.mm.default) + rms_node = find_auto_fn(match.nodes, torch.ops._C.fused_add_rms_norm.default) + gemm_node = find_fn(match.nodes, torch.ops.aten.mm.default) + if gemm_node is None: + gemm_node = find_fn(match.nodes, torch.ops.symm_mem.fused_all_gather_matmul.default) assert rms_node is not None assert gemm_node is not None @@ -267,21 +315,27 @@ def find_getitem(node, idx): graph.eliminate_dead_code() assert all(node not in graph.nodes for match in matches for node in match.nodes) -def process_matches2(graph: fx.Graph, matches2): - print(f"len2 = {len(matches2)}") + +def dump_graph(graph: torch.fx.Graph, stage: str): + logger.info("Printing graph to %s", f"{stage}.py") + with open(f"{stage}.py", "w") as f: + print(graph.python_code(root_module="self", verbose=True).src, file=f) + def async_rewrite(graph: fx.Graph): - matches, matches2 = get_matches() + matches = get_matches() matches.clear() - matches2.clear() count = my_patterns.apply(graph) - print(f"fused gemm match count = {count}") - count = my_patterns2.apply(graph) - print(f"final match count = {count}") - - process_matches(graph, matches) - process_matches2(graph, matches2) + print(f"fused gemm match count = {len(matches)}") + + # a bit hacky + if len(matches) > 0: + print("FINAL MATCH") + count = my_patterns2.apply(graph) + print(f"final match count = {count}") + print("FINAL MATCH DONE") + process_matches(graph, matches) return graph diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index f7644768f2842..5e4ffbd6b2fa7 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -22,6 +22,7 @@ """Inference-only LLaMA model compatible with HuggingFace weights.""" from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union +import os import torch from torch import nn from transformers import LlamaConfig @@ -349,7 +350,7 @@ def __init__(self, prefix: str = "", layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer): super().__init__() - fuse_gemms = False #True + fuse_gemms = bool(os.environ.get("VLLM_FUSE_GEMMS", "0") == "1") config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config From f51643144972c112ec4541d67f590417ec2f8004 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 24 Oct 2024 02:58:31 +0000 Subject: [PATCH 15/72] wip Signed-off-by: Bill Nell --- vllm/compilation/backends.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index d93cff0d7f4c9..1bae4a8130af5 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -41,7 +41,12 @@ def should_slice(shape) -> bool: # arg8_1 = rms norm weights # getitem_1a = residual # arg9_1 = second gemm weights -def match_gemm_rs_ag_gemm(getitem_1, arg7_1, getitem_22, arg8_1, arg9_1): +def match_gemm_rs_ag_gemm(getitem_1, # residual + arg7_1, # first gemm weight + getitem_22, # first gemm activation + arg8_1, # rms norm weight + arg9_1, # second gemm weight + ): permute_2 = torch.ops.aten.permute.default(arg7_1, [1, 0]) mm_1 = torch.ops.aten.mm.default(getitem_22, permute_2) auto_functionalized_4 = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.inplace_all_reduce.default, tensor = mm_1, group_name = 'tp:0') # how to deal with groupname? @@ -169,8 +174,9 @@ def gemm_rs_ag_gemm_fake(getitem_1: torch.Tensor, return (mm_res, resid) -def replace_gemm_rs_ag_gemm(getitem_1, getitem_28, arg7_1, getitem_22, arg8_1, arg9_1): - results = torch.ops.vllm.gemm_rs_ag_gemm(getitem_1, getitem_28, arg7_1, getitem_22, arg8_1, arg9_1) +# doesn't matter, only needed for signature +def replace_gemm_rs_ag_gemm(getitem_1, arg7_1, getitem_22, arg8_1, arg9_1): + results = torch.ops.vllm.gemm_rs_ag_gemm(getitem_1, getitem_1, arg7_1, getitem_22, arg8_1, arg9_1) getitem_34 = results[0] getitem_35 = results[1] return getitem_34, getitem_35 From a5c9f8d0fb13cda1d8df6dc0bbc96ed184b5a239 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 24 Oct 2024 03:37:42 +0000 Subject: [PATCH 16/72] wip Signed-off-by: Bill Nell --- vllm/compilation/backends.py | 71 ++++++++++++++++++++---------------- 1 file changed, 39 insertions(+), 32 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 1bae4a8130af5..8a4c9f3084efc 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -73,7 +73,7 @@ def gemm_rs_ag_gemm(getitem_1: torch.Tensor, # residual getitem_22: torch.Tensor, # first gemm activation arg8_1: torch.Tensor, # rms norm weights arg9_1: torch.Tensor, # second gemm weights - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: group_name = torch.distributed.group.WORLD.group_name # TODO: factor out to setup slice_size = 2048 split_1 = torch.ops.aten.split.Tensor(getitem_1, slice_size) @@ -98,15 +98,17 @@ def gemm_rs_ag_gemm(getitem_1: torch.Tensor, # residual # First split only on first occurrence! # need to introduce splits since original graph does not have them. -@torch.library.custom_op("vllm::gemm_rs_ag_gemm", mutates_args=()) -def gemm_rs_ag_gemm(getitem_1: torch.Tensor, - getitem_28: torch.Tensor, - arg7_1: torch.Tensor, - getitem_22: torch.Tensor, - arg8_1: torch.Tensor, - arg9_1: torch.Tensor, - first_layer: bool) -> Tuple[torch.Tensor, torch.Tensor]: - print(f"CUSTOM {getitem_1.shape}, should_slice={should_slice(getitem_1.shape)}, first={first_layer}") +schema_str="(Tensor(a) getitem_1, Tensor(a) getitem_28, Tensor arg7_1, Tensor getitem_22, Tensor arg8_1, Tensor arg9_1, bool first_layer) -> (Tensor, Tensor, Tensor)" + +@torch.library.custom_op("vllm::gemm_rs_ag_gemm", mutates_args=())#, schema=schema_str) +def gemm_rs_ag_gemm(getitem_1: torch.Tensor, # residual + getitem_28: torch.Tensor, # my residual + arg7_1: torch.Tensor, # first gemm weights + getitem_22: torch.Tensor, # first gemm activatiions + arg8_1: torch.Tensor, # rms norm weights + arg9_1: torch.Tensor, # second gemm weights + first_layer: bool) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + print(f"CUSTOM {getitem_1.shape}({getitem_28.shape}), should_slice={should_slice(getitem_1.shape)}, first={first_layer}") # this is terrible if True: @@ -133,10 +135,9 @@ def gemm_rs_ag_gemm(getitem_1: torch.Tensor, getitem_30 = auto_functionalized_4[2] permute_5 = torch.ops.aten.permute.default(arg9_1, [1, 0]) getitem_35 = torch.matmul(getitem_29, permute_5) - getitem_1 = getitem_26 - + getitem_30a = getitem_30.clone() print(f"DONE CUSTOM NAIVE {getitem_30.shape}") - return getitem_35, getitem_30 + return getitem_35, getitem_30, getitem_30a else: group_name = torch.distributed.group.WORLD.group_name # TODO: factor out to setup permute_3 = torch.ops.aten.permute.default(arg7_1, [1, 0]) @@ -145,10 +146,10 @@ def gemm_rs_ag_gemm(getitem_1: torch.Tensor, auto_functionalized_4 = torch.ops.higher_order.auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input=output, residual=getitem_26, weight=arg8_1, epsilon=1e-05) getitem_29 = auto_functionalized_4[1] getitem_30 = auto_functionalized_4[2] - slice_scatter_2 = torch.ops.aten.slice_scatter.default(getitem_28, getitem_30, 0, 0, slice_size) + getitem_28a = getitem_1 if first_layer else getitem_28 + slice_scatter_2 = torch.ops.aten.slice_scatter.default(getitem_28a, getitem_30, 0, 0, slice_size) split_2 = torch.ops.aten.split.Tensor(slice_scatter_2, slice_size) getitem_31 = split_2[0] - getitem_1 = getitem_31 permute_5 = torch.ops.aten.permute.default(arg9_1, [1, 0]) clone_1 = torch.ops.aten.clone.default(permute_5, memory_format = torch.contiguous_format) fused_all_gather_matmul = torch.ops.symm_mem.fused_all_gather_matmul.default(getitem_29, [clone_1], 0, group_name) @@ -156,7 +157,7 @@ def gemm_rs_ag_gemm(getitem_1: torch.Tensor, getitem_35 = getitem_34[0] print(f"DONE CUSTOM {getitem_31.shape}") - return getitem_35, getitem_31 # matmul, residual ##### + return getitem_35, getitem_31.clone(), slice_scatter_2 # this is wrong? do we need it? @@ -168,10 +169,11 @@ def gemm_rs_ag_gemm_fake(getitem_1: torch.Tensor, arg8_1: torch.Tensor, arg9_1: torch.Tensor, first_layer: bool, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: mm_res = torch.empty((getitem_22.shape[0], arg9_1.shape[0]), device=getitem_22.device, dtype=getitem_22.dtype) #??? resid = torch.empty_like(getitem_1) - return (mm_res, resid) + my_resid = resid.clone() # last one right? or needs to be split + return (mm_res, resid, my_resid) # doesn't matter, only needed for signature @@ -232,9 +234,7 @@ def record_match_fn(match: Match): resid = torch.empty([4,4], device='cuda') resid_w = torch.empty([4,4], device='cuda') x2 = torch.empty([4,4], device='cuda') - inputs = [resid, resid, x, w, resid_w, x2] - inputs_small = inputs[0:2] - inputs_med = inputs[0:4] + inputs = [resid, x, w, resid_w, x2] register_replacement(match_gemm_rs_ag_gemm, replace_gemm_rs_ag_gemm, @@ -274,30 +274,37 @@ def find_getitem(node, idx): def process_matches(graph: fx.Graph, matches): print(f"len = {len(matches)}") - first_layer = True # hacky nodes = list(graph.nodes) first_match = None - min_node = None - for match in matches: - first_node_in_match = min(match.nodes, key=lambda x: nodes.index(x)) - if not min_node or nodes.index(first_node_in_match) < min_node: - min_node = nodes.index(first_node_in_match) - first_match = match + + def find_min_index(match) -> int: + return min(match.nodes, key=lambda x: nodes.index(x)) + + # "sort" matches in topo order + matches = sorted(matches, key=lambda x: find_min_index(x)) + replacements = [] for match in matches: last_node_in_match = match.nodes[-1] #max(match.nodes, key=lambda x: nodes.index(x)) with graph.inserting_after(last_node_in_match): + if len(replacements) == 0: + replacements.append(graph.call_function(torch.ops.aten.empty.memory_format, + args = ([0, 0],), + kwargs = {"dtype": torch.float16, "device": "cuda", "pin_memory": False})) + graph.inserting_after(replacements[-1]) + kwargs = match.kwargs - kwargs["first_layer"] = match == first_match - kwargs["getitem_28"] = find_auto_fn(match.nodes, torch.ops._C.fused_add_rms_norm.default) - first_layer = False + kwargs["first_layer"] = match == matches[0] + kwargs["getitem_28"] = replacements[-1] fused_node = graph.call_function(torch.ops.vllm.gemm_rs_ag_gemm.default, kwargs=kwargs) graph.inserting_after(fused_node) result_node_new = graph.call_function(operator.getitem, (fused_node, 0)) residual_node_new = graph.call_function(operator.getitem, (fused_node, 1)) + my_residual_node_new = graph.call_function(operator.getitem, (fused_node, 2)) + replacements.append(my_residual_node_new) rms_node = find_auto_fn(match.nodes, torch.ops._C.fused_add_rms_norm.default) gemm_node = find_fn(match.nodes, torch.ops.aten.mm.default) @@ -319,7 +326,7 @@ def process_matches(graph: fx.Graph, matches): # Finally, remove matched nodes graph.eliminate_dead_code() - assert all(node not in graph.nodes for match in matches for node in match.nodes) + #assert all(node not in graph.nodes for match in matches for node in match.nodes) def dump_graph(graph: torch.fx.Graph, stage: str): From 269e7f9025e61d54e33302cc0f8d743fefd8ca19 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 24 Oct 2024 03:38:15 +0000 Subject: [PATCH 17/72] wip Signed-off-by: Bill Nell --- vllm/compilation/backends.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 8a4c9f3084efc..5c100beaac130 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -157,7 +157,7 @@ def gemm_rs_ag_gemm(getitem_1: torch.Tensor, # residual getitem_35 = getitem_34[0] print(f"DONE CUSTOM {getitem_31.shape}") - return getitem_35, getitem_31.clone(), slice_scatter_2 + return getitem_35, getitem_31.clone(), slice_scatter_2 # check if clones are needed # this is wrong? do we need it? @@ -289,11 +289,13 @@ def find_min_index(match) -> int: last_node_in_match = match.nodes[-1] #max(match.nodes, key=lambda x: nodes.index(x)) with graph.inserting_after(last_node_in_match): - if len(replacements) == 0: + if False and len(replacements) == 0: replacements.append(graph.call_function(torch.ops.aten.empty.memory_format, args = ([0, 0],), kwargs = {"dtype": torch.float16, "device": "cuda", "pin_memory": False})) graph.inserting_after(replacements[-1]) + else: + replacements.append(kwargs["getitem_1"]) kwargs = match.kwargs kwargs["first_layer"] = match == matches[0] From 786bcc059f86655f7bdc6ce3fc710d36e054f0c5 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 24 Oct 2024 03:40:17 +0000 Subject: [PATCH 18/72] wip Signed-off-by: Bill Nell --- vllm/compilation/backends.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 5c100beaac130..e7ceb24e40bc4 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -295,7 +295,7 @@ def find_min_index(match) -> int: kwargs = {"dtype": torch.float16, "device": "cuda", "pin_memory": False})) graph.inserting_after(replacements[-1]) else: - replacements.append(kwargs["getitem_1"]) + replacements.append(match.kwargs["getitem_1"]) kwargs = match.kwargs kwargs["first_layer"] = match == matches[0] From 570de5774c77c76d0cbedbab876ee571ac1b5173 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 28 Oct 2024 19:48:39 +0000 Subject: [PATCH 19/72] wip Signed-off-by: Bill Nell --- vllm/compilation/backends.py | 148 +++++++++++++---------------------- 1 file changed, 54 insertions(+), 94 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index e7ceb24e40bc4..687f5cc1f46ce 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -35,26 +35,21 @@ def should_slice(shape) -> bool: return (shape[0] % n_slices == 0 and shape[0] >= 128) -# getitem_1 = residual (mutation) -# arg7_1 = first gemm weights -# getitem_24 = first gemm activation -# arg8_1 = rms norm weights -# getitem_1a = residual -# arg9_1 = second gemm weights -def match_gemm_rs_ag_gemm(getitem_1, # residual - arg7_1, # first gemm weight - getitem_22, # first gemm activation - arg8_1, # rms norm weight - arg9_1, # second gemm weight +def match_gemm_rs_ag_gemm(residual, + #my_residual, + gemm_1_weights, + gemm_1_activations, + rms_norm_weight, + gemm_2_weights, ): - permute_2 = torch.ops.aten.permute.default(arg7_1, [1, 0]) - mm_1 = torch.ops.aten.mm.default(getitem_22, permute_2) + permute_2 = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) + mm_1 = torch.ops.aten.mm.default(gemm_1_activations, permute_2) auto_functionalized_4 = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.inplace_all_reduce.default, tensor = mm_1, group_name = 'tp:0') # how to deal with groupname? getitem_25 = auto_functionalized_4[1] - auto_functionalized_5 = torch.ops.higher_order.auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input = getitem_25, residual = getitem_1, weight = arg8_1, epsilon = 1e-05) + auto_functionalized_5 = torch.ops.higher_order.auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input = getitem_25, residual = residual, weight = rms_norm_weight, epsilon = 1e-05) getitem_27 = auto_functionalized_5[1] getitem_28 = auto_functionalized_5[2] # new residual - permute_3 = torch.ops.aten.permute.default(arg9_1, [1, 0]) + permute_3 = torch.ops.aten.permute.default(gemm_2_weights, [1, 0]) mm_2 = torch.ops.aten.mm.default(getitem_27, permute_3) return mm_2, getitem_28 @@ -66,91 +61,59 @@ def slices(residual) -> List[torch.Tensor]: return residual_slices -@torch.library.custom_op("vllm::gemm_rs_ag_gemm_orig", mutates_args=()) -def gemm_rs_ag_gemm(getitem_1: torch.Tensor, # residual - getitem_28: torch.Tensor, # my residual - arg7_1: torch.Tensor, # first gemm weights - getitem_22: torch.Tensor, # first gemm activation - arg8_1: torch.Tensor, # rms norm weights - arg9_1: torch.Tensor, # second gemm weights - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - group_name = torch.distributed.group.WORLD.group_name # TODO: factor out to setup - slice_size = 2048 - split_1 = torch.ops.aten.split.Tensor(getitem_1, slice_size) - getitem_26 = split_1[0]; split_1 = None - permute_3 = torch.ops.aten.permute.default(arg7_1, [1, 0]) - clone = torch.ops.aten.clone.default(permute_3, memory_format = torch.contiguous_format) - fused_matmul_reduce_scatter = torch.ops.symm_mem.fused_matmul_reduce_scatter.default(getitem_22, clone, 'avg', 0, group_name) - auto_functionalized_4 = torch.ops.higher_order.auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input = fused_matmul_reduce_scatter, residual = getitem_26, weight = arg8_1, epsilon = 1e-05) - getitem_29 = auto_functionalized_4[1] - getitem_30 = auto_functionalized_4[2] - slice_scatter_2 = torch.ops.aten.slice_scatter.default(getitem_1, getitem_30, 0, 0, slice_size) - split_2 = torch.ops.aten.split.Tensor(slice_scatter_2, slice_size) - getitem_31 = split_2[0] - permute_5 = torch.ops.aten.permute.default(arg9_1, [1, 0]) - clone_1 = torch.ops.aten.clone.default(permute_5, memory_format = torch.contiguous_format) - fused_all_gather_matmul = torch.ops.symm_mem.fused_all_gather_matmul.default(getitem_29, [clone_1], 0, group_name) # XXXXXXXXXXX - getitem_34 = fused_all_gather_matmul[1] - getitem_35 = getitem_34[0] - return getitem_35, getitem_31 - - -# First split only on first occurrence! -# need to introduce splits since original graph does not have them. - -schema_str="(Tensor(a) getitem_1, Tensor(a) getitem_28, Tensor arg7_1, Tensor getitem_22, Tensor arg8_1, Tensor arg9_1, bool first_layer) -> (Tensor, Tensor, Tensor)" +#schema_str="(Tensor(a) residual, Tensor(a) my_residual, Tensor gemm_1_weights, Tensor gemm_1_activations, Tensor rms_norm_weight, Tensor gemm_2_weights, bool first_layer) -> (Tensor, Tensor, Tensor)" @torch.library.custom_op("vllm::gemm_rs_ag_gemm", mutates_args=())#, schema=schema_str) -def gemm_rs_ag_gemm(getitem_1: torch.Tensor, # residual - getitem_28: torch.Tensor, # my residual - arg7_1: torch.Tensor, # first gemm weights - getitem_22: torch.Tensor, # first gemm activatiions - arg8_1: torch.Tensor, # rms norm weights - arg9_1: torch.Tensor, # second gemm weights +def gemm_rs_ag_gemm(residual: torch.Tensor, + my_residual: torch.Tensor, + gemm_1_weights: torch.Tensor, + gemm_1_activations: torch.Tensor, + rms_norm_weight: torch.Tensor, + gemm_2_weights: torch.Tensor, first_layer: bool) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - print(f"CUSTOM {getitem_1.shape}({getitem_28.shape}), should_slice={should_slice(getitem_1.shape)}, first={first_layer}") + print(f"CUSTOM {residual.shape}({my_residual.shape}), should_slice={should_slice(residual.shape)}, first={first_layer}") # this is terrible if True: - res_slices = slices(getitem_1) + res_slices = slices(residual) slice_size = res_slices[get_tensor_model_parallel_rank()].shape[0] else: slice_size = 2048 - print(f"SLICE_SIZE = {slice_size}, orig_shape={getitem_1.shape}, slice_shapes=[{[x.shape for x in res_slices]}]") + print(f"SLICE_SIZE = {slice_size}, orig_shape={residual.shape}, slice_shapes=[{[x.shape for x in res_slices]}]") - if should_slice(getitem_1.shape) and first_layer: - print(f"FIRST! rank={get_tensor_model_parallel_rank}") - split_1 = torch.ops.aten.split.Tensor(getitem_1, slice_size) + if should_slice(residual.shape) and first_layer: + print(f"FIRST! rank={get_tensor_model_parallel_rank()}") + split_1 = torch.ops.aten.split.Tensor(residual, slice_size) getitem_26 = split_1[0]; split_1 = None else: - getitem_26 = getitem_1 + getitem_26 = my_residual - if not should_slice(getitem_1.shape): + if not should_slice(residual.shape): print("NAIVE") - permute_3 = torch.ops.aten.permute.default(arg7_1, [1, 0]) - output = torch.matmul(getitem_22, permute_3) + permute_3 = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) + output = torch.matmul(gemm_1_activations, permute_3) # all reduce? - auto_functionalized_4 = torch.ops.higher_order.auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input=output, residual=getitem_26, weight=arg8_1, epsilon=1e-05) + auto_functionalized_4 = torch.ops.higher_order.auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input=output, residual=getitem_26, weight=rms_norm_weight, epsilon=1e-05) getitem_29 = auto_functionalized_4[1] getitem_30 = auto_functionalized_4[2] - permute_5 = torch.ops.aten.permute.default(arg9_1, [1, 0]) + permute_5 = torch.ops.aten.permute.default(gemm_2_weights, [1, 0]) getitem_35 = torch.matmul(getitem_29, permute_5) getitem_30a = getitem_30.clone() print(f"DONE CUSTOM NAIVE {getitem_30.shape}") return getitem_35, getitem_30, getitem_30a else: group_name = torch.distributed.group.WORLD.group_name # TODO: factor out to setup - permute_3 = torch.ops.aten.permute.default(arg7_1, [1, 0]) + permute_3 = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) clone = torch.ops.aten.clone.default(permute_3, memory_format = torch.contiguous_format) - output = torch.ops.symm_mem.fused_matmul_reduce_scatter.default(getitem_22, clone, 'avg', 0, group_name) - auto_functionalized_4 = torch.ops.higher_order.auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input=output, residual=getitem_26, weight=arg8_1, epsilon=1e-05) + output = torch.ops.symm_mem.fused_matmul_reduce_scatter.default(gemm_1_activations, clone, 'avg', 0, group_name) + auto_functionalized_4 = torch.ops.higher_order.auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input=output, residual=getitem_26, weight=rms_norm_weight, epsilon=1e-05) getitem_29 = auto_functionalized_4[1] getitem_30 = auto_functionalized_4[2] - getitem_28a = getitem_1 if first_layer else getitem_28 - slice_scatter_2 = torch.ops.aten.slice_scatter.default(getitem_28a, getitem_30, 0, 0, slice_size) + residual_1 = residual if first_layer else my_residual + slice_scatter_2 = torch.ops.aten.slice_scatter.default(residual_1, getitem_30, 0, 0, slice_size) split_2 = torch.ops.aten.split.Tensor(slice_scatter_2, slice_size) getitem_31 = split_2[0] - permute_5 = torch.ops.aten.permute.default(arg9_1, [1, 0]) + permute_5 = torch.ops.aten.permute.default(gemm_2_weights, [1, 0]) clone_1 = torch.ops.aten.clone.default(permute_5, memory_format = torch.contiguous_format) fused_all_gather_matmul = torch.ops.symm_mem.fused_all_gather_matmul.default(getitem_29, [clone_1], 0, group_name) getitem_34 = fused_all_gather_matmul[1] @@ -162,23 +125,23 @@ def gemm_rs_ag_gemm(getitem_1: torch.Tensor, # residual # this is wrong? do we need it? @torch.library.register_fake("vllm::gemm_rs_ag_gemm") -def gemm_rs_ag_gemm_fake(getitem_1: torch.Tensor, - getitem_28: torch.Tensor, - arg7_1: torch.Tensor, - getitem_22: torch.Tensor, - arg8_1: torch.Tensor, - arg9_1: torch.Tensor, +def gemm_rs_ag_gemm_fake(residual: torch.Tensor, + my_residual: torch.Tensor, + gemm_1_weights: torch.Tensor, + gemm_1_activations: torch.Tensor, + rms_norm_weight: torch.Tensor, + gemm_2_weights: torch.Tensor, first_layer: bool, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - mm_res = torch.empty((getitem_22.shape[0], arg9_1.shape[0]), device=getitem_22.device, dtype=getitem_22.dtype) #??? - resid = torch.empty_like(getitem_1) + mm_res = torch.empty((gemm_1_activations.shape[0], gemm_2_weights.shape[0]), device=gemm_1_activations.device, dtype=gemm_1_activations.dtype) #??? + resid = torch.empty_like(residual) my_resid = resid.clone() # last one right? or needs to be split return (mm_res, resid, my_resid) # doesn't matter, only needed for signature -def replace_gemm_rs_ag_gemm(getitem_1, arg7_1, getitem_22, arg8_1, arg9_1): - results = torch.ops.vllm.gemm_rs_ag_gemm(getitem_1, getitem_1, arg7_1, getitem_22, arg8_1, arg9_1) +def replace_gemm_rs_ag_gemm(residual, gemm_1_weights, gemm_1_activations, rms_norm_weight, gemm_2_weights): + results = torch.ops.vllm.gemm_rs_ag_gemm(residual, residual, gemm_1_weights, gemm_1_activations, rms_norm_weight, gemm_2_weights) getitem_34 = results[0] getitem_35 = results[1] return getitem_34, getitem_35 @@ -283,30 +246,27 @@ def find_min_index(match) -> int: # "sort" matches in topo order matches = sorted(matches, key=lambda x: find_min_index(x)) - replacements = [] + + # this is pretty hacky since the order doesn't necessarily encode the dependency. + res_replacements = [] + my_res_replacements = [] for match in matches: last_node_in_match = match.nodes[-1] #max(match.nodes, key=lambda x: nodes.index(x)) with graph.inserting_after(last_node_in_match): - if False and len(replacements) == 0: - replacements.append(graph.call_function(torch.ops.aten.empty.memory_format, - args = ([0, 0],), - kwargs = {"dtype": torch.float16, "device": "cuda", "pin_memory": False})) - graph.inserting_after(replacements[-1]) - else: - replacements.append(match.kwargs["getitem_1"]) - kwargs = match.kwargs kwargs["first_layer"] = match == matches[0] - kwargs["getitem_28"] = replacements[-1] + kwargs["residual"] = res_replacements[-1] if len(res_replacements) > 0 else match.kwargs["residual"] + kwargs["my_residual"] = my_res_replacements[-1] if len(my_res_replacements) > 0 else match.kwargs["residual"] fused_node = graph.call_function(torch.ops.vllm.gemm_rs_ag_gemm.default, kwargs=kwargs) graph.inserting_after(fused_node) result_node_new = graph.call_function(operator.getitem, (fused_node, 0)) residual_node_new = graph.call_function(operator.getitem, (fused_node, 1)) my_residual_node_new = graph.call_function(operator.getitem, (fused_node, 2)) - replacements.append(my_residual_node_new) + res_replacements.append(residual_node_new) + my_res_replacements.append(my_residual_node_new) rms_node = find_auto_fn(match.nodes, torch.ops._C.fused_add_rms_norm.default) gemm_node = find_fn(match.nodes, torch.ops.aten.mm.default) @@ -328,7 +288,7 @@ def find_min_index(match) -> int: # Finally, remove matched nodes graph.eliminate_dead_code() - #assert all(node not in graph.nodes for match in matches for node in match.nodes) + assert all(node not in graph.nodes for match in matches for node in match.nodes) def dump_graph(graph: torch.fx.Graph, stage: str): From 4aa4ab6c8a458c1352617736e6f9320fdc9797bd Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 29 Oct 2024 21:22:08 +0000 Subject: [PATCH 20/72] working Signed-off-by: Bill Nell --- vllm/compilation/backends.py | 29 ++++++++++++++++++++++++----- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 687f5cc1f46ce..13a0f51a4f7e9 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -86,7 +86,9 @@ def gemm_rs_ag_gemm(residual: torch.Tensor, split_1 = torch.ops.aten.split.Tensor(residual, slice_size) getitem_26 = split_1[0]; split_1 = None else: - getitem_26 = my_residual + #getitem_26 = my_residual + getitem_26 = residual + slice_size = residual.shape[0] if not should_slice(residual.shape): print("NAIVE") @@ -119,7 +121,7 @@ def gemm_rs_ag_gemm(residual: torch.Tensor, getitem_34 = fused_all_gather_matmul[1] getitem_35 = getitem_34[0] - print(f"DONE CUSTOM {getitem_31.shape}") + print(f"DONE CUSTOM {getitem_35.shape}, {getitem_31.shape}, {slice_scatter_2.shape}") return getitem_35, getitem_31.clone(), slice_scatter_2 # check if clones are needed @@ -133,10 +135,27 @@ def gemm_rs_ag_gemm_fake(residual: torch.Tensor, gemm_2_weights: torch.Tensor, first_layer: bool, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # this is terrible + if True: + res_slices = slices(residual) + slice_size = res_slices[get_tensor_model_parallel_rank()].shape[0] # can we always use rank 0? + else: + slice_size = 2048 + + if should_slice(residual.shape) and first_layer: + print(f"FIRST! rank={get_tensor_model_parallel_rank()}") + split_1 = torch.ops.aten.split.Tensor(residual, slice_size) + my_residual = split_1[0]; split_1 = None + else: + #residual = my_residual + slice_size = residual.shape[0] + + # is this type correct? seems to be mm_res = torch.empty((gemm_1_activations.shape[0], gemm_2_weights.shape[0]), device=gemm_1_activations.device, dtype=gemm_1_activations.dtype) #??? - resid = torch.empty_like(residual) - my_resid = resid.clone() # last one right? or needs to be split - return (mm_res, resid, my_resid) + + print(f"DONE FAKE = {mm_res.shape}, {my_residual.shape}, {residual.shape}") + + return (mm_res, my_residual, residual) # doesn't matter, only needed for signature From f6435dc92a5b8181cf75f440d4a7e628e4cc539e Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 30 Oct 2024 18:49:22 +0000 Subject: [PATCH 21/72] fix matcher. naive working Signed-off-by: Bill Nell --- vllm/compilation/backends.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 13a0f51a4f7e9..61c8273dd7a3c 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -19,7 +19,6 @@ logger = init_logger(__name__) -aten = torch.ops.aten FILENO=0 @@ -91,17 +90,21 @@ def gemm_rs_ag_gemm(residual: torch.Tensor, slice_size = residual.shape[0] if not should_slice(residual.shape): + # this branch probably broken print("NAIVE") permute_3 = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) output = torch.matmul(gemm_1_activations, permute_3) - # all reduce? + + output = tensor_model_parallel_all_reduce(output) ### + auto_functionalized_4 = torch.ops.higher_order.auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input=output, residual=getitem_26, weight=rms_norm_weight, epsilon=1e-05) getitem_29 = auto_functionalized_4[1] getitem_30 = auto_functionalized_4[2] + permute_5 = torch.ops.aten.permute.default(gemm_2_weights, [1, 0]) getitem_35 = torch.matmul(getitem_29, permute_5) getitem_30a = getitem_30.clone() - print(f"DONE CUSTOM NAIVE {getitem_30.shape}") + print(f"DONE CUSTOM NAIVE {getitem_35.shape}, {getitem_30.shape}, {getitem_30a.shape}") return getitem_35, getitem_30, getitem_30a else: group_name = torch.distributed.group.WORLD.group_name # TODO: factor out to setup @@ -186,7 +189,8 @@ def replace_final(arg227_1, getitem_1215, getitem_1209, arg228_1): if should_slice(getitem_1209.shape): group_name = torch.distributed.group.WORLD.group_name # factor out? - all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_1209, 2, group_name) + world_size = 2 # factor out + all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_1209, world_size, group_name) wait_tensor = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor) else: wait_tensor = getitem_1209 @@ -198,12 +202,13 @@ def replace_final(arg227_1, getitem_1215, getitem_1209, arg228_1): my_patterns: Optional[PatternMatcherPass] = None my_patterns2: Optional[PatternMatcherPass] = None +matches: List[Match] = [] def get_matches(): - global my_patterns, my_patterns2 - matches = [] + global my_patterns, my_patterns2, matches def record_match_fn(match: Match): + print(f"MATCHED {len(matches)}, {id(matches)}") matches.append(match) return False @@ -232,7 +237,6 @@ def record_match_fn(match: Match): fwd_only, [my_patterns2]) - return matches # find the output and the residual @@ -317,11 +321,13 @@ def dump_graph(graph: torch.fx.Graph, stage: str): def async_rewrite(graph: fx.Graph): - matches = get_matches() + global matches + rank = get_tensor_model_parallel_rank() + get_matches() matches.clear() count = my_patterns.apply(graph) - print(f"fused gemm match count = {len(matches)}") + print(f"fused gemm match count = {len(matches)} {id(matches)}") # a bit hacky if len(matches) > 0: From b654a8e202cb6f9b7f289dbc6eac97db7295e0f3 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 30 Oct 2024 19:16:01 +0000 Subject: [PATCH 22/72] move collective fusion to separate file Signed-off-by: Bill Nell --- vllm/compilation/collective_fusion.py | 497 ++++++++++++++++++++++++++ 1 file changed, 497 insertions(+) create mode 100644 vllm/compilation/collective_fusion.py diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py new file mode 100644 index 0000000000000..ba61ddc112a6f --- /dev/null +++ b/vllm/compilation/collective_fusion.py @@ -0,0 +1,497 @@ +import copy +import operator +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.fx as fx +from typing import Tuple, List, Optional + +from torch._higher_order_ops.auto_functionalize import auto_functionalized +from torch._inductor.pattern_matcher import PatternMatcherPass, register_replacement, fwd_only, joint_fwd_bwd, Match + +from vllm.distributed.parallel_state import get_tp_group, get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank +from vllm.distributed import tensor_model_parallel_all_reduce + +from vllm.logger import init_logger + +from .compile_context import get_compile_context +from .levels import CompilationLevel + +logger = init_logger(__name__) + + +FILENO=0 + + +def pprint(x): + #print(x) + pass + + +# This check is a hack, copied from linear.py +def should_slice(shape) -> bool: + n_slices = get_tensor_model_parallel_world_size() + return (shape[0] % n_slices == 0 and shape[0] >= 128) + + +def match_gemm_rs_ag_gemm(residual, + #my_residual, + gemm_1_weights, + gemm_1_activations, + rms_norm_weight, + gemm_2_weights, + ): + permute_2 = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) + mm_1 = torch.ops.aten.mm.default(gemm_1_activations, permute_2) + auto_functionalized_4 = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.inplace_all_reduce.default, tensor = mm_1, group_name = 'tp:0') # how to deal with groupname? + getitem_25 = auto_functionalized_4[1] + auto_functionalized_5 = torch.ops.higher_order.auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input = getitem_25, residual = residual, weight = rms_norm_weight, epsilon = 1e-05) + getitem_27 = auto_functionalized_5[1] + getitem_28 = auto_functionalized_5[2] # new residual + permute_3 = torch.ops.aten.permute.default(gemm_2_weights, [1, 0]) + mm_2 = torch.ops.aten.mm.default(getitem_27, permute_3) + return mm_2, getitem_28 + + +def slices(residual) -> List[torch.Tensor]: + n_slices = get_tensor_model_parallel_world_size() + residual_slices = torch.chunk(residual, n_slices, dim=0) + #pprint(f"SLICES {[r.shape for r in residual_slices]}") + return residual_slices + + +#schema_str="(Tensor(a) residual, Tensor(a) my_residual, Tensor gemm_1_weights, Tensor gemm_1_activations, Tensor rms_norm_weight, Tensor gemm_2_weights, bool first_layer) -> (Tensor, Tensor, Tensor)" + +@torch.library.custom_op("vllm::gemm_rs_ag_gemm", mutates_args=())#, schema=schema_str) +def gemm_rs_ag_gemm(residual: torch.Tensor, + my_residual: torch.Tensor, + gemm_1_weights: torch.Tensor, + gemm_1_activations: torch.Tensor, + rms_norm_weight: torch.Tensor, + gemm_2_weights: torch.Tensor, + first_layer: bool) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + print(f"CUSTOM {residual.shape}({my_residual.shape}), should_slice={should_slice(residual.shape)}, first={first_layer}") + + # this is terrible + if True: + res_slices = slices(residual) + slice_size = res_slices[get_tensor_model_parallel_rank()].shape[0] + else: + slice_size = 2048 + print(f"SLICE_SIZE = {slice_size}, orig_shape={residual.shape}, slice_shapes=[{[x.shape for x in res_slices]}]") + + if should_slice(residual.shape) and first_layer: + print(f"FIRST! rank={get_tensor_model_parallel_rank()}") + split_1 = torch.ops.aten.split.Tensor(residual, slice_size) + getitem_26 = split_1[0]; split_1 = None + else: + #getitem_26 = my_residual + getitem_26 = residual + slice_size = residual.shape[0] + + if not should_slice(residual.shape): + # this branch probably broken + print("NAIVE") + permute_3 = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) + output = torch.matmul(gemm_1_activations, permute_3) + + output = tensor_model_parallel_all_reduce(output) ### + + auto_functionalized_4 = torch.ops.higher_order.auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input=output, residual=getitem_26, weight=rms_norm_weight, epsilon=1e-05) + getitem_29 = auto_functionalized_4[1] + getitem_30 = auto_functionalized_4[2] + + permute_5 = torch.ops.aten.permute.default(gemm_2_weights, [1, 0]) + getitem_35 = torch.matmul(getitem_29, permute_5) + getitem_30a = getitem_30.clone() + print(f"DONE CUSTOM NAIVE {getitem_35.shape}, {getitem_30.shape}, {getitem_30a.shape}") + return getitem_35, getitem_30, getitem_30a + else: + group_name = torch.distributed.group.WORLD.group_name # TODO: factor out to setup + permute_3 = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) + clone = torch.ops.aten.clone.default(permute_3, memory_format = torch.contiguous_format) + output = torch.ops.symm_mem.fused_matmul_reduce_scatter.default(gemm_1_activations, clone, 'avg', 0, group_name) + auto_functionalized_4 = torch.ops.higher_order.auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input=output, residual=getitem_26, weight=rms_norm_weight, epsilon=1e-05) + getitem_29 = auto_functionalized_4[1] + getitem_30 = auto_functionalized_4[2] + residual_1 = residual if first_layer else my_residual + slice_scatter_2 = torch.ops.aten.slice_scatter.default(residual_1, getitem_30, 0, 0, slice_size) + split_2 = torch.ops.aten.split.Tensor(slice_scatter_2, slice_size) + getitem_31 = split_2[0] + permute_5 = torch.ops.aten.permute.default(gemm_2_weights, [1, 0]) + clone_1 = torch.ops.aten.clone.default(permute_5, memory_format = torch.contiguous_format) + fused_all_gather_matmul = torch.ops.symm_mem.fused_all_gather_matmul.default(getitem_29, [clone_1], 0, group_name) + getitem_34 = fused_all_gather_matmul[1] + getitem_35 = getitem_34[0] + + print(f"DONE CUSTOM {getitem_35.shape}, {getitem_31.shape}, {slice_scatter_2.shape}") + return getitem_35, getitem_31.clone(), slice_scatter_2 # check if clones are needed + + +# this is wrong? do we need it? +@torch.library.register_fake("vllm::gemm_rs_ag_gemm") +def gemm_rs_ag_gemm_fake(residual: torch.Tensor, + my_residual: torch.Tensor, + gemm_1_weights: torch.Tensor, + gemm_1_activations: torch.Tensor, + rms_norm_weight: torch.Tensor, + gemm_2_weights: torch.Tensor, + first_layer: bool, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # this is terrible + if True: + res_slices = slices(residual) + slice_size = res_slices[get_tensor_model_parallel_rank()].shape[0] # can we always use rank 0? + else: + slice_size = 2048 + + if should_slice(residual.shape) and first_layer: + print(f"FIRST! rank={get_tensor_model_parallel_rank()}") + split_1 = torch.ops.aten.split.Tensor(residual, slice_size) + my_residual = split_1[0]; split_1 = None + else: + #residual = my_residual + slice_size = residual.shape[0] + + # is this type correct? seems to be + mm_res = torch.empty((gemm_1_activations.shape[0], gemm_2_weights.shape[0]), device=gemm_1_activations.device, dtype=gemm_1_activations.dtype) #??? + + print(f"DONE FAKE = {mm_res.shape}, {my_residual.shape}, {residual.shape}") + + return (mm_res, my_residual, residual) + + +# doesn't matter, only needed for signature +def replace_gemm_rs_ag_gemm(residual, gemm_1_weights, gemm_1_activations, rms_norm_weight, gemm_2_weights): + results = torch.ops.vllm.gemm_rs_ag_gemm(residual, residual, gemm_1_weights, gemm_1_activations, rms_norm_weight, gemm_2_weights) + getitem_34 = results[0] + getitem_35 = results[1] + return getitem_34, getitem_35 + + +def match_final(arg227_1, getitem_1022, getitem_1020, arg228_1): + permute_128 = torch.ops.aten.permute.default(arg227_1, [1, 0]) + mm_127 = torch.ops.aten.mm.default(getitem_1022, permute_128) + auto_functionalized_224 = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.inplace_all_reduce.default, tensor = mm_127, group_name = 'tp:0') # TODO: not same as group name + getitem_1024 = auto_functionalized_224[1] + auto_functionalized_225 = torch.ops.higher_order.auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input = getitem_1024, residual = getitem_1020, weight = arg228_1, epsilon = 1e-05) + getitem_1026 = auto_functionalized_225[1] + return getitem_1026 + + +def replace_final(arg227_1, getitem_1215, getitem_1209, arg228_1): + tp_group_name = "tp:0" # f"tp:{group_name}" # TODO: not same as group name + + permute_254 = torch.ops.aten.permute.default(arg227_1, [1, 0]) + mm_1 = torch.ops.aten.mm.default(getitem_1215, permute_254) + auto_functionalized_161 = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.inplace_all_reduce.default, tensor = mm_1, group_name = tp_group_name) + getitem_1217 = auto_functionalized_161[1] + + if should_slice(getitem_1209.shape): + group_name = torch.distributed.group.WORLD.group_name # factor out? + world_size = 2 # factor out + all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_1209, world_size, group_name) + wait_tensor = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor) + else: + wait_tensor = getitem_1209 + + auto_functionalized_162 = torch.ops.higher_order.auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input = getitem_1217, residual = wait_tensor, weight = arg228_1, epsilon = 1e-05) + getitem_1219 = auto_functionalized_162[1] + return getitem_1219 + + +my_patterns: Optional[PatternMatcherPass] = None +my_patterns2: Optional[PatternMatcherPass] = None +matches: List[Match] = [] + +def get_matches(): + global my_patterns, my_patterns2, matches + + def record_match_fn(match: Match): + #print(f"MATCHED {len(matches)}, {id(matches)}") + matches.append(match) + return False + + if not my_patterns: + my_patterns = PatternMatcherPass() + my_patterns2 = PatternMatcherPass() + + x = torch.empty([4,4], device='cuda') + w = torch.empty([4,4], device='cuda') + resid = torch.empty([4,4], device='cuda') + resid_w = torch.empty([4,4], device='cuda') + x2 = torch.empty([4,4], device='cuda') + inputs = [resid, x, w, resid_w, x2] + + register_replacement(match_gemm_rs_ag_gemm, + replace_gemm_rs_ag_gemm, + inputs, + fwd_only, + [my_patterns], + extra_check=record_match_fn) + + final_inputs = [x, w, resid, resid_w] + register_replacement(match_final, + replace_final, + final_inputs, + fwd_only, + [my_patterns2]) + + + +# find the output and the residual +def find_fn(nodes, op): + for node in reversed(nodes): + if node.op == "call_function" and node.target == op: + return node + return None + +def find_auto_fn(nodes, op): + for node in reversed(nodes): + if node.op == "call_function" and node.target == auto_functionalized and node.args[0] == op: + return node + return None + +def find_getitem(node, idx): + for user in reversed(node.users): + if user.op == "call_function" and user.target == operator.getitem and user.args[1] == idx: + return user + return None + +def process_matches(graph: fx.Graph, matches): + print(f"len = {len(matches)}") + + nodes = list(graph.nodes) + first_match = None + + def find_min_index(match) -> int: + return min(match.nodes, key=lambda x: nodes.index(x)) + + # "sort" matches in topo order + matches = sorted(matches, key=lambda x: find_min_index(x)) + + # this is pretty hacky since the order doesn't necessarily encode the dependency. + res_replacements = [] + my_res_replacements = [] + + for match in matches: + last_node_in_match = match.nodes[-1] #max(match.nodes, key=lambda x: nodes.index(x)) + + with graph.inserting_after(last_node_in_match): + kwargs = match.kwargs + kwargs["first_layer"] = match == matches[0] + kwargs["residual"] = res_replacements[-1] if len(res_replacements) > 0 else match.kwargs["residual"] + kwargs["my_residual"] = my_res_replacements[-1] if len(my_res_replacements) > 0 else match.kwargs["residual"] + fused_node = graph.call_function(torch.ops.vllm.gemm_rs_ag_gemm.default, kwargs=kwargs) + + graph.inserting_after(fused_node) + result_node_new = graph.call_function(operator.getitem, (fused_node, 0)) + residual_node_new = graph.call_function(operator.getitem, (fused_node, 1)) + my_residual_node_new = graph.call_function(operator.getitem, (fused_node, 2)) + res_replacements.append(residual_node_new) + my_res_replacements.append(my_residual_node_new) + + rms_node = find_auto_fn(match.nodes, torch.ops._C.fused_add_rms_norm.default) + gemm_node = find_fn(match.nodes, torch.ops.aten.mm.default) + if gemm_node is None: + gemm_node = find_fn(match.nodes, torch.ops.symm_mem.fused_all_gather_matmul.default) + assert rms_node is not None + assert gemm_node is not None + + #assert len(rms_node.users) == 2 + #assert len(gemm_node.users) == 1 + + # meta["val"] is used by de-functionalization + rms_val = rms_node.meta["val"] + gemm_val = gemm_node.meta["val"] + fused_node.meta["val"] = (gemm_val, rms_val[2]) + + find_getitem(rms_node, 2).replace_all_uses_with(residual_node_new) + gemm_node.replace_all_uses_with(result_node_new) + + # Finally, remove matched nodes + graph.eliminate_dead_code() + assert all(node not in graph.nodes for match in matches for node in match.nodes) + + +def dump_graph(graph: torch.fx.Graph, stage: str): + logger.info("Printing graph to %s", f"{stage}.py") + with open(f"{stage}.py", "w") as f: + print(graph.python_code(root_module="self", verbose=True).src, file=f) + + +def async_rewrite(graph: fx.Graph): + global matches + rank = get_tensor_model_parallel_rank() + get_matches() + matches.clear() + + count = my_patterns.apply(graph) + print(f"fused gemm match count = {len(matches)} {id(matches)}") + + # a bit hacky + if len(matches) > 0: + count = my_patterns2.apply(graph) + print(f"final match count = {count}") + process_matches(graph, matches) + + return graph + + +def fix_functionalization(graph: fx.Graph): + """ + Rewrite the graph module to replace the pattern involving + torch._higher_order_ops.auto_functionalize.auto_functionalized + with a direct call to the inplace custom op. + + # TODO: check if PyTorch nightly has fixed this issue + """ + + # debug code, if we want to see the graph before the transformation + # with open("before.py", "w") as f: + # print(graph.python_code(root_module="self", verbose=True).src, file=f) + + nodes_to_remove = [] + + for node in graph.nodes: + # Identify the auto_functionalized node + if node.op == 'call_function' and node.target == torch._higher_order_ops.auto_functionalize.auto_functionalized: # noqa + if node.args[0] == torch.ops._C.rotary_embedding.default: + # manual replace for rotary_embedding + + # Now, collect the arguments + kwargs = node.kwargs + + query = kwargs['query'] + mm_node = query.args[0].args[0] + + # Create a new call to torch.ops._C.rotary_embedding.default + with graph.inserting_before(node): + # just insert the call to the custom op + # NOTE: don't run dead code elimination, + # otherwise this op will be removed + graph.call_function(torch.ops._C.rotary_embedding.default, + kwargs=kwargs) + + # Remove the auto_functionalized node + # Since the node may have outputs, we need to handle its users + # Replace uses of the outputs (getitem nodes) with mm_node + for user in list(node.users): + if user.op == 'call_function' and user.target == operator.getitem: # noqa + # Remove the getitem node + for getitem_user in list(user.users): + if (getitem_user.op == 'call_function' + and getitem_user.target + == torch.ops.aten.slice_scatter.default): + # Replace the uses of slice_scatter node + # with mm_node + getitem_user.replace_all_uses_with(mm_node) + nodes_to_remove.append(getitem_user) + nodes_to_remove.append(user) + nodes_to_remove.append(node) + + elif node.args[0] == torch.ops._C.fused_add_rms_norm.default: + # manual replace for fused_add_rms_norm + # this is the most effective optimization for llama + # failing to do this will result in many unnecessary copies + + kwargs = node.kwargs + + input = kwargs['input'] + residual = kwargs['residual'] + + # Create a new call to torch.ops._C.rotary_embedding.default + with graph.inserting_before(node): + # just insert the call to the custom op + # NOTE: don't run dead code elimination, + # otherwise this op will be removed + graph.call_function( + torch.ops._C.fused_add_rms_norm.default, kwargs=kwargs) + + for user in list(node.users): + if user.op == 'call_function' and user.target == operator.getitem: # noqa + # Remove the getitem node + if user.args[1] == 1: + replace_node = input + elif user.args[1] == 2: + replace_node = residual + user.replace_all_uses_with(replace_node) + nodes_to_remove.append(user) + nodes_to_remove.append(node) + + elif node.args[0] == torch.ops._C.rms_norm.default: + # manual replace for rms_norm + + kwargs = node.kwargs + + input = kwargs['input'] + out = kwargs['out'] + weight = kwargs['weight'] + epsilon = kwargs['epsilon'] + # Create a new call to torch.ops._C.rotary_embedding.default + # cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa + with graph.inserting_before(node): + # just insert the call to the custom op + # NOTE: don't run dead code elimination, + # otherwise this op will be removed + graph.call_function( + torch.ops._C.rms_norm.default, + args=(out, input, weight, epsilon), + ) + + replace_node = out + + for user in list(node.users): + if user.op == 'call_function' and user.target == operator.getitem: # noqa + user.replace_all_uses_with(replace_node) + nodes_to_remove.append(user) + nodes_to_remove.append(node) + + elif node.args[0] == torch.ops._C.silu_and_mul.default: + # manual replace for silu_and_mul + + kwargs = node.kwargs + + input = kwargs['input'] + out = kwargs['out'] + + # Create a new call to torch.ops._C.rotary_embedding.default + # cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa + with graph.inserting_before(node): + # just insert the call to the custom op + # NOTE: don't run dead code elimination, + # otherwise this op will be removed + graph.call_function( + torch.ops._C.silu_and_mul.default, + args=(out, input), + ) + replace_node = out + + for user in list(node.users): + if user.op == 'call_function' and user.target == operator.getitem: # noqa + user.replace_all_uses_with(replace_node) + nodes_to_remove.append(user) + nodes_to_remove.append(node) + + # Remove the nodes all at once + for node in nodes_to_remove: + graph.erase_node(node) + + +def collective_fusion(graph: fx.Graph): + global matches + rank = get_tensor_model_parallel_rank() + get_matches() + matches.clear() + + count = my_patterns.apply(graph) + print(f"fused gemm match count = {len(matches)} {id(matches)}") + + # a bit hacky + if len(matches) > 0: + count = my_patterns2.apply(graph) + print(f"final match count = {count}") + process_matches(graph, matches) + + return graph + From 54dde90d0d9cb0868b4c57481c00dea98508d87a Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 30 Oct 2024 19:20:41 +0000 Subject: [PATCH 23/72] move collective fusion to separate file Signed-off-by: Bill Nell --- vllm/compilation/collective_fusion.py | 171 +------------------------- 1 file changed, 3 insertions(+), 168 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index ba61ddc112a6f..92b94108e4ba7 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -178,7 +178,7 @@ def match_final(arg227_1, getitem_1022, getitem_1020, arg228_1): getitem_1026 = auto_functionalized_225[1] return getitem_1026 - +# TODO: wrap in custom op to prevent infinite recursion in inductor logging statement? def replace_final(arg227_1, getitem_1215, getitem_1209, arg228_1): tp_group_name = "tp:0" # f"tp:{group_name}" # TODO: not same as group name @@ -188,8 +188,8 @@ def replace_final(arg227_1, getitem_1215, getitem_1209, arg228_1): getitem_1217 = auto_functionalized_161[1] if should_slice(getitem_1209.shape): - group_name = torch.distributed.group.WORLD.group_name # factor out? - world_size = 2 # factor out + group_name = torch.distributed.group.WORLD.group_name # TODO: factor out + world_size = get_tp_group().world_size # TODO: factor out all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_1209, world_size, group_name) wait_tensor = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor) else: @@ -208,7 +208,6 @@ def get_matches(): global my_patterns, my_patterns2, matches def record_match_fn(match: Match): - #print(f"MATCHED {len(matches)}, {id(matches)}") matches.append(match) return False @@ -314,170 +313,6 @@ def find_min_index(match) -> int: assert all(node not in graph.nodes for match in matches for node in match.nodes) -def dump_graph(graph: torch.fx.Graph, stage: str): - logger.info("Printing graph to %s", f"{stage}.py") - with open(f"{stage}.py", "w") as f: - print(graph.python_code(root_module="self", verbose=True).src, file=f) - - -def async_rewrite(graph: fx.Graph): - global matches - rank = get_tensor_model_parallel_rank() - get_matches() - matches.clear() - - count = my_patterns.apply(graph) - print(f"fused gemm match count = {len(matches)} {id(matches)}") - - # a bit hacky - if len(matches) > 0: - count = my_patterns2.apply(graph) - print(f"final match count = {count}") - process_matches(graph, matches) - - return graph - - -def fix_functionalization(graph: fx.Graph): - """ - Rewrite the graph module to replace the pattern involving - torch._higher_order_ops.auto_functionalize.auto_functionalized - with a direct call to the inplace custom op. - - # TODO: check if PyTorch nightly has fixed this issue - """ - - # debug code, if we want to see the graph before the transformation - # with open("before.py", "w") as f: - # print(graph.python_code(root_module="self", verbose=True).src, file=f) - - nodes_to_remove = [] - - for node in graph.nodes: - # Identify the auto_functionalized node - if node.op == 'call_function' and node.target == torch._higher_order_ops.auto_functionalize.auto_functionalized: # noqa - if node.args[0] == torch.ops._C.rotary_embedding.default: - # manual replace for rotary_embedding - - # Now, collect the arguments - kwargs = node.kwargs - - query = kwargs['query'] - mm_node = query.args[0].args[0] - - # Create a new call to torch.ops._C.rotary_embedding.default - with graph.inserting_before(node): - # just insert the call to the custom op - # NOTE: don't run dead code elimination, - # otherwise this op will be removed - graph.call_function(torch.ops._C.rotary_embedding.default, - kwargs=kwargs) - - # Remove the auto_functionalized node - # Since the node may have outputs, we need to handle its users - # Replace uses of the outputs (getitem nodes) with mm_node - for user in list(node.users): - if user.op == 'call_function' and user.target == operator.getitem: # noqa - # Remove the getitem node - for getitem_user in list(user.users): - if (getitem_user.op == 'call_function' - and getitem_user.target - == torch.ops.aten.slice_scatter.default): - # Replace the uses of slice_scatter node - # with mm_node - getitem_user.replace_all_uses_with(mm_node) - nodes_to_remove.append(getitem_user) - nodes_to_remove.append(user) - nodes_to_remove.append(node) - - elif node.args[0] == torch.ops._C.fused_add_rms_norm.default: - # manual replace for fused_add_rms_norm - # this is the most effective optimization for llama - # failing to do this will result in many unnecessary copies - - kwargs = node.kwargs - - input = kwargs['input'] - residual = kwargs['residual'] - - # Create a new call to torch.ops._C.rotary_embedding.default - with graph.inserting_before(node): - # just insert the call to the custom op - # NOTE: don't run dead code elimination, - # otherwise this op will be removed - graph.call_function( - torch.ops._C.fused_add_rms_norm.default, kwargs=kwargs) - - for user in list(node.users): - if user.op == 'call_function' and user.target == operator.getitem: # noqa - # Remove the getitem node - if user.args[1] == 1: - replace_node = input - elif user.args[1] == 2: - replace_node = residual - user.replace_all_uses_with(replace_node) - nodes_to_remove.append(user) - nodes_to_remove.append(node) - - elif node.args[0] == torch.ops._C.rms_norm.default: - # manual replace for rms_norm - - kwargs = node.kwargs - - input = kwargs['input'] - out = kwargs['out'] - weight = kwargs['weight'] - epsilon = kwargs['epsilon'] - # Create a new call to torch.ops._C.rotary_embedding.default - # cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa - with graph.inserting_before(node): - # just insert the call to the custom op - # NOTE: don't run dead code elimination, - # otherwise this op will be removed - graph.call_function( - torch.ops._C.rms_norm.default, - args=(out, input, weight, epsilon), - ) - - replace_node = out - - for user in list(node.users): - if user.op == 'call_function' and user.target == operator.getitem: # noqa - user.replace_all_uses_with(replace_node) - nodes_to_remove.append(user) - nodes_to_remove.append(node) - - elif node.args[0] == torch.ops._C.silu_and_mul.default: - # manual replace for silu_and_mul - - kwargs = node.kwargs - - input = kwargs['input'] - out = kwargs['out'] - - # Create a new call to torch.ops._C.rotary_embedding.default - # cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa - with graph.inserting_before(node): - # just insert the call to the custom op - # NOTE: don't run dead code elimination, - # otherwise this op will be removed - graph.call_function( - torch.ops._C.silu_and_mul.default, - args=(out, input), - ) - replace_node = out - - for user in list(node.users): - if user.op == 'call_function' and user.target == operator.getitem: # noqa - user.replace_all_uses_with(replace_node) - nodes_to_remove.append(user) - nodes_to_remove.append(node) - - # Remove the nodes all at once - for node in nodes_to_remove: - graph.erase_node(node) - - def collective_fusion(graph: fx.Graph): global matches rank = get_tensor_model_parallel_rank() From e4b387153978c86cf430100569a8fb9de9e85ca4 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 30 Oct 2024 21:35:02 +0000 Subject: [PATCH 24/72] fix fake function Signed-off-by: Bill Nell --- vllm/compilation/collective_fusion.py | 97 +++++++++++++++------------ vllm/distributed/parallel_state.py | 2 - 2 files changed, 53 insertions(+), 46 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 92b94108e4ba7..97225b1e442c4 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -1,10 +1,8 @@ -import copy import operator from typing import Callable, Dict, List, Optional, Tuple, Union import torch import torch.fx as fx -from typing import Tuple, List, Optional from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._inductor.pattern_matcher import PatternMatcherPass, register_replacement, fwd_only, joint_fwd_bwd, Match @@ -14,15 +12,9 @@ from vllm.logger import init_logger -from .compile_context import get_compile_context -from .levels import CompilationLevel - logger = init_logger(__name__) -FILENO=0 - - def pprint(x): #print(x) pass @@ -35,7 +27,6 @@ def should_slice(shape) -> bool: def match_gemm_rs_ag_gemm(residual, - #my_residual, gemm_1_weights, gemm_1_activations, rms_norm_weight, @@ -60,9 +51,7 @@ def slices(residual) -> List[torch.Tensor]: return residual_slices -#schema_str="(Tensor(a) residual, Tensor(a) my_residual, Tensor gemm_1_weights, Tensor gemm_1_activations, Tensor rms_norm_weight, Tensor gemm_2_weights, bool first_layer) -> (Tensor, Tensor, Tensor)" - -@torch.library.custom_op("vllm::gemm_rs_ag_gemm", mutates_args=())#, schema=schema_str) +@torch.library.custom_op("vllm::gemm_rs_ag_gemm", mutates_args=()) def gemm_rs_ag_gemm(residual: torch.Tensor, my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, @@ -72,12 +61,11 @@ def gemm_rs_ag_gemm(residual: torch.Tensor, first_layer: bool) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: print(f"CUSTOM {residual.shape}({my_residual.shape}), should_slice={should_slice(residual.shape)}, first={first_layer}") + #### # this is terrible - if True: - res_slices = slices(residual) - slice_size = res_slices[get_tensor_model_parallel_rank()].shape[0] - else: - slice_size = 2048 + res_slices = slices(residual) + slice_size = res_slices[get_tensor_model_parallel_rank()].shape[0] + #### print(f"SLICE_SIZE = {slice_size}, orig_shape={residual.shape}, slice_shapes=[{[x.shape for x in res_slices]}]") if should_slice(residual.shape) and first_layer: @@ -104,7 +92,7 @@ def gemm_rs_ag_gemm(residual: torch.Tensor, permute_5 = torch.ops.aten.permute.default(gemm_2_weights, [1, 0]) getitem_35 = torch.matmul(getitem_29, permute_5) getitem_30a = getitem_30.clone() - print(f"DONE CUSTOM NAIVE {getitem_35.shape}, {getitem_30.shape}, {getitem_30a.shape}") + print(f"DONE CUSTOM NAIVE {getitem_35.shape}, {getitem_30.shape}, {getitem_30a.shape} {first_layer}") return getitem_35, getitem_30, getitem_30a else: group_name = torch.distributed.group.WORLD.group_name # TODO: factor out to setup @@ -124,11 +112,10 @@ def gemm_rs_ag_gemm(residual: torch.Tensor, getitem_34 = fused_all_gather_matmul[1] getitem_35 = getitem_34[0] - print(f"DONE CUSTOM {getitem_35.shape}, {getitem_31.shape}, {slice_scatter_2.shape}") + print(f"DONE CUSTOM {getitem_35.shape}, {getitem_31.shape}, {slice_scatter_2.shape} {first_layer}") return getitem_35, getitem_31.clone(), slice_scatter_2 # check if clones are needed -# this is wrong? do we need it? @torch.library.register_fake("vllm::gemm_rs_ag_gemm") def gemm_rs_ag_gemm_fake(residual: torch.Tensor, my_residual: torch.Tensor, @@ -138,30 +125,31 @@ def gemm_rs_ag_gemm_fake(residual: torch.Tensor, gemm_2_weights: torch.Tensor, first_layer: bool, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ##### # this is terrible - if True: - res_slices = slices(residual) - slice_size = res_slices[get_tensor_model_parallel_rank()].shape[0] # can we always use rank 0? - else: - slice_size = 2048 + res_slices = slices(residual) + slice_size = res_slices[get_tensor_model_parallel_rank()].shape[0] # can we use rank 0 instead? + ##### - if should_slice(residual.shape) and first_layer: +# if should_slice(residual.shape) and first_layer: + if should_slice(gemm_1_activations.shape) and first_layer: print(f"FIRST! rank={get_tensor_model_parallel_rank()}") split_1 = torch.ops.aten.split.Tensor(residual, slice_size) my_residual = split_1[0]; split_1 = None else: #residual = my_residual - slice_size = residual.shape[0] + #slice_size = residual.shape[0] + my_residual = residual - # is this type correct? seems to be - mm_res = torch.empty((gemm_1_activations.shape[0], gemm_2_weights.shape[0]), device=gemm_1_activations.device, dtype=gemm_1_activations.dtype) #??? + # verify the type is always correct + mm_res = torch.empty((gemm_1_activations.shape[0], gemm_2_weights.shape[0]), device=gemm_1_activations.device, dtype=gemm_1_activations.dtype) - print(f"DONE FAKE = {mm_res.shape}, {my_residual.shape}, {residual.shape}") + print(f"DONE FAKE = {mm_res.shape}, {my_residual.shape}, {residual.shape} {first_layer}") return (mm_res, my_residual, residual) -# doesn't matter, only needed for signature +# implementation doesn't matter, only needed for signature def replace_gemm_rs_ag_gemm(residual, gemm_1_weights, gemm_1_activations, rms_norm_weight, gemm_2_weights): results = torch.ops.vllm.gemm_rs_ag_gemm(residual, residual, gemm_1_weights, gemm_1_activations, rms_norm_weight, gemm_2_weights) getitem_34 = results[0] @@ -169,33 +157,54 @@ def replace_gemm_rs_ag_gemm(residual, gemm_1_weights, gemm_1_activations, rms_no return getitem_34, getitem_35 -def match_final(arg227_1, getitem_1022, getitem_1020, arg228_1): - permute_128 = torch.ops.aten.permute.default(arg227_1, [1, 0]) - mm_127 = torch.ops.aten.mm.default(getitem_1022, permute_128) - auto_functionalized_224 = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.inplace_all_reduce.default, tensor = mm_127, group_name = 'tp:0') # TODO: not same as group name +def match_final(my_residual, gemm_1_weights, gemm_1_activations, rms_norm_weights): + permute_128 = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) + mm_127 = torch.ops.aten.mm.default(gemm_1_activations, permute_128) + + auto_functionalized_224 = torch.ops.higher_order.auto_functionalized( + torch.ops.vllm.inplace_all_reduce.default, + tensor = mm_127, + group_name = 'tp:0' # TODO: not same as group name + ) getitem_1024 = auto_functionalized_224[1] - auto_functionalized_225 = torch.ops.higher_order.auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input = getitem_1024, residual = getitem_1020, weight = arg228_1, epsilon = 1e-05) + + auto_functionalized_225 = torch.ops.higher_order.auto_functionalized( + torch.ops._C.fused_add_rms_norm.default, + input = getitem_1024, + residual = my_residual, + weight = rms_norm_weights, + epsilon = 1e-05 + ) getitem_1026 = auto_functionalized_225[1] + return getitem_1026 + # TODO: wrap in custom op to prevent infinite recursion in inductor logging statement? -def replace_final(arg227_1, getitem_1215, getitem_1209, arg228_1): +def replace_final(my_residual, gemm_1_weights, gemm_1_activations, rms_norm_weights): tp_group_name = "tp:0" # f"tp:{group_name}" # TODO: not same as group name - permute_254 = torch.ops.aten.permute.default(arg227_1, [1, 0]) - mm_1 = torch.ops.aten.mm.default(getitem_1215, permute_254) + permute_254 = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) + mm_1 = torch.ops.aten.mm.default(gemm_1_activations, permute_254) auto_functionalized_161 = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.inplace_all_reduce.default, tensor = mm_1, group_name = tp_group_name) getitem_1217 = auto_functionalized_161[1] - if should_slice(getitem_1209.shape): + # is this the right thing to call it on? + if should_slice(gemm_1_activations.shape): group_name = torch.distributed.group.WORLD.group_name # TODO: factor out world_size = get_tp_group().world_size # TODO: factor out - all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_1209, world_size, group_name) + all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor.default(my_residual, world_size, group_name) wait_tensor = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor) else: - wait_tensor = getitem_1209 + wait_tensor = my_residual + + auto_functionalized_162 = torch.ops.higher_order.auto_functionalized( + torch.ops._C.fused_add_rms_norm.default, + input = getitem_1217, + residual = wait_tensor, + weight = rms_norm_weights, + epsilon = 1e-05) - auto_functionalized_162 = torch.ops.higher_order.auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input = getitem_1217, residual = wait_tensor, weight = arg228_1, epsilon = 1e-05) getitem_1219 = auto_functionalized_162[1] return getitem_1219 @@ -223,7 +232,7 @@ def record_match_fn(match: Match): inputs = [resid, x, w, resid_w, x2] register_replacement(match_gemm_rs_ag_gemm, - replace_gemm_rs_ag_gemm, + match_gemm_rs_ag_gemm, inputs, fwd_only, [my_patterns], diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index c2fbe70d3e7d9..e7aac04f42bf8 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -966,7 +966,6 @@ def init_distributed_environment( init_method=distributed_init_method, world_size=world_size, rank=rank) - print(f"INIT {backend}, {distributed_init_method}, {world_size}, {rank}") # set the local rank # local_rank is not available in torch ProcessGroup, # see https://github.com/pytorch/pytorch/issues/122816 @@ -986,7 +985,6 @@ def init_distributed_environment( "world group already initialized with a different world size") group_name = torch.distributed.group.WORLD.group_name - print(f"WORLD! {group_name}") _symmetric_memory.enable_symm_mem_for_group(group_name) From 3468420a7f57caea77aac7710cca5ed6d766b9ca Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 30 Oct 2024 23:49:40 +0000 Subject: [PATCH 25/72] use InductorPass from @ProExpertProg's PR Signed-off-by: Bill Nell --- vllm/compilation/collective_fusion.py | 228 ++++++++++++-------------- vllm/envs.py | 5 + 2 files changed, 114 insertions(+), 119 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 97225b1e442c4..99319d8c6dca2 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -7,8 +7,10 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._inductor.pattern_matcher import PatternMatcherPass, register_replacement, fwd_only, joint_fwd_bwd, Match +from vllm.compilation.inductor_pass import InductorPass from vllm.distributed.parallel_state import get_tp_group, get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank from vllm.distributed import tensor_model_parallel_all_reduce +from vllm.model_executor.layers.linear import should_slice from vllm.logger import init_logger @@ -20,12 +22,6 @@ def pprint(x): pass -# This check is a hack, copied from linear.py -def should_slice(shape) -> bool: - n_slices = get_tensor_model_parallel_world_size() - return (shape[0] % n_slices == 0 and shape[0] >= 128) - - def match_gemm_rs_ag_gemm(residual, gemm_1_weights, gemm_1_activations, @@ -59,17 +55,17 @@ def gemm_rs_ag_gemm(residual: torch.Tensor, rms_norm_weight: torch.Tensor, gemm_2_weights: torch.Tensor, first_layer: bool) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - print(f"CUSTOM {residual.shape}({my_residual.shape}), should_slice={should_slice(residual.shape)}, first={first_layer}") + pprint(f"CUSTOM {residual.shape}({my_residual.shape}), should_slice={should_slice(residual.shape)}, first={first_layer}") #### # this is terrible res_slices = slices(residual) slice_size = res_slices[get_tensor_model_parallel_rank()].shape[0] #### - print(f"SLICE_SIZE = {slice_size}, orig_shape={residual.shape}, slice_shapes=[{[x.shape for x in res_slices]}]") + pprint(f"SLICE_SIZE = {slice_size}, orig_shape={residual.shape}, slice_shapes=[{[x.shape for x in res_slices]}]") if should_slice(residual.shape) and first_layer: - print(f"FIRST! rank={get_tensor_model_parallel_rank()}") + pprint(f"FIRST! rank={get_tensor_model_parallel_rank()}") split_1 = torch.ops.aten.split.Tensor(residual, slice_size) getitem_26 = split_1[0]; split_1 = None else: @@ -79,7 +75,7 @@ def gemm_rs_ag_gemm(residual: torch.Tensor, if not should_slice(residual.shape): # this branch probably broken - print("NAIVE") + pprint("NAIVE") permute_3 = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) output = torch.matmul(gemm_1_activations, permute_3) @@ -92,7 +88,7 @@ def gemm_rs_ag_gemm(residual: torch.Tensor, permute_5 = torch.ops.aten.permute.default(gemm_2_weights, [1, 0]) getitem_35 = torch.matmul(getitem_29, permute_5) getitem_30a = getitem_30.clone() - print(f"DONE CUSTOM NAIVE {getitem_35.shape}, {getitem_30.shape}, {getitem_30a.shape} {first_layer}") + pprint(f"DONE CUSTOM NAIVE {getitem_35.shape}, {getitem_30.shape}, {getitem_30a.shape} {first_layer}") return getitem_35, getitem_30, getitem_30a else: group_name = torch.distributed.group.WORLD.group_name # TODO: factor out to setup @@ -112,7 +108,8 @@ def gemm_rs_ag_gemm(residual: torch.Tensor, getitem_34 = fused_all_gather_matmul[1] getitem_35 = getitem_34[0] - print(f"DONE CUSTOM {getitem_35.shape}, {getitem_31.shape}, {slice_scatter_2.shape} {first_layer}") + pprint(f"DONE CUSTOM {getitem_35.shape}, {getitem_31.shape}, {slice_scatter_2.shape} {first_layer}") + # TODO: can we avoid clone here? return getitem_35, getitem_31.clone(), slice_scatter_2 # check if clones are needed @@ -133,18 +130,16 @@ def gemm_rs_ag_gemm_fake(residual: torch.Tensor, # if should_slice(residual.shape) and first_layer: if should_slice(gemm_1_activations.shape) and first_layer: - print(f"FIRST! rank={get_tensor_model_parallel_rank()}") + pprint(f"FIRST! rank={get_tensor_model_parallel_rank()}") split_1 = torch.ops.aten.split.Tensor(residual, slice_size) my_residual = split_1[0]; split_1 = None else: - #residual = my_residual - #slice_size = residual.shape[0] my_residual = residual # verify the type is always correct mm_res = torch.empty((gemm_1_activations.shape[0], gemm_2_weights.shape[0]), device=gemm_1_activations.device, dtype=gemm_1_activations.dtype) - print(f"DONE FAKE = {mm_res.shape}, {my_residual.shape}, {residual.shape} {first_layer}") + pprint(f"DONE FAKE = {mm_res.shape}, {my_residual.shape}, {residual.shape} {first_layer}") return (mm_res, my_residual, residual) @@ -209,43 +204,6 @@ def replace_final(my_residual, gemm_1_weights, gemm_1_activations, rms_norm_weig return getitem_1219 -my_patterns: Optional[PatternMatcherPass] = None -my_patterns2: Optional[PatternMatcherPass] = None -matches: List[Match] = [] - -def get_matches(): - global my_patterns, my_patterns2, matches - - def record_match_fn(match: Match): - matches.append(match) - return False - - if not my_patterns: - my_patterns = PatternMatcherPass() - my_patterns2 = PatternMatcherPass() - - x = torch.empty([4,4], device='cuda') - w = torch.empty([4,4], device='cuda') - resid = torch.empty([4,4], device='cuda') - resid_w = torch.empty([4,4], device='cuda') - x2 = torch.empty([4,4], device='cuda') - inputs = [resid, x, w, resid_w, x2] - - register_replacement(match_gemm_rs_ag_gemm, - match_gemm_rs_ag_gemm, - inputs, - fwd_only, - [my_patterns], - extra_check=record_match_fn) - - final_inputs = [x, w, resid, resid_w] - register_replacement(match_final, - replace_final, - final_inputs, - fwd_only, - [my_patterns2]) - - # find the output and the residual def find_fn(nodes, op): @@ -254,88 +212,120 @@ def find_fn(nodes, op): return node return None + def find_auto_fn(nodes, op): for node in reversed(nodes): if node.op == "call_function" and node.target == auto_functionalized and node.args[0] == op: return node return None + def find_getitem(node, idx): for user in reversed(node.users): if user.op == "call_function" and user.target == operator.getitem and user.args[1] == idx: return user return None -def process_matches(graph: fx.Graph, matches): - print(f"len = {len(matches)}") - - nodes = list(graph.nodes) - first_match = None - - def find_min_index(match) -> int: - return min(match.nodes, key=lambda x: nodes.index(x)) - - # "sort" matches in topo order - matches = sorted(matches, key=lambda x: find_min_index(x)) - - # this is pretty hacky since the order doesn't necessarily encode the dependency. - res_replacements = [] - my_res_replacements = [] - - for match in matches: - last_node_in_match = match.nodes[-1] #max(match.nodes, key=lambda x: nodes.index(x)) - - with graph.inserting_after(last_node_in_match): - kwargs = match.kwargs - kwargs["first_layer"] = match == matches[0] - kwargs["residual"] = res_replacements[-1] if len(res_replacements) > 0 else match.kwargs["residual"] - kwargs["my_residual"] = my_res_replacements[-1] if len(my_res_replacements) > 0 else match.kwargs["residual"] - fused_node = graph.call_function(torch.ops.vllm.gemm_rs_ag_gemm.default, kwargs=kwargs) - graph.inserting_after(fused_node) - result_node_new = graph.call_function(operator.getitem, (fused_node, 0)) - residual_node_new = graph.call_function(operator.getitem, (fused_node, 1)) - my_residual_node_new = graph.call_function(operator.getitem, (fused_node, 2)) - res_replacements.append(residual_node_new) - my_res_replacements.append(my_residual_node_new) - - rms_node = find_auto_fn(match.nodes, torch.ops._C.fused_add_rms_norm.default) - gemm_node = find_fn(match.nodes, torch.ops.aten.mm.default) - if gemm_node is None: - gemm_node = find_fn(match.nodes, torch.ops.symm_mem.fused_all_gather_matmul.default) - assert rms_node is not None - assert gemm_node is not None - - #assert len(rms_node.users) == 2 - #assert len(gemm_node.users) == 1 - - # meta["val"] is used by de-functionalization - rms_val = rms_node.meta["val"] - gemm_val = gemm_node.meta["val"] - fused_node.meta["val"] = (gemm_val, rms_val[2]) - - find_getitem(rms_node, 2).replace_all_uses_with(residual_node_new) - gemm_node.replace_all_uses_with(result_node_new) - - # Finally, remove matched nodes - graph.eliminate_dead_code() - assert all(node not in graph.nodes for match in matches for node in match.nodes) +class CollectiveFusionPass(InductorPass): + def __init__(self): + self.my_patterns = PatternMatcherPass() + self.my_patterns2 = PatternMatcherPass() + self.matches: List[Match] = [] + x = torch.empty([4,4], device='cuda') + w = torch.empty([4,4], device='cuda') + resid = torch.empty([4,4], device='cuda') + resid_w = torch.empty([4,4], device='cuda') + x2 = torch.empty([4,4], device='cuda') + inputs = [resid, x, w, resid_w, x2] -def collective_fusion(graph: fx.Graph): - global matches - rank = get_tensor_model_parallel_rank() - get_matches() - matches.clear() + register_replacement(match_gemm_rs_ag_gemm, + match_gemm_rs_ag_gemm, + inputs, + fwd_only, + [self.my_patterns], + extra_check=lambda m: self.record_match(m)) - count = my_patterns.apply(graph) - print(f"fused gemm match count = {len(matches)} {id(matches)}") + final_inputs = [x, w, resid, resid_w] + register_replacement(match_final, + replace_final, + final_inputs, + fwd_only, + [self.my_patterns2]) - # a bit hacky - if len(matches) > 0: - count = my_patterns2.apply(graph) - print(f"final match count = {count}") - process_matches(graph, matches) + def record_match(self, match: Match) -> bool: + # Hijack the extra_check to record the match and + # save it for post-processing. + self.matches.append(match) - return graph + # Return False to prevent automatic replacement. + return False + def process_matches(self, graph: fx.Graph): + pprint(f"len = {len(self.matches)}") + + nodes = list(graph.nodes) + first_match = None + + def find_min_index(match) -> int: + return min(match.nodes, key=lambda x: nodes.index(x)) + + # "sort" matches in topo order. + matches = sorted(self.matches, key=lambda x: find_min_index(x)) + + res_replacements = [] + my_res_replacements = [] + + for match in matches: + last_node_in_match = match.nodes[-1] #max(match.nodes, key=lambda x: nodes.index(x)) + + with graph.inserting_after(last_node_in_match): + kwargs = match.kwargs + kwargs["first_layer"] = match == matches[0] + kwargs["residual"] = res_replacements[-1] if len(res_replacements) > 0 else match.kwargs["residual"] + kwargs["my_residual"] = my_res_replacements[-1] if len(my_res_replacements) > 0 else match.kwargs["residual"] + fused_node = graph.call_function(torch.ops.vllm.gemm_rs_ag_gemm.default, kwargs=kwargs) + + graph.inserting_after(fused_node) + result_node_new = graph.call_function(operator.getitem, (fused_node, 0)) + residual_node_new = graph.call_function(operator.getitem, (fused_node, 1)) + my_residual_node_new = graph.call_function(operator.getitem, (fused_node, 2)) + res_replacements.append(residual_node_new) + my_res_replacements.append(my_residual_node_new) + + rms_node = find_auto_fn(match.nodes, torch.ops._C.fused_add_rms_norm.default) + gemm_node = find_fn(match.nodes, torch.ops.aten.mm.default) + if gemm_node is None: + gemm_node = find_fn(match.nodes, torch.ops.symm_mem.fused_all_gather_matmul.default) + assert rms_node is not None + assert gemm_node is not None + + #assert len(rms_node.users) == 2 + #assert len(gemm_node.users) == 1 + + # meta["val"] is used by de-functionalization + rms_val = rms_node.meta["val"] + gemm_val = gemm_node.meta["val"] + fused_node.meta["val"] = (gemm_val, rms_val[2]) + + find_getitem(rms_node, 2).replace_all_uses_with(residual_node_new) + gemm_node.replace_all_uses_with(result_node_new) + + # Finally, remove matched nodes + graph.eliminate_dead_code() + assert all(node not in graph.nodes for match in matches for node in match.nodes) + + def __call__(self, graph: torch.fx.Graph): + self.dump_graph(graph, "before_collective_fusion") + count = self.my_patterns.apply(graph) + logger.info(f"fused gemm match count = {len(self.matches)}") + + # Don't apply final pattern unless we've matched and replaced the gemm+collective ops. + if len(self.matches) > 0: + count =self. my_patterns2.apply(graph) + logger.info(f"final match count = {count}") + self.process_matches(graph) + + self.dump_graph(graph, "after_collective_fusion") + self.matches.clear() diff --git a/vllm/envs.py b/vllm/envs.py index c896770e5f6bc..9b8aa2934104e 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -221,6 +221,11 @@ def get_default_config_root(): "VLLM_ENGINE_ITERATION_TIMEOUT_S": lambda: int(os.environ.get("VLLM_ENGINE_ITERATION_TIMEOUT_S", "60")), + # Internal flag for dumping the model graph at different stages of + # custom pass compilation + "VLLM_TORCH_COMPILE_DUMP": + lambda: list(os.environ.get("VLLM_TORCH_COMPILE_DUMP", "").split(",")), + # API key for VLLM API server "VLLM_API_KEY": lambda: os.environ.get("VLLM_API_KEY", None), From e2c9ef09cfa88f94af5a24826d437da519ffe080 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 31 Oct 2024 01:19:10 +0000 Subject: [PATCH 26/72] rebase Signed-off-by: Bill Nell --- vllm/compilation/backends.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 61c8273dd7a3c..0cd9bcb2a88d9 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -339,6 +339,7 @@ def async_rewrite(graph: fx.Graph): return graph +collective_fusion_pass = CollectiveFusionPass() def wrap_inductor(graph, example_inputs, From 81465d2696fc5b7cdd8641000cf626f1bf3c4819 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 31 Oct 2024 14:07:59 +0000 Subject: [PATCH 27/72] cleanups Signed-off-by: Bill Nell --- vllm/compilation/collective_fusion.py | 147 +++++++++++--------------- vllm/model_executor/layers/linear.py | 5 + vllm/model_executor/models/llama.py | 13 +-- 3 files changed, 74 insertions(+), 91 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 99319d8c6dca2..9b6846877fb35 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -1,5 +1,5 @@ import operator -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union, Iterable import torch import torch.fx as fx @@ -10,41 +10,44 @@ from vllm.compilation.inductor_pass import InductorPass from vllm.distributed.parallel_state import get_tp_group, get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank from vllm.distributed import tensor_model_parallel_all_reduce -from vllm.model_executor.layers.linear import should_slice +from vllm.model_executor.layers.linear import should_slice, slice_residual from vllm.logger import init_logger logger = init_logger(__name__) -def pprint(x): - #print(x) - pass +def match_gemm_rs_ag_gemm( + residual: torch.Tensor, + gemm_1_weights: torch.Tensor, + gemm_1_activations: torch.Tensor, + rms_norm_weight: torch.Tensor, + gemm_2_weights: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + gemm_1_w_perm = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) + mm_1 = torch.ops.aten.mm.default(gemm_1_activations, gemm_1_w_perm) + all_reduce = torch.ops.higher_order.auto_functionalized( + torch.ops.vllm.inplace_all_reduce.default, + tensor = mm_1, + group_name = 'tp:0' # how to deal with groupname? + ) + all_reduce = all_reduce[1] -def match_gemm_rs_ag_gemm(residual, - gemm_1_weights, - gemm_1_activations, - rms_norm_weight, - gemm_2_weights, - ): - permute_2 = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) - mm_1 = torch.ops.aten.mm.default(gemm_1_activations, permute_2) - auto_functionalized_4 = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.inplace_all_reduce.default, tensor = mm_1, group_name = 'tp:0') # how to deal with groupname? - getitem_25 = auto_functionalized_4[1] - auto_functionalized_5 = torch.ops.higher_order.auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input = getitem_25, residual = residual, weight = rms_norm_weight, epsilon = 1e-05) - getitem_27 = auto_functionalized_5[1] - getitem_28 = auto_functionalized_5[2] # new residual - permute_3 = torch.ops.aten.permute.default(gemm_2_weights, [1, 0]) - mm_2 = torch.ops.aten.mm.default(getitem_27, permute_3) - return mm_2, getitem_28 + norm_res = torch.ops.higher_order.auto_functionalized( + torch.ops._C.fused_add_rms_norm.default, + input = all_reduce, + residual = residual, + weight = rms_norm_weight, + epsilon = 1e-05 + ) + normalized = norm_res[1] + new_residual = norm_res[2] + gemm_2_w_perm = torch.ops.aten.permute.default(gemm_2_weights, [1, 0]) + mm_2 = torch.ops.aten.mm.default(normalized, gemm_2_w_perm) -def slices(residual) -> List[torch.Tensor]: - n_slices = get_tensor_model_parallel_world_size() - residual_slices = torch.chunk(residual, n_slices, dim=0) - #pprint(f"SLICES {[r.shape for r in residual_slices]}") - return residual_slices + return mm_2, new_residual @torch.library.custom_op("vllm::gemm_rs_ag_gemm", mutates_args=()) @@ -55,31 +58,20 @@ def gemm_rs_ag_gemm(residual: torch.Tensor, rms_norm_weight: torch.Tensor, gemm_2_weights: torch.Tensor, first_layer: bool) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - pprint(f"CUSTOM {residual.shape}({my_residual.shape}), should_slice={should_slice(residual.shape)}, first={first_layer}") - - #### - # this is terrible - res_slices = slices(residual) - slice_size = res_slices[get_tensor_model_parallel_rank()].shape[0] - #### - pprint(f"SLICE_SIZE = {slice_size}, orig_shape={residual.shape}, slice_shapes=[{[x.shape for x in res_slices]}]") if should_slice(residual.shape) and first_layer: - pprint(f"FIRST! rank={get_tensor_model_parallel_rank()}") + res_slices = slice_residual(residual) + slice_size = res_slices[get_tensor_model_parallel_rank()].shape[0] split_1 = torch.ops.aten.split.Tensor(residual, slice_size) getitem_26 = split_1[0]; split_1 = None else: - #getitem_26 = my_residual getitem_26 = residual slice_size = residual.shape[0] if not should_slice(residual.shape): - # this branch probably broken - pprint("NAIVE") permute_3 = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) output = torch.matmul(gemm_1_activations, permute_3) - - output = tensor_model_parallel_all_reduce(output) ### + output = tensor_model_parallel_all_reduce(output) auto_functionalized_4 = torch.ops.higher_order.auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input=output, residual=getitem_26, weight=rms_norm_weight, epsilon=1e-05) getitem_29 = auto_functionalized_4[1] @@ -88,7 +80,6 @@ def gemm_rs_ag_gemm(residual: torch.Tensor, permute_5 = torch.ops.aten.permute.default(gemm_2_weights, [1, 0]) getitem_35 = torch.matmul(getitem_29, permute_5) getitem_30a = getitem_30.clone() - pprint(f"DONE CUSTOM NAIVE {getitem_35.shape}, {getitem_30.shape}, {getitem_30a.shape} {first_layer}") return getitem_35, getitem_30, getitem_30a else: group_name = torch.distributed.group.WORLD.group_name # TODO: factor out to setup @@ -108,9 +99,8 @@ def gemm_rs_ag_gemm(residual: torch.Tensor, getitem_34 = fused_all_gather_matmul[1] getitem_35 = getitem_34[0] - pprint(f"DONE CUSTOM {getitem_35.shape}, {getitem_31.shape}, {slice_scatter_2.shape} {first_layer}") # TODO: can we avoid clone here? - return getitem_35, getitem_31.clone(), slice_scatter_2 # check if clones are needed + return getitem_35, getitem_31.clone(), slice_scatter_2 @torch.library.register_fake("vllm::gemm_rs_ag_gemm") @@ -122,61 +112,56 @@ def gemm_rs_ag_gemm_fake(residual: torch.Tensor, gemm_2_weights: torch.Tensor, first_layer: bool, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - ##### - # this is terrible - res_slices = slices(residual) - slice_size = res_slices[get_tensor_model_parallel_rank()].shape[0] # can we use rank 0 instead? - ##### -# if should_slice(residual.shape) and first_layer: if should_slice(gemm_1_activations.shape) and first_layer: - pprint(f"FIRST! rank={get_tensor_model_parallel_rank()}") + res_slices = slice_residual(residual) + slice_size = res_slices[get_tensor_model_parallel_rank()].shape[0] split_1 = torch.ops.aten.split.Tensor(residual, slice_size) - my_residual = split_1[0]; split_1 = None + my_residual = split_1[0] else: my_residual = residual # verify the type is always correct - mm_res = torch.empty((gemm_1_activations.shape[0], gemm_2_weights.shape[0]), device=gemm_1_activations.device, dtype=gemm_1_activations.dtype) - - pprint(f"DONE FAKE = {mm_res.shape}, {my_residual.shape}, {residual.shape} {first_layer}") + mm_res = torch.empty( + (gemm_1_activations.shape[0], gemm_2_weights.shape[0]), + device=gemm_1_activations.device, + dtype=gemm_1_activations.dtype + ) return (mm_res, my_residual, residual) -# implementation doesn't matter, only needed for signature -def replace_gemm_rs_ag_gemm(residual, gemm_1_weights, gemm_1_activations, rms_norm_weight, gemm_2_weights): - results = torch.ops.vllm.gemm_rs_ag_gemm(residual, residual, gemm_1_weights, gemm_1_activations, rms_norm_weight, gemm_2_weights) - getitem_34 = results[0] - getitem_35 = results[1] - return getitem_34, getitem_35 - +def match_final(my_residual: torch.Tensor, + gemm_1_weights: torch.Tensor, + gemm_1_activations: torch.Tensor, + rms_norm_weights: torch.Tensor) -> torch.Tensor: + gemm_1_w_perm = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) + mm_1 = torch.ops.aten.mm.default(gemm_1_activations, gemm_1_w_perm) -def match_final(my_residual, gemm_1_weights, gemm_1_activations, rms_norm_weights): - permute_128 = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) - mm_127 = torch.ops.aten.mm.default(gemm_1_activations, permute_128) - - auto_functionalized_224 = torch.ops.higher_order.auto_functionalized( + all_reduce = torch.ops.higher_order.auto_functionalized( torch.ops.vllm.inplace_all_reduce.default, - tensor = mm_127, + tensor = mm_1, group_name = 'tp:0' # TODO: not same as group name ) - getitem_1024 = auto_functionalized_224[1] + all_reduce = all_reduce[1] - auto_functionalized_225 = torch.ops.higher_order.auto_functionalized( + norm_res = torch.ops.higher_order.auto_functionalized( torch.ops._C.fused_add_rms_norm.default, - input = getitem_1024, + input = all_reduce, residual = my_residual, weight = rms_norm_weights, epsilon = 1e-05 ) - getitem_1026 = auto_functionalized_225[1] + normalized = norm_res[1] - return getitem_1026 + return normalized # TODO: wrap in custom op to prevent infinite recursion in inductor logging statement? -def replace_final(my_residual, gemm_1_weights, gemm_1_activations, rms_norm_weights): +def replace_final(my_residual: torch.Tensor, + gemm_1_weights: torch.Tensor, + gemm_1_activations: torch.Tensor, + rms_norm_weights: torch.Tensor) -> torch.Tensor: tp_group_name = "tp:0" # f"tp:{group_name}" # TODO: not same as group name permute_254 = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) @@ -186,8 +171,8 @@ def replace_final(my_residual, gemm_1_weights, gemm_1_activations, rms_norm_weig # is this the right thing to call it on? if should_slice(gemm_1_activations.shape): - group_name = torch.distributed.group.WORLD.group_name # TODO: factor out - world_size = get_tp_group().world_size # TODO: factor out + group_name = torch.distributed.group.WORLD.group_name + world_size = get_tensor_model_parallel_world_size() all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor.default(my_residual, world_size, group_name) wait_tensor = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor) else: @@ -204,23 +189,21 @@ def replace_final(my_residual, gemm_1_weights, gemm_1_activations, rms_norm_weig return getitem_1219 - -# find the output and the residual -def find_fn(nodes, op): +def find_fn(nodes: Iterable[fx.Node], op): for node in reversed(nodes): if node.op == "call_function" and node.target == op: return node return None -def find_auto_fn(nodes, op): +def find_auto_fn(nodes: Iterable[fx.Node], op): for node in reversed(nodes): if node.op == "call_function" and node.target == auto_functionalized and node.args[0] == op: return node return None -def find_getitem(node, idx): +def find_getitem(node: Iterable[fx.Node], idx): for user in reversed(node.users): if user.op == "call_function" and user.target == operator.getitem and user.args[1] == idx: return user @@ -263,8 +246,6 @@ def record_match(self, match: Match) -> bool: return False def process_matches(self, graph: fx.Graph): - pprint(f"len = {len(self.matches)}") - nodes = list(graph.nodes) first_match = None @@ -316,7 +297,7 @@ def find_min_index(match) -> int: graph.eliminate_dead_code() assert all(node not in graph.nodes for match in matches for node in match.nodes) - def __call__(self, graph: torch.fx.Graph): + def __call__(self, graph: fx.Graph): self.dump_graph(graph, "before_collective_fusion") count = self.my_patterns.apply(graph) logger.info(f"fused gemm match count = {len(self.matches)}") diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 858358911dcc1..bba0ee855b030 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -254,6 +254,11 @@ def should_slice(shape) -> bool: return (shape[0] % n_slices == 0 and shape[0] >= 128) +def slice_residual(residual) -> List[torch.Tensor]: + n_slices = get_tensor_model_parallel_world_size() + return torch.chunk(residual, n_slices, dim=0) + + class MatmulRS(LinearMethodBase): #Fused Gemm-ReduceScatter without quantization. diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 5e4ffbd6b2fa7..339ede528586f 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -38,7 +38,8 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear, - should_slice) + should_slice, + slice_residual) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig @@ -290,15 +291,11 @@ def forward( pprint(f"RESIDUAL SHAPE = {residual.shape}") - def slices(residual) -> List[torch.Tensor]: + def slices(residual: torch.Tensor) -> List[torch.Tensor]: if not self.fuse_gemms or not should_slice(residual.shape): - pprint(f"SLICES TOO SMALL {[residual.shape]}") return [] - - n_slices = get_tensor_model_parallel_world_size() - residual_slices = torch.chunk(residual, n_slices, dim=0) - pprint(f"SLICES {[r.shape for r in residual_slices]}") - return residual_slices + else: + return slice_residual(residual) orig_residual_shape = residual.shape From 0dd9ca6f90320436984bec1556f3fd1d1617e8aa Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 31 Oct 2024 14:19:55 +0000 Subject: [PATCH 28/72] cleanups Signed-off-by: Bill Nell --- vllm/compilation/collective_fusion.py | 41 +++++---------------------- vllm/compilation/utils.py | 25 ++++++++++++++++ 2 files changed, 32 insertions(+), 34 deletions(-) create mode 100644 vllm/compilation/utils.py diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 9b6846877fb35..57bf95565e8e6 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -4,10 +4,10 @@ import torch import torch.fx as fx -from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._inductor.pattern_matcher import PatternMatcherPass, register_replacement, fwd_only, joint_fwd_bwd, Match from vllm.compilation.inductor_pass import InductorPass +from vllm.compilation.utils import find_fn, find_auto_fn, find_getitem from vllm.distributed.parallel_state import get_tp_group, get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank from vllm.distributed import tensor_model_parallel_all_reduce from vllm.model_executor.layers.linear import should_slice, slice_residual @@ -189,27 +189,6 @@ def replace_final(my_residual: torch.Tensor, return getitem_1219 -def find_fn(nodes: Iterable[fx.Node], op): - for node in reversed(nodes): - if node.op == "call_function" and node.target == op: - return node - return None - - -def find_auto_fn(nodes: Iterable[fx.Node], op): - for node in reversed(nodes): - if node.op == "call_function" and node.target == auto_functionalized and node.args[0] == op: - return node - return None - - -def find_getitem(node: Iterable[fx.Node], idx): - for user in reversed(node.users): - if user.op == "call_function" and user.target == operator.getitem and user.args[1] == idx: - return user - return None - - class CollectiveFusionPass(InductorPass): def __init__(self): self.my_patterns = PatternMatcherPass() @@ -275,20 +254,13 @@ def find_min_index(match) -> int: res_replacements.append(residual_node_new) my_res_replacements.append(my_residual_node_new) - rms_node = find_auto_fn(match.nodes, torch.ops._C.fused_add_rms_norm.default) - gemm_node = find_fn(match.nodes, torch.ops.aten.mm.default) - if gemm_node is None: - gemm_node = find_fn(match.nodes, torch.ops.symm_mem.fused_all_gather_matmul.default) + rms_node = find_auto_fn(reversed(match.nodes), torch.ops._C.fused_add_rms_norm.default) + gemm_node = find_fn(reversed(match.nodes), torch.ops.aten.mm.default) assert rms_node is not None assert gemm_node is not None - #assert len(rms_node.users) == 2 - #assert len(gemm_node.users) == 1 - - # meta["val"] is used by de-functionalization - rms_val = rms_node.meta["val"] - gemm_val = gemm_node.meta["val"] - fused_node.meta["val"] = (gemm_val, rms_val[2]) + assert len(rms_node.users) == 2 + assert len(gemm_node.users) == 1 or len(gemm_node.users) == 2 find_getitem(rms_node, 2).replace_all_uses_with(residual_node_new) gemm_node.replace_all_uses_with(result_node_new) @@ -302,7 +274,8 @@ def __call__(self, graph: fx.Graph): count = self.my_patterns.apply(graph) logger.info(f"fused gemm match count = {len(self.matches)}") - # Don't apply final pattern unless we've matched and replaced the gemm+collective ops. + # Don't apply final pattern unless we've matched and replaced the + # gemm+collective ops. if len(self.matches) > 0: count =self. my_patterns2.apply(graph) logger.info(f"final match count = {count}") diff --git a/vllm/compilation/utils.py b/vllm/compilation/utils.py new file mode 100644 index 0000000000000..189a8ec653684 --- /dev/null +++ b/vllm/compilation/utils.py @@ -0,0 +1,25 @@ +import operator +import torch +import torch.fx as fx +from torch._higher_order_ops.auto_functionalize import auto_functionalized +from typing import Iterable, Optional + +def find_fn(nodes: Iterable[fx.Node], op) -> Optional[fx.Node]: + for node in nodes: + if node.op == "call_function" and node.target == op: + return node + return None + + +def find_auto_fn(nodes: Iterable[fx.Node], op) -> Optional[fx.Node]: + for node in nodes: + if node.op == "call_function" and node.target == auto_functionalized and node.args[0] == op: + return node + return None + + +def find_getitem(node: Iterable[fx.Node], idx: int) -> Optional[fx.Node]: + for user in node.users: + if user.op == "call_function" and user.target == operator.getitem and user.args[1] == idx: + return user + return None From bb2f2d05a2bb7a29f1f4a4b123180301d4e6e04a Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 31 Oct 2024 14:31:51 +0000 Subject: [PATCH 29/72] cleanups Signed-off-by: Bill Nell --- vllm/compilation/collective_fusion.py | 6 +++--- vllm/compilation/utils.py | 10 ++++++++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 57bf95565e8e6..74b536379f745 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -7,7 +7,7 @@ from torch._inductor.pattern_matcher import PatternMatcherPass, register_replacement, fwd_only, joint_fwd_bwd, Match from vllm.compilation.inductor_pass import InductorPass -from vllm.compilation.utils import find_fn, find_auto_fn, find_getitem +from vllm.compilation.utils import find_fn, find_auto_fn, find_getitem, last_node_in_match from vllm.distributed.parallel_state import get_tp_group, get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank from vllm.distributed import tensor_model_parallel_all_reduce from vllm.model_executor.layers.linear import should_slice, slice_residual @@ -238,9 +238,9 @@ def find_min_index(match) -> int: my_res_replacements = [] for match in matches: - last_node_in_match = match.nodes[-1] #max(match.nodes, key=lambda x: nodes.index(x)) + last_node = last_node_in_match(match) - with graph.inserting_after(last_node_in_match): + with graph.inserting_after(last_node): kwargs = match.kwargs kwargs["first_layer"] = match == matches[0] kwargs["residual"] = res_replacements[-1] if len(res_replacements) > 0 else match.kwargs["residual"] diff --git a/vllm/compilation/utils.py b/vllm/compilation/utils.py index 189a8ec653684..8558d5ff04f29 100644 --- a/vllm/compilation/utils.py +++ b/vllm/compilation/utils.py @@ -2,6 +2,7 @@ import torch import torch.fx as fx from torch._higher_order_ops.auto_functionalize import auto_functionalized +from torch._inductor.pattern_matcher import Match from typing import Iterable, Optional def find_fn(nodes: Iterable[fx.Node], op) -> Optional[fx.Node]: @@ -23,3 +24,12 @@ def find_getitem(node: Iterable[fx.Node], idx: int) -> Optional[fx.Node]: if user.op == "call_function" and user.target == operator.getitem and user.args[1] == idx: return user return None + + +def last_node_in_match(match: Match) -> fx.Node: + if len(match.nodes) > 0: + graph = match.nodes[0].graph + for n in reversed(graph.nodes): + if n in reversed(match.nodes): + return n + raise ValueError("No nodes in graph") From c78ce79c56f84bd946949514da2c4d70b7b1b17b Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 31 Oct 2024 14:47:13 +0000 Subject: [PATCH 30/72] cleanups Signed-off-by: Bill Nell --- vllm/compilation/collective_fusion.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 74b536379f745..dbd72e2166e19 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -8,7 +8,7 @@ from vllm.compilation.inductor_pass import InductorPass from vllm.compilation.utils import find_fn, find_auto_fn, find_getitem, last_node_in_match -from vllm.distributed.parallel_state import get_tp_group, get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank +from vllm.distributed.parallel_state import get_tp_group, get_world_group, get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank from vllm.distributed import tensor_model_parallel_all_reduce from vllm.model_executor.layers.linear import should_slice, slice_residual @@ -16,6 +16,14 @@ logger = init_logger(__name__) +# TODO: factor out somehow +TP_GROUP_NAME = "tp:0" + + +# how to do this properly? +def get_world_name() -> str: + return torch.distributed.group.WORLD.group_name + def match_gemm_rs_ag_gemm( residual: torch.Tensor, @@ -30,7 +38,7 @@ def match_gemm_rs_ag_gemm( all_reduce = torch.ops.higher_order.auto_functionalized( torch.ops.vllm.inplace_all_reduce.default, tensor = mm_1, - group_name = 'tp:0' # how to deal with groupname? + group_name = TP_GROUP_NAME # how to deal with groupname? capture w/lambda ) all_reduce = all_reduce[1] @@ -82,7 +90,7 @@ def gemm_rs_ag_gemm(residual: torch.Tensor, getitem_30a = getitem_30.clone() return getitem_35, getitem_30, getitem_30a else: - group_name = torch.distributed.group.WORLD.group_name # TODO: factor out to setup + group_name = get_world_name() permute_3 = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) clone = torch.ops.aten.clone.default(permute_3, memory_format = torch.contiguous_format) output = torch.ops.symm_mem.fused_matmul_reduce_scatter.default(gemm_1_activations, clone, 'avg', 0, group_name) @@ -141,7 +149,7 @@ def match_final(my_residual: torch.Tensor, all_reduce = torch.ops.higher_order.auto_functionalized( torch.ops.vllm.inplace_all_reduce.default, tensor = mm_1, - group_name = 'tp:0' # TODO: not same as group name + group_name = TP_GROUP_NAME ) all_reduce = all_reduce[1] @@ -162,16 +170,14 @@ def replace_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, gemm_1_activations: torch.Tensor, rms_norm_weights: torch.Tensor) -> torch.Tensor: - tp_group_name = "tp:0" # f"tp:{group_name}" # TODO: not same as group name - permute_254 = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) mm_1 = torch.ops.aten.mm.default(gemm_1_activations, permute_254) - auto_functionalized_161 = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.inplace_all_reduce.default, tensor = mm_1, group_name = tp_group_name) + auto_functionalized_161 = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.inplace_all_reduce.default, tensor = mm_1, group_name = TP_GROUP_NAME) getitem_1217 = auto_functionalized_161[1] # is this the right thing to call it on? if should_slice(gemm_1_activations.shape): - group_name = torch.distributed.group.WORLD.group_name + group_name = get_world_name() world_size = get_tensor_model_parallel_world_size() all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor.default(my_residual, world_size, group_name) wait_tensor = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor) From 82fc8079a399c0b514dc782c4b1f7d945bae7b8a Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 31 Oct 2024 14:56:20 +0000 Subject: [PATCH 31/72] cleanups renames Signed-off-by: Bill Nell --- vllm/compilation/collective_fusion.py | 56 ++++++++++++++------------- 1 file changed, 30 insertions(+), 26 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index dbd72e2166e19..4fdf0adeab940 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -38,7 +38,7 @@ def match_gemm_rs_ag_gemm( all_reduce = torch.ops.higher_order.auto_functionalized( torch.ops.vllm.inplace_all_reduce.default, tensor = mm_1, - group_name = TP_GROUP_NAME # how to deal with groupname? capture w/lambda + group_name = TP_GROUP_NAME ) all_reduce = all_reduce[1] @@ -71,44 +71,48 @@ def gemm_rs_ag_gemm(residual: torch.Tensor, res_slices = slice_residual(residual) slice_size = res_slices[get_tensor_model_parallel_rank()].shape[0] split_1 = torch.ops.aten.split.Tensor(residual, slice_size) - getitem_26 = split_1[0]; split_1 = None + getitem_26 = split_1[0] else: getitem_26 = residual slice_size = residual.shape[0] + gemm_1_w_perm = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) + if not should_slice(residual.shape): - permute_3 = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) - output = torch.matmul(gemm_1_activations, permute_3) - output = tensor_model_parallel_all_reduce(output) - - auto_functionalized_4 = torch.ops.higher_order.auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input=output, residual=getitem_26, weight=rms_norm_weight, epsilon=1e-05) - getitem_29 = auto_functionalized_4[1] - getitem_30 = auto_functionalized_4[2] - - permute_5 = torch.ops.aten.permute.default(gemm_2_weights, [1, 0]) - getitem_35 = torch.matmul(getitem_29, permute_5) - getitem_30a = getitem_30.clone() - return getitem_35, getitem_30, getitem_30a + output = torch.matmul(gemm_1_activations, gemm_1_w_perm) + reduced_output = tensor_model_parallel_all_reduce(output) + + norm_res = torch.ops.higher_order.auto_functionalized( + torch.ops._C.fused_add_rms_norm.default, + input=reduced_output, + residual=getitem_26, + weight=rms_norm_weight, + epsilon=1e-05 + ) + normalized = norm_res[1] + new_residual = norm_res[2] + + gemm_2_w_perm = torch.ops.aten.permute.default(gemm_2_weights, [1, 0]) + mm_2 = torch.matmul(normalized, gemm_2_w_perm) + return mm_2, new_residual, new_residual.clone() else: group_name = get_world_name() - permute_3 = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) - clone = torch.ops.aten.clone.default(permute_3, memory_format = torch.contiguous_format) - output = torch.ops.symm_mem.fused_matmul_reduce_scatter.default(gemm_1_activations, clone, 'avg', 0, group_name) - auto_functionalized_4 = torch.ops.higher_order.auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input=output, residual=getitem_26, weight=rms_norm_weight, epsilon=1e-05) - getitem_29 = auto_functionalized_4[1] - getitem_30 = auto_functionalized_4[2] + output = torch.ops.symm_mem.fused_matmul_reduce_scatter.default(gemm_1_activations, gemm_1_w_perm, 'avg', 0, group_name) + + norm_res = torch.ops.higher_order.auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input=output, residual=getitem_26, weight=rms_norm_weight, epsilon=1e-05) + getitem_29 = norm_res[1] + getitem_30 = norm_res[2] residual_1 = residual if first_layer else my_residual slice_scatter_2 = torch.ops.aten.slice_scatter.default(residual_1, getitem_30, 0, 0, slice_size) split_2 = torch.ops.aten.split.Tensor(slice_scatter_2, slice_size) getitem_31 = split_2[0] - permute_5 = torch.ops.aten.permute.default(gemm_2_weights, [1, 0]) - clone_1 = torch.ops.aten.clone.default(permute_5, memory_format = torch.contiguous_format) - fused_all_gather_matmul = torch.ops.symm_mem.fused_all_gather_matmul.default(getitem_29, [clone_1], 0, group_name) - getitem_34 = fused_all_gather_matmul[1] - getitem_35 = getitem_34[0] + + gemm_2_w_perm = torch.ops.aten.permute.default(gemm_2_weights, [1, 0]) + fused_all_gather_matmul = torch.ops.symm_mem.fused_all_gather_matmul.default(getitem_29, [gemm_2_w_perm], 0, group_name) + mm_2 = fused_all_gather_matmul[1] # TODO: can we avoid clone here? - return getitem_35, getitem_31.clone(), slice_scatter_2 + return mm_2[0], getitem_31.clone(), slice_scatter_2 @torch.library.register_fake("vllm::gemm_rs_ag_gemm") From 2c0e799ab7b01973786f2788210ba298c1b49844 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 31 Oct 2024 15:19:10 +0000 Subject: [PATCH 32/72] fix formatting Signed-off-by: Bill Nell --- vllm/compilation/backends.py | 1 + vllm/compilation/collective_fusion.py | 211 ++++++++++++++------------ vllm/compilation/utils.py | 13 +- vllm/distributed/parallel_state.py | 3 +- vllm/model_executor/layers/linear.py | 45 ++---- vllm/model_executor/models/llama.py | 26 ++-- 6 files changed, 150 insertions(+), 149 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 0cd9bcb2a88d9..d1b68e810956b 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -13,6 +13,7 @@ from vllm.logger import init_logger from vllm.utils import weak_ref_tensors +from .collective_fusion import CollectiveFusionPass from .counter import compilation_counter from .inductor_pass import InductorPass from .pass_manager import PostGradPassManager diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 4fdf0adeab940..b06f86ecde266 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -1,18 +1,19 @@ import operator -from typing import Callable, Dict, List, Optional, Tuple, Union, Iterable +from typing import List, Tuple import torch import torch.fx as fx - -from torch._inductor.pattern_matcher import PatternMatcherPass, register_replacement, fwd_only, joint_fwd_bwd, Match +from torch._inductor.pattern_matcher import (Match, PatternMatcherPass, + fwd_only, register_replacement) from vllm.compilation.inductor_pass import InductorPass -from vllm.compilation.utils import find_fn, find_auto_fn, find_getitem, last_node_in_match -from vllm.distributed.parallel_state import get_tp_group, get_world_group, get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank +from vllm.compilation.utils import (find_auto_fn, find_fn, find_getitem, + last_node_in_match) from vllm.distributed import tensor_model_parallel_all_reduce -from vllm.model_executor.layers.linear import should_slice, slice_residual - +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.logger import init_logger +from vllm.model_executor.layers.linear import should_slice, slice_residual logger = init_logger(__name__) @@ -26,29 +27,27 @@ def get_world_name() -> str: def match_gemm_rs_ag_gemm( - residual: torch.Tensor, - gemm_1_weights: torch.Tensor, - gemm_1_activations: torch.Tensor, - rms_norm_weight: torch.Tensor, - gemm_2_weights: torch.Tensor, + residual: torch.Tensor, + gemm_1_weights: torch.Tensor, + gemm_1_activations: torch.Tensor, + rms_norm_weight: torch.Tensor, + gemm_2_weights: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: gemm_1_w_perm = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) mm_1 = torch.ops.aten.mm.default(gemm_1_activations, gemm_1_w_perm) all_reduce = torch.ops.higher_order.auto_functionalized( torch.ops.vllm.inplace_all_reduce.default, - tensor = mm_1, - group_name = TP_GROUP_NAME - ) + tensor=mm_1, + group_name=TP_GROUP_NAME) all_reduce = all_reduce[1] norm_res = torch.ops.higher_order.auto_functionalized( torch.ops._C.fused_add_rms_norm.default, - input = all_reduce, - residual = residual, - weight = rms_norm_weight, - epsilon = 1e-05 - ) + input=all_reduce, + residual=residual, + weight=rms_norm_weight, + epsilon=1e-05) normalized = norm_res[1] new_residual = norm_res[2] @@ -59,21 +58,19 @@ def match_gemm_rs_ag_gemm( @torch.library.custom_op("vllm::gemm_rs_ag_gemm", mutates_args=()) -def gemm_rs_ag_gemm(residual: torch.Tensor, - my_residual: torch.Tensor, - gemm_1_weights: torch.Tensor, - gemm_1_activations: torch.Tensor, - rms_norm_weight: torch.Tensor, - gemm_2_weights: torch.Tensor, - first_layer: bool) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +def gemm_rs_ag_gemm( + residual: torch.Tensor, old_my_residual: torch.Tensor, + gemm_1_weights: torch.Tensor, gemm_1_activations: torch.Tensor, + rms_norm_weight: torch.Tensor, gemm_2_weights: torch.Tensor, + first_layer: bool) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if should_slice(residual.shape) and first_layer: res_slices = slice_residual(residual) slice_size = res_slices[get_tensor_model_parallel_rank()].shape[0] - split_1 = torch.ops.aten.split.Tensor(residual, slice_size) - getitem_26 = split_1[0] + residual_chunk = torch.ops.aten.split.Tensor(residual, slice_size) + my_residual = residual_chunk[0] else: - getitem_26 = residual + my_residual = residual slice_size = residual.shape[0] gemm_1_w_perm = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) @@ -85,10 +82,9 @@ def gemm_rs_ag_gemm(residual: torch.Tensor, norm_res = torch.ops.higher_order.auto_functionalized( torch.ops._C.fused_add_rms_norm.default, input=reduced_output, - residual=getitem_26, + residual=my_residual, weight=rms_norm_weight, - epsilon=1e-05 - ) + epsilon=1e-05) normalized = norm_res[1] new_residual = norm_res[2] @@ -97,18 +93,28 @@ def gemm_rs_ag_gemm(residual: torch.Tensor, return mm_2, new_residual, new_residual.clone() else: group_name = get_world_name() - output = torch.ops.symm_mem.fused_matmul_reduce_scatter.default(gemm_1_activations, gemm_1_w_perm, 'avg', 0, group_name) + output = torch.ops.symm_mem.fused_matmul_reduce_scatter.default( + gemm_1_activations, gemm_1_w_perm, 'avg', 0, group_name) + + norm_res = torch.ops.higher_order.auto_functionalized( + torch.ops._C.fused_add_rms_norm.default, + input=output, + residual=my_residual, + weight=rms_norm_weight, + epsilon=1e-05) + normalized = norm_res[1] + new_residual = norm_res[2] - norm_res = torch.ops.higher_order.auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input=output, residual=getitem_26, weight=rms_norm_weight, epsilon=1e-05) - getitem_29 = norm_res[1] - getitem_30 = norm_res[2] - residual_1 = residual if first_layer else my_residual - slice_scatter_2 = torch.ops.aten.slice_scatter.default(residual_1, getitem_30, 0, 0, slice_size) + residual_1 = residual if first_layer else old_my_residual + slice_scatter_2 = torch.ops.aten.slice_scatter.default( + residual_1, new_residual, 0, 0, slice_size) split_2 = torch.ops.aten.split.Tensor(slice_scatter_2, slice_size) getitem_31 = split_2[0] gemm_2_w_perm = torch.ops.aten.permute.default(gemm_2_weights, [1, 0]) - fused_all_gather_matmul = torch.ops.symm_mem.fused_all_gather_matmul.default(getitem_29, [gemm_2_w_perm], 0, group_name) + fused_all_gather_matmul = ( + torch.ops.symm_mem.fused_all_gather_matmul.default( + normalized, [gemm_2_w_perm], 0, group_name)) mm_2 = fused_all_gather_matmul[1] # TODO: can we avoid clone here? @@ -116,14 +122,15 @@ def gemm_rs_ag_gemm(residual: torch.Tensor, @torch.library.register_fake("vllm::gemm_rs_ag_gemm") -def gemm_rs_ag_gemm_fake(residual: torch.Tensor, - my_residual: torch.Tensor, - gemm_1_weights: torch.Tensor, - gemm_1_activations: torch.Tensor, - rms_norm_weight: torch.Tensor, - gemm_2_weights: torch.Tensor, - first_layer: bool, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +def gemm_rs_ag_gemm_fake( + residual: torch.Tensor, + my_residual: torch.Tensor, + gemm_1_weights: torch.Tensor, + gemm_1_activations: torch.Tensor, + rms_norm_weight: torch.Tensor, + gemm_2_weights: torch.Tensor, + first_layer: bool, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if should_slice(gemm_1_activations.shape) and first_layer: res_slices = slice_residual(residual) @@ -137,14 +144,12 @@ def gemm_rs_ag_gemm_fake(residual: torch.Tensor, mm_res = torch.empty( (gemm_1_activations.shape[0], gemm_2_weights.shape[0]), device=gemm_1_activations.device, - dtype=gemm_1_activations.dtype - ) + dtype=gemm_1_activations.dtype) return (mm_res, my_residual, residual) -def match_final(my_residual: torch.Tensor, - gemm_1_weights: torch.Tensor, +def match_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, gemm_1_activations: torch.Tensor, rms_norm_weights: torch.Tensor) -> torch.Tensor: gemm_1_w_perm = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) @@ -152,79 +157,79 @@ def match_final(my_residual: torch.Tensor, all_reduce = torch.ops.higher_order.auto_functionalized( torch.ops.vllm.inplace_all_reduce.default, - tensor = mm_1, - group_name = TP_GROUP_NAME - ) + tensor=mm_1, + group_name=TP_GROUP_NAME) all_reduce = all_reduce[1] norm_res = torch.ops.higher_order.auto_functionalized( torch.ops._C.fused_add_rms_norm.default, - input = all_reduce, - residual = my_residual, - weight = rms_norm_weights, - epsilon = 1e-05 - ) + input=all_reduce, + residual=my_residual, + weight=rms_norm_weights, + epsilon=1e-05) normalized = norm_res[1] return normalized -# TODO: wrap in custom op to prevent infinite recursion in inductor logging statement? -def replace_final(my_residual: torch.Tensor, - gemm_1_weights: torch.Tensor, +# TODO: wrap in custom op to prevent infinite recursion in inductor logging +# statement? +def replace_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, gemm_1_activations: torch.Tensor, rms_norm_weights: torch.Tensor) -> torch.Tensor: permute_254 = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) mm_1 = torch.ops.aten.mm.default(gemm_1_activations, permute_254) - auto_functionalized_161 = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.inplace_all_reduce.default, tensor = mm_1, group_name = TP_GROUP_NAME) + auto_functionalized_161 = torch.ops.higher_order.auto_functionalized( + torch.ops.vllm.inplace_all_reduce.default, + tensor=mm_1, + group_name=TP_GROUP_NAME) getitem_1217 = auto_functionalized_161[1] # is this the right thing to call it on? if should_slice(gemm_1_activations.shape): group_name = get_world_name() world_size = get_tensor_model_parallel_world_size() - all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor.default(my_residual, world_size, group_name) - wait_tensor = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor) + all_gather = torch.ops._c10d_functional.all_gather_into_tensor.default( + my_residual, world_size, group_name) + wait_tensor = torch.ops._c10d_functional.wait_tensor.default( + all_gather) else: wait_tensor = my_residual auto_functionalized_162 = torch.ops.higher_order.auto_functionalized( torch.ops._C.fused_add_rms_norm.default, - input = getitem_1217, - residual = wait_tensor, - weight = rms_norm_weights, - epsilon = 1e-05) + input=getitem_1217, + residual=wait_tensor, + weight=rms_norm_weights, + epsilon=1e-05) getitem_1219 = auto_functionalized_162[1] return getitem_1219 class CollectiveFusionPass(InductorPass): + def __init__(self): self.my_patterns = PatternMatcherPass() self.my_patterns2 = PatternMatcherPass() self.matches: List[Match] = [] - x = torch.empty([4,4], device='cuda') - w = torch.empty([4,4], device='cuda') - resid = torch.empty([4,4], device='cuda') - resid_w = torch.empty([4,4], device='cuda') - x2 = torch.empty([4,4], device='cuda') + x = torch.empty([4, 4], device='cuda') + w = torch.empty([4, 4], device='cuda') + resid = torch.empty([4, 4], device='cuda') + resid_w = torch.empty([4, 4], device='cuda') + x2 = torch.empty([4, 4], device='cuda') inputs = [resid, x, w, resid_w, x2] register_replacement(match_gemm_rs_ag_gemm, match_gemm_rs_ag_gemm, inputs, - fwd_only, - [self.my_patterns], + fwd_only, [self.my_patterns], extra_check=lambda m: self.record_match(m)) final_inputs = [x, w, resid, resid_w] - register_replacement(match_final, - replace_final, - final_inputs, - fwd_only, - [self.my_patterns2]) + register_replacement(match_final, replace_final, final_inputs, + fwd_only, [self.my_patterns2]) def record_match(self, match: Match) -> bool: # Hijack the extra_check to record the match and @@ -236,7 +241,6 @@ def record_match(self, match: Match) -> bool: def process_matches(self, graph: fx.Graph): nodes = list(graph.nodes) - first_match = None def find_min_index(match) -> int: return min(match.nodes, key=lambda x: nodes.index(x)) @@ -244,8 +248,8 @@ def find_min_index(match) -> int: # "sort" matches in topo order. matches = sorted(self.matches, key=lambda x: find_min_index(x)) - res_replacements = [] - my_res_replacements = [] + res_replacements: List[fx.Node] = [] + my_res_replacements: List[fx.Node] = [] for match in matches: last_node = last_node_in_match(match) @@ -253,42 +257,53 @@ def find_min_index(match) -> int: with graph.inserting_after(last_node): kwargs = match.kwargs kwargs["first_layer"] = match == matches[0] - kwargs["residual"] = res_replacements[-1] if len(res_replacements) > 0 else match.kwargs["residual"] - kwargs["my_residual"] = my_res_replacements[-1] if len(my_res_replacements) > 0 else match.kwargs["residual"] - fused_node = graph.call_function(torch.ops.vllm.gemm_rs_ag_gemm.default, kwargs=kwargs) + kwargs["residual"] = res_replacements[-1] if len( + res_replacements) > 0 else match.kwargs["residual"] + kwargs["old_my_residual"] = my_res_replacements[-1] if len( + my_res_replacements) > 0 else match.kwargs["residual"] + fused_node = graph.call_function( + torch.ops.vllm.gemm_rs_ag_gemm.default, kwargs=kwargs) graph.inserting_after(fused_node) - result_node_new = graph.call_function(operator.getitem, (fused_node, 0)) - residual_node_new = graph.call_function(operator.getitem, (fused_node, 1)) - my_residual_node_new = graph.call_function(operator.getitem, (fused_node, 2)) + result_node_new = graph.call_function(operator.getitem, + (fused_node, 0)) + residual_node_new = graph.call_function( + operator.getitem, (fused_node, 1)) + my_residual_node_new = graph.call_function( + operator.getitem, (fused_node, 2)) res_replacements.append(residual_node_new) my_res_replacements.append(my_residual_node_new) - rms_node = find_auto_fn(reversed(match.nodes), torch.ops._C.fused_add_rms_norm.default) - gemm_node = find_fn(reversed(match.nodes), torch.ops.aten.mm.default) + rms_node = find_auto_fn(reversed(match.nodes), + torch.ops._C.fused_add_rms_norm.default) + gemm_node = find_fn(reversed(match.nodes), + torch.ops.aten.mm.default) assert rms_node is not None assert gemm_node is not None assert len(rms_node.users) == 2 assert len(gemm_node.users) == 1 or len(gemm_node.users) == 2 - find_getitem(rms_node, 2).replace_all_uses_with(residual_node_new) + residual_getter_node = find_getitem(rms_node, 2) + assert residual_getter_node is not None + residual_getter_node.replace_all_uses_with(residual_node_new) gemm_node.replace_all_uses_with(result_node_new) # Finally, remove matched nodes graph.eliminate_dead_code() - assert all(node not in graph.nodes for match in matches for node in match.nodes) + assert all(node not in graph.nodes for match in matches + for node in match.nodes) def __call__(self, graph: fx.Graph): self.dump_graph(graph, "before_collective_fusion") count = self.my_patterns.apply(graph) - logger.info(f"fused gemm match count = {len(self.matches)}") + logger.info("fused gemm match count = %d", len(self.matches)) # Don't apply final pattern unless we've matched and replaced the # gemm+collective ops. if len(self.matches) > 0: - count =self. my_patterns2.apply(graph) - logger.info(f"final match count = {count}") + count = self.my_patterns2.apply(graph) + logger.info("final match count = %d", count) self.process_matches(graph) self.dump_graph(graph, "after_collective_fusion") diff --git a/vllm/compilation/utils.py b/vllm/compilation/utils.py index 8558d5ff04f29..12a84db22e877 100644 --- a/vllm/compilation/utils.py +++ b/vllm/compilation/utils.py @@ -1,9 +1,10 @@ import operator -import torch +from typing import Iterable, Optional + import torch.fx as fx from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._inductor.pattern_matcher import Match -from typing import Iterable, Optional + def find_fn(nodes: Iterable[fx.Node], op) -> Optional[fx.Node]: for node in nodes: @@ -14,14 +15,16 @@ def find_fn(nodes: Iterable[fx.Node], op) -> Optional[fx.Node]: def find_auto_fn(nodes: Iterable[fx.Node], op) -> Optional[fx.Node]: for node in nodes: - if node.op == "call_function" and node.target == auto_functionalized and node.args[0] == op: + if (node.op == "call_function" and node.target == auto_functionalized + and node.args[0] == op): return node return None -def find_getitem(node: Iterable[fx.Node], idx: int) -> Optional[fx.Node]: +def find_getitem(node: fx.Node, idx: int) -> Optional[fx.Node]: for user in node.users: - if user.op == "call_function" and user.target == operator.getitem and user.args[1] == idx: + if (user.op == "call_function" and user.target == operator.getitem + and user.args[1] == idx): return user return None diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index e7aac04f42bf8..30a946f91d99b 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -206,7 +206,8 @@ def __init__( self.use_xpu_communicator = use_xpu_communicator # Initialize pynvshmem - if has_flux and torch.distributed.get_world_size(self.device_group) > 1: + if has_flux and torch.distributed.get_world_size( + self.device_group) > 1: flux.init_flux_shm(self.device_group) # lazy import to avoid documentation build error diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index bba0ee855b030..d8574a9308dc3 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -10,7 +10,6 @@ import torch import torch.nn.functional as F -import torch.distributed as D from torch.nn.parameter import Parameter, UninitializedParameter from vllm.distributed import (divide, get_tensor_model_parallel_rank, @@ -18,7 +17,7 @@ split_tensor_along_last_dim, tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) -from vllm.distributed.parallel_state import get_tp_group, get_world_group +from vllm.distributed.parallel_state import get_tp_group from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) @@ -41,11 +40,6 @@ ] -def pprint(x): - #print(x) - pass - - def adjust_marlin_shard(param, shard_size, shard_offset): marlin_tile_size = getattr(param, "marlin_tile_size", None) if marlin_tile_size is None: @@ -264,6 +258,7 @@ class MatmulRS(LinearMethodBase): def __init__(self, separate_bias_add: bool = False): self.separate_bias_add = separate_bias_add + self.group_name = torch.distributed.group.WORLD.group_name def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, @@ -277,7 +272,6 @@ def create_weights(self, layer: torch.nn.Module, set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) layer.register_parameter("weight", weight) set_weight_attrs(weight, extra_weight_attrs) - pprint(f"inpp={input_size_per_partition}, output_part_siz={output_partition_sizes}, input_size={input_size}, output_size={output_size}") def apply(self, layer: torch.nn.Module, @@ -285,26 +279,17 @@ def apply(self, bias: Optional[torch.Tensor] = None) -> torch.Tensor: assert bias is None - pprint(f"MATMUL_RS {get_tp_group().rank} {x.shape}, {layer.weight.transpose(1,0).shape}") - if not should_slice(x.shape): - pprint("MATMUL_RS naive") output = torch.matmul(x, layer.weight.transpose(1, 0)) - # total hack - output = tensor_model_parallel_all_reduce(output) + # This is a bit hacky + return tensor_model_parallel_all_reduce(output) else: - group_name = torch.distributed.group.WORLD.group_name # TODO: factor out to setup - output = torch.ops.symm_mem.fused_matmul_reduce_scatter( + return torch.ops.symm_mem.fused_matmul_reduce_scatter( x, layer.weight.transpose(1, 0).contiguous(), "avg", - scatter_dim=0, # ? - group_name=group_name - ) - - pprint(f"MATMUL_RS DONE {get_tp_group().rank} {output.shape}") - - return output + scatter_dim=0, + group_name=self.group_name) class AGMatmul(LinearMethodBase): @@ -312,6 +297,7 @@ class AGMatmul(LinearMethodBase): def __init__(self, separate_bias_add: bool = False): self.separate_bias_add = separate_bias_add + self.group_name = torch.distributed.group.WORLD.group_name def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, @@ -332,23 +318,16 @@ def apply(self, bias: Optional[torch.Tensor] = None) -> torch.Tensor: assert bias is None - pprint(f"AG_MATMUL {get_tp_group().rank}, {x.shape}, {layer.weight.transpose(1,0).shape}") - if not should_slice(x.shape): - output = torch.matmul(x, layer.weight.transpose(1,0)) + return torch.matmul(x, layer.weight.transpose(1, 0)) else: - group_name = torch.distributed.group.WORLD.group_name ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_matmul( x, - [layer.weight.transpose(1,0).contiguous()], + [layer.weight.transpose(1, 0).contiguous()], gather_dim=0, - group_name=group_name, + group_name=self.group_name, ) - output = mm_outputs[0] - - pprint(f"AG_MATMUL DONE {get_tp_group().rank}, {output.shape}") - - return output + return mm_outputs[0] class LinearBase(torch.nn.Module): diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 339ede528586f..42a829e9aabc7 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -22,7 +22,6 @@ """Inference-only LLaMA model compatible with HuggingFace weights.""" from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union -import os import torch from torch import nn from transformers import LlamaConfig @@ -32,13 +31,12 @@ from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, - tensor_model_parallel_all_gather, get_tp_group) + get_tp_group, tensor_model_parallel_all_gather) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, - RowParallelLinear, - should_slice, + RowParallelLinear, should_slice, slice_residual) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.pooler import Pooler, PoolingType @@ -79,7 +77,7 @@ def __init__( bias: bool = False, prefix: str = "", last_layer: bool = False, - fuse_gemms = True, + fuse_gemms=True, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( @@ -221,7 +219,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", - fuse_gemms = True, + fuse_gemms=True, ) -> None: super().__init__() self.fuse_gemms = fuse_gemms @@ -312,13 +310,17 @@ def slices(residual: torch.Tensor) -> List[torch.Tensor]: attn_metadata=attn_metadata) # Fully Connected - assert (hidden_states.shape == my_residual.shape), f"{hidden_states.shape} != {my_residual.shape}" + assert (hidden_states.shape == my_residual.shape + ), f"{hidden_states.shape} != {my_residual.shape}" hidden_states, my_residual = self.post_attention_layernorm( hidden_states, my_residual) hidden_states = self.mlp(hidden_states) - pprint(f"LAST_LAYER = {self.last_layer}, #slices = {len(residual_slices)}") - if self.fuse_gemms and self.last_layer and should_slice(orig_residual_shape): + pprint( + f"LAST_LAYER = {self.last_layer}, #slices = {len(residual_slices)}" + ) + if self.fuse_gemms and self.last_layer and should_slice( + orig_residual_shape): pprint(f"FINAL REDUCE {my_residual.shape}") if False: residual = tensor_model_parallel_all_gather(my_residual, 0) @@ -326,15 +328,15 @@ def slices(residual: torch.Tensor) -> List[torch.Tensor]: residual = torch.ops._c10d_functional.all_gather_into_tensor( my_residual.contiguous(), get_tp_group().world_size, - torch.distributed.group.WORLD.group_name - ) + torch.distributed.group.WORLD.group_name) residual = torch.ops._c10d_functional.wait_tensor(residual) pprint(f"GOT HERE2 {my_residual.shape}, {residual.shape}") else: residual = my_residual - assert (hidden_states.shape == residual.shape), f"{hidden_states.shape} != {residual.shape}" + assert (hidden_states.shape == residual.shape + ), f"{hidden_states.shape} != {residual.shape}" return hidden_states, residual From 0dc3c046c41b7c834c61b8384d684fb7027af4b7 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 31 Oct 2024 15:34:15 +0000 Subject: [PATCH 33/72] cleanups Signed-off-by: Bill Nell --- vllm/compilation/collective_fusion.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index b06f86ecde266..d91de33b82517 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -106,10 +106,10 @@ def gemm_rs_ag_gemm( new_residual = norm_res[2] residual_1 = residual if first_layer else old_my_residual - slice_scatter_2 = torch.ops.aten.slice_scatter.default( + slice_scatter = torch.ops.aten.slice_scatter.default( residual_1, new_residual, 0, 0, slice_size) - split_2 = torch.ops.aten.split.Tensor(slice_scatter_2, slice_size) - getitem_31 = split_2[0] + split_2 = torch.ops.aten.split.Tensor(slice_scatter, slice_size) + new_residual = split_2[0] gemm_2_w_perm = torch.ops.aten.permute.default(gemm_2_weights, [1, 0]) fused_all_gather_matmul = ( @@ -118,7 +118,7 @@ def gemm_rs_ag_gemm( mm_2 = fused_all_gather_matmul[1] # TODO: can we avoid clone here? - return mm_2[0], getitem_31.clone(), slice_scatter_2 + return mm_2[0], new_residual.clone(), slice_scatter @torch.library.register_fake("vllm::gemm_rs_ag_gemm") From 59d2100761696b35e02148aab68d165d038c756c Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 31 Oct 2024 16:36:05 +0000 Subject: [PATCH 34/72] revert some hacks Signed-off-by: Bill Nell --- vllm/compilation/backends.py | 4 ++++ vllm/compilation/collective_fusion.py | 29 +++++++++++++-------------- vllm/distributed/parallel_state.py | 1 - 3 files changed, 18 insertions(+), 16 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index d1b68e810956b..5305a38ee1164 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -360,7 +360,11 @@ def wrap_inductor(graph, logger.info("Compiling a graph for shape %s", runtime_shape) from torch._inductor import config + torch._inductor.config._micro_pipeline_tp = True + # Set to False to avoid infinite recursion logging + torch._inductor.config.implicit_fallbacks = True + current_config = config.shallow_copy_dict() from torch._inductor.compile_fx import compile_fx diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index d91de33b82517..3e1ca5f03fee7 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -172,8 +172,6 @@ def match_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, return normalized -# TODO: wrap in custom op to prevent infinite recursion in inductor logging -# statement? def replace_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, gemm_1_activations: torch.Tensor, rms_norm_weights: torch.Tensor) -> torch.Tensor: @@ -183,7 +181,7 @@ def replace_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, torch.ops.vllm.inplace_all_reduce.default, tensor=mm_1, group_name=TP_GROUP_NAME) - getitem_1217 = auto_functionalized_161[1] + reduced = auto_functionalized_161[1] # is this the right thing to call it on? if should_slice(gemm_1_activations.shape): @@ -196,22 +194,21 @@ def replace_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, else: wait_tensor = my_residual - auto_functionalized_162 = torch.ops.higher_order.auto_functionalized( + norm_res = torch.ops.higher_order.auto_functionalized( torch.ops._C.fused_add_rms_norm.default, - input=getitem_1217, + input=reduced, residual=wait_tensor, weight=rms_norm_weights, epsilon=1e-05) - getitem_1219 = auto_functionalized_162[1] - return getitem_1219 + return norm_res[1] class CollectiveFusionPass(InductorPass): def __init__(self): - self.my_patterns = PatternMatcherPass() - self.my_patterns2 = PatternMatcherPass() + self.gemm_rs_ag_gemm_pattern = PatternMatcherPass() + self.final_pattern = PatternMatcherPass() self.matches: List[Match] = [] x = torch.empty([4, 4], device='cuda') @@ -224,12 +221,14 @@ def __init__(self): register_replacement(match_gemm_rs_ag_gemm, match_gemm_rs_ag_gemm, inputs, - fwd_only, [self.my_patterns], + fwd_only, [self.gemm_rs_ag_gemm_pattern], extra_check=lambda m: self.record_match(m)) final_inputs = [x, w, resid, resid_w] - register_replacement(match_final, replace_final, final_inputs, - fwd_only, [self.my_patterns2]) + register_replacement(match_final, + #torch.ops.vllm.gemm_ag_final, + replace_final, + final_inputs, fwd_only, [self.final_pattern]) def record_match(self, match: Match) -> bool: # Hijack the extra_check to record the match and @@ -242,7 +241,7 @@ def record_match(self, match: Match) -> bool: def process_matches(self, graph: fx.Graph): nodes = list(graph.nodes) - def find_min_index(match) -> int: + def find_min_index(match: Match) -> int: return min(match.nodes, key=lambda x: nodes.index(x)) # "sort" matches in topo order. @@ -296,13 +295,13 @@ def find_min_index(match) -> int: def __call__(self, graph: fx.Graph): self.dump_graph(graph, "before_collective_fusion") - count = self.my_patterns.apply(graph) + count = self.gemm_rs_ag_gemm_pattern.apply(graph) logger.info("fused gemm match count = %d", len(self.matches)) # Don't apply final pattern unless we've matched and replaced the # gemm+collective ops. if len(self.matches) > 0: - count = self.my_patterns2.apply(graph) + count = self.final_pattern.apply(graph) logger.info("final match count = %d", count) self.process_matches(graph) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 30a946f91d99b..907449ed568e9 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -205,7 +205,6 @@ def __init__( self.use_hpu_communicator = use_hpu_communicator self.use_xpu_communicator = use_xpu_communicator - # Initialize pynvshmem if has_flux and torch.distributed.get_world_size( self.device_group) > 1: flux.init_flux_shm(self.device_group) From e0d72039bac198398b6589ceb0550a5630405fc6 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 31 Oct 2024 16:39:49 +0000 Subject: [PATCH 35/72] cleanups Signed-off-by: Bill Nell --- vllm/compilation/backends.py | 3 ++- vllm/compilation/collective_fusion.py | 11 +++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 5305a38ee1164..bb49ec30af806 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -340,7 +340,8 @@ def async_rewrite(graph: fx.Graph): return graph -collective_fusion_pass = CollectiveFusionPass() + +collective_fusion_pass: Optional[CollectiveFusionPass] = None def wrap_inductor(graph, example_inputs, diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 3e1ca5f03fee7..b9c7eeb75ca69 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -225,10 +225,13 @@ def __init__(self): extra_check=lambda m: self.record_match(m)) final_inputs = [x, w, resid, resid_w] - register_replacement(match_final, - #torch.ops.vllm.gemm_ag_final, - replace_final, - final_inputs, fwd_only, [self.final_pattern]) + register_replacement( + match_final, + #torch.ops.vllm.gemm_ag_final, + replace_final, + final_inputs, + fwd_only, + [self.final_pattern]) def record_match(self, match: Match) -> bool: # Hijack the extra_check to record the match and From 70c625083ada7e016e5075faa383074604d8a554 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 31 Oct 2024 18:36:55 +0000 Subject: [PATCH 36/72] back out llama model changes Signed-off-by: Bill Nell --- vllm/model_executor/models/llama.py | 85 +++-------------------------- 1 file changed, 7 insertions(+), 78 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 42a829e9aabc7..355b2f3ef8b28 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -30,14 +30,12 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - get_tp_group, tensor_model_parallel_all_gather) + get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, - RowParallelLinear, should_slice, - slice_residual) + RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig @@ -61,11 +59,6 @@ maybe_prefix) -def pprint(x): - #print(x) - pass - - class LlamaMLP(nn.Module): def __init__( @@ -76,8 +69,6 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, bias: bool = False, prefix: str = "", - last_layer: bool = False, - fuse_gemms=True, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( @@ -86,15 +77,13 @@ def __init__( bias=bias, quant_config=quant_config, prefix=f"{prefix}.gate_up_proj", - fuse_ag_gemm=fuse_gemms) - + ) self.down_proj = RowParallelLinear( input_size=intermediate_size, output_size=hidden_size, bias=bias, quant_config=quant_config, prefix=f"{prefix}.down_proj", - fuse_gemm_rs=(not last_layer) and fuse_gemms, ) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " @@ -116,7 +105,6 @@ def __init__( hidden_size: int, num_heads: int, num_kv_heads: int, - first_layer: bool, rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, @@ -124,7 +112,6 @@ def __init__( bias: bool = False, cache_config: Optional[CacheConfig] = None, prefix: str = "", - fuse_gemms=True, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -159,14 +146,14 @@ def __init__( bias=bias, quant_config=quant_config, prefix=f"{prefix}.qkv_proj", - fuse_ag_gemm=(not first_layer) and fuse_gemms) + ) + self.o_proj = RowParallelLinear( input_size=self.total_num_heads * self.head_dim, output_size=hidden_size, bias=bias, quant_config=quant_config, prefix=f"{prefix}.o_proj", - fuse_gemm_rs=fuse_gemms, ) is_neox_style = True @@ -211,18 +198,11 @@ class LlamaDecoderLayer(nn.Module): def __init__( self, config: LlamaConfig, - # Hack: pass in whether this is the first/last layer - # so we know if we can rewrite AllReduce -> ReduceScatter + AllGather, - # and then propagate the AllGather to the next layer. - first_layer: bool, - last_layer: bool, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", - fuse_gemms=True, ) -> None: super().__init__() - self.fuse_gemms = fuse_gemms self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) @@ -242,7 +222,6 @@ def __init__( num_heads=config.num_attention_heads, num_kv_heads=getattr(config, "num_key_value_heads", config.num_attention_heads), - first_layer=first_layer, rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, @@ -250,7 +229,6 @@ def __init__( bias=attention_bias, cache_config=cache_config, prefix=f"{prefix}.self_attn", - fuse_gemms=self.fuse_gemms, ) self.mlp = LlamaMLP( hidden_size=self.hidden_size, @@ -259,17 +237,12 @@ def __init__( quant_config=quant_config, bias=getattr(config, "mlp_bias", False), prefix=f"{prefix}.mlp", - last_layer=last_layer, - fuse_gemms=self.fuse_gemms, ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.first_layer = first_layer - self.last_layer = last_layer - def forward( self, positions: torch.Tensor, @@ -283,60 +256,17 @@ def forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - assert (hidden_states.shape == residual.shape) hidden_states, residual = self.input_layernorm( hidden_states, residual) - - pprint(f"RESIDUAL SHAPE = {residual.shape}") - - def slices(residual: torch.Tensor) -> List[torch.Tensor]: - if not self.fuse_gemms or not should_slice(residual.shape): - return [] - else: - return slice_residual(residual) - - orig_residual_shape = residual.shape - - # Partition residual - residual_slices = slices(residual) if self.first_layer else [] - if len(residual_slices) > 0: - my_residual = residual_slices[get_tensor_model_parallel_rank()] - else: - my_residual = residual - hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states, kv_cache=kv_cache, attn_metadata=attn_metadata) # Fully Connected - assert (hidden_states.shape == my_residual.shape - ), f"{hidden_states.shape} != {my_residual.shape}" - hidden_states, my_residual = self.post_attention_layernorm( - hidden_states, my_residual) + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) hidden_states = self.mlp(hidden_states) - - pprint( - f"LAST_LAYER = {self.last_layer}, #slices = {len(residual_slices)}" - ) - if self.fuse_gemms and self.last_layer and should_slice( - orig_residual_shape): - pprint(f"FINAL REDUCE {my_residual.shape}") - if False: - residual = tensor_model_parallel_all_gather(my_residual, 0) - else: - residual = torch.ops._c10d_functional.all_gather_into_tensor( - my_residual.contiguous(), - get_tp_group().world_size, - torch.distributed.group.WORLD.group_name) - residual = torch.ops._c10d_functional.wait_tensor(residual) - - pprint(f"GOT HERE2 {my_residual.shape}, {residual.shape}") - else: - residual = my_residual - - assert (hidden_states.shape == residual.shape - ), f"{hidden_states.shape} != {residual.shape}" return hidden_states, residual @@ -349,7 +279,6 @@ def __init__(self, prefix: str = "", layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer): super().__init__() - fuse_gemms = bool(os.environ.get("VLLM_FUSE_GEMMS", "0") == "1") config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config From 96d97560433cd3d0224142ca45dbee5c0071deca Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 31 Oct 2024 18:38:12 +0000 Subject: [PATCH 37/72] back out models/utils changes Signed-off-by: Bill Nell --- vllm/model_executor/models/utils.py | 20 +------------------- 1 file changed, 1 insertion(+), 19 deletions(-) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index b3ee6bc5d99c7..dcfd2cb7d2622 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -501,32 +501,14 @@ def make_layers( """Make a list of layers with the given layer function, taking pipeline parallelism into account. """ - import inspect - from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.utils import get_pp_indices start_layer, end_layer = get_pp_indices(num_hidden_layers, get_pp_group().rank_in_group, get_pp_group().world_size) - - # Determine if layer_fn accepts first/last args by inspecting its signature - sig = inspect.signature(layer_fn) - has_firstlast_args = ('first_layer' - in sig.parameters) and ('last_layer' - in sig.parameters) - - def make_one_layer(idx, start_layer, end_layer): - if has_firstlast_args: - return maybe_offload_to_cpu( - layer_fn(prefix=f"{prefix}.{idx}", - first_layer=(idx == start_layer), - last_layer=(idx == end_layer - 1))) - else: - return maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}")) - modules = torch.nn.ModuleList( [PPMissingLayer() for _ in range(start_layer)] + [ - make_one_layer(idx, start_layer, end_layer) + maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}")) for idx in range(start_layer, end_layer) ] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)]) return start_layer, end_layer, modules From 707df6a5345235746052315f70aeb75eedfa623c Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Thu, 31 Oct 2024 19:12:12 +0000 Subject: [PATCH 38/72] remove cruft Signed-off-by: Bill Nell --- flux_env.sh | 17 -- vllm/model_executor/layers/linear.py | 248 ++------------------------- 2 files changed, 13 insertions(+), 252 deletions(-) delete mode 100644 flux_env.sh diff --git a/flux_env.sh b/flux_env.sh deleted file mode 100644 index 8979ce0858d0c..0000000000000 --- a/flux_env.sh +++ /dev/null @@ -1,17 +0,0 @@ -#Point to the directory containing the flux .so files: -export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/nm-vllm/flux_experiment/lib - -export NVSHMEM_BOOTSTRAP_MPI_PLUGIN=nvshmem_bootstrap_torch.so - -# Env variables for symmetric heap allocation. -# These are needed for supporting CUDA_VISIBLE DEVICES -# This is big enough for llama3 8b, but should be set correctly -export NVSHMEM_SYMMETRIC_SIZE=$((8*1024**3)) -export NVSHMEM_DISABLE_CUDA_VMM=1 # moving from cpp to shell - -# Not sure if these are needed -export CUDA_DEVICE_MAX_CONNECTIONS=1 -export BYTED_TORCH_BYTECCL=O0 -export NCCL_IB_TIMEOUT=${NCCL_IB_TIMEOUT:=23} -export NCCL_IB_GID_INDEX=${NCCL_IB_GID_INDEX:=3} -export NVSHMEM_IB_GID_INDEX=3 diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index d8574a9308dc3..46ef11e7d02c6 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -2,12 +2,6 @@ from abc import abstractmethod from typing import Dict, List, Optional, Tuple -try: - import flux - has_flux = True -except ImportError: - has_flux = False - import torch import torch.nn.functional as F from torch.nn.parameter import Parameter, UninitializedParameter @@ -17,7 +11,6 @@ split_tensor_along_last_dim, tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) -from vllm.distributed.parallel_state import get_tp_group from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) @@ -144,192 +137,6 @@ def apply(self, return F.linear(x, layer.weight, bias) -class FluxGemmRS(LinearMethodBase): - #Fused Gemm-ReduceScatter without quantization. - - def __init__(self, separate_bias_add: bool = False): - self.separate_bias_add = separate_bias_add - - def create_weights(self, layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: List[int], input_size: int, - output_size: int, params_dtype: torch.dtype, - **extra_weight_attrs): - weight = Parameter(torch.empty(sum(output_partition_sizes), - input_size_per_partition, - dtype=params_dtype), - requires_grad=False) - set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) - layer.register_parameter("weight", weight) - set_weight_attrs(weight, extra_weight_attrs) - - self.gemm_rs_op = flux.GemmRS( - get_tp_group().device_group, - 1, # One node - 8192, # Max M. TODO: Pass in correctly. - output_size, # N - # TODO: Pass in input dtype correctly. - # TODO: It would be nicer to modify flux to dispatch based on dtype - # at run time, but I don't know what the downside would be. - # Similar comment for max m. - torch.float16, - # Note: transpose_weight=False means that B is transposed - transpose_weight=False, - # Note: bfloat16 requires fuse_reduction=False. - fuse_reduction=False, - ) - - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - assert bias is None - - output = self.gemm_rs_op.forward(x, layer.weight) - output = output.squeeze(0) - - return output - - -class FluxAGCook(LinearMethodBase): - #Fused AllGather-Gemm without quantization. - - def __init__(self, separate_bias_add: bool = False): - self.separate_bias_add = separate_bias_add - - def create_weights(self, layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: List[int], input_size: int, - output_size: int, params_dtype: torch.dtype, - **extra_weight_attrs): - weight = Parameter(torch.empty(sum(output_partition_sizes), - input_size_per_partition, - dtype=params_dtype), - requires_grad=False) - set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) - layer.register_parameter("weight", weight) - set_weight_attrs(weight, extra_weight_attrs) - - self.ag_gemm_op = flux.AGKernel( - get_tp_group().device_group, - 1, # One node - 8192, # Max M. TODO: Pass in correctly. - weight.shape[0], # N - weight.shape[1], # K - # TODO: Pass in input dtype correctly. - # TODO: It would be nicer to modify flux to dispatch based on dtype - # at run time, but I don't know what the downside would be. - # Similar comment for max m. - torch.float16, - torch.float16, - # Note: transpose_weight=False means that B is transposed - transpose_weight=False, - # Note: if local_copy=True, I hit the following runtime error: - # /flux/src/all_gather/ths_op/all_gather_gemm_kernel.cc:648 - # Check failed: 33554432((input.numel() * input.element_size())) - # == 139836453421056((this->chunk_size)) - local_copy=False, - ) - - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - assert bias is None - - output = self.ag_gemm_op.forward(x, layer.weight) - - return output - - -# This check is a hack -def should_slice(shape) -> bool: - n_slices = get_tensor_model_parallel_world_size() - return (shape[0] % n_slices == 0 and shape[0] >= 128) - - -def slice_residual(residual) -> List[torch.Tensor]: - n_slices = get_tensor_model_parallel_world_size() - return torch.chunk(residual, n_slices, dim=0) - - -class MatmulRS(LinearMethodBase): - #Fused Gemm-ReduceScatter without quantization. - - def __init__(self, separate_bias_add: bool = False): - self.separate_bias_add = separate_bias_add - self.group_name = torch.distributed.group.WORLD.group_name - - def create_weights(self, layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: List[int], input_size: int, - output_size: int, params_dtype: torch.dtype, - **extra_weight_attrs): - weight = Parameter(torch.empty(sum(output_partition_sizes), - input_size_per_partition, - dtype=params_dtype), - requires_grad=False) - set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) - layer.register_parameter("weight", weight) - set_weight_attrs(weight, extra_weight_attrs) - - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - assert bias is None - - if not should_slice(x.shape): - output = torch.matmul(x, layer.weight.transpose(1, 0)) - # This is a bit hacky - return tensor_model_parallel_all_reduce(output) - else: - return torch.ops.symm_mem.fused_matmul_reduce_scatter( - x, - layer.weight.transpose(1, 0).contiguous(), - "avg", - scatter_dim=0, - group_name=self.group_name) - - -class AGMatmul(LinearMethodBase): - #Fused AllGather-Gemm without quantization. - - def __init__(self, separate_bias_add: bool = False): - self.separate_bias_add = separate_bias_add - self.group_name = torch.distributed.group.WORLD.group_name - - def create_weights(self, layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: List[int], input_size: int, - output_size: int, params_dtype: torch.dtype, - **extra_weight_attrs): - weight = Parameter(torch.empty(sum(output_partition_sizes), - input_size_per_partition, - dtype=params_dtype), - requires_grad=False) - set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) - layer.register_parameter("weight", weight) - set_weight_attrs(weight, extra_weight_attrs) - - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - assert bias is None - - if not should_slice(x.shape): - return torch.matmul(x, layer.weight.transpose(1, 0)) - else: - ag_output, mm_outputs = torch.ops.symm_mem.fused_all_gather_matmul( - x, - [layer.weight.transpose(1, 0).contiguous()], - gather_dim=0, - group_name=self.group_name, - ) - return mm_outputs[0] - - class LinearBase(torch.nn.Module): """Base linear layer. @@ -350,8 +157,6 @@ def __init__( params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", - fuse_gemm_rs: bool = False, - fuse_ag_gemm: bool = False, ): super().__init__() @@ -362,18 +167,9 @@ def __init__( if params_dtype is None: params_dtype = torch.get_default_dtype() self.params_dtype = params_dtype - self.quant_method: Optional[QuantizeMethodBase] = None - - tp_size = get_tensor_model_parallel_world_size() - - if fuse_gemm_rs and tp_size > 1: - assert (quant_config is None) - self.quant_method = FluxGemmRS() if has_flux else MatmulRS() - elif fuse_ag_gemm and tp_size > 1: - assert (quant_config is None) - self.quant_method = FluxAGCook() if has_flux else AGMatmul() - elif quant_config is None: - self.quant_method = UnquantizedLinearMethod() + if quant_config is None: + self.quant_method: Optional[ + QuantizeMethodBase] = UnquantizedLinearMethod() else: self.quant_method = quant_config.get_quant_method(self, prefix=prefix) @@ -486,15 +282,9 @@ def __init__(self, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, output_sizes: Optional[List[int]] = None, - prefix: str = "", - fuse_ag_gemm: bool = False): - super().__init__(input_size, - output_size, - skip_bias_add, - params_dtype, - quant_config, - prefix, - fuse_ag_gemm=fuse_ag_gemm) + prefix: str = ""): + super().__init__(input_size, output_size, skip_bias_add, params_dtype, + quant_config, prefix) self.gather_output = gather_output @@ -629,8 +419,7 @@ def __init__(self, skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - fuse_ag_gemm: bool = False): + prefix: str = ""): self.output_sizes = output_sizes tp_size = get_tensor_model_parallel_world_size() assert all(output_size % tp_size == 0 for output_size in output_sizes) @@ -641,8 +430,7 @@ def __init__(self, skip_bias_add=skip_bias_add, params_dtype=params_dtype, quant_config=quant_config, - prefix=prefix, - fuse_ag_gemm=fuse_ag_gemm) + prefix=prefix) def weight_loader(self, param: Parameter, @@ -879,8 +667,7 @@ def __init__(self, skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - fuse_ag_gemm: bool = False): + prefix: str = ""): self.hidden_size = hidden_size self.head_size = head_size self.total_num_heads = total_num_heads @@ -913,8 +700,7 @@ def __init__(self, skip_bias_add=skip_bias_add, params_dtype=params_dtype, quant_config=quant_config, - prefix=prefix, - fuse_ag_gemm=fuse_ag_gemm) + prefix=prefix) def _get_shard_offset_mapping(self, loaded_shard_id: str): shard_offset_mapping = { @@ -1209,20 +995,12 @@ def __init__(self, params_dtype: Optional[torch.dtype] = None, reduce_results: bool = True, quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - fuse_gemm_rs: bool = False): - super().__init__(input_size, - output_size, - skip_bias_add, - params_dtype, - quant_config, - prefix, - fuse_gemm_rs=fuse_gemm_rs) + prefix: str = ""): + super().__init__(input_size, output_size, skip_bias_add, params_dtype, + quant_config, prefix) self.input_is_parallel = input_is_parallel self.reduce_results = reduce_results - if fuse_gemm_rs: - self.reduce_results = False # Divide the weight matrix along the last dimension. self.tp_rank = get_tensor_model_parallel_rank() From 04ec8ca1cc9abb42f3d5f558483ecdbf277b4401 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 1 Nov 2024 18:06:33 +0000 Subject: [PATCH 39/72] move utilities Signed-off-by: Bill Nell --- vllm/compilation/collective_fusion.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index b9c7eeb75ca69..d9db7bc949de6 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -13,7 +13,6 @@ from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.logger import init_logger -from vllm.model_executor.layers.linear import should_slice, slice_residual logger = init_logger(__name__) @@ -26,6 +25,17 @@ def get_world_name() -> str: return torch.distributed.group.WORLD.group_name +# This check is a hack +def should_slice(shape) -> bool: + n_slices = get_tensor_model_parallel_world_size() + return (shape[0] % n_slices == 0 and shape[0] >= 128) + + +def slice_residual(residual) -> List[torch.Tensor]: + n_slices = get_tensor_model_parallel_world_size() + return torch.chunk(residual, n_slices, dim=0) + + def match_gemm_rs_ag_gemm( residual: torch.Tensor, gemm_1_weights: torch.Tensor, From c3bb8757a27ff11de24790ea95e47873972b4b17 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 1 Nov 2024 21:51:14 +0000 Subject: [PATCH 40/72] add flux support Signed-off-by: Bill Nell --- vllm/compilation/collective_fusion.py | 173 ++++++++++++++++-- .../device_communicators/pynccl_wrapper.py | 9 + vllm/distributed/parallel_state.py | 14 +- vllm/envs.py | 5 + 4 files changed, 183 insertions(+), 18 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index d9db7bc949de6..80668039a9cf7 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -6,14 +6,26 @@ from torch._inductor.pattern_matcher import (Match, PatternMatcherPass, fwd_only, register_replacement) +import vllm.envs as envs + from vllm.compilation.inductor_pass import InductorPass from vllm.compilation.utils import (find_auto_fn, find_fn, find_getitem, last_node_in_match) -from vllm.distributed import tensor_model_parallel_all_reduce +from vllm.distributed import tensor_model_parallel_all_reduce, tensor_model_parallel_all_gather from vllm.distributed.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) + get_tp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.logger import init_logger +use_flux = False +if envs.VLLM_USE_FLUX: + try: + import flux + use_flux = True + print("USE FLUX") + except ImportError: + use_flux = False + + logger = init_logger(__name__) # TODO: factor out somehow @@ -67,7 +79,7 @@ def match_gemm_rs_ag_gemm( return mm_2, new_residual -@torch.library.custom_op("vllm::gemm_rs_ag_gemm", mutates_args=()) +@torch.library.custom_op("vllm::gemm_rs_ag_gemm_old", mutates_args=()) def gemm_rs_ag_gemm( residual: torch.Tensor, old_my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, gemm_1_activations: torch.Tensor, @@ -131,7 +143,7 @@ def gemm_rs_ag_gemm( return mm_2[0], new_residual.clone(), slice_scatter -@torch.library.register_fake("vllm::gemm_rs_ag_gemm") +#@torch.library.register_fake("vllm::gemm_rs_ag_gemm") def gemm_rs_ag_gemm_fake( residual: torch.Tensor, my_residual: torch.Tensor, @@ -159,6 +171,124 @@ def gemm_rs_ag_gemm_fake( return (mm_res, my_residual, residual) +def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_weights: torch.Size, gemm_2_weights: torch.Size): + + if use_flux: + gemm_rs_op = flux.GemmRS( + get_tp_group().device_group, + 1, # One node + 8192, # Max M. TODO: Pass in correctly. + gemm_1_weights[0], # N + # TODO: Pass in input dtype correctly. + # TODO: It would be nicer to modify flux to dispatch based on dtype + # at run time, but I don't know what the downside would be. + # Similar comment for max m. + torch.float16, + # Note: transpose_weight=False means that B is transposed + transpose_weight=False, + # Note: bfloat16 requires fuse_reduction=False. + fuse_reduction=False, + ) + + ag_gemm_op = flux.AGKernel( + get_tp_group().device_group, + 1, # One node + 8192, # Max M. TODO: Pass in correctly. + gemm_2_weights[0], # N + gemm_2_weights[1], # K + # TODO: Pass in input dtype correctly. + # TODO: It would be nicer to modify flux to dispatch based on dtype + # at run time, but I don't know what the downside would be. + # Similar comment for max m. + torch.float16, + torch.float16, + # Note: transpose_weight=False means that B is transposed + transpose_weight=False, + # Note: if local_copy=True, I hit the following runtime error: + # /flux/src/all_gather/ths_op/all_gather_gemm_kernel.cc:648 + # Check failed: 33554432((input.numel() * input.element_size())) + # == 139836453421056((this->chunk_size)) + local_copy=False, + ) + + gemm_rs = lambda act, wt: gemm_rs_op.forward(act, wt).squeeze(0) + ag_gemm = lambda act, wt: ag_gemm_op.forward(act, wt) + + name = f"gemm_rs_ag_gemm_{gemm_1_weights[0]}_{gemm_2_weights[0]}_{gemm_2_weights[1]}" + else: + group_name = get_world_name() + + gemm_rs = lambda act, wt: torch.ops.symm_mem.fused_matmul_reduce_scatter.default( + act, wt.transpose(1,0), 'avg', 0, group_name) + + ag_gemm = lambda act, wt: torch.ops.symm_mem.fused_all_gather_matmul.default( + act, [wt.transpose(1,0)], 0, group_name)[1] + + name = "gemm_rs_ag_gemm" + + def gemm_rs_ag_gemm( + residual: torch.Tensor, old_my_residual: torch.Tensor, + gemm_1_weights: torch.Tensor, gemm_1_activations: torch.Tensor, + rms_norm_weight: torch.Tensor, gemm_2_weights: torch.Tensor, + first_layer: bool) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + if should_slice(residual.shape) and first_layer: + res_slices = slice_residual(residual) + slice_size = res_slices[get_tensor_model_parallel_rank()].shape[0] + residual_chunk = torch.ops.aten.split.Tensor(residual, slice_size) + my_residual = residual_chunk[0] + else: + my_residual = residual + slice_size = residual.shape[0] + + if not should_slice(residual.shape): + output = torch.matmul(gemm_1_activations, gemm_1_weights.transpose(1,0)) + reduced_output = tensor_model_parallel_all_reduce(output) + + norm_res = torch.ops.higher_order.auto_functionalized( + torch.ops._C.fused_add_rms_norm.default, + input=reduced_output, + residual=my_residual, + weight=rms_norm_weight, + epsilon=1e-05) + normalized = norm_res[1] + new_residual = norm_res[2] + + mm_2 = torch.matmul(normalized, gemm_2_weights.transpose(1,0)) + return mm_2, new_residual, new_residual.clone() + else: + group_name = get_world_name() + output = gemm_rs(gemm_1_activations, gemm_1_weights) + + norm_res = torch.ops.higher_order.auto_functionalized( + torch.ops._C.fused_add_rms_norm.default, + input=output, + residual=my_residual, + weight=rms_norm_weight, + epsilon=1e-05) + normalized = norm_res[1] + new_residual = norm_res[2] + + residual_1 = residual if first_layer else old_my_residual + slice_scatter = torch.ops.aten.slice_scatter.default( + residual_1, new_residual, 0, 0, slice_size) + split_2 = torch.ops.aten.split.Tensor(slice_scatter, slice_size) + new_residual = split_2[0] + + mm_2 = ag_gemm(normalized, gemm_2_weights) + + # TODO: can we avoid clone here? + return mm_2[0], new_residual.clone(), slice_scatter + + if not hasattr(torch.ops.vllm, name): + logger.info("registering torch.ops.vllm.%s", name) + torch.library.custom_op(f"vllm::{name}", gemm_rs_ag_gemm, mutates_args=()) + torch.library.register_fake(f"vllm::{name}", gemm_rs_ag_gemm_fake) + assert getattr(torch.ops.vllm, name) + + return getattr(torch.ops.vllm, name).default + + def match_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, gemm_1_activations: torch.Tensor, rms_norm_weights: torch.Tensor) -> torch.Tensor: @@ -181,7 +311,8 @@ def match_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, return normalized - +# Register this as a custom op since all reduce cannot be torch.compiled. +#@torch.library.custom_op("vllm::gemm_ag_final", mutates_args=()) def replace_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, gemm_1_activations: torch.Tensor, rms_norm_weights: torch.Tensor) -> torch.Tensor: @@ -195,12 +326,15 @@ def replace_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, # is this the right thing to call it on? if should_slice(gemm_1_activations.shape): - group_name = get_world_name() - world_size = get_tensor_model_parallel_world_size() - all_gather = torch.ops._c10d_functional.all_gather_into_tensor.default( - my_residual, world_size, group_name) - wait_tensor = torch.ops._c10d_functional.wait_tensor.default( - all_gather) + if True: + group_name = get_world_name() + world_size = get_tensor_model_parallel_world_size() + all_gather = torch.ops._c10d_functional.all_gather_into_tensor.default( + my_residual, world_size, group_name) + wait_tensor = torch.ops._c10d_functional.wait_tensor.default( + all_gather) + else: + wait_tensor = tensor_model_parallel_all_gather(my_residual) else: wait_tensor = my_residual @@ -214,6 +348,15 @@ def replace_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, return norm_res[1] +#@torch.library.register_fake("vllm::gemm_ag_final") +def replace_final_fake(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, + gemm_1_activations: torch.Tensor, + rms_norm_weights: torch.Tensor) -> torch.Tensor: + return torch.empty([gemm_1_activations.shape[0], + my_residual.shape[1]], + dtype=my_residual.dtype, device=my_residual.device) + + class CollectiveFusionPass(InductorPass): def __init__(self): @@ -273,8 +416,14 @@ def find_min_index(match: Match) -> int: res_replacements) > 0 else match.kwargs["residual"] kwargs["old_my_residual"] = my_res_replacements[-1] if len( my_res_replacements) > 0 else match.kwargs["residual"] + + # TODO: use get + gemm_1_w = kwargs["gemm_1_weights"].meta["val"].shape + gemm_2_w = kwargs["gemm_2_weights"].meta["val"].shape + fused_node = graph.call_function( - torch.ops.vllm.gemm_rs_ag_gemm.default, kwargs=kwargs) + get_gemm_rs_ag_gemm(use_flux, gemm_1_w, gemm_2_w), + kwargs=kwargs) graph.inserting_after(fused_node) result_node_new = graph.call_function(operator.getitem, diff --git a/vllm/distributed/device_communicators/pynccl_wrapper.py b/vllm/distributed/device_communicators/pynccl_wrapper.py index ff88f72470b27..aa7717af46484 100644 --- a/vllm/distributed/device_communicators/pynccl_wrapper.py +++ b/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -315,6 +315,15 @@ def ncclRecv(self, recvbuff: buffer_type, count: int, datatype: int, def ncclCommDestroy(self, comm: ncclComm_t) -> None: self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm)) + def ncclAllGather(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, comm: ncclComm_t, + stream: cudaStream_t) -> None: + # `datatype` actually should be `ncclDataType_t` + # which is an aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK(self._funcs["ncclAllGather"](sendbuff, recvbuff, count, + datatype, comm, stream)) __all__ = [ "NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId", diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 907449ed568e9..6b403ff465094 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -30,12 +30,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union from unittest.mock import patch -try: - import flux - has_flux = True -except ImportError: - has_flux = False - import torch import torch.distributed from torch.distributed import Backend, ProcessGroup, _symmetric_memory @@ -45,6 +39,14 @@ from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op, supports_custom_op +has_flux = False +if envs.VLLM_USE_FLUX: + try: + import flux + has_flux = True + except ImportError: + has_flux = False + @dataclass class GraphCaptureContext: diff --git a/vllm/envs.py b/vllm/envs.py index 9b8aa2934104e..7be4856fc9d91 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -70,6 +70,7 @@ VLLM_DISABLED_KERNELS: List[str] = [] VLLM_USE_V1: bool = False VLLM_ENABLE_V1_MULTIPROCESSING: bool = False + VLLM_USE_FLUX: bool = False def get_default_cache_root(): @@ -462,6 +463,10 @@ def get_default_config_root(): # If set, enable multiprocessing in LLM for the V1 code path. "VLLM_ENABLE_V1_MULTIPROCESSING": lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0"))), + + # If set, try to use the flux fused collective comminucation gemm kernels + "VLLM_USE_FLUX": + lambda: bool(int(os.getenv("VLLM_USE_FLUX", "0"))), } # end-env-vars-definition From e4225626221e28f973bc23613292af6d59692fff Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 1 Nov 2024 22:00:19 +0000 Subject: [PATCH 41/72] add flux support Signed-off-by: Bill Nell --- vllm/compilation/collective_fusion.py | 126 +++++------------- .../device_communicators/pynccl_wrapper.py | 1 + vllm/envs.py | 2 +- 3 files changed, 37 insertions(+), 92 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 80668039a9cf7..816227c2a4857 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -7,27 +7,27 @@ fwd_only, register_replacement) import vllm.envs as envs - from vllm.compilation.inductor_pass import InductorPass from vllm.compilation.utils import (find_auto_fn, find_fn, find_getitem, last_node_in_match) -from vllm.distributed import tensor_model_parallel_all_reduce, tensor_model_parallel_all_gather +from vllm.distributed import (tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce) from vllm.distributed.parallel_state import ( - get_tp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, + get_tp_group) from vllm.logger import init_logger +logger = init_logger(__name__) + use_flux = False if envs.VLLM_USE_FLUX: try: import flux use_flux = True - print("USE FLUX") + logger.info("USING FLUX") except ImportError: use_flux = False - -logger = init_logger(__name__) - # TODO: factor out somehow TP_GROUP_NAME = "tp:0" @@ -79,71 +79,6 @@ def match_gemm_rs_ag_gemm( return mm_2, new_residual -@torch.library.custom_op("vllm::gemm_rs_ag_gemm_old", mutates_args=()) -def gemm_rs_ag_gemm( - residual: torch.Tensor, old_my_residual: torch.Tensor, - gemm_1_weights: torch.Tensor, gemm_1_activations: torch.Tensor, - rms_norm_weight: torch.Tensor, gemm_2_weights: torch.Tensor, - first_layer: bool) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - - if should_slice(residual.shape) and first_layer: - res_slices = slice_residual(residual) - slice_size = res_slices[get_tensor_model_parallel_rank()].shape[0] - residual_chunk = torch.ops.aten.split.Tensor(residual, slice_size) - my_residual = residual_chunk[0] - else: - my_residual = residual - slice_size = residual.shape[0] - - gemm_1_w_perm = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) - - if not should_slice(residual.shape): - output = torch.matmul(gemm_1_activations, gemm_1_w_perm) - reduced_output = tensor_model_parallel_all_reduce(output) - - norm_res = torch.ops.higher_order.auto_functionalized( - torch.ops._C.fused_add_rms_norm.default, - input=reduced_output, - residual=my_residual, - weight=rms_norm_weight, - epsilon=1e-05) - normalized = norm_res[1] - new_residual = norm_res[2] - - gemm_2_w_perm = torch.ops.aten.permute.default(gemm_2_weights, [1, 0]) - mm_2 = torch.matmul(normalized, gemm_2_w_perm) - return mm_2, new_residual, new_residual.clone() - else: - group_name = get_world_name() - output = torch.ops.symm_mem.fused_matmul_reduce_scatter.default( - gemm_1_activations, gemm_1_w_perm, 'avg', 0, group_name) - - norm_res = torch.ops.higher_order.auto_functionalized( - torch.ops._C.fused_add_rms_norm.default, - input=output, - residual=my_residual, - weight=rms_norm_weight, - epsilon=1e-05) - normalized = norm_res[1] - new_residual = norm_res[2] - - residual_1 = residual if first_layer else old_my_residual - slice_scatter = torch.ops.aten.slice_scatter.default( - residual_1, new_residual, 0, 0, slice_size) - split_2 = torch.ops.aten.split.Tensor(slice_scatter, slice_size) - new_residual = split_2[0] - - gemm_2_w_perm = torch.ops.aten.permute.default(gemm_2_weights, [1, 0]) - fused_all_gather_matmul = ( - torch.ops.symm_mem.fused_all_gather_matmul.default( - normalized, [gemm_2_w_perm], 0, group_name)) - mm_2 = fused_all_gather_matmul[1] - - # TODO: can we avoid clone here? - return mm_2[0], new_residual.clone(), slice_scatter - - -#@torch.library.register_fake("vllm::gemm_rs_ag_gemm") def gemm_rs_ag_gemm_fake( residual: torch.Tensor, my_residual: torch.Tensor, @@ -171,7 +106,8 @@ def gemm_rs_ag_gemm_fake( return (mm_res, my_residual, residual) -def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_weights: torch.Size, gemm_2_weights: torch.Size): +def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_weights: torch.Size, + gemm_2_weights: torch.Size): if use_flux: gemm_rs_op = flux.GemmRS( @@ -214,15 +150,18 @@ def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_weights: torch.Size, gemm_2_weigh gemm_rs = lambda act, wt: gemm_rs_op.forward(act, wt).squeeze(0) ag_gemm = lambda act, wt: ag_gemm_op.forward(act, wt) - name = f"gemm_rs_ag_gemm_{gemm_1_weights[0]}_{gemm_2_weights[0]}_{gemm_2_weights[1]}" + name = (f"gemm_rs_ag_gemm_{gemm_1_weights[0]}_" + f"{gemm_2_weights[0]}_{gemm_2_weights[1]}") else: group_name = get_world_name() - gemm_rs = lambda act, wt: torch.ops.symm_mem.fused_matmul_reduce_scatter.default( - act, wt.transpose(1,0), 'avg', 0, group_name) + gemm_rs = lambda act, wt: \ + torch.ops.symm_mem.fused_matmul_reduce_scatter.default( + act, wt.transpose(1, 0), 'avg', 0, group_name) - ag_gemm = lambda act, wt: torch.ops.symm_mem.fused_all_gather_matmul.default( - act, [wt.transpose(1,0)], 0, group_name)[1] + ag_gemm = lambda act, wt: \ + torch.ops.symm_mem.fused_all_gather_matmul.default( + act, [wt.transpose(1, 0)], 0, group_name)[1] name = "gemm_rs_ag_gemm" @@ -230,7 +169,8 @@ def gemm_rs_ag_gemm( residual: torch.Tensor, old_my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, gemm_1_activations: torch.Tensor, rms_norm_weight: torch.Tensor, gemm_2_weights: torch.Tensor, - first_layer: bool) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + first_layer: bool + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if should_slice(residual.shape) and first_layer: res_slices = slice_residual(residual) @@ -242,7 +182,8 @@ def gemm_rs_ag_gemm( slice_size = residual.shape[0] if not should_slice(residual.shape): - output = torch.matmul(gemm_1_activations, gemm_1_weights.transpose(1,0)) + output = torch.matmul(gemm_1_activations, + gemm_1_weights.transpose(1, 0)) reduced_output = tensor_model_parallel_all_reduce(output) norm_res = torch.ops.higher_order.auto_functionalized( @@ -254,10 +195,9 @@ def gemm_rs_ag_gemm( normalized = norm_res[1] new_residual = norm_res[2] - mm_2 = torch.matmul(normalized, gemm_2_weights.transpose(1,0)) + mm_2 = torch.matmul(normalized, gemm_2_weights.transpose(1, 0)) return mm_2, new_residual, new_residual.clone() else: - group_name = get_world_name() output = gemm_rs(gemm_1_activations, gemm_1_weights) norm_res = torch.ops.higher_order.auto_functionalized( @@ -282,7 +222,9 @@ def gemm_rs_ag_gemm( if not hasattr(torch.ops.vllm, name): logger.info("registering torch.ops.vllm.%s", name) - torch.library.custom_op(f"vllm::{name}", gemm_rs_ag_gemm, mutates_args=()) + torch.library.custom_op(f"vllm::{name}", + gemm_rs_ag_gemm, + mutates_args=()) torch.library.register_fake(f"vllm::{name}", gemm_rs_ag_gemm_fake) assert getattr(torch.ops.vllm, name) @@ -311,6 +253,7 @@ def match_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, return normalized + # Register this as a custom op since all reduce cannot be torch.compiled. #@torch.library.custom_op("vllm::gemm_ag_final", mutates_args=()) def replace_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, @@ -329,8 +272,9 @@ def replace_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, if True: group_name = get_world_name() world_size = get_tensor_model_parallel_world_size() - all_gather = torch.ops._c10d_functional.all_gather_into_tensor.default( - my_residual, world_size, group_name) + all_gather = ( + torch.ops._c10d_functional.all_gather_into_tensor.default( + my_residual, world_size, group_name)) wait_tensor = torch.ops._c10d_functional.wait_tensor.default( all_gather) else: @@ -352,9 +296,9 @@ def replace_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, def replace_final_fake(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, gemm_1_activations: torch.Tensor, rms_norm_weights: torch.Tensor) -> torch.Tensor: - return torch.empty([gemm_1_activations.shape[0], - my_residual.shape[1]], - dtype=my_residual.dtype, device=my_residual.device) + return torch.empty([gemm_1_activations.shape[0], my_residual.shape[1]], + dtype=my_residual.dtype, + device=my_residual.device) class CollectiveFusionPass(InductorPass): @@ -421,9 +365,9 @@ def find_min_index(match: Match) -> int: gemm_1_w = kwargs["gemm_1_weights"].meta["val"].shape gemm_2_w = kwargs["gemm_2_weights"].meta["val"].shape - fused_node = graph.call_function( - get_gemm_rs_ag_gemm(use_flux, gemm_1_w, gemm_2_w), - kwargs=kwargs) + fused_node = graph.call_function(get_gemm_rs_ag_gemm( + use_flux, gemm_1_w, gemm_2_w), + kwargs=kwargs) graph.inserting_after(fused_node) result_node_new = graph.call_function(operator.getitem, diff --git a/vllm/distributed/device_communicators/pynccl_wrapper.py b/vllm/distributed/device_communicators/pynccl_wrapper.py index aa7717af46484..c5cc6b33fcce4 100644 --- a/vllm/distributed/device_communicators/pynccl_wrapper.py +++ b/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -325,6 +325,7 @@ def ncclAllGather(self, sendbuff: buffer_type, recvbuff: buffer_type, self.NCCL_CHECK(self._funcs["ncclAllGather"](sendbuff, recvbuff, count, datatype, comm, stream)) + __all__ = [ "NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId", "ncclComm_t", "cudaStream_t", "buffer_type" diff --git a/vllm/envs.py b/vllm/envs.py index 7be4856fc9d91..4ac6767c0a3ff 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -464,7 +464,7 @@ def get_default_config_root(): "VLLM_ENABLE_V1_MULTIPROCESSING": lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0"))), - # If set, try to use the flux fused collective comminucation gemm kernels + # If set, try to use the flux fused collective communication gemm kernels. "VLLM_USE_FLUX": lambda: bool(int(os.getenv("VLLM_USE_FLUX", "0"))), } From 4edbcab2ae33e6901b35be432018b497a6ee9cfd Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sat, 2 Nov 2024 02:18:47 +0000 Subject: [PATCH 42/72] add types to flux kernels Signed-off-by: Bill Nell --- vllm/compilation/collective_fusion.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 816227c2a4857..7ae734fc5276f 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -106,7 +106,10 @@ def gemm_rs_ag_gemm_fake( return (mm_res, my_residual, residual) -def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_weights: torch.Size, +def get_gemm_rs_ag_gemm(use_flux: bool, + gemm_1_type, + gemm_1_weights: torch.Size, + gemm_2_type, gemm_2_weights: torch.Size): if use_flux: @@ -119,7 +122,7 @@ def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_weights: torch.Size, # TODO: It would be nicer to modify flux to dispatch based on dtype # at run time, but I don't know what the downside would be. # Similar comment for max m. - torch.float16, + gemm_1_type, # Note: transpose_weight=False means that B is transposed transpose_weight=False, # Note: bfloat16 requires fuse_reduction=False. @@ -136,8 +139,8 @@ def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_weights: torch.Size, # TODO: It would be nicer to modify flux to dispatch based on dtype # at run time, but I don't know what the downside would be. # Similar comment for max m. - torch.float16, - torch.float16, + gemm_2_type, + gemm_2_type, # Note: transpose_weight=False means that B is transposed transpose_weight=False, # Note: if local_copy=True, I hit the following runtime error: @@ -150,8 +153,10 @@ def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_weights: torch.Size, gemm_rs = lambda act, wt: gemm_rs_op.forward(act, wt).squeeze(0) ag_gemm = lambda act, wt: ag_gemm_op.forward(act, wt) - name = (f"gemm_rs_ag_gemm_{gemm_1_weights[0]}_" - f"{gemm_2_weights[0]}_{gemm_2_weights[1]}") + gemm_1_str = str(gemm_1_type).removeprefix("torch.") + gemm_2_str = str(gemm_2_type).removeprefix("torch.") + name = (f"gemm_rs_ag_gemm_{gemm_1_str}_{gemm_1_weights[0]}_" + f"{gemm_2_str}_{gemm_2_weights[0]}_{gemm_2_weights[1]}") else: group_name = get_world_name() @@ -362,11 +367,11 @@ def find_min_index(match: Match) -> int: my_res_replacements) > 0 else match.kwargs["residual"] # TODO: use get - gemm_1_w = kwargs["gemm_1_weights"].meta["val"].shape - gemm_2_w = kwargs["gemm_2_weights"].meta["val"].shape + gemm_1 = kwargs["gemm_1_weights"].meta["val"] + gemm_2 = kwargs["gemm_2_weights"].meta["val"] fused_node = graph.call_function(get_gemm_rs_ag_gemm( - use_flux, gemm_1_w, gemm_2_w), + use_flux, gemm_1.dtype, gemm_1.shape, gemm_2.dtype, gemm_2.shape), kwargs=kwargs) graph.inserting_after(fused_node) From 689d819b65e2afd120cca5e4ff3e045646a3eb0c Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 4 Nov 2024 14:54:56 +0000 Subject: [PATCH 43/72] wip Signed-off-by: Bill Nell --- vllm/compilation/collective_fusion.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 7ae734fc5276f..a5fd5b631be93 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -260,7 +260,7 @@ def match_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, # Register this as a custom op since all reduce cannot be torch.compiled. -#@torch.library.custom_op("vllm::gemm_ag_final", mutates_args=()) +@torch.library.custom_op("vllm::gemm_ag_final", mutates_args=()) def replace_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, gemm_1_activations: torch.Tensor, rms_norm_weights: torch.Tensor) -> torch.Tensor: @@ -274,7 +274,7 @@ def replace_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, # is this the right thing to call it on? if should_slice(gemm_1_activations.shape): - if True: + if True: #not use_flux: group_name = get_world_name() world_size = get_tensor_model_parallel_world_size() all_gather = ( @@ -297,7 +297,7 @@ def replace_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, return norm_res[1] -#@torch.library.register_fake("vllm::gemm_ag_final") +@torch.library.register_fake("vllm::gemm_ag_final") def replace_final_fake(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, gemm_1_activations: torch.Tensor, rms_norm_weights: torch.Tensor) -> torch.Tensor: @@ -329,8 +329,8 @@ def __init__(self): final_inputs = [x, w, resid, resid_w] register_replacement( match_final, - #torch.ops.vllm.gemm_ag_final, - replace_final, + torch.ops.vllm.gemm_ag_final, + #replace_final, final_inputs, fwd_only, [self.final_pattern]) @@ -343,7 +343,7 @@ def record_match(self, match: Match) -> bool: # Return False to prevent automatic replacement. return False - def process_matches(self, graph: fx.Graph): + def process_matches(self, graph: fx.Graph, num_to_process: int): nodes = list(graph.nodes) def find_min_index(match: Match) -> int: @@ -355,7 +355,7 @@ def find_min_index(match: Match) -> int: res_replacements: List[fx.Node] = [] my_res_replacements: List[fx.Node] = [] - for match in matches: + for match in matches[:num_to_process]: last_node = last_node_in_match(match) with graph.inserting_after(last_node): @@ -401,20 +401,22 @@ def find_min_index(match: Match) -> int: # Finally, remove matched nodes graph.eliminate_dead_code() - assert all(node not in graph.nodes for match in matches - for node in match.nodes) + assert all(node not in graph.nodes for match in matches[:num_to_process] for node in match.nodes) def __call__(self, graph: fx.Graph): self.dump_graph(graph, "before_collective_fusion") count = self.gemm_rs_ag_gemm_pattern.apply(graph) logger.info("fused gemm match count = %d", len(self.matches)) + num_to_process = 1 + # Don't apply final pattern unless we've matched and replaced the # gemm+collective ops. if len(self.matches) > 0: - count = self.final_pattern.apply(graph) - logger.info("final match count = %d", count) - self.process_matches(graph) + if True or num_to_process == len(self.matches): + count = self.final_pattern.apply(graph) + logger.info("final match count = %d", count) + self.process_matches(graph, num_to_process) self.dump_graph(graph, "after_collective_fusion") self.matches.clear() From 30d4fb98b4706519dc065d5ab58217410acec473 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 4 Nov 2024 21:26:29 +0000 Subject: [PATCH 44/72] improve perf. Signed-off-by: Bill Nell --- vllm/compilation/collective_fusion.py | 96 +++++++++++++-------------- 1 file changed, 46 insertions(+), 50 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index a5fd5b631be93..23e95a70cfc91 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -6,6 +6,7 @@ from torch._inductor.pattern_matcher import (Match, PatternMatcherPass, fwd_only, register_replacement) +import vllm._custom_ops as ops import vllm.envs as envs from vllm.compilation.inductor_pass import InductorPass from vllm.compilation.utils import (find_auto_fn, find_fn, find_getitem, @@ -16,6 +17,7 @@ get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, get_tp_group) from vllm.logger import init_logger +from vllm.utils import direct_register_custom_op logger = init_logger(__name__) @@ -106,6 +108,7 @@ def gemm_rs_ag_gemm_fake( return (mm_res, my_residual, residual) +# TODO: factor out groupnames, etc. def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_type, gemm_1_weights: torch.Size, @@ -177,13 +180,13 @@ def gemm_rs_ag_gemm( first_layer: bool ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - if should_slice(residual.shape) and first_layer: + if first_layer and should_slice(residual.shape): res_slices = slice_residual(residual) slice_size = res_slices[get_tensor_model_parallel_rank()].shape[0] residual_chunk = torch.ops.aten.split.Tensor(residual, slice_size) my_residual = residual_chunk[0] else: - my_residual = residual + my_residual = residual #.clone() slice_size = residual.shape[0] if not should_slice(residual.shape): @@ -191,46 +194,41 @@ def gemm_rs_ag_gemm( gemm_1_weights.transpose(1, 0)) reduced_output = tensor_model_parallel_all_reduce(output) - norm_res = torch.ops.higher_order.auto_functionalized( - torch.ops._C.fused_add_rms_norm.default, - input=reduced_output, - residual=my_residual, - weight=rms_norm_weight, - epsilon=1e-05) - normalized = norm_res[1] - new_residual = norm_res[2] - - mm_2 = torch.matmul(normalized, gemm_2_weights.transpose(1, 0)) - return mm_2, new_residual, new_residual.clone() + ops.fused_add_rms_norm(input=reduced_output, + residual=my_residual, + weight=rms_norm_weight, + epsilon=1e-05) + + mm_2 = torch.matmul(reduced_output, gemm_2_weights.transpose(1, 0)) + return mm_2, my_residual, my_residual.clone() else: output = gemm_rs(gemm_1_activations, gemm_1_weights) - norm_res = torch.ops.higher_order.auto_functionalized( - torch.ops._C.fused_add_rms_norm.default, - input=output, - residual=my_residual, - weight=rms_norm_weight, - epsilon=1e-05) - normalized = norm_res[1] - new_residual = norm_res[2] + ops.fused_add_rms_norm(input=output, + residual=my_residual, + weight=rms_norm_weight, + epsilon=1e-05) residual_1 = residual if first_layer else old_my_residual slice_scatter = torch.ops.aten.slice_scatter.default( - residual_1, new_residual, 0, 0, slice_size) + residual_1, my_residual, 0, 0, slice_size) split_2 = torch.ops.aten.split.Tensor(slice_scatter, slice_size) - new_residual = split_2[0] - - mm_2 = ag_gemm(normalized, gemm_2_weights) # TODO: can we avoid clone here? - return mm_2[0], new_residual.clone(), slice_scatter + new_residual = split_2[0] #.clone() + + mm_2 = ag_gemm(output, gemm_2_weights) + + return mm_2[0], new_residual, slice_scatter if not hasattr(torch.ops.vllm, name): logger.info("registering torch.ops.vllm.%s", name) - torch.library.custom_op(f"vllm::{name}", - gemm_rs_ag_gemm, - mutates_args=()) - torch.library.register_fake(f"vllm::{name}", gemm_rs_ag_gemm_fake) + direct_register_custom_op( + name, + gemm_rs_ag_gemm, + mutates_args=[], + fake_impl=gemm_rs_ag_gemm_fake + ) assert getattr(torch.ops.vllm, name) return getattr(torch.ops.vllm, name).default @@ -260,19 +258,14 @@ def match_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, # Register this as a custom op since all reduce cannot be torch.compiled. -@torch.library.custom_op("vllm::gemm_ag_final", mutates_args=()) def replace_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, gemm_1_activations: torch.Tensor, rms_norm_weights: torch.Tensor) -> torch.Tensor: permute_254 = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) mm_1 = torch.ops.aten.mm.default(gemm_1_activations, permute_254) - auto_functionalized_161 = torch.ops.higher_order.auto_functionalized( - torch.ops.vllm.inplace_all_reduce.default, - tensor=mm_1, - group_name=TP_GROUP_NAME) - reduced = auto_functionalized_161[1] - # is this the right thing to call it on? + reduced = tensor_model_parallel_all_reduce(mm_1) + if should_slice(gemm_1_activations.shape): if True: #not use_flux: group_name = get_world_name() @@ -287,17 +280,15 @@ def replace_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, else: wait_tensor = my_residual - norm_res = torch.ops.higher_order.auto_functionalized( - torch.ops._C.fused_add_rms_norm.default, + ops.fused_add_rms_norm( input=reduced, residual=wait_tensor, weight=rms_norm_weights, epsilon=1e-05) - return norm_res[1] + return reduced -@torch.library.register_fake("vllm::gemm_ag_final") def replace_final_fake(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, gemm_1_activations: torch.Tensor, rms_norm_weights: torch.Tensor) -> torch.Tensor: @@ -306,6 +297,14 @@ def replace_final_fake(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, device=my_residual.device) +direct_register_custom_op( + "gemm_ag_final", + replace_final, + mutates_args=[], + fake_impl=replace_final_fake +) + + class CollectiveFusionPass(InductorPass): def __init__(self): @@ -343,7 +342,7 @@ def record_match(self, match: Match) -> bool: # Return False to prevent automatic replacement. return False - def process_matches(self, graph: fx.Graph, num_to_process: int): + def process_matches(self, graph: fx.Graph): nodes = list(graph.nodes) def find_min_index(match: Match) -> int: @@ -355,7 +354,7 @@ def find_min_index(match: Match) -> int: res_replacements: List[fx.Node] = [] my_res_replacements: List[fx.Node] = [] - for match in matches[:num_to_process]: + for match in matches: last_node = last_node_in_match(match) with graph.inserting_after(last_node): @@ -401,22 +400,19 @@ def find_min_index(match: Match) -> int: # Finally, remove matched nodes graph.eliminate_dead_code() - assert all(node not in graph.nodes for match in matches[:num_to_process] for node in match.nodes) + assert all(node not in graph.nodes for match in matches for node in match.nodes) def __call__(self, graph: fx.Graph): self.dump_graph(graph, "before_collective_fusion") count = self.gemm_rs_ag_gemm_pattern.apply(graph) logger.info("fused gemm match count = %d", len(self.matches)) - num_to_process = 1 - # Don't apply final pattern unless we've matched and replaced the # gemm+collective ops. if len(self.matches) > 0: - if True or num_to_process == len(self.matches): - count = self.final_pattern.apply(graph) - logger.info("final match count = %d", count) - self.process_matches(graph, num_to_process) + count = self.final_pattern.apply(graph) + logger.info("final match count = %d", count) + self.process_matches(graph) self.dump_graph(graph, "after_collective_fusion") self.matches.clear() From 9bc764da8a7d7497ebfafdd27a247ead8db084b4 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 5 Nov 2024 18:37:34 +0000 Subject: [PATCH 45/72] support custom ar Signed-off-by: Bill Nell --- vllm/compilation/collective_fusion.py | 121 +++++++++++++++++++++++--- 1 file changed, 107 insertions(+), 14 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 23e95a70cfc91..8392f2f6ae50a 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -1,5 +1,5 @@ import operator -from typing import List, Tuple +from typing import Any, Callable, List, Optional, Sequence, Tuple import torch import torch.fx as fx @@ -60,6 +60,7 @@ def match_gemm_rs_ag_gemm( gemm_1_w_perm = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) mm_1 = torch.ops.aten.mm.default(gemm_1_activations, gemm_1_w_perm) + #all_reduce = tensor_model_parallel_all_reduce(mm_1) all_reduce = torch.ops.higher_order.auto_functionalized( torch.ops.vllm.inplace_all_reduce.default, tensor=mm_1, @@ -81,6 +82,35 @@ def match_gemm_rs_ag_gemm( return mm_2, new_residual +def match_gemm_rs_ag_gemm_custom_ar( + residual: torch.Tensor, + gemm_1_weights: torch.Tensor, + gemm_1_activations: torch.Tensor, + rms_norm_weight: torch.Tensor, + gemm_2_weights: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + gemm_1_w_perm = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) + mm_1 = torch.ops.aten.mm.default(gemm_1_activations, gemm_1_w_perm) + + #all_reduce = tensor_model_parallel_all_reduce(mm_1) + all_reduce = torch.ops.vllm.outplace_all_reduce.default(mm_1, + TP_GROUP_NAME) + + norm_res = torch.ops.higher_order.auto_functionalized( + torch.ops._C.fused_add_rms_norm.default, + input=all_reduce, + residual=residual, + weight=rms_norm_weight, + epsilon=1e-05) + normalized = norm_res[1] + new_residual = norm_res[2] + + gemm_2_w_perm = torch.ops.aten.permute.default(gemm_2_weights, [1, 0]) + mm_2 = torch.ops.aten.mm.default(normalized, gemm_2_w_perm) + + return mm_2, new_residual + + def gemm_rs_ag_gemm_fake( residual: torch.Tensor, my_residual: torch.Tensor, @@ -91,7 +121,7 @@ def gemm_rs_ag_gemm_fake( first_layer: bool, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - if should_slice(gemm_1_activations.shape) and first_layer: + if first_layer and should_slice(gemm_1_activations.shape): res_slices = slice_residual(residual) slice_size = res_slices[get_tensor_model_parallel_rank()].shape[0] split_1 = torch.ops.aten.split.Tensor(residual, slice_size) @@ -240,6 +270,7 @@ def match_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, gemm_1_w_perm = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) mm_1 = torch.ops.aten.mm.default(gemm_1_activations, gemm_1_w_perm) + #all_reduce = tensor_model_parallel_all_reduce(mm_1) all_reduce = torch.ops.higher_order.auto_functionalized( torch.ops.vllm.inplace_all_reduce.default, tensor=mm_1, @@ -257,6 +288,26 @@ def match_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, return normalized +def match_final_custom_ar(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, + gemm_1_activations: torch.Tensor, + rms_norm_weights: torch.Tensor) -> torch.Tensor: + gemm_1_w_perm = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) + mm_1 = torch.ops.aten.mm.default(gemm_1_activations, gemm_1_w_perm) + + all_reduce = torch.ops.vllm.outplace_all_reduce.default(mm_1, + TP_GROUP_NAME) + + norm_res = torch.ops.higher_order.auto_functionalized( + torch.ops._C.fused_add_rms_norm.default, + input=all_reduce, + residual=my_residual, + weight=rms_norm_weights, + epsilon=1e-05) + normalized = norm_res[1] + + return normalized + + # Register this as a custom op since all reduce cannot be torch.compiled. def replace_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, gemm_1_activations: torch.Tensor, @@ -304,6 +355,45 @@ def replace_final_fake(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, fake_impl=replace_final_fake ) +# Copied from pattern_matcher.py fwd_only so we can use tracing_mode="fake" +# TODO: convert args to fake tenors and forward to original fwd_only. +@torch.no_grad() +def fake_fwd_only( + fn: Callable[..., Any], + args: Sequence[Any], + *, + run_functional_passes: bool = True, + get_decomp_fn: Optional[Callable[..., Any]] = None, +) -> torch.fx.GraphModule: + from torch._dispatch.python import enable_python_dispatcher + from torch.fx.experimental.proxy_tensor import make_fx + from torch._inductor.decomposition import select_decomp_table + + """Build a normalized inference graph, for use with fx_to_pattern""" + # TODO - look into using aot autograd, asserting no mutating ops here + with enable_python_dispatcher(): + decompositions = ( + get_decomp_fn() if get_decomp_fn is not None else select_decomp_table() + ) + #with torch._dynamo.utils.detect_fake_mode(args) as fm: + # new_args = [] + # for arg in args: + # if isinstance(arg, torch.Tensor): + # new_args.append(fm.from_tensor(arg)) + # else: + # new_args.append(arg) + #gm = make_fx(fn, decompositions, tracing_mode="real")(*new_args) + gm = make_fx(fn, decompositions, tracing_mode="fake")(*args) + + from torch._inductor.fx_passes.post_grad import remove_noop_ops + + if run_functional_passes: + remove_noop_ops(gm.graph) + gm.graph.eliminate_dead_code() + + gm.recompile() + return gm + class CollectiveFusionPass(InductorPass): @@ -319,20 +409,23 @@ def __init__(self): x2 = torch.empty([4, 4], device='cuda') inputs = [resid, x, w, resid_w, x2] - register_replacement(match_gemm_rs_ag_gemm, - match_gemm_rs_ag_gemm, - inputs, - fwd_only, [self.gemm_rs_ag_gemm_pattern], - extra_check=lambda m: self.record_match(m)) + for m in [match_gemm_rs_ag_gemm, match_gemm_rs_ag_gemm_custom_ar]: + register_replacement(m, + match_gemm_rs_ag_gemm, + inputs, + fake_fwd_only, + [self.gemm_rs_ag_gemm_pattern], + extra_check=lambda m: self.record_match(m)) final_inputs = [x, w, resid, resid_w] - register_replacement( - match_final, - torch.ops.vllm.gemm_ag_final, - #replace_final, - final_inputs, - fwd_only, - [self.final_pattern]) + for m in [match_final, match_final_custom_ar]: + register_replacement( + m, + torch.ops.vllm.gemm_ag_final, + #replace_final, + final_inputs, + fake_fwd_only, + [self.final_pattern]) def record_match(self, match: Match) -> bool: # Hijack the extra_check to record the match and From 72bb3b6201c4fa6d30269fb3b86a82a06fcfd91a Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 5 Nov 2024 19:42:22 +0000 Subject: [PATCH 46/72] tweaks Signed-off-by: Bill Nell --- vllm/compilation/collective_fusion.py | 30 +++++++++++++-------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 8392f2f6ae50a..e35350bcfd82c 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -31,6 +31,8 @@ use_flux = False # TODO: factor out somehow +# register multiple patterns for all tp names (0-numgpus-1)? +# or pass as additional args? TP_GROUP_NAME = "tp:0" @@ -220,8 +222,8 @@ def gemm_rs_ag_gemm( slice_size = residual.shape[0] if not should_slice(residual.shape): - output = torch.matmul(gemm_1_activations, - gemm_1_weights.transpose(1, 0)) + output = torch.ops.aten.mm.default(gemm_1_activations, + gemm_1_weights.transpose(1, 0)) reduced_output = tensor_model_parallel_all_reduce(output) ops.fused_add_rms_norm(input=reduced_output, @@ -229,7 +231,8 @@ def gemm_rs_ag_gemm( weight=rms_norm_weight, epsilon=1e-05) - mm_2 = torch.matmul(reduced_output, gemm_2_weights.transpose(1, 0)) + mm_2 = torch.ops.aten.mm.default(reduced_output, + gemm_2_weights.transpose(1, 0)) return mm_2, my_residual, my_residual.clone() else: output = gemm_rs(gemm_1_activations, gemm_1_weights) @@ -308,7 +311,7 @@ def match_final_custom_ar(my_residual: torch.Tensor, gemm_1_weights: torch.Tenso return normalized -# Register this as a custom op since all reduce cannot be torch.compiled. +# Register this as a custom op since all reduce cannot be torch.compiled yet. def replace_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, gemm_1_activations: torch.Tensor, rms_norm_weights: torch.Tensor) -> torch.Tensor: @@ -318,16 +321,7 @@ def replace_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, reduced = tensor_model_parallel_all_reduce(mm_1) if should_slice(gemm_1_activations.shape): - if True: #not use_flux: - group_name = get_world_name() - world_size = get_tensor_model_parallel_world_size() - all_gather = ( - torch.ops._c10d_functional.all_gather_into_tensor.default( - my_residual, world_size, group_name)) - wait_tensor = torch.ops._c10d_functional.wait_tensor.default( - all_gather) - else: - wait_tensor = tensor_model_parallel_all_gather(my_residual) + wait_tensor = tensor_model_parallel_all_gather(my_residual) else: wait_tensor = my_residual @@ -355,7 +349,10 @@ def replace_final_fake(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, fake_impl=replace_final_fake ) -# Copied from pattern_matcher.py fwd_only so we can use tracing_mode="fake" +# Copied from pattern_matcher.py fwd_only so we can use tracing_mode="fake". +# "real" mode chokes on custom ar primitives since the custom ar data structure(s) +# have not been set up. We could also try to only register the custom_ar patterns +# if custom ar has been initialized. Not sure how hard that is. # TODO: convert args to fake tenors and forward to original fwd_only. @torch.no_grad() def fake_fwd_only( @@ -375,6 +372,7 @@ def fake_fwd_only( decompositions = ( get_decomp_fn() if get_decomp_fn is not None else select_decomp_table() ) + # This doesn't seem to work. #with torch._dynamo.utils.detect_fake_mode(args) as fm: # new_args = [] # for arg in args: @@ -409,6 +407,8 @@ def __init__(self): x2 = torch.empty([4, 4], device='cuda') inputs = [resid, x, w, resid_w, x2] + # register multiple patterns for all group/world names? + for m in [match_gemm_rs_ag_gemm, match_gemm_rs_ag_gemm_custom_ar]: register_replacement(m, match_gemm_rs_ag_gemm, From da034eb3e53099901b025c8e0d19de1c8d1f30a4 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 5 Nov 2024 21:36:10 +0000 Subject: [PATCH 47/72] factor out group names Signed-off-by: Bill Nell --- vllm/compilation/collective_fusion.py | 199 +++++++++++--------------- 1 file changed, 82 insertions(+), 117 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index e35350bcfd82c..469a040d4fb02 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -30,11 +30,6 @@ except ImportError: use_flux = False -# TODO: factor out somehow -# register multiple patterns for all tp names (0-numgpus-1)? -# or pass as additional args? -TP_GROUP_NAME = "tp:0" - # how to do this properly? def get_world_name() -> str: @@ -52,65 +47,43 @@ def slice_residual(residual) -> List[torch.Tensor]: return torch.chunk(residual, n_slices, dim=0) -def match_gemm_rs_ag_gemm( - residual: torch.Tensor, - gemm_1_weights: torch.Tensor, - gemm_1_activations: torch.Tensor, - rms_norm_weight: torch.Tensor, - gemm_2_weights: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - gemm_1_w_perm = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) - mm_1 = torch.ops.aten.mm.default(gemm_1_activations, gemm_1_w_perm) - - #all_reduce = tensor_model_parallel_all_reduce(mm_1) - all_reduce = torch.ops.higher_order.auto_functionalized( - torch.ops.vllm.inplace_all_reduce.default, - tensor=mm_1, - group_name=TP_GROUP_NAME) - all_reduce = all_reduce[1] - - norm_res = torch.ops.higher_order.auto_functionalized( - torch.ops._C.fused_add_rms_norm.default, - input=all_reduce, - residual=residual, - weight=rms_norm_weight, - epsilon=1e-05) - normalized = norm_res[1] - new_residual = norm_res[2] - - gemm_2_w_perm = torch.ops.aten.permute.default(gemm_2_weights, [1, 0]) - mm_2 = torch.ops.aten.mm.default(normalized, gemm_2_w_perm) +def get_match_gemm_rs_ag_gemm(tp_group_name: str, custom_ar: bool): + def match_gemm_rs_ag_gemm( + residual: torch.Tensor, + gemm_1_weights: torch.Tensor, + gemm_1_activations: torch.Tensor, + rms_norm_weight: torch.Tensor, + gemm_2_weights: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + gemm_1_w_perm = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) + mm_1 = torch.ops.aten.mm.default(gemm_1_activations, gemm_1_w_perm) + + #all_reduce = tensor_model_parallel_all_reduce(mm_1) + if custom_ar: + all_reduce = torch.ops.vllm.outplace_all_reduce.default(mm_1, + tp_group_name) + else: + all_reduce = torch.ops.higher_order.auto_functionalized( + torch.ops.vllm.inplace_all_reduce.default, + tensor=mm_1, + group_name=tp_group_name) + all_reduce = all_reduce[1] - return mm_2, new_residual + norm_res = torch.ops.higher_order.auto_functionalized( + torch.ops._C.fused_add_rms_norm.default, + input=all_reduce, + residual=residual, + weight=rms_norm_weight, + epsilon=1e-05) + normalized = norm_res[1] + new_residual = norm_res[2] + gemm_2_w_perm = torch.ops.aten.permute.default(gemm_2_weights, [1, 0]) + mm_2 = torch.ops.aten.mm.default(normalized, gemm_2_w_perm) -def match_gemm_rs_ag_gemm_custom_ar( - residual: torch.Tensor, - gemm_1_weights: torch.Tensor, - gemm_1_activations: torch.Tensor, - rms_norm_weight: torch.Tensor, - gemm_2_weights: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - gemm_1_w_perm = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) - mm_1 = torch.ops.aten.mm.default(gemm_1_activations, gemm_1_w_perm) - - #all_reduce = tensor_model_parallel_all_reduce(mm_1) - all_reduce = torch.ops.vllm.outplace_all_reduce.default(mm_1, - TP_GROUP_NAME) - - norm_res = torch.ops.higher_order.auto_functionalized( - torch.ops._C.fused_add_rms_norm.default, - input=all_reduce, - residual=residual, - weight=rms_norm_weight, - epsilon=1e-05) - normalized = norm_res[1] - new_residual = norm_res[2] + return mm_2, new_residual - gemm_2_w_perm = torch.ops.aten.permute.default(gemm_2_weights, [1, 0]) - mm_2 = torch.ops.aten.mm.default(normalized, gemm_2_w_perm) - - return mm_2, new_residual + return match_gemm_rs_ag_gemm def gemm_rs_ag_gemm_fake( @@ -267,48 +240,36 @@ def gemm_rs_ag_gemm( return getattr(torch.ops.vllm, name).default -def match_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, - gemm_1_activations: torch.Tensor, - rms_norm_weights: torch.Tensor) -> torch.Tensor: - gemm_1_w_perm = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) - mm_1 = torch.ops.aten.mm.default(gemm_1_activations, gemm_1_w_perm) - - #all_reduce = tensor_model_parallel_all_reduce(mm_1) - all_reduce = torch.ops.higher_order.auto_functionalized( - torch.ops.vllm.inplace_all_reduce.default, - tensor=mm_1, - group_name=TP_GROUP_NAME) - all_reduce = all_reduce[1] - - norm_res = torch.ops.higher_order.auto_functionalized( - torch.ops._C.fused_add_rms_norm.default, - input=all_reduce, - residual=my_residual, - weight=rms_norm_weights, - epsilon=1e-05) - normalized = norm_res[1] - - return normalized - +def get_match_final(tp_group_name: str, use_custom_ar: bool): + def match_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, + gemm_1_activations: torch.Tensor, + rms_norm_weights: torch.Tensor, + ) -> torch.Tensor: + gemm_1_w_perm = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) + mm_1 = torch.ops.aten.mm.default(gemm_1_activations, gemm_1_w_perm) -def match_final_custom_ar(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, - gemm_1_activations: torch.Tensor, - rms_norm_weights: torch.Tensor) -> torch.Tensor: - gemm_1_w_perm = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) - mm_1 = torch.ops.aten.mm.default(gemm_1_activations, gemm_1_w_perm) + #all_reduce = tensor_model_parallel_all_reduce(mm_1) + if use_custom_ar: + all_reduce = torch.ops.vllm.outplace_all_reduce.default(mm_1, + tp_group_name) + else: + all_reduce = torch.ops.higher_order.auto_functionalized( + torch.ops.vllm.inplace_all_reduce.default, + tensor=mm_1, + group_name=tp_group_name) + all_reduce = all_reduce[1] - all_reduce = torch.ops.vllm.outplace_all_reduce.default(mm_1, - TP_GROUP_NAME) + norm_res = torch.ops.higher_order.auto_functionalized( + torch.ops._C.fused_add_rms_norm.default, + input=all_reduce, + residual=my_residual, + weight=rms_norm_weights, + epsilon=1e-05) + normalized = norm_res[1] - norm_res = torch.ops.higher_order.auto_functionalized( - torch.ops._C.fused_add_rms_norm.default, - input=all_reduce, - residual=my_residual, - weight=rms_norm_weights, - epsilon=1e-05) - normalized = norm_res[1] + return normalized - return normalized + return match_final # Register this as a custom op since all reduce cannot be torch.compiled yet. @@ -406,26 +367,30 @@ def __init__(self): resid_w = torch.empty([4, 4], device='cuda') x2 = torch.empty([4, 4], device='cuda') inputs = [resid, x, w, resid_w, x2] - - # register multiple patterns for all group/world names? - - for m in [match_gemm_rs_ag_gemm, match_gemm_rs_ag_gemm_custom_ar]: - register_replacement(m, - match_gemm_rs_ag_gemm, - inputs, - fake_fwd_only, - [self.gemm_rs_ag_gemm_pattern], - extra_check=lambda m: self.record_match(m)) - final_inputs = [x, w, resid, resid_w] - for m in [match_final, match_final_custom_ar]: - register_replacement( - m, - torch.ops.vllm.gemm_ag_final, - #replace_final, - final_inputs, - fake_fwd_only, - [self.final_pattern]) + + # register multiple patterns for all group names, fill out to max_gpus. + group_names = ["tp:0"] + + for group_name in group_names: + for m in [get_match_gemm_rs_ag_gemm(group_name, False), + get_match_gemm_rs_ag_gemm(group_name, True)]: + register_replacement(m, + m, + inputs, + fake_fwd_only, + [self.gemm_rs_ag_gemm_pattern], + extra_check=lambda m: self.record_match(m)) + + for m in [get_match_final(group_name, False), + get_match_final(group_name, True)]: + register_replacement( + m, + torch.ops.vllm.gemm_ag_final, + #replace_final, + final_inputs, + fake_fwd_only, + [self.final_pattern]) def record_match(self, match: Match) -> bool: # Hijack the extra_check to record the match and From aa401314582f66274821d347bc7b8df66a091cda Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 5 Nov 2024 21:37:41 +0000 Subject: [PATCH 48/72] factor out group names Signed-off-by: Bill Nell --- vllm/compilation/collective_fusion.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 469a040d4fb02..15af08b409751 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -121,8 +121,9 @@ def get_gemm_rs_ag_gemm(use_flux: bool, gemm_2_weights: torch.Size): if use_flux: + print(f"DG = {get_tp_group().device_group}") gemm_rs_op = flux.GemmRS( - get_tp_group().device_group, + get_tp_group().device_group, # XXXXXXXXXXXXXXX 1, # One node 8192, # Max M. TODO: Pass in correctly. gemm_1_weights[0], # N @@ -138,7 +139,7 @@ def get_gemm_rs_ag_gemm(use_flux: bool, ) ag_gemm_op = flux.AGKernel( - get_tp_group().device_group, + get_tp_group().device_group, # XXXXXXXXXXXXXXX 1, # One node 8192, # Max M. TODO: Pass in correctly. gemm_2_weights[0], # N @@ -166,7 +167,7 @@ def get_gemm_rs_ag_gemm(use_flux: bool, name = (f"gemm_rs_ag_gemm_{gemm_1_str}_{gemm_1_weights[0]}_" f"{gemm_2_str}_{gemm_2_weights[0]}_{gemm_2_weights[1]}") else: - group_name = get_world_name() + group_name = get_world_name() # XXXXXXXXXXXXXXXX make parameter gemm_rs = lambda act, wt: \ torch.ops.symm_mem.fused_matmul_reduce_scatter.default( From a54883b31d7637e30b9c16879adf90bf7da96a5a Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 5 Nov 2024 22:34:29 +0000 Subject: [PATCH 49/72] factor out group names, cleanups, etc. Signed-off-by: Bill Nell --- vllm/compilation/collective_fusion.py | 225 ++++++++++++-------------- vllm/distributed/parallel_state.py | 7 +- 2 files changed, 108 insertions(+), 124 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 15af08b409751..ab64476499b20 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -1,5 +1,5 @@ import operator -from typing import Any, Callable, List, Optional, Sequence, Tuple +from typing import List, Tuple import torch import torch.fx as fx @@ -14,8 +14,8 @@ from vllm.distributed import (tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) from vllm.distributed.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, - get_tp_group) + get_group_from_group_name, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) from vllm.logger import init_logger from vllm.utils import direct_register_custom_op @@ -48,20 +48,21 @@ def slice_residual(residual) -> List[torch.Tensor]: def get_match_gemm_rs_ag_gemm(tp_group_name: str, custom_ar: bool): + def match_gemm_rs_ag_gemm( - residual: torch.Tensor, - gemm_1_weights: torch.Tensor, - gemm_1_activations: torch.Tensor, - rms_norm_weight: torch.Tensor, - gemm_2_weights: torch.Tensor, + residual: torch.Tensor, + gemm_1_weights: torch.Tensor, + gemm_1_activations: torch.Tensor, + rms_norm_weight: torch.Tensor, + gemm_2_weights: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: gemm_1_w_perm = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) mm_1 = torch.ops.aten.mm.default(gemm_1_activations, gemm_1_w_perm) #all_reduce = tensor_model_parallel_all_reduce(mm_1) if custom_ar: - all_reduce = torch.ops.vllm.outplace_all_reduce.default(mm_1, - tp_group_name) + all_reduce = torch.ops.vllm.outplace_all_reduce.default( + mm_1, tp_group_name) else: all_reduce = torch.ops.higher_order.auto_functionalized( torch.ops.vllm.inplace_all_reduce.default, @@ -98,6 +99,7 @@ def gemm_rs_ag_gemm_fake( if first_layer and should_slice(gemm_1_activations.shape): res_slices = slice_residual(residual) + # is this rank ok? slice_size = res_slices[get_tensor_model_parallel_rank()].shape[0] split_1 = torch.ops.aten.split.Tensor(residual, slice_size) my_residual = split_1[0] @@ -114,16 +116,14 @@ def gemm_rs_ag_gemm_fake( # TODO: factor out groupnames, etc. -def get_gemm_rs_ag_gemm(use_flux: bool, - gemm_1_type, - gemm_1_weights: torch.Size, - gemm_2_type, - gemm_2_weights: torch.Size): +def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_type, + gemm_1_weights: torch.Size, gemm_2_type, + gemm_2_weights: torch.Size, tp_group_name: str): if use_flux: - print(f"DG = {get_tp_group().device_group}") + device_group = get_group_from_group_name(tp_group_name).device_group gemm_rs_op = flux.GemmRS( - get_tp_group().device_group, # XXXXXXXXXXXXXXX + device_group, 1, # One node 8192, # Max M. TODO: Pass in correctly. gemm_1_weights[0], # N @@ -139,7 +139,7 @@ def get_gemm_rs_ag_gemm(use_flux: bool, ) ag_gemm_op = flux.AGKernel( - get_tp_group().device_group, # XXXXXXXXXXXXXXX + device_group, 1, # One node 8192, # Max M. TODO: Pass in correctly. gemm_2_weights[0], # N @@ -164,20 +164,24 @@ def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_str = str(gemm_1_type).removeprefix("torch.") gemm_2_str = str(gemm_2_type).removeprefix("torch.") - name = (f"gemm_rs_ag_gemm_{gemm_1_str}_{gemm_1_weights[0]}_" - f"{gemm_2_str}_{gemm_2_weights[0]}_{gemm_2_weights[1]}") + group_str = tp_group_name.replace(":", "_") + name = ( + f"gemm_rs_ag_gemm_{gemm_1_str}_{gemm_1_weights[0]}_" + f"{gemm_2_str}_{gemm_2_weights[0]}_{gemm_2_weights[1]}_{group_str}" + ) else: - group_name = get_world_name() # XXXXXXXXXXXXXXXX make parameter + world_group_name = get_world_name() gemm_rs = lambda act, wt: \ torch.ops.symm_mem.fused_matmul_reduce_scatter.default( - act, wt.transpose(1, 0), 'avg', 0, group_name) + act, wt.transpose(1, 0), 'avg', 0, world_group_name) ag_gemm = lambda act, wt: \ torch.ops.symm_mem.fused_all_gather_matmul.default( - act, [wt.transpose(1, 0)], 0, group_name)[1] + act, [wt.transpose(1, 0)], 0, world_group_name)[1] - name = "gemm_rs_ag_gemm" + group_str = tp_group_name.replace(":", "_") + name = f"gemm_rs_ag_gemm_{group_str}" def gemm_rs_ag_gemm( residual: torch.Tensor, old_my_residual: torch.Tensor, @@ -188,11 +192,12 @@ def gemm_rs_ag_gemm( if first_layer and should_slice(residual.shape): res_slices = slice_residual(residual) + # is this rank ok? slice_size = res_slices[get_tensor_model_parallel_rank()].shape[0] residual_chunk = torch.ops.aten.split.Tensor(residual, slice_size) my_residual = residual_chunk[0] else: - my_residual = residual #.clone() + my_residual = residual #.clone() slice_size = residual.shape[0] if not should_slice(residual.shape): @@ -230,29 +235,30 @@ def gemm_rs_ag_gemm( if not hasattr(torch.ops.vllm, name): logger.info("registering torch.ops.vllm.%s", name) - direct_register_custom_op( - name, - gemm_rs_ag_gemm, - mutates_args=[], - fake_impl=gemm_rs_ag_gemm_fake - ) + direct_register_custom_op(name, + gemm_rs_ag_gemm, + mutates_args=[], + fake_impl=gemm_rs_ag_gemm_fake) assert getattr(torch.ops.vllm, name) return getattr(torch.ops.vllm, name).default def get_match_final(tp_group_name: str, use_custom_ar: bool): - def match_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, - gemm_1_activations: torch.Tensor, - rms_norm_weights: torch.Tensor, - ) -> torch.Tensor: + + def match_final( + my_residual: torch.Tensor, + gemm_1_weights: torch.Tensor, + gemm_1_activations: torch.Tensor, + rms_norm_weights: torch.Tensor, + ) -> torch.Tensor: gemm_1_w_perm = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) mm_1 = torch.ops.aten.mm.default(gemm_1_activations, gemm_1_w_perm) #all_reduce = tensor_model_parallel_all_reduce(mm_1) if use_custom_ar: - all_reduce = torch.ops.vllm.outplace_all_reduce.default(mm_1, - tp_group_name) + all_reduce = torch.ops.vllm.outplace_all_reduce.default( + mm_1, tp_group_name) else: all_reduce = torch.ops.higher_order.auto_functionalized( torch.ops.vllm.inplace_all_reduce.default, @@ -287,11 +293,10 @@ def replace_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, else: wait_tensor = my_residual - ops.fused_add_rms_norm( - input=reduced, - residual=wait_tensor, - weight=rms_norm_weights, - epsilon=1e-05) + ops.fused_add_rms_norm(input=reduced, + residual=wait_tensor, + weight=rms_norm_weights, + epsilon=1e-05) return reduced @@ -304,55 +309,10 @@ def replace_final_fake(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, device=my_residual.device) -direct_register_custom_op( - "gemm_ag_final", - replace_final, - mutates_args=[], - fake_impl=replace_final_fake -) - -# Copied from pattern_matcher.py fwd_only so we can use tracing_mode="fake". -# "real" mode chokes on custom ar primitives since the custom ar data structure(s) -# have not been set up. We could also try to only register the custom_ar patterns -# if custom ar has been initialized. Not sure how hard that is. -# TODO: convert args to fake tenors and forward to original fwd_only. -@torch.no_grad() -def fake_fwd_only( - fn: Callable[..., Any], - args: Sequence[Any], - *, - run_functional_passes: bool = True, - get_decomp_fn: Optional[Callable[..., Any]] = None, -) -> torch.fx.GraphModule: - from torch._dispatch.python import enable_python_dispatcher - from torch.fx.experimental.proxy_tensor import make_fx - from torch._inductor.decomposition import select_decomp_table - - """Build a normalized inference graph, for use with fx_to_pattern""" - # TODO - look into using aot autograd, asserting no mutating ops here - with enable_python_dispatcher(): - decompositions = ( - get_decomp_fn() if get_decomp_fn is not None else select_decomp_table() - ) - # This doesn't seem to work. - #with torch._dynamo.utils.detect_fake_mode(args) as fm: - # new_args = [] - # for arg in args: - # if isinstance(arg, torch.Tensor): - # new_args.append(fm.from_tensor(arg)) - # else: - # new_args.append(arg) - #gm = make_fx(fn, decompositions, tracing_mode="real")(*new_args) - gm = make_fx(fn, decompositions, tracing_mode="fake")(*args) - - from torch._inductor.fx_passes.post_grad import remove_noop_ops - - if run_functional_passes: - remove_noop_ops(gm.graph) - gm.graph.eliminate_dead_code() - - gm.recompile() - return gm +direct_register_custom_op("gemm_ag_final", + replace_final, + mutates_args=[], + fake_impl=replace_final_fake) class CollectiveFusionPass(InductorPass): @@ -362,36 +322,42 @@ def __init__(self): self.final_pattern = PatternMatcherPass() self.matches: List[Match] = [] - x = torch.empty([4, 4], device='cuda') - w = torch.empty([4, 4], device='cuda') - resid = torch.empty([4, 4], device='cuda') - resid_w = torch.empty([4, 4], device='cuda') - x2 = torch.empty([4, 4], device='cuda') - inputs = [resid, x, w, resid_w, x2] - final_inputs = [x, w, resid, resid_w] - - # register multiple patterns for all group names, fill out to max_gpus. - group_names = ["tp:0"] - - for group_name in group_names: - for m in [get_match_gemm_rs_ag_gemm(group_name, False), - get_match_gemm_rs_ag_gemm(group_name, True)]: - register_replacement(m, - m, - inputs, - fake_fwd_only, - [self.gemm_rs_ag_gemm_pattern], - extra_check=lambda m: self.record_match(m)) - - for m in [get_match_final(group_name, False), - get_match_final(group_name, True)]: - register_replacement( - m, - torch.ops.vllm.gemm_ag_final, - #replace_final, - final_inputs, - fake_fwd_only, - [self.final_pattern]) + with torch._dynamo.utils.detect_fake_mode(): + x = torch.empty([4, 4], device='cuda') + w = torch.empty([4, 4], device='cuda') + resid = torch.empty([4, 4], device='cuda') + resid_w = torch.empty([4, 4], device='cuda') + x2 = torch.empty([4, 4], device='cuda') + inputs = [resid, x, w, resid_w, x2] + final_inputs = [x, w, resid, resid_w] + + # register multiple patterns for all group names. + max_gpus = 8 # TODO: get this officially + group_names = [f"tp:{rank}" for rank in range(max_gpus)] + + for group_name in group_names: + for m in [ + get_match_gemm_rs_ag_gemm(group_name, False), + get_match_gemm_rs_ag_gemm(group_name, True) + ]: + register_replacement( + m, + m, + inputs, + fwd_only, [self.gemm_rs_ag_gemm_pattern], + extra_check=lambda m: self.record_match(m)) + + for m in [ + get_match_final(group_name, False), + get_match_final(group_name, True) + ]: + register_replacement( + m, + torch.ops.vllm.gemm_ag_final, + #replace_final, + final_inputs, + fwd_only, + [self.final_pattern]) def record_match(self, match: Match) -> bool: # Hijack the extra_check to record the match and @@ -428,8 +394,20 @@ def find_min_index(match: Match) -> int: gemm_1 = kwargs["gemm_1_weights"].meta["val"] gemm_2 = kwargs["gemm_2_weights"].meta["val"] + ar_node = find_auto_fn( + match.nodes, torch.ops.vllm.inplace_all_reduce.default) + if ar_node is not None: + tp_group_name = ar_node.kwargs["group_name"] + else: + ar_node = find_fn( + match.nodes, + torch.ops.vllm.outplace_all_reduce.default) + assert ar_node is not None + tp_group_name = ar_node.args[1] + fused_node = graph.call_function(get_gemm_rs_ag_gemm( - use_flux, gemm_1.dtype, gemm_1.shape, gemm_2.dtype, gemm_2.shape), + use_flux, gemm_1.dtype, gemm_1.shape, gemm_2.dtype, + gemm_2.shape, tp_group_name), kwargs=kwargs) graph.inserting_after(fused_node) @@ -459,7 +437,8 @@ def find_min_index(match: Match) -> int: # Finally, remove matched nodes graph.eliminate_dead_code() - assert all(node not in graph.nodes for match in matches for node in match.nodes) + assert all(node not in graph.nodes for match in matches + for node in match.nodes) def __call__(self, graph: fx.Graph): self.dump_graph(graph, "before_collective_fusion") diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 6b403ff465094..0d5b97eba3850 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -104,11 +104,16 @@ def _register_group(group: "GroupCoordinator") -> None: _groups[group.unique_name] = weakref.ref(group) -def all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor: +def get_group_from_group_name(group_name: str) -> "GroupCoordinator": assert group_name in _groups, f"Group {group_name} is not found." group = _groups[group_name]() if group is None: raise ValueError(f"Group {group_name} is destroyed.") + return group + + +def all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor: + group = get_group_from_name(group_name) return group._all_reduce_out_place(tensor) From 0e2a024b61c43e201c02d0b5d8d9bccc9d87cb77 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 5 Nov 2024 22:58:24 +0000 Subject: [PATCH 50/72] fix some todos Signed-off-by: Bill Nell --- vllm/compilation/collective_fusion.py | 103 +++++++++++++------------- 1 file changed, 51 insertions(+), 52 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index ab64476499b20..20e0499ff4bbc 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -14,8 +14,7 @@ from vllm.distributed import (tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) from vllm.distributed.parallel_state import ( - get_group_from_group_name, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) + get_group_from_group_name, get_tensor_model_parallel_world_size) from vllm.logger import init_logger from vllm.utils import direct_register_custom_op @@ -31,12 +30,12 @@ use_flux = False -# how to do this properly? +# TODO: is this right? def get_world_name() -> str: return torch.distributed.group.WORLD.group_name -# This check is a hack +# TODO: This check is a hack def should_slice(shape) -> bool: n_slices = get_tensor_model_parallel_world_size() return (shape[0] % n_slices == 0 and shape[0] >= 128) @@ -59,6 +58,7 @@ def match_gemm_rs_ag_gemm( gemm_1_w_perm = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) mm_1 = torch.ops.aten.mm.default(gemm_1_activations, gemm_1_w_perm) + # It would be nice to do this instead of having two separate patterns #all_reduce = tensor_model_parallel_all_reduce(mm_1) if custom_ar: all_reduce = torch.ops.vllm.outplace_all_reduce.default( @@ -87,47 +87,20 @@ def match_gemm_rs_ag_gemm( return match_gemm_rs_ag_gemm -def gemm_rs_ag_gemm_fake( - residual: torch.Tensor, - my_residual: torch.Tensor, - gemm_1_weights: torch.Tensor, - gemm_1_activations: torch.Tensor, - rms_norm_weight: torch.Tensor, - gemm_2_weights: torch.Tensor, - first_layer: bool, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - - if first_layer and should_slice(gemm_1_activations.shape): - res_slices = slice_residual(residual) - # is this rank ok? - slice_size = res_slices[get_tensor_model_parallel_rank()].shape[0] - split_1 = torch.ops.aten.split.Tensor(residual, slice_size) - my_residual = split_1[0] - else: - my_residual = residual - - # verify the type is always correct - mm_res = torch.empty( - (gemm_1_activations.shape[0], gemm_2_weights.shape[0]), - device=gemm_1_activations.device, - dtype=gemm_1_activations.dtype) - - return (mm_res, my_residual, residual) - - -# TODO: factor out groupnames, etc. def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_type, gemm_1_weights: torch.Size, gemm_2_type, gemm_2_weights: torch.Size, tp_group_name: str): + group = get_group_from_group_name(tp_group_name) + device_group = group.device_group + rank = group.rank_in_group + if use_flux: - device_group = get_group_from_group_name(tp_group_name).device_group gemm_rs_op = flux.GemmRS( device_group, 1, # One node 8192, # Max M. TODO: Pass in correctly. gemm_1_weights[0], # N - # TODO: Pass in input dtype correctly. # TODO: It would be nicer to modify flux to dispatch based on dtype # at run time, but I don't know what the downside would be. # Similar comment for max m. @@ -144,7 +117,6 @@ def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_type, 8192, # Max M. TODO: Pass in correctly. gemm_2_weights[0], # N gemm_2_weights[1], # K - # TODO: Pass in input dtype correctly. # TODO: It would be nicer to modify flux to dispatch based on dtype # at run time, but I don't know what the downside would be. # Similar comment for max m. @@ -192,12 +164,11 @@ def gemm_rs_ag_gemm( if first_layer and should_slice(residual.shape): res_slices = slice_residual(residual) - # is this rank ok? - slice_size = res_slices[get_tensor_model_parallel_rank()].shape[0] + slice_size = res_slices[rank].shape[0] residual_chunk = torch.ops.aten.split.Tensor(residual, slice_size) my_residual = residual_chunk[0] else: - my_residual = residual #.clone() + my_residual = residual slice_size = residual.shape[0] if not should_slice(residual.shape): @@ -225,14 +196,37 @@ def gemm_rs_ag_gemm( slice_scatter = torch.ops.aten.slice_scatter.default( residual_1, my_residual, 0, 0, slice_size) split_2 = torch.ops.aten.split.Tensor(slice_scatter, slice_size) - - # TODO: can we avoid clone here? - new_residual = split_2[0] #.clone() + new_residual = split_2[0] mm_2 = ag_gemm(output, gemm_2_weights) return mm_2[0], new_residual, slice_scatter + def gemm_rs_ag_gemm_fake( + residual: torch.Tensor, + my_residual: torch.Tensor, + gemm_1_weights: torch.Tensor, + gemm_1_activations: torch.Tensor, + rms_norm_weight: torch.Tensor, + gemm_2_weights: torch.Tensor, + first_layer: bool, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if first_layer and should_slice(gemm_1_activations.shape): + res_slices = slice_residual(residual) + slice_size = res_slices[rank].shape[0] + split_1 = torch.ops.aten.split.Tensor(residual, slice_size) + my_residual = split_1[0] + else: + my_residual = residual + + # TODO: verify the type is always correct + mm_res = torch.empty( + (gemm_1_activations.shape[0], gemm_2_weights.shape[0]), + device=gemm_1_activations.device, + dtype=gemm_1_activations.dtype) + + return (mm_res, my_residual, residual) + if not hasattr(torch.ops.vllm, name): logger.info("registering torch.ops.vllm.%s", name) direct_register_custom_op(name, @@ -255,6 +249,7 @@ def match_final( gemm_1_w_perm = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) mm_1 = torch.ops.aten.mm.default(gemm_1_activations, gemm_1_w_perm) + # TODO: it would be nice to be able to use the official api directly. #all_reduce = tensor_model_parallel_all_reduce(mm_1) if use_custom_ar: all_reduce = torch.ops.vllm.outplace_all_reduce.default( @@ -322,6 +317,8 @@ def __init__(self): self.final_pattern = PatternMatcherPass() self.matches: List[Match] = [] + # Run in fake mode so that we don't call real functions + # when tracing the patterns. with torch._dynamo.utils.detect_fake_mode(): x = torch.empty([4, 4], device='cuda') w = torch.empty([4, 4], device='cuda') @@ -351,13 +348,9 @@ def __init__(self): get_match_final(group_name, False), get_match_final(group_name, True) ]: - register_replacement( - m, - torch.ops.vllm.gemm_ag_final, - #replace_final, - final_inputs, - fwd_only, - [self.final_pattern]) + register_replacement(m, torch.ops.vllm.gemm_ag_final, + final_inputs, fwd_only, + [self.final_pattern]) def record_match(self, match: Match) -> bool: # Hijack the extra_check to record the match and @@ -394,6 +387,8 @@ def find_min_index(match: Match) -> int: gemm_1 = kwargs["gemm_1_weights"].meta["val"] gemm_2 = kwargs["gemm_2_weights"].meta["val"] + # Extract group_name from matched code. Use to + # generate proper replacement code. ar_node = find_auto_fn( match.nodes, torch.ops.vllm.inplace_all_reduce.default) if ar_node is not None: @@ -405,9 +400,13 @@ def find_min_index(match: Match) -> int: assert ar_node is not None tp_group_name = ar_node.args[1] - fused_node = graph.call_function(get_gemm_rs_ag_gemm( - use_flux, gemm_1.dtype, gemm_1.shape, gemm_2.dtype, - gemm_2.shape, tp_group_name), + fused_gemm_func = get_gemm_rs_ag_gemm(use_flux, gemm_1.dtype, + gemm_1.shape, + gemm_2.dtype, + gemm_2.shape, + tp_group_name) + + fused_node = graph.call_function(fused_gemm_func, kwargs=kwargs) graph.inserting_after(fused_node) From b01205d915c4a2d5887e7697322962ab83d4c0a4 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 8 Nov 2024 22:07:41 +0000 Subject: [PATCH 51/72] find max m for flux kernels Signed-off-by: Bill Nell --- vllm/compilation/collective_fusion.py | 45 ++++++++++++++++++++------- 1 file changed, 34 insertions(+), 11 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 20e0499ff4bbc..1917b3c7ccf46 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -87,9 +87,14 @@ def match_gemm_rs_ag_gemm( return match_gemm_rs_ag_gemm -def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_type, - gemm_1_weights: torch.Size, gemm_2_type, - gemm_2_weights: torch.Size, tp_group_name: str): +def get_gemm_rs_ag_gemm(use_flux: bool, + gemm_1_type, + gemm_1_weights: torch.Size, + gemm_1_max_m: int, + gemm_2_type, + gemm_2_weights: torch.Size, + gemm_2_max_m: int, + tp_group_name: str): group = get_group_from_group_name(tp_group_name) device_group = group.device_group @@ -99,7 +104,7 @@ def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_type, gemm_rs_op = flux.GemmRS( device_group, 1, # One node - 8192, # Max M. TODO: Pass in correctly. + gemm_1_max_m, gemm_1_weights[0], # N # TODO: It would be nicer to modify flux to dispatch based on dtype # at run time, but I don't know what the downside would be. @@ -114,7 +119,7 @@ def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_type, ag_gemm_op = flux.AGKernel( device_group, 1, # One node - 8192, # Max M. TODO: Pass in correctly. + gemm_2_max_m, gemm_2_weights[0], # N gemm_2_weights[1], # K # TODO: It would be nicer to modify flux to dispatch based on dtype @@ -138,8 +143,9 @@ def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_type, gemm_2_str = str(gemm_2_type).removeprefix("torch.") group_str = tp_group_name.replace(":", "_") name = ( - f"gemm_rs_ag_gemm_{gemm_1_str}_{gemm_1_weights[0]}_" - f"{gemm_2_str}_{gemm_2_weights[0]}_{gemm_2_weights[1]}_{group_str}" + f"gemm_rs_ag_gemm_{gemm_1_str}_{gemm_1_weights[0]}_{gemm_1_max_m}_" + f"{gemm_2_str}_{gemm_2_weights[0]}_{gemm_2_weights[1]}_{gemm_2_max_m}_" + f"{group_str}" ) else: world_group_name = get_world_name() @@ -275,7 +281,7 @@ def match_final( # Register this as a custom op since all reduce cannot be torch.compiled yet. -def replace_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, +def gemm_ag_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, gemm_1_activations: torch.Tensor, rms_norm_weights: torch.Tensor) -> torch.Tensor: permute_254 = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) @@ -296,7 +302,7 @@ def replace_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, return reduced -def replace_final_fake(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, +def gemm_ag_final_fake(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, gemm_1_activations: torch.Tensor, rms_norm_weights: torch.Tensor) -> torch.Tensor: return torch.empty([gemm_1_activations.shape[0], my_residual.shape[1]], @@ -305,9 +311,9 @@ def replace_final_fake(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, direct_register_custom_op("gemm_ag_final", - replace_final, + gemm_ag_final, mutates_args=[], - fake_impl=replace_final_fake) + fake_impl=gemm_ag_final_fake) class CollectiveFusionPass(InductorPass): @@ -360,6 +366,18 @@ def record_match(self, match: Match) -> bool: # Return False to prevent automatic replacement. return False + def find_max_m(self, matches) -> Tuple[int, int]: + gemm_1_max_m = 0 + gemm_2_max_m = 0 + for m in matches: + gemm_1 = m.kwargs["gemm_1_weights"].meta["val"] + gemm_2 = m.kwargs["gemm_2_weights"].meta["val"] + gemm_1_max_m = max(gemm_1_max_m, gemm_1.shape[0]) + gemm_2_max_m = max(gemm_2_max_m, gemm_2.shape[0]) + assert gemm_1_max_m > 0 + assert gemm_2_max_m > 0 + return gemm_1_max_m, gemm_2_max_m + def process_matches(self, graph: fx.Graph): nodes = list(graph.nodes) @@ -372,6 +390,9 @@ def find_min_index(match: Match) -> int: res_replacements: List[fx.Node] = [] my_res_replacements: List[fx.Node] = [] + gemm_1_max_m, gemm_2_max_m = self.find_max_m(matches) + logger.info("max m = %d, %d", gemm_1_max_m, gemm_2_max_m) + for match in matches: last_node = last_node_in_match(match) @@ -402,8 +423,10 @@ def find_min_index(match: Match) -> int: fused_gemm_func = get_gemm_rs_ag_gemm(use_flux, gemm_1.dtype, gemm_1.shape, + gemm_1_max_m, gemm_2.dtype, gemm_2.shape, + gemm_2_max_m, tp_group_name) fused_node = graph.call_function(fused_gemm_func, From 9f90853b91e4f2d4c9e7a4f1b547e7a5d5cadd99 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 8 Nov 2024 23:29:44 +0000 Subject: [PATCH 52/72] rebase Signed-off-by: Bill Nell --- vllm/compilation/backends.py | 4 +--- vllm/compilation/collective_fusion.py | 22 +++++++++++++++++++++- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index bb49ec30af806..d8646cb741ba7 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -341,11 +341,9 @@ def async_rewrite(graph: fx.Graph): return graph -collective_fusion_pass: Optional[CollectiveFusionPass] = None - def wrap_inductor(graph, example_inputs, - additional_inductor_config, + additional_inductor_config = None, do_logging=False, runtime_shape: Optional[int] = None, use_inductor: bool = True): diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 1917b3c7ccf46..0eacf35d793e3 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -8,6 +8,7 @@ import vllm._custom_ops as ops import vllm.envs as envs +from vllm.compilation.config import CompilationConfig from vllm.compilation.inductor_pass import InductorPass from vllm.compilation.utils import (find_auto_fn, find_fn, find_getitem, last_node_in_match) @@ -318,7 +319,26 @@ def gemm_ag_final_fake(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, class CollectiveFusionPass(InductorPass): - def __init__(self): + _instance: 'Optional[CollectiveFusionPass]' = None + + @classmethod + def instance(cls, config: CompilationConfig): + """ + Get the singleton instance of the CollectiveFusionPass. + If the instance exists, the config is updated but + initialization is not repeated. + """ + if cls._instance is None: + cls._instance = CollectiveFusionPass(config) + else: + cls._instance.config = config + return cls._instance + + def __init__(self, config): + assert self.__class__._instance is None, \ + "FusionPass singleton instance already exists" + super().__init__(config) + self.gemm_rs_ag_gemm_pattern = PatternMatcherPass() self.final_pattern = PatternMatcherPass() self.matches: List[Match] = [] From a21fb98919ae6a4b63f5cd29dedba90600a8c6d0 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 8 Nov 2024 23:45:13 +0000 Subject: [PATCH 53/72] add error check Signed-off-by: Bill Nell --- vllm/compilation/collective_fusion.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 0eacf35d793e3..c905e6f04dbeb 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -424,9 +424,10 @@ def find_min_index(match: Match) -> int: kwargs["old_my_residual"] = my_res_replacements[-1] if len( my_res_replacements) > 0 else match.kwargs["residual"] - # TODO: use get - gemm_1 = kwargs["gemm_1_weights"].meta["val"] - gemm_2 = kwargs["gemm_2_weights"].meta["val"] + gemm_1 = kwargs["gemm_1_weights"].meta.get("val") + gemm_2 = kwargs["gemm_2_weights"].meta.get("val") + if gemm_1 is None or gemm_2 is None: + raise ValueError("Missing 'val' in gemm weights meta data") # Extract group_name from matched code. Use to # generate proper replacement code. From 7aa754666ec9091eb7b0977dccbad12420e52f72 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sat, 9 Nov 2024 00:39:23 +0000 Subject: [PATCH 54/72] review comments Signed-off-by: Bill Nell --- vllm/compilation/collective_fusion.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index c905e6f04dbeb..b271dfbfa1473 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -36,7 +36,7 @@ def get_world_name() -> str: return torch.distributed.group.WORLD.group_name -# TODO: This check is a hack +# 128 is tile_size on sm90 def should_slice(shape) -> bool: n_slices = get_tensor_model_parallel_world_size() return (shape[0] % n_slices == 0 and shape[0] >= 128) @@ -392,8 +392,8 @@ def find_max_m(self, matches) -> Tuple[int, int]: for m in matches: gemm_1 = m.kwargs["gemm_1_weights"].meta["val"] gemm_2 = m.kwargs["gemm_2_weights"].meta["val"] - gemm_1_max_m = max(gemm_1_max_m, gemm_1.shape[0]) - gemm_2_max_m = max(gemm_2_max_m, gemm_2.shape[0]) + gemm_1_max_m = max(gemm_1_max_m, gemm_1.shape[1]) + gemm_2_max_m = max(gemm_2_max_m, gemm_2.shape[1]) assert gemm_1_max_m > 0 assert gemm_2_max_m > 0 return gemm_1_max_m, gemm_2_max_m From 65fcaf50b671ab290f4ace45e32a8e086eb81260 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sat, 9 Nov 2024 01:08:24 +0000 Subject: [PATCH 55/72] format Signed-off-by: Bill Nell --- vllm/compilation/backends.py | 3 +-- vllm/compilation/collective_fusion.py | 27 +++++++++------------------ 2 files changed, 10 insertions(+), 20 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index d8646cb741ba7..b1be3f405f81f 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -6,7 +6,6 @@ import torch import torch.fx as fx -from typing import Tuple, List, Optional import vllm.envs as envs from vllm.config import CompilationConfig @@ -343,7 +342,7 @@ def async_rewrite(graph: fx.Graph): def wrap_inductor(graph, example_inputs, - additional_inductor_config = None, + additional_inductor_config=None, do_logging=False, runtime_shape: Optional[int] = None, use_inductor: bool = True): diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index b271dfbfa1473..a96d435f5bd9e 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -1,5 +1,5 @@ import operator -from typing import List, Tuple +from typing import List, Optional, Tuple import torch import torch.fx as fx @@ -88,14 +88,10 @@ def match_gemm_rs_ag_gemm( return match_gemm_rs_ag_gemm -def get_gemm_rs_ag_gemm(use_flux: bool, - gemm_1_type, - gemm_1_weights: torch.Size, - gemm_1_max_m: int, - gemm_2_type, - gemm_2_weights: torch.Size, - gemm_2_max_m: int, - tp_group_name: str): +def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_type, + gemm_1_weights: torch.Size, gemm_1_max_m: int, + gemm_2_type, gemm_2_weights: torch.Size, + gemm_2_max_m: int, tp_group_name: str): group = get_group_from_group_name(tp_group_name) device_group = group.device_group @@ -146,8 +142,7 @@ def get_gemm_rs_ag_gemm(use_flux: bool, name = ( f"gemm_rs_ag_gemm_{gemm_1_str}_{gemm_1_weights[0]}_{gemm_1_max_m}_" f"{gemm_2_str}_{gemm_2_weights[0]}_{gemm_2_weights[1]}_{gemm_2_max_m}_" - f"{group_str}" - ) + f"{group_str}") else: world_group_name = get_world_name() @@ -442,13 +437,9 @@ def find_min_index(match: Match) -> int: assert ar_node is not None tp_group_name = ar_node.args[1] - fused_gemm_func = get_gemm_rs_ag_gemm(use_flux, gemm_1.dtype, - gemm_1.shape, - gemm_1_max_m, - gemm_2.dtype, - gemm_2.shape, - gemm_2_max_m, - tp_group_name) + fused_gemm_func = get_gemm_rs_ag_gemm( + use_flux, gemm_1.dtype, gemm_1.shape, gemm_1_max_m, + gemm_2.dtype, gemm_2.shape, gemm_2_max_m, tp_group_name) fused_node = graph.call_function(fused_gemm_func, kwargs=kwargs) From 6ce19bda9695d56a8b3a811d528a99ce097c8444 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sat, 9 Nov 2024 07:58:37 +0000 Subject: [PATCH 56/72] fix cudagraph support Signed-off-by: Bill Nell --- vllm/compilation/collective_fusion.py | 42 ++++++++++++++++----------- 1 file changed, 25 insertions(+), 17 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index a96d435f5bd9e..20608cdaadae4 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -30,21 +30,26 @@ except ImportError: use_flux = False +FLUX_TILE_SIZE: int = 128 + # TODO: is this right? def get_world_name() -> str: return torch.distributed.group.WORLD.group_name -# 128 is tile_size on sm90 +# Note: this heuristic is unique to flux def should_slice(shape) -> bool: n_slices = get_tensor_model_parallel_world_size() - return (shape[0] % n_slices == 0 and shape[0] >= 128) + return (shape[0] % (FLUX_TILE_SIZE * n_slices) == 0 + and shape[0] >= FLUX_TILE_SIZE * n_slices) -def slice_residual(residual) -> List[torch.Tensor]: +# This is really inefficient. Should only pick the slice required. +def residual_slice_shape(residual, rank) -> List[torch.Size]: n_slices = get_tensor_model_parallel_world_size() - return torch.chunk(residual, n_slices, dim=0) + slices = torch.chunk(residual, n_slices, dim=0) + return slices[rank].shape def get_match_gemm_rs_ag_gemm(tp_group_name: str, custom_ar: bool): @@ -101,7 +106,7 @@ def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_type, gemm_rs_op = flux.GemmRS( device_group, 1, # One node - gemm_1_max_m, + gemm_1_max_m, # M gemm_1_weights[0], # N # TODO: It would be nicer to modify flux to dispatch based on dtype # at run time, but I don't know what the downside would be. @@ -116,7 +121,7 @@ def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_type, ag_gemm_op = flux.AGKernel( device_group, 1, # One node - gemm_2_max_m, + gemm_2_max_m, # M gemm_2_weights[0], # N gemm_2_weights[1], # K # TODO: It would be nicer to modify flux to dispatch based on dtype @@ -165,13 +170,12 @@ def gemm_rs_ag_gemm( ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if first_layer and should_slice(residual.shape): - res_slices = slice_residual(residual) - slice_size = res_slices[rank].shape[0] - residual_chunk = torch.ops.aten.split.Tensor(residual, slice_size) + slice_shape = residual_slice_shape(residual, rank)[0] + residual_chunk = torch.ops.aten.split.Tensor(residual, slice_shape) my_residual = residual_chunk[0] else: my_residual = residual - slice_size = residual.shape[0] + slice_shape = residual.shape[0] if not should_slice(residual.shape): output = torch.ops.aten.mm.default(gemm_1_activations, @@ -196,8 +200,8 @@ def gemm_rs_ag_gemm( residual_1 = residual if first_layer else old_my_residual slice_scatter = torch.ops.aten.slice_scatter.default( - residual_1, my_residual, 0, 0, slice_size) - split_2 = torch.ops.aten.split.Tensor(slice_scatter, slice_size) + residual_1, my_residual, 0, 0, slice_shape) + split_2 = torch.ops.aten.split.Tensor(slice_scatter, slice_shape) new_residual = split_2[0] mm_2 = ag_gemm(output, gemm_2_weights) @@ -214,9 +218,8 @@ def gemm_rs_ag_gemm_fake( first_layer: bool, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if first_layer and should_slice(gemm_1_activations.shape): - res_slices = slice_residual(residual) - slice_size = res_slices[rank].shape[0] - split_1 = torch.ops.aten.split.Tensor(residual, slice_size) + slice_shape = residual_slice_shape(residual, rank)[0] + split_1 = torch.ops.aten.split.Tensor(residual, slice_shape) my_residual = split_1[0] else: my_residual = residual @@ -385,10 +388,15 @@ def find_max_m(self, matches) -> Tuple[int, int]: gemm_1_max_m = 0 gemm_2_max_m = 0 for m in matches: - gemm_1 = m.kwargs["gemm_1_weights"].meta["val"] - gemm_2 = m.kwargs["gemm_2_weights"].meta["val"] + #gemm_1 = m.kwargs["gemm_1_weights"].meta["val"] + #gemm_2 = m.kwargs["gemm_2_weights"].meta["val"] + #gemm_1_max_m = max(gemm_1_max_m, gemm_1.shape[1]) + #gemm_2_max_m = max(gemm_2_max_m, gemm_2.shape[1]) + gemm_1 = m.kwargs["residual"].meta["val"] + gemm_2 = m.kwargs["residual"].meta["val"] gemm_1_max_m = max(gemm_1_max_m, gemm_1.shape[1]) gemm_2_max_m = max(gemm_2_max_m, gemm_2.shape[1]) + assert gemm_1_max_m > 0 assert gemm_2_max_m > 0 return gemm_1_max_m, gemm_2_max_m From ddc0b2068227c083b0c753dc0c22a9684dda0d26 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sat, 9 Nov 2024 19:37:36 +0000 Subject: [PATCH 57/72] perf improvements Signed-off-by: Bill Nell --- vllm/compilation/collective_fusion.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 20608cdaadae4..8b51e0f9bda47 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -45,11 +45,16 @@ def should_slice(shape) -> bool: and shape[0] >= FLUX_TILE_SIZE * n_slices) -# This is really inefficient. Should only pick the slice required. -def residual_slice_shape(residual, rank) -> List[torch.Size]: +def residual_slice_shape(residual, rank) -> int: + n_slices = get_tensor_model_parallel_world_size() + chunk, rem = divmod(residual.shape[0], n_slices) + return chunk if rank < n_slices - 1 or rem == 0 else rem + + +def residual_slice_shape_fake(residual, rank) -> int: n_slices = get_tensor_model_parallel_world_size() slices = torch.chunk(residual, n_slices, dim=0) - return slices[rank].shape + return slices[rank].shape[0] def get_match_gemm_rs_ag_gemm(tp_group_name: str, custom_ar: bool): @@ -170,7 +175,7 @@ def gemm_rs_ag_gemm( ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if first_layer and should_slice(residual.shape): - slice_shape = residual_slice_shape(residual, rank)[0] + slice_shape = residual_slice_shape(residual, rank) residual_chunk = torch.ops.aten.split.Tensor(residual, slice_shape) my_residual = residual_chunk[0] else: @@ -218,7 +223,7 @@ def gemm_rs_ag_gemm_fake( first_layer: bool, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if first_layer and should_slice(gemm_1_activations.shape): - slice_shape = residual_slice_shape(residual, rank)[0] + slice_shape = residual_slice_shape_fake(residual, rank) split_1 = torch.ops.aten.split.Tensor(residual, slice_shape) my_residual = split_1[0] else: From 039d2853191a7749e35c42ca921276d657a440cd Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sat, 9 Nov 2024 23:01:38 +0000 Subject: [PATCH 58/72] cleanups Signed-off-by: Bill Nell --- vllm/compilation/backends.py | 4 +- vllm/compilation/collective_fusion.py | 119 ++++++++++++-------------- 2 files changed, 58 insertions(+), 65 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index b1be3f405f81f..c44c53244bc27 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -342,7 +342,7 @@ def async_rewrite(graph: fx.Graph): def wrap_inductor(graph, example_inputs, - additional_inductor_config=None, + additional_inductor_config: Optional[Dict] = None, do_logging=False, runtime_shape: Optional[int] = None, use_inductor: bool = True): @@ -360,7 +360,7 @@ def wrap_inductor(graph, from torch._inductor import config torch._inductor.config._micro_pipeline_tp = True - # Set to False to avoid infinite recursion logging + # Set to False to avoid infinite recursion logging? torch._inductor.config.implicit_fallbacks = True current_config = config.shallow_copy_dict() diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 8b51e0f9bda47..0f76e7a4c1260 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -1,12 +1,11 @@ import operator -from typing import List, Optional, Tuple +from typing import Callable, List, Optional, Tuple import torch import torch.fx as fx from torch._inductor.pattern_matcher import (Match, PatternMatcherPass, fwd_only, register_replacement) -import vllm._custom_ops as ops import vllm.envs as envs from vllm.compilation.config import CompilationConfig from vllm.compilation.inductor_pass import InductorPass @@ -33,31 +32,31 @@ FLUX_TILE_SIZE: int = 128 -# TODO: is this right? +# TODO: is this ok? def get_world_name() -> str: return torch.distributed.group.WORLD.group_name # Note: this heuristic is unique to flux -def should_slice(shape) -> bool: +def should_slice(shape: torch.Size) -> bool: n_slices = get_tensor_model_parallel_world_size() return (shape[0] % (FLUX_TILE_SIZE * n_slices) == 0 and shape[0] >= FLUX_TILE_SIZE * n_slices) -def residual_slice_shape(residual, rank) -> int: +def residual_slice_shape(residual: torch.Tensor, rank: int) -> int: n_slices = get_tensor_model_parallel_world_size() chunk, rem = divmod(residual.shape[0], n_slices) return chunk if rank < n_slices - 1 or rem == 0 else rem -def residual_slice_shape_fake(residual, rank) -> int: +def residual_slice_shape_fake(residual: torch.Tensor, rank: int) -> int: n_slices = get_tensor_model_parallel_world_size() slices = torch.chunk(residual, n_slices, dim=0) return slices[rank].shape[0] -def get_match_gemm_rs_ag_gemm(tp_group_name: str, custom_ar: bool): +def get_match_gemm_rs_ag_gemm(tp_group_name: str, custom_ar: bool) -> Callable: def match_gemm_rs_ag_gemm( residual: torch.Tensor, @@ -69,8 +68,8 @@ def match_gemm_rs_ag_gemm( gemm_1_w_perm = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) mm_1 = torch.ops.aten.mm.default(gemm_1_activations, gemm_1_w_perm) - # It would be nice to do this instead of having two separate patterns - #all_reduce = tensor_model_parallel_all_reduce(mm_1) + # It would be nice to do this instead of having two separate patterns. + # all_reduce = tensor_model_parallel_all_reduce(mm_1) if custom_ar: all_reduce = torch.ops.vllm.outplace_all_reduce.default( mm_1, tp_group_name) @@ -98,10 +97,10 @@ def match_gemm_rs_ag_gemm( return match_gemm_rs_ag_gemm -def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_type, - gemm_1_weights: torch.Size, gemm_1_max_m: int, - gemm_2_type, gemm_2_weights: torch.Size, - gemm_2_max_m: int, tp_group_name: str): +def get_gemm_rs_ag_gemm(use_flux: bool, max_m: int, gemm_1_type: torch.dtype, + gemm_1_weights: torch.Size, gemm_2_type: torch.dtype, + gemm_2_weights: torch.Size, + tp_group_name: str) -> Callable: group = get_group_from_group_name(tp_group_name) device_group = group.device_group @@ -111,7 +110,7 @@ def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_type, gemm_rs_op = flux.GemmRS( device_group, 1, # One node - gemm_1_max_m, # M + max_m, # max M gemm_1_weights[0], # N # TODO: It would be nicer to modify flux to dispatch based on dtype # at run time, but I don't know what the downside would be. @@ -126,7 +125,7 @@ def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_type, ag_gemm_op = flux.AGKernel( device_group, 1, # One node - gemm_2_max_m, # M + max_m, # max M gemm_2_weights[0], # N gemm_2_weights[1], # K # TODO: It would be nicer to modify flux to dispatch based on dtype @@ -149,10 +148,9 @@ def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_type, gemm_1_str = str(gemm_1_type).removeprefix("torch.") gemm_2_str = str(gemm_2_type).removeprefix("torch.") group_str = tp_group_name.replace(":", "_") - name = ( - f"gemm_rs_ag_gemm_{gemm_1_str}_{gemm_1_weights[0]}_{gemm_1_max_m}_" - f"{gemm_2_str}_{gemm_2_weights[0]}_{gemm_2_weights[1]}_{gemm_2_max_m}_" - f"{group_str}") + name = (f"gemm_rs_ag_gemm_{max_m}_{gemm_1_str}_{gemm_1_weights[0]}_" + f"{gemm_2_str}_{gemm_2_weights[0]}_{gemm_2_weights[1]}_" + f"{group_str}") else: world_group_name = get_world_name() @@ -187,10 +185,10 @@ def gemm_rs_ag_gemm( gemm_1_weights.transpose(1, 0)) reduced_output = tensor_model_parallel_all_reduce(output) - ops.fused_add_rms_norm(input=reduced_output, - residual=my_residual, - weight=rms_norm_weight, - epsilon=1e-05) + torch.ops._C.fused_add_rms_norm.default(input=reduced_output, + residual=my_residual, + weight=rms_norm_weight, + epsilon=1e-05) mm_2 = torch.ops.aten.mm.default(reduced_output, gemm_2_weights.transpose(1, 0)) @@ -198,16 +196,20 @@ def gemm_rs_ag_gemm( else: output = gemm_rs(gemm_1_activations, gemm_1_weights) - ops.fused_add_rms_norm(input=output, - residual=my_residual, - weight=rms_norm_weight, - epsilon=1e-05) + torch.ops._C.fused_add_rms_norm.default(input=output, + residual=my_residual, + weight=rms_norm_weight, + epsilon=1e-05) residual_1 = residual if first_layer else old_my_residual - slice_scatter = torch.ops.aten.slice_scatter.default( - residual_1, my_residual, 0, 0, slice_shape) - split_2 = torch.ops.aten.split.Tensor(slice_scatter, slice_shape) - new_residual = split_2[0] + #if False: + #slice_scatter = torch.ops.aten.slice_scatter.default( + # residual_1, my_residual, 0, 0, slice_shape) + #split_2 = torch.ops.aten.split.Tensor(slice_scatter, slice_shape) + #new_residual = split_2[0] + #else: + slice_scatter = my_residual + new_residual = residual_1 mm_2 = ag_gemm(output, gemm_2_weights) @@ -248,7 +250,7 @@ def gemm_rs_ag_gemm_fake( return getattr(torch.ops.vllm, name).default -def get_match_final(tp_group_name: str, use_custom_ar: bool): +def get_match_final(tp_group_name: str, use_custom_ar: bool) -> Callable: def match_final( my_residual: torch.Tensor, @@ -260,7 +262,7 @@ def match_final( mm_1 = torch.ops.aten.mm.default(gemm_1_activations, gemm_1_w_perm) # TODO: it would be nice to be able to use the official api directly. - #all_reduce = tensor_model_parallel_all_reduce(mm_1) + # all_reduce = tensor_model_parallel_all_reduce(mm_1) if use_custom_ar: all_reduce = torch.ops.vllm.outplace_all_reduce.default( mm_1, tp_group_name) @@ -288,8 +290,8 @@ def match_final( def gemm_ag_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, gemm_1_activations: torch.Tensor, rms_norm_weights: torch.Tensor) -> torch.Tensor: - permute_254 = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) - mm_1 = torch.ops.aten.mm.default(gemm_1_activations, permute_254) + mm_1 = torch.ops.aten.mm.default(gemm_1_activations, + gemm_1_weights.transpose(1, 0)) reduced = tensor_model_parallel_all_reduce(mm_1) @@ -298,10 +300,10 @@ def gemm_ag_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, else: wait_tensor = my_residual - ops.fused_add_rms_norm(input=reduced, - residual=wait_tensor, - weight=rms_norm_weights, - epsilon=1e-05) + torch.ops._C.fused_add_rms_norm.default(input=reduced, + residual=wait_tensor, + weight=rms_norm_weights, + epsilon=1e-05) return reduced @@ -325,7 +327,7 @@ class CollectiveFusionPass(InductorPass): _instance: 'Optional[CollectiveFusionPass]' = None @classmethod - def instance(cls, config: CompilationConfig): + def instance(cls, config: CompilationConfig) -> "CollectiveFusionPass": """ Get the singleton instance of the CollectiveFusionPass. If the instance exists, the config is updated but @@ -358,8 +360,8 @@ def __init__(self, config): final_inputs = [x, w, resid, resid_w] # register multiple patterns for all group names. - max_gpus = 8 # TODO: get this officially - group_names = [f"tp:{rank}" for rank in range(max_gpus)] + world_size = get_tensor_model_parallel_world_size() + group_names = [f"tp:{rank}" for rank in range(world_size)] for group_name in group_names: for m in [ @@ -389,24 +391,15 @@ def record_match(self, match: Match) -> bool: # Return False to prevent automatic replacement. return False - def find_max_m(self, matches) -> Tuple[int, int]: - gemm_1_max_m = 0 - gemm_2_max_m = 0 + def find_max_m(self, matches: List[Match]) -> int: + max_m = 0 for m in matches: - #gemm_1 = m.kwargs["gemm_1_weights"].meta["val"] - #gemm_2 = m.kwargs["gemm_2_weights"].meta["val"] - #gemm_1_max_m = max(gemm_1_max_m, gemm_1.shape[1]) - #gemm_2_max_m = max(gemm_2_max_m, gemm_2.shape[1]) - gemm_1 = m.kwargs["residual"].meta["val"] - gemm_2 = m.kwargs["residual"].meta["val"] - gemm_1_max_m = max(gemm_1_max_m, gemm_1.shape[1]) - gemm_2_max_m = max(gemm_2_max_m, gemm_2.shape[1]) - - assert gemm_1_max_m > 0 - assert gemm_2_max_m > 0 - return gemm_1_max_m, gemm_2_max_m - - def process_matches(self, graph: fx.Graph): + residual = m.kwargs["residual"].meta["val"] + max_m = max(max_m, residual.shape[1]) + assert max_m > 0 + return max_m + + def process_matches(self, graph: fx.Graph) -> None: nodes = list(graph.nodes) def find_min_index(match: Match) -> int: @@ -418,8 +411,8 @@ def find_min_index(match: Match) -> int: res_replacements: List[fx.Node] = [] my_res_replacements: List[fx.Node] = [] - gemm_1_max_m, gemm_2_max_m = self.find_max_m(matches) - logger.info("max m = %d, %d", gemm_1_max_m, gemm_2_max_m) + max_m = self.find_max_m(matches) + logger.info("max m = %d", max_m) for match in matches: last_node = last_node_in_match(match) @@ -451,8 +444,8 @@ def find_min_index(match: Match) -> int: tp_group_name = ar_node.args[1] fused_gemm_func = get_gemm_rs_ag_gemm( - use_flux, gemm_1.dtype, gemm_1.shape, gemm_1_max_m, - gemm_2.dtype, gemm_2.shape, gemm_2_max_m, tp_group_name) + use_flux, max_m, gemm_1.dtype, gemm_1.shape, gemm_2.dtype, + gemm_2.shape, tp_group_name) fused_node = graph.call_function(fused_gemm_func, kwargs=kwargs) From 515f56cb4819edab4e59d3c62359ecb7c62939b7 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 22 Nov 2024 16:59:36 +0000 Subject: [PATCH 59/72] wip Signed-off-by: Bill Nell --- vllm/compilation/collective_fusion.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 0f76e7a4c1260..9fe8f369cda3b 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -172,6 +172,8 @@ def gemm_rs_ag_gemm( first_layer: bool ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + #print(f"START RESIDUAL {residual.shape}") + if first_layer and should_slice(residual.shape): slice_shape = residual_slice_shape(residual, rank) residual_chunk = torch.ops.aten.split.Tensor(residual, slice_shape) @@ -180,11 +182,15 @@ def gemm_rs_ag_gemm( my_residual = residual slice_shape = residual.shape[0] + #print(f"MY RESIDUAL {my_residual.shape}") + if not should_slice(residual.shape): output = torch.ops.aten.mm.default(gemm_1_activations, gemm_1_weights.transpose(1, 0)) reduced_output = tensor_model_parallel_all_reduce(output) + #print(f"NAIVE GEMM1 {gemm_1_activations.shape}, {gemm_1_weights.shape}, {output.shape}") + torch.ops._C.fused_add_rms_norm.default(input=reduced_output, residual=my_residual, weight=rms_norm_weight, @@ -192,27 +198,30 @@ def gemm_rs_ag_gemm( mm_2 = torch.ops.aten.mm.default(reduced_output, gemm_2_weights.transpose(1, 0)) + + #print(f"NAIVE GEMM2 {gemm_2_weights.shape}, {gemm_2_weights.shape}, {output.shape}") + return mm_2, my_residual, my_residual.clone() else: output = gemm_rs(gemm_1_activations, gemm_1_weights) + #print(f"FLUX GEMM1 {gemm_1_activations.shape}, {gemm_1_weights.shape}, {output.shape}") + torch.ops._C.fused_add_rms_norm.default(input=output, residual=my_residual, weight=rms_norm_weight, epsilon=1e-05) residual_1 = residual if first_layer else old_my_residual - #if False: - #slice_scatter = torch.ops.aten.slice_scatter.default( - # residual_1, my_residual, 0, 0, slice_shape) - #split_2 = torch.ops.aten.split.Tensor(slice_scatter, slice_shape) - #new_residual = split_2[0] - #else: - slice_scatter = my_residual - new_residual = residual_1 + slice_scatter = torch.ops.aten.slice_scatter.default( + residual_1, my_residual, 0, 0, slice_shape) + split_2 = torch.ops.aten.split.Tensor(slice_scatter, slice_shape) + new_residual = split_2[0] mm_2 = ag_gemm(output, gemm_2_weights) + #print(f"FLUX GEMM2 {gemm_2_weights.shape}, {gemm_2_weights.shape}, {output.shape}") + return mm_2[0], new_residual, slice_scatter def gemm_rs_ag_gemm_fake( @@ -300,6 +309,8 @@ def gemm_ag_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, else: wait_tensor = my_residual + #print(f"FINAL RESIDUAL {wait_tensor.shape}") + torch.ops._C.fused_add_rms_norm.default(input=reduced, residual=wait_tensor, weight=rms_norm_weights, From 5a6be3cf4f9950bd0455cd16cc4c4791e4ce4271 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 22 Nov 2024 18:02:41 +0000 Subject: [PATCH 60/72] rebase Signed-off-by: Bill Nell --- vllm/compilation/backends.py | 3 +-- vllm/compilation/collective_fusion.py | 4 ++-- vllm/config.py | 2 ++ .../distributed/device_communicators/pynccl_wrapper.py | 10 ---------- vllm/envs.py | 5 ----- 5 files changed, 5 insertions(+), 19 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index c44c53244bc27..0005ba90899fa 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -359,9 +359,8 @@ def wrap_inductor(graph, from torch._inductor import config + # Enable support for symmetric memory ops in the inductor. torch._inductor.config._micro_pipeline_tp = True - # Set to False to avoid infinite recursion logging? - torch._inductor.config.implicit_fallbacks = True current_config = config.shallow_copy_dict() from torch._inductor.compile_fx import compile_fx diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 9fe8f369cda3b..dbf9becf598f6 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -7,7 +7,7 @@ fwd_only, register_replacement) import vllm.envs as envs -from vllm.compilation.config import CompilationConfig +from vllm.config import CompilationConfig from vllm.compilation.inductor_pass import InductorPass from vllm.compilation.utils import (find_auto_fn, find_fn, find_getitem, last_node_in_match) @@ -350,7 +350,7 @@ def instance(cls, config: CompilationConfig) -> "CollectiveFusionPass": cls._instance.config = config return cls._instance - def __init__(self, config): + def __init__(self, config: CompilationConfig): assert self.__class__._instance is None, \ "FusionPass singleton instance already exists" super().__init__(config) diff --git a/vllm/config.py b/vllm/config.py index eae6f909e3933..fb9d84241f48a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2168,12 +2168,14 @@ class PassConfig(BaseModel): - dump_graph_stages: list of stages for which we want to dump the graph. Each pass defines its own stages (before, after, maybe in-between). - dump_graph_dir: directory to dump the graphs. Default is . + - enable_collective_fusion: whether to enable the custom collective communication fusion pass. - enable_fusion: whether to enable the custom fusion pass. - enable_reshape: whether to enable the custom reshape elimination pass. TODO better pass enabling system. """ dump_graph_stages: List[str] = Field(default_factory=list) dump_graph_dir: Path = Field(default=Path(".")) + enable_collective_fusion: bool = True enable_fusion: bool = True enable_reshape: bool = True diff --git a/vllm/distributed/device_communicators/pynccl_wrapper.py b/vllm/distributed/device_communicators/pynccl_wrapper.py index c5cc6b33fcce4..ff88f72470b27 100644 --- a/vllm/distributed/device_communicators/pynccl_wrapper.py +++ b/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -315,16 +315,6 @@ def ncclRecv(self, recvbuff: buffer_type, count: int, datatype: int, def ncclCommDestroy(self, comm: ncclComm_t) -> None: self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm)) - def ncclAllGather(self, sendbuff: buffer_type, recvbuff: buffer_type, - count: int, datatype: int, comm: ncclComm_t, - stream: cudaStream_t) -> None: - # `datatype` actually should be `ncclDataType_t` - # which is an aliases of `ctypes.c_int` - # when we pass int to a function, it will be converted to `ctypes.c_int` - # by ctypes automatically - self.NCCL_CHECK(self._funcs["ncclAllGather"](sendbuff, recvbuff, count, - datatype, comm, stream)) - __all__ = [ "NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId", diff --git a/vllm/envs.py b/vllm/envs.py index 4ac6767c0a3ff..e8ae19416f7f7 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -222,11 +222,6 @@ def get_default_config_root(): "VLLM_ENGINE_ITERATION_TIMEOUT_S": lambda: int(os.environ.get("VLLM_ENGINE_ITERATION_TIMEOUT_S", "60")), - # Internal flag for dumping the model graph at different stages of - # custom pass compilation - "VLLM_TORCH_COMPILE_DUMP": - lambda: list(os.environ.get("VLLM_TORCH_COMPILE_DUMP", "").split(",")), - # API key for VLLM API server "VLLM_API_KEY": lambda: os.environ.get("VLLM_API_KEY", None), From d4b0aa29ddf0c3d4d690323826c42a4f090a1070 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 22 Nov 2024 19:12:22 +0000 Subject: [PATCH 61/72] fixing Signed-off-by: Bill Nell --- vllm/compilation/backends.py | 321 -------------------------- vllm/compilation/collective_fusion.py | 4 +- vllm/compilation/pass_manager.py | 4 + vllm/config.py | 7 +- vllm/entrypoints/llm.py | 2 +- 5 files changed, 12 insertions(+), 326 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 0005ba90899fa..0ef9e7d79666b 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -12,7 +12,6 @@ from vllm.logger import init_logger from vllm.utils import weak_ref_tensors -from .collective_fusion import CollectiveFusionPass from .counter import compilation_counter from .inductor_pass import InductorPass from .pass_manager import PostGradPassManager @@ -20,326 +19,6 @@ logger = init_logger(__name__) -FILENO=0 - - -def pprint(x): - #print(x) - pass - - -# This check is a hack, copied from linear.py -def should_slice(shape) -> bool: - n_slices = get_tensor_model_parallel_world_size() - return (shape[0] % n_slices == 0 and shape[0] >= 128) - - -def match_gemm_rs_ag_gemm(residual, - #my_residual, - gemm_1_weights, - gemm_1_activations, - rms_norm_weight, - gemm_2_weights, - ): - permute_2 = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) - mm_1 = torch.ops.aten.mm.default(gemm_1_activations, permute_2) - auto_functionalized_4 = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.inplace_all_reduce.default, tensor = mm_1, group_name = 'tp:0') # how to deal with groupname? - getitem_25 = auto_functionalized_4[1] - auto_functionalized_5 = torch.ops.higher_order.auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input = getitem_25, residual = residual, weight = rms_norm_weight, epsilon = 1e-05) - getitem_27 = auto_functionalized_5[1] - getitem_28 = auto_functionalized_5[2] # new residual - permute_3 = torch.ops.aten.permute.default(gemm_2_weights, [1, 0]) - mm_2 = torch.ops.aten.mm.default(getitem_27, permute_3) - return mm_2, getitem_28 - - -def slices(residual) -> List[torch.Tensor]: - n_slices = get_tensor_model_parallel_world_size() - residual_slices = torch.chunk(residual, n_slices, dim=0) - #pprint(f"SLICES {[r.shape for r in residual_slices]}") - return residual_slices - - -#schema_str="(Tensor(a) residual, Tensor(a) my_residual, Tensor gemm_1_weights, Tensor gemm_1_activations, Tensor rms_norm_weight, Tensor gemm_2_weights, bool first_layer) -> (Tensor, Tensor, Tensor)" - -@torch.library.custom_op("vllm::gemm_rs_ag_gemm", mutates_args=())#, schema=schema_str) -def gemm_rs_ag_gemm(residual: torch.Tensor, - my_residual: torch.Tensor, - gemm_1_weights: torch.Tensor, - gemm_1_activations: torch.Tensor, - rms_norm_weight: torch.Tensor, - gemm_2_weights: torch.Tensor, - first_layer: bool) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - print(f"CUSTOM {residual.shape}({my_residual.shape}), should_slice={should_slice(residual.shape)}, first={first_layer}") - - # this is terrible - if True: - res_slices = slices(residual) - slice_size = res_slices[get_tensor_model_parallel_rank()].shape[0] - else: - slice_size = 2048 - print(f"SLICE_SIZE = {slice_size}, orig_shape={residual.shape}, slice_shapes=[{[x.shape for x in res_slices]}]") - - if should_slice(residual.shape) and first_layer: - print(f"FIRST! rank={get_tensor_model_parallel_rank()}") - split_1 = torch.ops.aten.split.Tensor(residual, slice_size) - getitem_26 = split_1[0]; split_1 = None - else: - #getitem_26 = my_residual - getitem_26 = residual - slice_size = residual.shape[0] - - if not should_slice(residual.shape): - # this branch probably broken - print("NAIVE") - permute_3 = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) - output = torch.matmul(gemm_1_activations, permute_3) - - output = tensor_model_parallel_all_reduce(output) ### - - auto_functionalized_4 = torch.ops.higher_order.auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input=output, residual=getitem_26, weight=rms_norm_weight, epsilon=1e-05) - getitem_29 = auto_functionalized_4[1] - getitem_30 = auto_functionalized_4[2] - - permute_5 = torch.ops.aten.permute.default(gemm_2_weights, [1, 0]) - getitem_35 = torch.matmul(getitem_29, permute_5) - getitem_30a = getitem_30.clone() - print(f"DONE CUSTOM NAIVE {getitem_35.shape}, {getitem_30.shape}, {getitem_30a.shape}") - return getitem_35, getitem_30, getitem_30a - else: - group_name = torch.distributed.group.WORLD.group_name # TODO: factor out to setup - permute_3 = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) - clone = torch.ops.aten.clone.default(permute_3, memory_format = torch.contiguous_format) - output = torch.ops.symm_mem.fused_matmul_reduce_scatter.default(gemm_1_activations, clone, 'avg', 0, group_name) - auto_functionalized_4 = torch.ops.higher_order.auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input=output, residual=getitem_26, weight=rms_norm_weight, epsilon=1e-05) - getitem_29 = auto_functionalized_4[1] - getitem_30 = auto_functionalized_4[2] - residual_1 = residual if first_layer else my_residual - slice_scatter_2 = torch.ops.aten.slice_scatter.default(residual_1, getitem_30, 0, 0, slice_size) - split_2 = torch.ops.aten.split.Tensor(slice_scatter_2, slice_size) - getitem_31 = split_2[0] - permute_5 = torch.ops.aten.permute.default(gemm_2_weights, [1, 0]) - clone_1 = torch.ops.aten.clone.default(permute_5, memory_format = torch.contiguous_format) - fused_all_gather_matmul = torch.ops.symm_mem.fused_all_gather_matmul.default(getitem_29, [clone_1], 0, group_name) - getitem_34 = fused_all_gather_matmul[1] - getitem_35 = getitem_34[0] - - print(f"DONE CUSTOM {getitem_35.shape}, {getitem_31.shape}, {slice_scatter_2.shape}") - return getitem_35, getitem_31.clone(), slice_scatter_2 # check if clones are needed - - -# this is wrong? do we need it? -@torch.library.register_fake("vllm::gemm_rs_ag_gemm") -def gemm_rs_ag_gemm_fake(residual: torch.Tensor, - my_residual: torch.Tensor, - gemm_1_weights: torch.Tensor, - gemm_1_activations: torch.Tensor, - rms_norm_weight: torch.Tensor, - gemm_2_weights: torch.Tensor, - first_layer: bool, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - # this is terrible - if True: - res_slices = slices(residual) - slice_size = res_slices[get_tensor_model_parallel_rank()].shape[0] # can we always use rank 0? - else: - slice_size = 2048 - - if should_slice(residual.shape) and first_layer: - print(f"FIRST! rank={get_tensor_model_parallel_rank()}") - split_1 = torch.ops.aten.split.Tensor(residual, slice_size) - my_residual = split_1[0]; split_1 = None - else: - #residual = my_residual - slice_size = residual.shape[0] - - # is this type correct? seems to be - mm_res = torch.empty((gemm_1_activations.shape[0], gemm_2_weights.shape[0]), device=gemm_1_activations.device, dtype=gemm_1_activations.dtype) #??? - - print(f"DONE FAKE = {mm_res.shape}, {my_residual.shape}, {residual.shape}") - - return (mm_res, my_residual, residual) - - -# doesn't matter, only needed for signature -def replace_gemm_rs_ag_gemm(residual, gemm_1_weights, gemm_1_activations, rms_norm_weight, gemm_2_weights): - results = torch.ops.vllm.gemm_rs_ag_gemm(residual, residual, gemm_1_weights, gemm_1_activations, rms_norm_weight, gemm_2_weights) - getitem_34 = results[0] - getitem_35 = results[1] - return getitem_34, getitem_35 - - -def match_final(arg227_1, getitem_1022, getitem_1020, arg228_1): - permute_128 = torch.ops.aten.permute.default(arg227_1, [1, 0]) - mm_127 = torch.ops.aten.mm.default(getitem_1022, permute_128) - auto_functionalized_224 = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.inplace_all_reduce.default, tensor = mm_127, group_name = 'tp:0') # TODO: not same as group name - getitem_1024 = auto_functionalized_224[1] - auto_functionalized_225 = torch.ops.higher_order.auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input = getitem_1024, residual = getitem_1020, weight = arg228_1, epsilon = 1e-05) - getitem_1026 = auto_functionalized_225[1] - return getitem_1026 - - -def replace_final(arg227_1, getitem_1215, getitem_1209, arg228_1): - tp_group_name = "tp:0" # f"tp:{group_name}" # TODO: not same as group name - - permute_254 = torch.ops.aten.permute.default(arg227_1, [1, 0]) - mm_1 = torch.ops.aten.mm.default(getitem_1215, permute_254) - auto_functionalized_161 = torch.ops.higher_order.auto_functionalized(torch.ops.vllm.inplace_all_reduce.default, tensor = mm_1, group_name = tp_group_name) - getitem_1217 = auto_functionalized_161[1] - - if should_slice(getitem_1209.shape): - group_name = torch.distributed.group.WORLD.group_name # factor out? - world_size = 2 # factor out - all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem_1209, world_size, group_name) - wait_tensor = torch.ops._c10d_functional.wait_tensor.default(all_gather_into_tensor) - else: - wait_tensor = getitem_1209 - - auto_functionalized_162 = torch.ops.higher_order.auto_functionalized(torch.ops._C.fused_add_rms_norm.default, input = getitem_1217, residual = wait_tensor, weight = arg228_1, epsilon = 1e-05) - getitem_1219 = auto_functionalized_162[1] - return getitem_1219 - - -my_patterns: Optional[PatternMatcherPass] = None -my_patterns2: Optional[PatternMatcherPass] = None -matches: List[Match] = [] - -def get_matches(): - global my_patterns, my_patterns2, matches - - def record_match_fn(match: Match): - print(f"MATCHED {len(matches)}, {id(matches)}") - matches.append(match) - return False - - if not my_patterns: - my_patterns = PatternMatcherPass() - my_patterns2 = PatternMatcherPass() - - x = torch.empty([4,4], device='cuda') - w = torch.empty([4,4], device='cuda') - resid = torch.empty([4,4], device='cuda') - resid_w = torch.empty([4,4], device='cuda') - x2 = torch.empty([4,4], device='cuda') - inputs = [resid, x, w, resid_w, x2] - - register_replacement(match_gemm_rs_ag_gemm, - replace_gemm_rs_ag_gemm, - inputs, - fwd_only, - [my_patterns], - extra_check=record_match_fn) - - final_inputs = [x, w, resid, resid_w] - register_replacement(match_final, - replace_final, - final_inputs, - fwd_only, - [my_patterns2]) - - - -# find the output and the residual -def find_fn(nodes, op): - for node in reversed(nodes): - if node.op == "call_function" and node.target == op: - return node - return None - -def find_auto_fn(nodes, op): - for node in reversed(nodes): - if node.op == "call_function" and node.target == auto_functionalized and node.args[0] == op: - return node - return None - -def find_getitem(node, idx): - for user in reversed(node.users): - if user.op == "call_function" and user.target == operator.getitem and user.args[1] == idx: - return user - return None - -def process_matches(graph: fx.Graph, matches): - print(f"len = {len(matches)}") - - nodes = list(graph.nodes) - first_match = None - - def find_min_index(match) -> int: - return min(match.nodes, key=lambda x: nodes.index(x)) - - # "sort" matches in topo order - matches = sorted(matches, key=lambda x: find_min_index(x)) - - # this is pretty hacky since the order doesn't necessarily encode the dependency. - res_replacements = [] - my_res_replacements = [] - - for match in matches: - last_node_in_match = match.nodes[-1] #max(match.nodes, key=lambda x: nodes.index(x)) - - with graph.inserting_after(last_node_in_match): - kwargs = match.kwargs - kwargs["first_layer"] = match == matches[0] - kwargs["residual"] = res_replacements[-1] if len(res_replacements) > 0 else match.kwargs["residual"] - kwargs["my_residual"] = my_res_replacements[-1] if len(my_res_replacements) > 0 else match.kwargs["residual"] - fused_node = graph.call_function(torch.ops.vllm.gemm_rs_ag_gemm.default, kwargs=kwargs) - - graph.inserting_after(fused_node) - result_node_new = graph.call_function(operator.getitem, (fused_node, 0)) - residual_node_new = graph.call_function(operator.getitem, (fused_node, 1)) - my_residual_node_new = graph.call_function(operator.getitem, (fused_node, 2)) - res_replacements.append(residual_node_new) - my_res_replacements.append(my_residual_node_new) - - rms_node = find_auto_fn(match.nodes, torch.ops._C.fused_add_rms_norm.default) - gemm_node = find_fn(match.nodes, torch.ops.aten.mm.default) - if gemm_node is None: - gemm_node = find_fn(match.nodes, torch.ops.symm_mem.fused_all_gather_matmul.default) - assert rms_node is not None - assert gemm_node is not None - - #assert len(rms_node.users) == 2 - #assert len(gemm_node.users) == 1 - - # meta["val"] is used by de-functionalization - rms_val = rms_node.meta["val"] - gemm_val = gemm_node.meta["val"] - fused_node.meta["val"] = (gemm_val, rms_val[2]) - - find_getitem(rms_node, 2).replace_all_uses_with(residual_node_new) - gemm_node.replace_all_uses_with(result_node_new) - - # Finally, remove matched nodes - graph.eliminate_dead_code() - assert all(node not in graph.nodes for match in matches for node in match.nodes) - - -def dump_graph(graph: torch.fx.Graph, stage: str): - logger.info("Printing graph to %s", f"{stage}.py") - with open(f"{stage}.py", "w") as f: - print(graph.python_code(root_module="self", verbose=True).src, file=f) - - -def async_rewrite(graph: fx.Graph): - global matches - rank = get_tensor_model_parallel_rank() - get_matches() - matches.clear() - - count = my_patterns.apply(graph) - print(f"fused gemm match count = {len(matches)} {id(matches)}") - - # a bit hacky - if len(matches) > 0: - print("FINAL MATCH") - count = my_patterns2.apply(graph) - print(f"final match count = {count}") - print("FINAL MATCH DONE") - process_matches(graph, matches) - - return graph - - def wrap_inductor(graph, example_inputs, additional_inductor_config: Optional[Dict] = None, diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index dbf9becf598f6..7adfe93bdbb5a 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -8,7 +8,7 @@ import vllm.envs as envs from vllm.config import CompilationConfig -from vllm.compilation.inductor_pass import InductorPass +from .vllm_inductor_pass import VllmInductorPass from vllm.compilation.utils import (find_auto_fn, find_fn, find_getitem, last_node_in_match) from vllm.distributed import (tensor_model_parallel_all_gather, @@ -333,7 +333,7 @@ def gemm_ag_final_fake(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, fake_impl=gemm_ag_final_fake) -class CollectiveFusionPass(InductorPass): +class CollectiveFusionPass(VllmInductorPass): _instance: 'Optional[CollectiveFusionPass]' = None diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py index fb522ae053e97..833d177624af0 100644 --- a/vllm/compilation/pass_manager.py +++ b/vllm/compilation/pass_manager.py @@ -5,6 +5,7 @@ from vllm.config import CompilationConfig from vllm.logger import init_logger +from .collective_fusion import CollectiveFusionPass from .fix_functionalization import FixFunctionalizationPass from .fusion import FusionPass from .inductor_pass import InductorPass @@ -47,6 +48,9 @@ def configure(self, pass_config: CompilationConfig.PassConfig): if pass_config.enable_fusion: self.passes += [FusionPass.instance(pass_config)] + if pass_config.enable_collective_fusion: + self.passes += [CollectiveFusionPass.instance(pass_config)] + self.fix_functionalization = FixFunctionalizationPass(pass_config) def add(self, pass_: InductorPass): diff --git a/vllm/config.py b/vllm/config.py index fb9d84241f48a..70e5cc43f4a09 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2168,7 +2168,8 @@ class PassConfig(BaseModel): - dump_graph_stages: list of stages for which we want to dump the graph. Each pass defines its own stages (before, after, maybe in-between). - dump_graph_dir: directory to dump the graphs. Default is . - - enable_collective_fusion: whether to enable the custom collective communication fusion pass. + - enable_collective_fusion: whether to enable the custom collective + communication fusion pass. - enable_fusion: whether to enable the custom fusion pass. - enable_reshape: whether to enable the custom reshape elimination pass. TODO better pass enabling system. @@ -2187,7 +2188,8 @@ def uuid(self): compilation. """ dict_ = self.model_dump( - include={"enable_fusion", "enable_reshape"}) + include={"enable_collective_fusion", "enable_fusion", + "enable_reshape"}) encoded = json.dumps(dict_, sort_keys=True).encode("utf-8") return hashlib.sha256(encoded).digest() @@ -2400,6 +2402,7 @@ def __post_init__(self): self.compilation_config.custom_ops = ["none"] self.compilation_config.use_cudagraph = True self.compilation_config.use_inductor = True + self.compilation_config.pass_config.enable_collective_fusion = False self.compilation_config.pass_config.enable_fusion = False self.compilation_config.pass_config.enable_reshape = False self.compilation_config.level = CompilationLevel.PIECEWISE diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 1551a9a998160..7021e63bc9d40 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -171,7 +171,7 @@ def __init__( # After positional args are removed, move this right below `model` task: TaskOption = "auto", override_pooler_config: Optional[PoolerConfig] = None, - compilation_config: Optional[Union[int, Dict[str, Any]]] = None, + compilation_config: Optional[Union[int, Dict[str, Any], CompilationConfig]] = None, **kwargs, ) -> None: ''' From 2c15cd36746aef4f10a9921d8189541a97f66122 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 25 Nov 2024 16:45:18 +0000 Subject: [PATCH 62/72] fix merge problems. make dump graph nicer Signed-off-by: Bill Nell --- vllm/compilation/backends.py | 13 ++++++-- vllm/compilation/collective_fusion.py | 43 ++++++++++---------------- vllm/compilation/utils.py | 37 +++++++++++++++++++++- vllm/compilation/vllm_inductor_pass.py | 19 ++---------- vllm/config.py | 6 ++-- vllm/entrypoints/llm.py | 3 +- 6 files changed, 69 insertions(+), 52 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 0ef9e7d79666b..8ced57e21f06c 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -15,14 +15,15 @@ from .counter import compilation_counter from .inductor_pass import InductorPass from .pass_manager import PostGradPassManager +from .utils import dump_graph logger = init_logger(__name__) -def wrap_inductor(graph, - example_inputs, +def wrap_inductor(graph: fx.GraphModule, + example_inputs: Sequence[Any], additional_inductor_config: Optional[Dict] = None, - do_logging=False, + do_logging: bool = False, runtime_shape: Optional[int] = None, use_inductor: bool = True): if not use_inductor: @@ -252,9 +253,15 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: self.compilation_configs.init_during_runtime() self.configure_post_pass() + dump_graph(self.compilation_configs.pass_config, graph.graph, + "before_split_graph") + self.split_gm, self.piecewise_graphs = split_graph( graph, self.compilation_configs.splitting_ops) + dump_graph(self.compilation_configs.pass_config, graph.graph, + "after_split_graph") + from torch._dynamo.utils import lazy_format_graph_code logger.debug("%s", lazy_format_graph_code("before split", self.graph)) logger.debug("%s", lazy_format_graph_code("after split", diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 7adfe93bdbb5a..7ba64d4f7ea0f 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -7,10 +7,9 @@ fwd_only, register_replacement) import vllm.envs as envs -from vllm.config import CompilationConfig -from .vllm_inductor_pass import VllmInductorPass from vllm.compilation.utils import (find_auto_fn, find_fn, find_getitem, last_node_in_match) +from vllm.config import CompilationConfig from vllm.distributed import (tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) from vllm.distributed.parallel_state import ( @@ -18,6 +17,8 @@ from vllm.logger import init_logger from vllm.utils import direct_register_custom_op +from .vllm_inductor_pass import VllmInductorPass + logger = init_logger(__name__) use_flux = False @@ -62,7 +63,7 @@ def match_gemm_rs_ag_gemm( residual: torch.Tensor, gemm_1_weights: torch.Tensor, gemm_1_activations: torch.Tensor, - rms_norm_weight: torch.Tensor, + rms_norm_weights: torch.Tensor, gemm_2_weights: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: gemm_1_w_perm = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) @@ -84,7 +85,7 @@ def match_gemm_rs_ag_gemm( torch.ops._C.fused_add_rms_norm.default, input=all_reduce, residual=residual, - weight=rms_norm_weight, + weight=rms_norm_weights, epsilon=1e-05) normalized = norm_res[1] new_residual = norm_res[2] @@ -168,12 +169,10 @@ def get_gemm_rs_ag_gemm(use_flux: bool, max_m: int, gemm_1_type: torch.dtype, def gemm_rs_ag_gemm( residual: torch.Tensor, old_my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, gemm_1_activations: torch.Tensor, - rms_norm_weight: torch.Tensor, gemm_2_weights: torch.Tensor, + rms_norm_weights: torch.Tensor, gemm_2_weights: torch.Tensor, first_layer: bool ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - #print(f"START RESIDUAL {residual.shape}") - if first_layer and should_slice(residual.shape): slice_shape = residual_slice_shape(residual, rank) residual_chunk = torch.ops.aten.split.Tensor(residual, slice_shape) @@ -182,34 +181,26 @@ def gemm_rs_ag_gemm( my_residual = residual slice_shape = residual.shape[0] - #print(f"MY RESIDUAL {my_residual.shape}") - if not should_slice(residual.shape): output = torch.ops.aten.mm.default(gemm_1_activations, gemm_1_weights.transpose(1, 0)) reduced_output = tensor_model_parallel_all_reduce(output) - #print(f"NAIVE GEMM1 {gemm_1_activations.shape}, {gemm_1_weights.shape}, {output.shape}") - torch.ops._C.fused_add_rms_norm.default(input=reduced_output, residual=my_residual, - weight=rms_norm_weight, + weight=rms_norm_weights, epsilon=1e-05) mm_2 = torch.ops.aten.mm.default(reduced_output, gemm_2_weights.transpose(1, 0)) - #print(f"NAIVE GEMM2 {gemm_2_weights.shape}, {gemm_2_weights.shape}, {output.shape}") - return mm_2, my_residual, my_residual.clone() else: output = gemm_rs(gemm_1_activations, gemm_1_weights) - #print(f"FLUX GEMM1 {gemm_1_activations.shape}, {gemm_1_weights.shape}, {output.shape}") - torch.ops._C.fused_add_rms_norm.default(input=output, residual=my_residual, - weight=rms_norm_weight, + weight=rms_norm_weights, epsilon=1e-05) residual_1 = residual if first_layer else old_my_residual @@ -220,8 +211,6 @@ def gemm_rs_ag_gemm( mm_2 = ag_gemm(output, gemm_2_weights) - #print(f"FLUX GEMM2 {gemm_2_weights.shape}, {gemm_2_weights.shape}, {output.shape}") - return mm_2[0], new_residual, slice_scatter def gemm_rs_ag_gemm_fake( @@ -229,7 +218,7 @@ def gemm_rs_ag_gemm_fake( my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, gemm_1_activations: torch.Tensor, - rms_norm_weight: torch.Tensor, + rms_norm_weights: torch.Tensor, gemm_2_weights: torch.Tensor, first_layer: bool, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -309,8 +298,6 @@ def gemm_ag_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, else: wait_tensor = my_residual - #print(f"FINAL RESIDUAL {wait_tensor.shape}") - torch.ops._C.fused_add_rms_norm.default(input=reduced, residual=wait_tensor, weight=rms_norm_weights, @@ -362,11 +349,11 @@ def __init__(self, config: CompilationConfig): # Run in fake mode so that we don't call real functions # when tracing the patterns. with torch._dynamo.utils.detect_fake_mode(): - x = torch.empty([4, 4], device='cuda') - w = torch.empty([4, 4], device='cuda') - resid = torch.empty([4, 4], device='cuda') - resid_w = torch.empty([4, 4], device='cuda') - x2 = torch.empty([4, 4], device='cuda') + x = torch.empty([4, 4], device='cuda', dtype=torch.float16) + w = torch.empty([4, 4], device='cuda', dtype=torch.float16) + resid = torch.empty([4, 4], device='cuda', dtype=torch.float16) + resid_w = torch.empty([4, 4], device='cuda', dtype=torch.float16) + x2 = torch.empty([4, 4], device='cuda', dtype=torch.float16) inputs = [resid, x, w, resid_w, x2] final_inputs = [x, w, resid, resid_w] @@ -492,6 +479,8 @@ def find_min_index(match: Match) -> int: for node in match.nodes) def __call__(self, graph: fx.Graph): + # TODO: disable if chunk prefill size is too small + # or when doing decode. self.dump_graph(graph, "before_collective_fusion") count = self.gemm_rs_ag_gemm_pattern.apply(graph) logger.info("fused gemm match count = %d", len(self.matches)) diff --git a/vllm/compilation/utils.py b/vllm/compilation/utils.py index 12a84db22e877..82a6010393b64 100644 --- a/vllm/compilation/utils.py +++ b/vllm/compilation/utils.py @@ -1,10 +1,25 @@ import operator -from typing import Iterable, Optional +import os +from typing import Dict, Iterable, Optional import torch.fx as fx from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._inductor.pattern_matcher import Match +from vllm.config import CompilationConfig +# yapf: disable +from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank +from vllm.distributed import ( + get_tensor_model_parallel_world_size as get_tp_world_size) +from vllm.distributed import model_parallel_is_initialized as p_is_init +from vllm.logger import init_logger + +# yapf: enable + +logger = init_logger(__name__) + +COUNTS: Dict[str, int] = {} + def find_fn(nodes: Iterable[fx.Node], op) -> Optional[fx.Node]: for node in nodes: @@ -36,3 +51,23 @@ def last_node_in_match(match: Match) -> fx.Node: if n in reversed(match.nodes): return n raise ValueError("No nodes in graph") + + +def dump_graph(config: CompilationConfig.PassConfig, graph: fx.Graph, + name: str) -> None: + global COUNTS + count = COUNTS.get(name, 0) + + # Make sure filename includes rank in the distributed setting + parallel = p_is_init() and get_tp_world_size() > 1 + rank = f"-{get_tp_rank()}" if parallel else "" + filepath = config.dump_graph_dir / f"{name}{rank}-{count}.py" + COUNTS[name] = count + 1 + + os.makedirs(config.dump_graph_dir, exist_ok=True) + logger.info("%s printing graph to %s", name, filepath) + with open(filepath, "w") as f: + src = graph.owning_module.print_readable(print_output=False) + # Add imports so it's not full of errors + print("import torch; from torch import device", file=f) + print(src, file=f) diff --git a/vllm/compilation/vllm_inductor_pass.py b/vllm/compilation/vllm_inductor_pass.py index dbf6b8f7789e1..e43a95863eee2 100644 --- a/vllm/compilation/vllm_inductor_pass.py +++ b/vllm/compilation/vllm_inductor_pass.py @@ -3,15 +3,10 @@ import torch from vllm.config import CompilationConfig -# yapf: disable -from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank -from vllm.distributed import ( - get_tensor_model_parallel_world_size as get_tp_world_size) -from vllm.distributed import model_parallel_is_initialized as p_is_init -# yapf: enable from vllm.logger import init_logger from .inductor_pass import InductorPass +from .utils import dump_graph as utils_dump_graph logger = init_logger(__name__) @@ -32,17 +27,7 @@ def __init__(self, config: CompilationConfig.PassConfig): def dump_graph(self, graph: torch.fx.Graph, stage: str): if stage in self.config.dump_graph_stages: - # Make sure filename includes rank in the distributed setting - parallel = p_is_init() and get_tp_world_size() > 1 - rank = f"-{get_tp_rank()}" if parallel else "" - filepath = self.config.dump_graph_dir / f"{stage}{rank}.py" - - logger.info("%s printing graph to %s", self.pass_name, filepath) - with open(filepath, "w") as f: - src = graph.python_code(root_module="self", verbose=True).src - # Add imports so it's not full of errors - print("import torch; from torch import device", file=f) - print(src, file=f) + utils_dump_graph(self.config, graph, stage) def begin(self): self._start_time = time.perf_counter_ns() diff --git a/vllm/config.py b/vllm/config.py index 70e5cc43f4a09..e5e5520f5228b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2187,9 +2187,9 @@ def uuid(self): Do not include dump_graph_* in the hash - they don't affect compilation. """ - dict_ = self.model_dump( - include={"enable_collective_fusion", "enable_fusion", - "enable_reshape"}) + dict_ = self.model_dump(include={ + "enable_collective_fusion", "enable_fusion", "enable_reshape" + }) encoded = json.dumps(dict_, sort_keys=True).encode("utf-8") return hashlib.sha256(encoded).digest() diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 7021e63bc9d40..d860242d4674e 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -171,7 +171,8 @@ def __init__( # After positional args are removed, move this right below `model` task: TaskOption = "auto", override_pooler_config: Optional[PoolerConfig] = None, - compilation_config: Optional[Union[int, Dict[str, Any], CompilationConfig]] = None, + compilation_config: Optional[Union[int, Dict[str, Any], + CompilationConfig]] = None, **kwargs, ) -> None: ''' From da18a92a281faa7c777e78c7ad12a8609cc4147a Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 25 Nov 2024 19:00:13 +0000 Subject: [PATCH 63/72] disable collective fusion when chunk size is too small Signed-off-by: Bill Nell --- vllm/compilation/collective_fusion.py | 23 +++++++---------------- vllm/compilation/utils.py | 23 +++++++++++++++-------- vllm/config.py | 9 +++++++++ 3 files changed, 31 insertions(+), 24 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 7ba64d4f7ea0f..bc5a98606cc1d 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -8,7 +8,7 @@ import vllm.envs as envs from vllm.compilation.utils import (find_auto_fn, find_fn, find_getitem, - last_node_in_match) + last_node_in_match, use_cc_kernels) from vllm.config import CompilationConfig from vllm.distributed import (tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) @@ -26,25 +26,16 @@ try: import flux use_flux = True - logger.info("USING FLUX") + logger.info("Using flux kernels for collective communication fusion.") except ImportError: + logger.info("Attempting to use flux but flux not installed.") use_flux = False -FLUX_TILE_SIZE: int = 128 - -# TODO: is this ok? def get_world_name() -> str: return torch.distributed.group.WORLD.group_name -# Note: this heuristic is unique to flux -def should_slice(shape: torch.Size) -> bool: - n_slices = get_tensor_model_parallel_world_size() - return (shape[0] % (FLUX_TILE_SIZE * n_slices) == 0 - and shape[0] >= FLUX_TILE_SIZE * n_slices) - - def residual_slice_shape(residual: torch.Tensor, rank: int) -> int: n_slices = get_tensor_model_parallel_world_size() chunk, rem = divmod(residual.shape[0], n_slices) @@ -173,7 +164,7 @@ def gemm_rs_ag_gemm( first_layer: bool ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - if first_layer and should_slice(residual.shape): + if first_layer and use_cc_kernels(residual.shape[0]): slice_shape = residual_slice_shape(residual, rank) residual_chunk = torch.ops.aten.split.Tensor(residual, slice_shape) my_residual = residual_chunk[0] @@ -181,7 +172,7 @@ def gemm_rs_ag_gemm( my_residual = residual slice_shape = residual.shape[0] - if not should_slice(residual.shape): + if not use_cc_kernels(residual.shape[0]): output = torch.ops.aten.mm.default(gemm_1_activations, gemm_1_weights.transpose(1, 0)) reduced_output = tensor_model_parallel_all_reduce(output) @@ -222,7 +213,7 @@ def gemm_rs_ag_gemm_fake( gemm_2_weights: torch.Tensor, first_layer: bool, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - if first_layer and should_slice(gemm_1_activations.shape): + if first_layer and use_cc_kernels(gemm_1_activations.shape[0]): slice_shape = residual_slice_shape_fake(residual, rank) split_1 = torch.ops.aten.split.Tensor(residual, slice_shape) my_residual = split_1[0] @@ -293,7 +284,7 @@ def gemm_ag_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, reduced = tensor_model_parallel_all_reduce(mm_1) - if should_slice(gemm_1_activations.shape): + if use_cc_kernels(gemm_1_activations.shape[0]): wait_tensor = tensor_model_parallel_all_gather(my_residual) else: wait_tensor = my_residual diff --git a/vllm/compilation/utils.py b/vllm/compilation/utils.py index 82a6010393b64..60d7ab7c55cab 100644 --- a/vllm/compilation/utils.py +++ b/vllm/compilation/utils.py @@ -6,20 +6,20 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._inductor.pattern_matcher import Match -from vllm.config import CompilationConfig -# yapf: disable from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank from vllm.distributed import ( get_tensor_model_parallel_world_size as get_tp_world_size) from vllm.distributed import model_parallel_is_initialized as p_is_init from vllm.logger import init_logger -# yapf: enable - logger = init_logger(__name__) COUNTS: Dict[str, int] = {} +# Depends on arch, see auto_tile_shape in include/flux/gemm_hparams.h +# Can be 256 on sm80. +FLUX_TILE_SIZE: int = 128 + def find_fn(nodes: Iterable[fx.Node], op) -> Optional[fx.Node]: for node in nodes: @@ -53,21 +53,28 @@ def last_node_in_match(match: Match) -> fx.Node: raise ValueError("No nodes in graph") -def dump_graph(config: CompilationConfig.PassConfig, graph: fx.Graph, - name: str) -> None: +def dump_graph(pass_config, graph: fx.Graph, name: str) -> None: global COUNTS count = COUNTS.get(name, 0) # Make sure filename includes rank in the distributed setting parallel = p_is_init() and get_tp_world_size() > 1 rank = f"-{get_tp_rank()}" if parallel else "" - filepath = config.dump_graph_dir / f"{name}{rank}-{count}.py" + filepath = pass_config.dump_graph_dir / f"{name}{rank}-{count}.py" COUNTS[name] = count + 1 - os.makedirs(config.dump_graph_dir, exist_ok=True) + os.makedirs(pass_config.dump_graph_dir, exist_ok=True) logger.info("%s printing graph to %s", name, filepath) with open(filepath, "w") as f: src = graph.owning_module.print_readable(print_output=False) # Add imports so it's not full of errors print("import torch; from torch import device", file=f) print(src, file=f) + + +# Note: this heuristic is unique to flux +def use_cc_kernels(m_shape: int, n_slices: Optional[int] = None) -> bool: + if n_slices is None: + n_slices = get_tp_world_size() + return (m_shape % (FLUX_TILE_SIZE * n_slices) == 0 + and m_shape >= FLUX_TILE_SIZE * n_slices) diff --git a/vllm/config.py b/vllm/config.py index e5e5520f5228b..087c8903b765e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -16,6 +16,7 @@ import vllm.envs as envs from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass +from vllm.compilation.utils import use_cc_kernels from vllm.logger import init_logger from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS, get_quantization_config) @@ -2421,6 +2422,14 @@ def __post_init__(self): "Disabling `torch.compile`.") self.compilation_config.level = CompilationLevel.NO_COMPILATION + n_slices = self.parallel_config.world_size + max_tokens = self.scheduler_config.max_num_batched_tokens + if not use_cc_kernels(max_tokens / n_slices, n_slices): + logger.info( + ("Disabling collective fusion pass since chunked prefill size " + "%d is too small."), max_tokens) + self.compilation_config.pass_config.enable_collective_fusion = False + current_platform.check_and_update_config(self) def __str__(self): From bead129de7cd5f8e19ade280a8cb640390555e32 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 25 Nov 2024 19:04:16 +0000 Subject: [PATCH 64/72] fix mypy Signed-off-by: Bill Nell --- vllm/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/config.py b/vllm/config.py index 087c8903b765e..fac1f2ba4f6dc 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2424,7 +2424,7 @@ def __post_init__(self): n_slices = self.parallel_config.world_size max_tokens = self.scheduler_config.max_num_batched_tokens - if not use_cc_kernels(max_tokens / n_slices, n_slices): + if not use_cc_kernels(int(max_tokens / n_slices), n_slices): logger.info( ("Disabling collective fusion pass since chunked prefill size " "%d is too small."), max_tokens) From 72953ccfe8d3398b71176fa04e5439346f260e82 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 25 Nov 2024 19:13:09 +0000 Subject: [PATCH 65/72] fix yapf Signed-off-by: Bill Nell --- vllm/compilation/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/compilation/utils.py b/vllm/compilation/utils.py index 60d7ab7c55cab..7cef1d4225967 100644 --- a/vllm/compilation/utils.py +++ b/vllm/compilation/utils.py @@ -7,8 +7,10 @@ from torch._inductor.pattern_matcher import Match from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank +# yapf: disable from vllm.distributed import ( get_tensor_model_parallel_world_size as get_tp_world_size) +# yapf: enable from vllm.distributed import model_parallel_is_initialized as p_is_init from vllm.logger import init_logger From 8724fabf4ce18e75cead543903f93591dcb1f15a Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 25 Nov 2024 22:20:59 +0000 Subject: [PATCH 66/72] disable collective fusion if TP is not on Signed-off-by: Bill Nell --- vllm/compilation/backends.py | 11 ++++++++++- vllm/compilation/collective_fusion.py | 6 +++++- vllm/compilation/utils.py | 3 ++- vllm/envs.py | 2 +- 4 files changed, 18 insertions(+), 4 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 8ced57e21f06c..7613ea6870b4f 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -26,6 +26,9 @@ def wrap_inductor(graph: fx.GraphModule, do_logging: bool = False, runtime_shape: Optional[int] = None, use_inductor: bool = True): + + print(f"WRAP_INDUCTOR {graph}") + if not use_inductor: return graph @@ -150,12 +153,15 @@ def call_module(self, target: torch.fx.node.Target, assert isinstance(target, str) output = super().call_module(target, args, kwargs) + print(f"TARGET {target}") + if target in self.compile_submod_names: index = self.compile_submod_names.index(target) submod = self.fetch_attr(target) sym_shape_indices = [ i for i, x in enumerate(args) if isinstance(x, torch.SymInt) ] + print(f"COMPILE {target}") compiled_graph_for_general_shape = wrap_inductor( submod, args, @@ -259,7 +265,7 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: self.split_gm, self.piecewise_graphs = split_graph( graph, self.compilation_configs.splitting_ops) - dump_graph(self.compilation_configs.pass_config, graph.graph, + dump_graph(self.compilation_configs.pass_config, self.split_gm.graph, "after_split_graph") from torch._dynamo.utils import lazy_format_graph_code @@ -274,6 +280,8 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: if not item.is_splitting_graph ] + print(f"submod_names_to_compile = {submod_names_to_compile}") + # propagate the split graph to the piecewise backend, # compile submodules with symbolic shapes PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile, @@ -413,6 +421,7 @@ def __call__(self, *args) -> Any: if entry.need_to_compile and not entry.compiled: entry.compiled = True # args are real arguments + print(f"COMPILE ENTRY") entry.runnable = wrap_inductor( self.graph, args, diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index bc5a98606cc1d..35c3591fb9496 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -470,10 +470,14 @@ def find_min_index(match: Match) -> int: for node in match.nodes) def __call__(self, graph: fx.Graph): + if not (model_parallel_is_initialized() and + get_tensor_model_parallel_world_size() > 1): + return + # TODO: disable if chunk prefill size is too small # or when doing decode. self.dump_graph(graph, "before_collective_fusion") - count = self.gemm_rs_ag_gemm_pattern.apply(graph) + self.gemm_rs_ag_gemm_pattern.apply(graph) logger.info("fused gemm match count = %d", len(self.matches)) # Don't apply final pattern unless we've matched and replaced the diff --git a/vllm/compilation/utils.py b/vllm/compilation/utils.py index 7cef1d4225967..95af7be94e376 100644 --- a/vllm/compilation/utils.py +++ b/vllm/compilation/utils.py @@ -1,5 +1,6 @@ import operator import os +from pathlib import Path from typing import Dict, Iterable, Optional import torch.fx as fx @@ -62,7 +63,7 @@ def dump_graph(pass_config, graph: fx.Graph, name: str) -> None: # Make sure filename includes rank in the distributed setting parallel = p_is_init() and get_tp_world_size() > 1 rank = f"-{get_tp_rank()}" if parallel else "" - filepath = pass_config.dump_graph_dir / f"{name}{rank}-{count}.py" + filepath = Path(pass_config.dump_graph_dir) / f"{name}{rank}-{count}.py" COUNTS[name] = count + 1 os.makedirs(pass_config.dump_graph_dir, exist_ok=True) diff --git a/vllm/envs.py b/vllm/envs.py index e8ae19416f7f7..852f101be2758 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -461,7 +461,7 @@ def get_default_config_root(): # If set, try to use the flux fused collective communication gemm kernels. "VLLM_USE_FLUX": - lambda: bool(int(os.getenv("VLLM_USE_FLUX", "0"))), + lambda: bool(int(os.getenv("VLLM_USE_FLUX", "1"))), } # end-env-vars-definition From ec07de114745caaaad0a6be506a2dffdf0ab7d2d Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 25 Nov 2024 22:22:28 +0000 Subject: [PATCH 67/72] remove cruft Signed-off-by: Bill Nell --- vllm/compilation/backends.py | 9 --------- vllm/compilation/collective_fusion.py | 7 ++++--- vllm/compilation/utils.py | 4 ++-- 3 files changed, 6 insertions(+), 14 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 7613ea6870b4f..60599f297e97a 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -26,9 +26,6 @@ def wrap_inductor(graph: fx.GraphModule, do_logging: bool = False, runtime_shape: Optional[int] = None, use_inductor: bool = True): - - print(f"WRAP_INDUCTOR {graph}") - if not use_inductor: return graph @@ -153,15 +150,12 @@ def call_module(self, target: torch.fx.node.Target, assert isinstance(target, str) output = super().call_module(target, args, kwargs) - print(f"TARGET {target}") - if target in self.compile_submod_names: index = self.compile_submod_names.index(target) submod = self.fetch_attr(target) sym_shape_indices = [ i for i, x in enumerate(args) if isinstance(x, torch.SymInt) ] - print(f"COMPILE {target}") compiled_graph_for_general_shape = wrap_inductor( submod, args, @@ -280,8 +274,6 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: if not item.is_splitting_graph ] - print(f"submod_names_to_compile = {submod_names_to_compile}") - # propagate the split graph to the piecewise backend, # compile submodules with symbolic shapes PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile, @@ -421,7 +413,6 @@ def __call__(self, *args) -> Any: if entry.need_to_compile and not entry.compiled: entry.compiled = True # args are real arguments - print(f"COMPILE ENTRY") entry.runnable = wrap_inductor( self.graph, args, diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 35c3591fb9496..79b01af999db1 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -10,7 +10,8 @@ from vllm.compilation.utils import (find_auto_fn, find_fn, find_getitem, last_node_in_match, use_cc_kernels) from vllm.config import CompilationConfig -from vllm.distributed import (tensor_model_parallel_all_gather, +from vllm.distributed import (model_parallel_is_initialized, + tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) from vllm.distributed.parallel_state import ( get_group_from_group_name, get_tensor_model_parallel_world_size) @@ -470,8 +471,8 @@ def find_min_index(match: Match) -> int: for node in match.nodes) def __call__(self, graph: fx.Graph): - if not (model_parallel_is_initialized() and - get_tensor_model_parallel_world_size() > 1): + if not (model_parallel_is_initialized() + and get_tensor_model_parallel_world_size() > 1): return # TODO: disable if chunk prefill size is too small diff --git a/vllm/compilation/utils.py b/vllm/compilation/utils.py index 95af7be94e376..195e7c8e812fe 100644 --- a/vllm/compilation/utils.py +++ b/vllm/compilation/utils.py @@ -7,11 +7,11 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._inductor.pattern_matcher import Match -from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank +# yapf: enable # yapf: disable +from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank from vllm.distributed import ( get_tensor_model_parallel_world_size as get_tp_world_size) -# yapf: enable from vllm.distributed import model_parallel_is_initialized as p_is_init from vllm.logger import init_logger From 6e26b9a5f4c9869091367cf989b428063e830b25 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 25 Nov 2024 23:52:33 +0000 Subject: [PATCH 68/72] disable collective fusion pass if TP is not enabled Signed-off-by: Bill Nell --- vllm/compilation/collective_fusion.py | 7 +------ vllm/config.py | 21 ++++++++++++++------- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 79b01af999db1..b3ba5b16f3b37 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -10,8 +10,7 @@ from vllm.compilation.utils import (find_auto_fn, find_fn, find_getitem, last_node_in_match, use_cc_kernels) from vllm.config import CompilationConfig -from vllm.distributed import (model_parallel_is_initialized, - tensor_model_parallel_all_gather, +from vllm.distributed import (tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) from vllm.distributed.parallel_state import ( get_group_from_group_name, get_tensor_model_parallel_world_size) @@ -471,10 +470,6 @@ def find_min_index(match: Match) -> int: for node in match.nodes) def __call__(self, graph: fx.Graph): - if not (model_parallel_is_initialized() - and get_tensor_model_parallel_world_size() > 1): - return - # TODO: disable if chunk prefill size is too small # or when doing decode. self.dump_graph(graph, "before_collective_fusion") diff --git a/vllm/config.py b/vllm/config.py index fac1f2ba4f6dc..c59d22f0799ed 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2422,13 +2422,20 @@ def __post_init__(self): "Disabling `torch.compile`.") self.compilation_config.level = CompilationLevel.NO_COMPILATION - n_slices = self.parallel_config.world_size - max_tokens = self.scheduler_config.max_num_batched_tokens - if not use_cc_kernels(int(max_tokens / n_slices), n_slices): - logger.info( - ("Disabling collective fusion pass since chunked prefill size " - "%d is too small."), max_tokens) - self.compilation_config.pass_config.enable_collective_fusion = False + if self.compilation_config.pass_config.enable_collective_fusion: + n_slices = self.parallel_config.world_size + max_tokens = self.scheduler_config.max_num_batched_tokens + if not use_cc_kernels(int(max_tokens / n_slices), n_slices): + logger.info( + ("Disabling collective fusion pass since chunked prefill " + "size %d is too small."), max_tokens) + self.compilation_config.pass_config.enable_collective_fusion = \ + False + if n_slices == 1: + logger.info("Disabling collective fusion pass since tensor " + "parallelism is not enabled.") + self.compilation_config.pass_config.enable_collective_fusion = \ + False current_platform.check_and_update_config(self) From f69ae533370602c141d6971a1f3907af4cc1c70e Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 26 Nov 2024 19:48:51 +0000 Subject: [PATCH 69/72] wip Signed-off-by: Bill Nell --- vllm/compilation/backends.py | 10 ++++++---- vllm/compilation/collective_fusion.py | 28 ++++++++++++++++++++------- 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 60599f297e97a..65c8d84f841c3 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -253,14 +253,16 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: self.compilation_configs.init_during_runtime() self.configure_post_pass() - dump_graph(self.compilation_configs.pass_config, graph.graph, - "before_split_graph") + if "before_split_graph" in self.compilation_configs.pass_config.dump_graph_stages: + dump_graph(self.compilation_configs.pass_config, graph.graph, + "before_split_graph") self.split_gm, self.piecewise_graphs = split_graph( graph, self.compilation_configs.splitting_ops) - dump_graph(self.compilation_configs.pass_config, self.split_gm.graph, - "after_split_graph") + if "after_split_graph" in self.compilation_configs.pass_config.dump_graph_stages: + dump_graph(self.compilation_configs.pass_config, self.split_gm.graph, + "after_split_graph") from torch._dynamo.utils import lazy_format_graph_code logger.debug("%s", lazy_format_graph_code("before split", self.graph)) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index b3ba5b16f3b37..4c83238389cc1 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -8,7 +8,7 @@ import vllm.envs as envs from vllm.compilation.utils import (find_auto_fn, find_fn, find_getitem, - last_node_in_match, use_cc_kernels) + last_node_in_match) from vllm.config import CompilationConfig from vllm.distributed import (tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) @@ -32,6 +32,17 @@ use_flux = False +# Depends on arch, see auto_tile_shape in include/flux/gemm_hparams.h +# Can be 256 on sm80. +FLUX_TILE_SIZE: int = 128 + + +def use_cc_kernels(m_shape: int) -> bool: + n_slices = get_tensor_model_parallel_world_size() + return (m_shape % (FLUX_TILE_SIZE * n_slices) == 0 + and m_shape >= FLUX_TILE_SIZE * n_slices) + + def get_world_name() -> str: return torch.distributed.group.WORLD.group_name @@ -134,8 +145,11 @@ def get_gemm_rs_ag_gemm(use_flux: bool, max_m: int, gemm_1_type: torch.dtype, local_copy=False, ) - gemm_rs = lambda act, wt: gemm_rs_op.forward(act, wt).squeeze(0) - ag_gemm = lambda act, wt: ag_gemm_op.forward(act, wt) + def gemm_rs(act, wt): + return gemm_rs_op.forward(act, wt).squeeze(0) + + def ag_gemm(act, wt): + return ag_gemm_op.forward(act, wt) gemm_1_str = str(gemm_1_type).removeprefix("torch.") gemm_2_str = str(gemm_2_type).removeprefix("torch.") @@ -146,12 +160,12 @@ def get_gemm_rs_ag_gemm(use_flux: bool, max_m: int, gemm_1_type: torch.dtype, else: world_group_name = get_world_name() - gemm_rs = lambda act, wt: \ - torch.ops.symm_mem.fused_matmul_reduce_scatter.default( + def gemm_rs(act, wt): + return torch.ops.symm_mem.fused_matmul_reduce_scatter.default( act, wt.transpose(1, 0), 'avg', 0, world_group_name) - ag_gemm = lambda act, wt: \ - torch.ops.symm_mem.fused_all_gather_matmul.default( + def ag_gemm(act, wt): + return torch.ops.symm_mem.fused_all_gather_matmul.default( act, [wt.transpose(1, 0)], 0, world_group_name)[1] group_str = tp_group_name.replace(":", "_") From 41ab065ce5c1f68d26117028bdf46bbcf3a82a08 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 26 Nov 2024 22:38:09 +0000 Subject: [PATCH 70/72] rebase + simplify Signed-off-by: Bill Nell --- vllm/compilation/collective_fusion.py | 64 ++++++++++++++++----------- vllm/distributed/parallel_state.py | 2 +- 2 files changed, 38 insertions(+), 28 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 4c83238389cc1..ef274abac4d16 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -72,16 +72,16 @@ def match_gemm_rs_ag_gemm( mm_1 = torch.ops.aten.mm.default(gemm_1_activations, gemm_1_w_perm) # It would be nice to do this instead of having two separate patterns. - # all_reduce = tensor_model_parallel_all_reduce(mm_1) - if custom_ar: - all_reduce = torch.ops.vllm.outplace_all_reduce.default( - mm_1, tp_group_name) - else: - all_reduce = torch.ops.higher_order.auto_functionalized( - torch.ops.vllm.inplace_all_reduce.default, - tensor=mm_1, - group_name=tp_group_name) - all_reduce = all_reduce[1] + all_reduce = tensor_model_parallel_all_reduce(mm_1) + #if custom_ar: + # all_reduce = torch.ops.vllm.outplace_all_reduce.default( + # mm_1, tp_group_name) + #else: + # all_reduce = torch.ops.higher_order.auto_functionalized( + # torch.ops.vllm.inplace_all_reduce.default, + # tensor=mm_1, + # group_name=tp_group_name) + # all_reduce = all_reduce[1] norm_res = torch.ops.higher_order.auto_functionalized( torch.ops._C.fused_add_rms_norm.default, @@ -265,16 +265,16 @@ def match_final( mm_1 = torch.ops.aten.mm.default(gemm_1_activations, gemm_1_w_perm) # TODO: it would be nice to be able to use the official api directly. - # all_reduce = tensor_model_parallel_all_reduce(mm_1) - if use_custom_ar: - all_reduce = torch.ops.vllm.outplace_all_reduce.default( - mm_1, tp_group_name) - else: - all_reduce = torch.ops.higher_order.auto_functionalized( - torch.ops.vllm.inplace_all_reduce.default, - tensor=mm_1, - group_name=tp_group_name) - all_reduce = all_reduce[1] + all_reduce = tensor_model_parallel_all_reduce(mm_1) + #if use_custom_ar: + # all_reduce = torch.ops.vllm.outplace_all_reduce.default( + # mm_1, tp_group_name) + #else: + # all_reduce = torch.ops.higher_order.auto_functionalized( + # torch.ops.vllm.inplace_all_reduce.default, + # tensor=mm_1, + # group_name=tp_group_name) + # all_reduce = all_reduce[1] norm_res = torch.ops.higher_order.auto_functionalized( torch.ops._C.fused_add_rms_norm.default, @@ -366,10 +366,18 @@ def __init__(self, config: CompilationConfig): world_size = get_tensor_model_parallel_world_size() group_names = [f"tp:{rank}" for rank in range(world_size)] + m = get_match_gemm_rs_ag_gemm(group_names[0], False) + register_replacement( + m, + m, + inputs, + fwd_only, [self.gemm_rs_ag_gemm_pattern], + extra_check=lambda m: self.record_match(m)) + for group_name in group_names: for m in [ - get_match_gemm_rs_ag_gemm(group_name, False), - get_match_gemm_rs_ag_gemm(group_name, True) + #get_match_gemm_rs_ag_gemm(group_name, False), + #get_match_gemm_rs_ag_gemm(group_name, True) ]: register_replacement( m, @@ -380,9 +388,11 @@ def __init__(self, config: CompilationConfig): for m in [ get_match_final(group_name, False), - get_match_final(group_name, True) + #get_match_final(group_name, True) ]: - register_replacement(m, torch.ops.vllm.gemm_ag_final, + torch._inductor.pattern_matcher._seen_patterns.clear() + register_replacement(m, + torch.ops.vllm.gemm_ag_final, final_inputs, fwd_only, [self.final_pattern]) @@ -435,14 +445,14 @@ def find_min_index(match: Match) -> int: # Extract group_name from matched code. Use to # generate proper replacement code. - ar_node = find_auto_fn( - match.nodes, torch.ops.vllm.inplace_all_reduce.default) + #ar_node = find_auto_fn(match.nodes, torch.ops.vllm.inplace_all_reduce.default) + ar_node = None if ar_node is not None: tp_group_name = ar_node.kwargs["group_name"] else: ar_node = find_fn( match.nodes, - torch.ops.vllm.outplace_all_reduce.default) + torch.ops.vllm.all_reduce.default) assert ar_node is not None tp_group_name = ar_node.args[1] diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 0d5b97eba3850..c4d2b42c18e97 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -113,7 +113,7 @@ def get_group_from_group_name(group_name: str) -> "GroupCoordinator": def all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor: - group = get_group_from_name(group_name) + group = get_group_from_group_name(group_name) return group._all_reduce_out_place(tensor) From b75cbbae6e88d1a786cc139f5af9cfb83f9e3845 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 26 Nov 2024 22:40:34 +0000 Subject: [PATCH 71/72] rebase + simplify Signed-off-by: Bill Nell --- vllm/compilation/collective_fusion.py | 133 ++++++++------------------ 1 file changed, 41 insertions(+), 92 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index ef274abac4d16..83ab5024e4108 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -59,45 +59,30 @@ def residual_slice_shape_fake(residual: torch.Tensor, rank: int) -> int: return slices[rank].shape[0] -def get_match_gemm_rs_ag_gemm(tp_group_name: str, custom_ar: bool) -> Callable: - - def match_gemm_rs_ag_gemm( +def match_gemm_rs_ag_gemm( residual: torch.Tensor, gemm_1_weights: torch.Tensor, gemm_1_activations: torch.Tensor, rms_norm_weights: torch.Tensor, gemm_2_weights: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - gemm_1_w_perm = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) - mm_1 = torch.ops.aten.mm.default(gemm_1_activations, gemm_1_w_perm) - - # It would be nice to do this instead of having two separate patterns. - all_reduce = tensor_model_parallel_all_reduce(mm_1) - #if custom_ar: - # all_reduce = torch.ops.vllm.outplace_all_reduce.default( - # mm_1, tp_group_name) - #else: - # all_reduce = torch.ops.higher_order.auto_functionalized( - # torch.ops.vllm.inplace_all_reduce.default, - # tensor=mm_1, - # group_name=tp_group_name) - # all_reduce = all_reduce[1] - - norm_res = torch.ops.higher_order.auto_functionalized( - torch.ops._C.fused_add_rms_norm.default, - input=all_reduce, - residual=residual, - weight=rms_norm_weights, - epsilon=1e-05) - normalized = norm_res[1] - new_residual = norm_res[2] - - gemm_2_w_perm = torch.ops.aten.permute.default(gemm_2_weights, [1, 0]) - mm_2 = torch.ops.aten.mm.default(normalized, gemm_2_w_perm) - - return mm_2, new_residual - - return match_gemm_rs_ag_gemm +) -> Tuple[torch.Tensor, torch.Tensor]: + gemm_1_w_perm = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) + mm_1 = torch.ops.aten.mm.default(gemm_1_activations, gemm_1_w_perm) + all_reduce = tensor_model_parallel_all_reduce(mm_1) + + norm_res = torch.ops.higher_order.auto_functionalized( + torch.ops._C.fused_add_rms_norm.default, + input=all_reduce, + residual=residual, + weight=rms_norm_weights, + epsilon=1e-05) + normalized = norm_res[1] + new_residual = norm_res[2] + + gemm_2_w_perm = torch.ops.aten.permute.default(gemm_2_weights, [1, 0]) + mm_2 = torch.ops.aten.mm.default(normalized, gemm_2_w_perm) + + return mm_2, new_residual def get_gemm_rs_ag_gemm(use_flux: bool, max_m: int, gemm_1_type: torch.dtype, @@ -253,40 +238,26 @@ def gemm_rs_ag_gemm_fake( return getattr(torch.ops.vllm, name).default -def get_match_final(tp_group_name: str, use_custom_ar: bool) -> Callable: - - def match_final( +def match_final( my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, gemm_1_activations: torch.Tensor, rms_norm_weights: torch.Tensor, - ) -> torch.Tensor: - gemm_1_w_perm = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) - mm_1 = torch.ops.aten.mm.default(gemm_1_activations, gemm_1_w_perm) - - # TODO: it would be nice to be able to use the official api directly. - all_reduce = tensor_model_parallel_all_reduce(mm_1) - #if use_custom_ar: - # all_reduce = torch.ops.vllm.outplace_all_reduce.default( - # mm_1, tp_group_name) - #else: - # all_reduce = torch.ops.higher_order.auto_functionalized( - # torch.ops.vllm.inplace_all_reduce.default, - # tensor=mm_1, - # group_name=tp_group_name) - # all_reduce = all_reduce[1] - - norm_res = torch.ops.higher_order.auto_functionalized( - torch.ops._C.fused_add_rms_norm.default, - input=all_reduce, - residual=my_residual, - weight=rms_norm_weights, - epsilon=1e-05) - normalized = norm_res[1] - - return normalized - - return match_final +) -> torch.Tensor: + gemm_1_w_perm = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) + mm_1 = torch.ops.aten.mm.default(gemm_1_activations, gemm_1_w_perm) + + all_reduce = tensor_model_parallel_all_reduce(mm_1) + + norm_res = torch.ops.higher_order.auto_functionalized( + torch.ops._C.fused_add_rms_norm.default, + input=all_reduce, + residual=my_residual, + weight=rms_norm_weights, + epsilon=1e-05) + normalized = norm_res[1] + + return normalized # Register this as a custom op since all reduce cannot be torch.compiled yet. @@ -362,39 +333,17 @@ def __init__(self, config: CompilationConfig): inputs = [resid, x, w, resid_w, x2] final_inputs = [x, w, resid, resid_w] - # register multiple patterns for all group names. - world_size = get_tensor_model_parallel_world_size() - group_names = [f"tp:{rank}" for rank in range(world_size)] - - m = get_match_gemm_rs_ag_gemm(group_names[0], False) register_replacement( - m, - m, + match_gemm_rs_ag_gemm, + match_gemm_rs_ag_gemm, inputs, fwd_only, [self.gemm_rs_ag_gemm_pattern], extra_check=lambda m: self.record_match(m)) - for group_name in group_names: - for m in [ - #get_match_gemm_rs_ag_gemm(group_name, False), - #get_match_gemm_rs_ag_gemm(group_name, True) - ]: - register_replacement( - m, - m, - inputs, - fwd_only, [self.gemm_rs_ag_gemm_pattern], - extra_check=lambda m: self.record_match(m)) - - for m in [ - get_match_final(group_name, False), - #get_match_final(group_name, True) - ]: - torch._inductor.pattern_matcher._seen_patterns.clear() - register_replacement(m, - torch.ops.vllm.gemm_ag_final, - final_inputs, fwd_only, - [self.final_pattern]) + register_replacement(match_final + torch.ops.vllm.gemm_ag_final, + final_inputs, fwd_only, + [self.final_pattern]) def record_match(self, match: Match) -> bool: # Hijack the extra_check to record the match and From 7e2c490c131a7e9c2db4c41314e1de92befac527 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Tue, 26 Nov 2024 23:05:33 +0000 Subject: [PATCH 72/72] cleanup Signed-off-by: Bill Nell --- vllm/compilation/backends.py | 10 +++--- vllm/compilation/collective_fusion.py | 52 +++++++++++---------------- 2 files changed, 27 insertions(+), 35 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 65c8d84f841c3..b0189725d427d 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -253,16 +253,18 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: self.compilation_configs.init_during_runtime() self.configure_post_pass() - if "before_split_graph" in self.compilation_configs.pass_config.dump_graph_stages: + if ("before_split_graph" + in self.compilation_configs.pass_config.dump_graph_stages): dump_graph(self.compilation_configs.pass_config, graph.graph, "before_split_graph") self.split_gm, self.piecewise_graphs = split_graph( graph, self.compilation_configs.splitting_ops) - if "after_split_graph" in self.compilation_configs.pass_config.dump_graph_stages: - dump_graph(self.compilation_configs.pass_config, self.split_gm.graph, - "after_split_graph") + if ("after_split_graph" + in self.compilation_configs.pass_config.dump_graph_stages): + dump_graph(self.compilation_configs.pass_config, + self.split_gm.graph, "after_split_graph") from torch._dynamo.utils import lazy_format_graph_code logger.debug("%s", lazy_format_graph_code("before split", self.graph)) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 83ab5024e4108..167b619135de5 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -31,7 +31,6 @@ logger.info("Attempting to use flux but flux not installed.") use_flux = False - # Depends on arch, see auto_tile_shape in include/flux/gemm_hparams.h # Can be 256 on sm80. FLUX_TILE_SIZE: int = 128 @@ -60,11 +59,11 @@ def residual_slice_shape_fake(residual: torch.Tensor, rank: int) -> int: def match_gemm_rs_ag_gemm( - residual: torch.Tensor, - gemm_1_weights: torch.Tensor, - gemm_1_activations: torch.Tensor, - rms_norm_weights: torch.Tensor, - gemm_2_weights: torch.Tensor, + residual: torch.Tensor, + gemm_1_weights: torch.Tensor, + gemm_1_activations: torch.Tensor, + rms_norm_weights: torch.Tensor, + gemm_2_weights: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: gemm_1_w_perm = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) mm_1 = torch.ops.aten.mm.default(gemm_1_activations, gemm_1_w_perm) @@ -239,10 +238,10 @@ def gemm_rs_ag_gemm_fake( def match_final( - my_residual: torch.Tensor, - gemm_1_weights: torch.Tensor, - gemm_1_activations: torch.Tensor, - rms_norm_weights: torch.Tensor, + my_residual: torch.Tensor, + gemm_1_weights: torch.Tensor, + gemm_1_activations: torch.Tensor, + rms_norm_weights: torch.Tensor, ) -> torch.Tensor: gemm_1_w_perm = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) mm_1 = torch.ops.aten.mm.default(gemm_1_activations, gemm_1_w_perm) @@ -260,7 +259,7 @@ def match_final( return normalized -# Register this as a custom op since all reduce cannot be torch.compiled yet. +# Register this as a custom op since all gather cannot be torch.compiled yet. def gemm_ag_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, gemm_1_activations: torch.Tensor, rms_norm_weights: torch.Tensor) -> torch.Tensor: @@ -333,17 +332,14 @@ def __init__(self, config: CompilationConfig): inputs = [resid, x, w, resid_w, x2] final_inputs = [x, w, resid, resid_w] - register_replacement( - match_gemm_rs_ag_gemm, - match_gemm_rs_ag_gemm, - inputs, - fwd_only, [self.gemm_rs_ag_gemm_pattern], - extra_check=lambda m: self.record_match(m)) + register_replacement(match_gemm_rs_ag_gemm, + match_gemm_rs_ag_gemm, + inputs, + fwd_only, [self.gemm_rs_ag_gemm_pattern], + extra_check=lambda m: self.record_match(m)) - register_replacement(match_final - torch.ops.vllm.gemm_ag_final, - final_inputs, fwd_only, - [self.final_pattern]) + register_replacement(match_final, torch.ops.vllm.gemm_ag_final, + final_inputs, fwd_only, [self.final_pattern]) def record_match(self, match: Match) -> bool: # Hijack the extra_check to record the match and @@ -394,16 +390,10 @@ def find_min_index(match: Match) -> int: # Extract group_name from matched code. Use to # generate proper replacement code. - #ar_node = find_auto_fn(match.nodes, torch.ops.vllm.inplace_all_reduce.default) - ar_node = None - if ar_node is not None: - tp_group_name = ar_node.kwargs["group_name"] - else: - ar_node = find_fn( - match.nodes, - torch.ops.vllm.all_reduce.default) - assert ar_node is not None - tp_group_name = ar_node.args[1] + ar_node = find_fn(match.nodes, + torch.ops.vllm.all_reduce.default) + assert ar_node is not None + tp_group_name = ar_node.args[1] fused_gemm_func = get_gemm_rs_ag_gemm( use_flux, max_m, gemm_1.dtype, gemm_1.shape, gemm_2.dtype,