From 91642db952458fbb6ae7c2d167757dc86b105991 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 11 Dec 2024 10:43:05 -0800 Subject: [PATCH 01/85] [torch.compile] use depyf to dump torch.compile internals (#10972) Signed-off-by: youkaichao --- requirements-common.txt | 1 + vllm/compilation/backends.py | 69 ++++++++++++++++++---------------- vllm/compilation/decorators.py | 2 +- vllm/compilation/monitor.py | 23 ++++++++++-- vllm/compilation/wrapper.py | 4 +- vllm/config.py | 6 ++- vllm/worker/model_runner.py | 3 +- 7 files changed, 66 insertions(+), 42 deletions(-) diff --git a/requirements-common.txt b/requirements-common.txt index 792cd58e80669..850b8f4101701 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -33,3 +33,4 @@ six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that need setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12 einops # Required for Qwen2-VL. compressed-tensors == 0.8.0 # required for compressed-tensors +depyf==0.18.0 # required for profiling and debugging torch.compile diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index f002a8ff905b1..09a3daa731829 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -9,7 +9,7 @@ import torch.fx as fx import vllm.envs as envs -from vllm.config import CompilationConfig +from vllm.config import CompilationConfig, VllmConfig from vllm.logger import init_logger from vllm.utils import weak_ref_tensors @@ -149,14 +149,15 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): """ def __init__(self, module: torch.fx.GraphModule, - compile_submod_names: List[str], - compilation_configs: CompilationConfig, graph_pool): + compile_submod_names: List[str], vllm_config: VllmConfig, + graph_pool): super().__init__(module) from torch._guards import detect_fake_mode self.fake_mode = detect_fake_mode() self.compile_submod_names = compile_submod_names - self.compilation_configs = compilation_configs + self.compilation_config = vllm_config.compilation_config self.graph_pool = graph_pool + self.vllm_config = vllm_config def run(self, *args): fake_args = [ @@ -182,15 +183,15 @@ def call_module(self, target: torch.fx.node.Target, compiled_graph_for_general_shape = wrap_inductor( submod, args, - self.compilation_configs.inductor_compile_config, - self.compilation_configs, + self.compilation_config.inductor_compile_config, + self.compilation_config, graph_index=index, num_graphs=len(self.compile_submod_names), runtime_shape=None, - use_inductor=self.compilation_configs.use_inductor) + use_inductor=self.compilation_config.use_inductor) self.module.__dict__[target] = PiecewiseBackend( - submod, self.compilation_configs, self.graph_pool, index, + submod, self.vllm_config, self.graph_pool, index, len(self.compile_submod_names), sym_shape_indices, compiled_graph_for_general_shape) @@ -211,7 +212,8 @@ class VllmBackend: which handles the post-grad passes. """ - compilation_configs: CompilationConfig + vllm_config: VllmConfig + compilation_config: CompilationConfig graph_pool: Any _called: bool = False # the graph we compiled @@ -227,7 +229,7 @@ class VllmBackend: def __init__( self, - compilation_configs: CompilationConfig, + vllm_config: VllmConfig, ): global global_graph_pool if global_graph_pool is None: @@ -244,13 +246,14 @@ def __init__( self.sym_tensor_indices = [] self.input_buffers = [] - self.compilation_configs = compilation_configs + self.vllm_config = vllm_config + self.compilation_config = vllm_config.compilation_config # `torch.compile` is JIT compiled, so we don't need to # do anything here def configure_post_pass(self): - config = self.compilation_configs + config = self.compilation_config self.post_grad_pass_manager.configure(config.pass_config) # Post-grad custom passes are run using the post_grad_custom_post_pass @@ -271,7 +274,7 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: from .monitor import torch_compile_start_time dynamo_time = time.time() - torch_compile_start_time logger.info("Dynamo bytecode transform time: %.2f s", dynamo_time) - self.compilation_configs.compilation_time += dynamo_time + self.compilation_config.compilation_time += dynamo_time # we control the compilation process, each instance can only be # called once @@ -281,7 +284,7 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: self.configure_post_pass() self.split_gm, self.piecewise_graphs = split_graph( - graph, self.compilation_configs.splitting_ops) + graph, self.compilation_config.splitting_ops) from torch._dynamo.utils import lazy_format_graph_code logger.debug("%s", lazy_format_graph_code("before split", self.graph)) @@ -298,13 +301,13 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: # propagate the split graph to the piecewise backend, # compile submodules with symbolic shapes PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile, - self.compilation_configs, + self.vllm_config, self.graph_pool).run(*example_inputs) self._called = True - if not self.compilation_configs.use_cudagraph or \ - not self.compilation_configs.cudagraph_copy_inputs: + if not self.compilation_config.use_cudagraph or \ + not self.compilation_config.cudagraph_copy_inputs: return self.split_gm # if we need to copy input buffers for cudagraph @@ -364,10 +367,9 @@ class ConcreteSizeEntry: class PiecewiseBackend: - def __init__(self, graph: fx.GraphModule, - compilation_configs: CompilationConfig, graph_pool: Any, - piecewise_compile_index: int, total_piecewise_compiles: int, - sym_shape_indices: List[int], + def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, + graph_pool: Any, piecewise_compile_index: int, + total_piecewise_compiles: int, sym_shape_indices: List[int], compiled_graph_for_general_shape: Callable): """ The backend for piecewise compilation. @@ -375,7 +377,7 @@ def __init__(self, graph: fx.GraphModule, We will compile `self.graph` once for the general shape, and then compile for different shapes specified in - `compilation_configs.compile_sizes`. + `compilation_config.compile_sizes`. Independently, we will capture cudagraph for different shapes. @@ -383,7 +385,8 @@ def __init__(self, graph: fx.GraphModule, compile it first, and then capture cudagraph. """ self.graph = graph - self.compilation_configs = compilation_configs + self.vllm_config = vllm_config + self.compilation_config = vllm_config.compilation_config self.graph_pool = graph_pool self.piecewise_compile_index = piecewise_compile_index self.total_piecewise_compiles = total_piecewise_compiles @@ -393,10 +396,10 @@ def __init__(self, graph: fx.GraphModule, piecewise_compile_index == total_piecewise_compiles - 1) self.compile_sizes: Set[int] = set( - self.compilation_configs.compile_sizes) + self.compilation_config.compile_sizes) self.capture_sizes: Set[int] = set( - self.compilation_configs.capture_sizes - ) if self.compilation_configs.use_cudagraph else set() + self.compilation_config.capture_sizes + ) if self.compilation_config.use_cudagraph else set() self.first_run_finished = False @@ -423,7 +426,7 @@ def __call__(self, *args) -> Any: self.first_run_finished = True # no specific sizes to compile if self.is_last_graph and not self.to_be_compiled_sizes: - end_monitoring_torch_compile(self.compilation_configs) + end_monitoring_torch_compile(self.vllm_config) return self.compiled_graph_for_general_shape(*args) runtime_shape = args[self.sym_shape_indices[0]] @@ -443,28 +446,28 @@ def __call__(self, *args) -> Any: entry.runnable = wrap_inductor( self.graph, args, - self.compilation_configs.inductor_compile_config, - self.compilation_configs, + self.compilation_config.inductor_compile_config, + self.compilation_config, graph_index=self.piecewise_compile_index, num_graphs=self.total_piecewise_compiles, runtime_shape=runtime_shape, - use_inductor=self.compilation_configs.use_inductor) + use_inductor=self.compilation_config.use_inductor) # finished compilations for all required shapes if self.is_last_graph and not self.to_be_compiled_sizes: - end_monitoring_torch_compile(self.compilation_configs) + end_monitoring_torch_compile(self.vllm_config) if not entry.use_cudagraph: return entry.runnable(*args) if entry.cudagraph is None: - if entry.num_finished_warmup < self.compilation_configs.cudagraph_num_of_warmups: # noqa + if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa entry.num_finished_warmup += 1 if self.is_first_graph: logger.debug( "Warming up %s/%s for shape %s", entry.num_finished_warmup, - self.compilation_configs.cudagraph_num_of_warmups, + self.compilation_config.cudagraph_num_of_warmups, runtime_shape) return entry.runnable(*args) diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 938430fe2a501..805a217ee6ca1 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -185,7 +185,7 @@ def __call__(self, *args, **kwargs): "Unsupported dynamic dimensions" f" {dims} for argument {k} with type {type(arg)}.") # here, it is the starting point of the `torch.compile` process - start_monitoring_torch_compile(self.vllm_config.compilation_config) + start_monitoring_torch_compile(self.vllm_config) # if we don't use custom dispatcher, we can directly call the # compiled function and let torch.compile handle the dispatching, diff --git a/vllm/compilation/monitor.py b/vllm/compilation/monitor.py index 3348674b09af2..b97e40415b41b 100644 --- a/vllm/compilation/monitor.py +++ b/vllm/compilation/monitor.py @@ -1,19 +1,36 @@ +import os import time -from vllm.config import CompilationConfig, CompilationLevel +from vllm.config import CompilationConfig, CompilationLevel, VllmConfig from vllm.logger import init_logger logger = init_logger(__name__) +context_manager = None torch_compile_start_time: float = 0.0 -def start_monitoring_torch_compile(compilation_config: CompilationConfig): +def start_monitoring_torch_compile(vllm_config: VllmConfig): global torch_compile_start_time torch_compile_start_time = time.time() + compilation_config: CompilationConfig = vllm_config.compilation_config + if compilation_config.level == CompilationLevel.PIECEWISE and \ + compilation_config.debug_dump_path: + import depyf + path = os.path.join(compilation_config.debug_dump_path, + f"rank_{vllm_config.parallel_config.rank}") + global context_manager + context_manager = depyf.prepare_debug(path) + context_manager.__enter__() -def end_monitoring_torch_compile(compilation_config: CompilationConfig): + +def end_monitoring_torch_compile(vllm_config: VllmConfig): + compilation_config: CompilationConfig = vllm_config.compilation_config if compilation_config.level == CompilationLevel.PIECEWISE: logger.info("torch.compile takes %.2f s in total", compilation_config.compilation_time) + global context_manager + if context_manager is not None: + context_manager.__exit__(None, None, None) + context_manager = None diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index bc4d292fef402..c10241b483169 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -32,8 +32,8 @@ def __init__(self, # default compilation settings # compiling the forward method - backend = get_current_vllm_config( - ).compilation_config.init_backend() + vllm_config = get_current_vllm_config() + backend = vllm_config.compilation_config.init_backend(vllm_config) compiled_callable = torch.compile( self.forward, diff --git a/vllm/config.py b/vllm/config.py index 322c8f8990a40..7f9be5a3a98bc 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2222,6 +2222,7 @@ class CompilationConfig(BaseModel): - 1: dynamo as is. - 2: dynamo once. - 3: piecewise compilation. + - debug_dump_path: the path to dump the debug information. - backend: the backend for compilation. It needs to be a string. - "" (empty string): use the default backend. - "eager"/"openxla"/...: use the specified backend registered in PyTorch. @@ -2289,6 +2290,7 @@ class CompilationConfig(BaseModel): certain small batchsizes, where inductor is good at optimizing. """ # noqa level: int = 0 + debug_dump_path: str = "" backend: str = "" custom_ops: List[str] = Field(default_factory=list) splitting_ops: List[str] = Field(default_factory=lambda: [ @@ -2394,7 +2396,7 @@ def model_post_init(self, __context: Any) -> None: self.static_forward_context = {} self.compilation_time = 0.0 - def init_backend(self) -> Union[str, Callable]: + def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]: if self.level == CompilationLevel.NO_COMPILATION: raise ValueError("No compilation level is set.") @@ -2413,7 +2415,7 @@ def init_backend(self) -> Union[str, Callable]: # merge with the config use_inductor assert self.level == CompilationLevel.PIECEWISE from vllm.compilation.backends import VllmBackend - return VllmBackend(self) + return VllmBackend(vllm_config) def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]): """To complete the initialization of config, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 551b84435fdc0..26fd486130ce6 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1162,7 +1162,8 @@ def load_model(self) -> None: if self.vllm_config.compilation_config.level ==\ CompilationLevel.DYNAMO_AS_IS and supports_dynamo(): - backend = self.vllm_config.compilation_config.init_backend() + backend = self.vllm_config.compilation_config.init_backend( + self.vllm_config) self.model = torch.compile( self.model, fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, From d643c2aba1cd5421200f3a3bad1813dd067233b4 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 11 Dec 2024 10:49:23 -0800 Subject: [PATCH 02/85] [V1] Use input_ids as input for text-only models (#11032) Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu_model_runner.py | 68 +++++++++++++++++++++--------- 1 file changed, 47 insertions(+), 21 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 8d9976ded7c5e..e75be21ef2d91 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -61,6 +61,7 @@ def __init__( self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ cache_config.cache_dtype] + self.is_multimodal_model = model_config.is_multimodal_model self.sliding_window = model_config.get_sliding_window() self.block_size = cache_config.block_size self.max_model_len = model_config.max_model_len @@ -103,6 +104,11 @@ def __init__( # The batch sizes in the config are in descending order. self.cudagraph_batch_sizes = list( reversed(self.vllm_config.compilation_config.capture_sizes)) + + # Persistent buffers for CUDA graphs. + self.input_ids = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device=self.device) self.positions = torch.zeros(self.max_num_tokens, dtype=torch.int64, device=self.device) @@ -310,7 +316,8 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): seq_start_loc_np[0] = 0 np.cumsum(seq_lens, out=seq_start_loc_np[1:]) - input_ids = input_ids.to(self.device, non_blocking=True) + self.input_ids[:total_num_scheduled_tokens].copy_(input_ids, + non_blocking=True) self.positions[:total_num_scheduled_tokens].copy_(positions, non_blocking=True) query_start_loc = query_start_loc.to(self.device, non_blocking=True) @@ -331,7 +338,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # token from the partial request. # TODO: Support prompt logprobs. logits_indices = query_start_loc[1:] - 1 - return input_ids, attn_metadata, logits_indices + return attn_metadata, logits_indices def _prepare_sampling( self, @@ -427,13 +434,15 @@ def execute_model( ) -> ModelRunnerOutput: self._update_states(scheduler_output) - # Run the encoder. - self._execute_encoder(scheduler_output) - encoder_outputs = self._gather_encoder_outputs(scheduler_output) + if self.is_multimodal_model: + # Run the multimodal encoder if any. + self._execute_encoder(scheduler_output) + encoder_outputs = self._gather_encoder_outputs(scheduler_output) + else: + encoder_outputs = [] # Prepare the decoder inputs. - input_ids, attn_metadata, logits_indices = self._prepare_inputs( - scheduler_output) + attn_metadata, logits_indices = self._prepare_inputs(scheduler_output) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if (self.use_cuda_graph and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): @@ -444,29 +453,39 @@ def execute_model( else: # Eager mode. num_input_tokens = num_scheduled_tokens - attn_metadata.num_input_tokens = num_input_tokens - # Get the inputs embeds. - if encoder_outputs: - inputs_embeds = self.model.get_input_embeddings( - input_ids, encoder_outputs) + if self.is_multimodal_model: + # NOTE(woosuk): To unify token ids and soft tokens (vision + # embeddings), we always use embeddings (rather than token ids) + # as input to the multimodal model, even when the input is text. + input_ids = self.input_ids[:num_scheduled_tokens] + if encoder_outputs: + inputs_embeds = self.model.get_input_embeddings( + input_ids, encoder_outputs) + else: + inputs_embeds = self.model.get_input_embeddings(input_ids) + # TODO(woosuk): Avoid the copy. Optimize. + self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds) + inputs_embeds = self.inputs_embeds[:num_input_tokens] + input_ids = None else: - inputs_embeds = self.model.get_input_embeddings(input_ids) - # NOTE(woosuk): To unify token ids and soft tokens (vision embeddings), - # always use embeddings (rather than token ids) as input to the model. - # TODO(woosuk): Avoid the copy. Optimize. - self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds) + # For text-only models, we use token ids as input. + # While it is possible to use embeddings as input just like the + # multimodal models, it is not desirable for performance since + # then the embedding layer is not included in the CUDA graph. + input_ids = self.input_ids[:num_input_tokens] + inputs_embeds = None # Run the decoder. # Use persistent buffers for CUDA graphs. with set_forward_context(attn_metadata, self.vllm_config): hidden_states = self.model( - input_ids=None, + input_ids=input_ids, positions=self.positions[:num_input_tokens], kv_caches=self.kv_caches, attn_metadata=None, - inputs_embeds=self.inputs_embeds[:num_input_tokens], + inputs_embeds=inputs_embeds, ) hidden_states = hidden_states[:num_scheduled_tokens] hidden_states = hidden_states[logits_indices] @@ -534,13 +553,20 @@ def _dummy_run( num_tokens: int, kv_caches: List[torch.Tensor], ) -> torch.Tensor: + if self.is_multimodal_model: + input_ids = None + inputs_embeds = self.inputs_embeds[:num_tokens] + else: + input_ids = self.input_ids[:num_tokens] + inputs_embeds = None with set_forward_context(None, self.vllm_config): hidden_states = model( - input_ids=None, + input_ids=input_ids, positions=self.positions[:num_tokens], kv_caches=kv_caches, attn_metadata=None, - inputs_embeds=self.inputs_embeds[:num_tokens]) + inputs_embeds=inputs_embeds, + ) return hidden_states def profile_run(self) -> None: From 66aaa7722df3d7ef9e9bd2942cab5cd0d7473174 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 11 Dec 2024 10:59:50 -0800 Subject: [PATCH 03/85] [torch.compile] remove graph logging in ci (#11110) Signed-off-by: youkaichao --- vllm/compilation/backends.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 09a3daa731829..4a5dc337d01b8 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -287,9 +287,11 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: graph, self.compilation_config.splitting_ops) from torch._dynamo.utils import lazy_format_graph_code - logger.debug("%s", lazy_format_graph_code("before split", self.graph)) - logger.debug("%s", lazy_format_graph_code("after split", - self.split_gm)) + + # depyf will hook lazy_format_graph_code and dump the graph + # for debugging, no need to print the graph here + lazy_format_graph_code("before split", self.graph) + lazy_format_graph_code("after split", self.split_gm) compilation_counter.num_piecewise_graphs_seen += len( self.piecewise_graphs) From 72ff3a968682e6a3f7620ab59f2baf5e8eb2777b Mon Sep 17 00:00:00 2001 From: Rui Qiao <161574667+ruisearch42@users.noreply.github.com> Date: Wed, 11 Dec 2024 11:36:35 -0800 Subject: [PATCH 04/85] [core] Bump ray to use _overlap_gpu_communication in compiled graph tests (#10410) Signed-off-by: Rui Qiao Signed-off-by: Rui Qiao Co-authored-by: Rui Qiao --- requirements-test.in | 2 +- requirements-test.txt | 2 +- vllm/envs.py | 8 ++++++++ vllm/executor/ray_gpu_executor.py | 17 ++++++++++------- 4 files changed, 20 insertions(+), 9 deletions(-) diff --git a/requirements-test.in b/requirements-test.in index c0b228148ab31..57fddb416317e 100644 --- a/requirements-test.in +++ b/requirements-test.in @@ -13,7 +13,7 @@ einops # required for MPT, qwen-vl and Mamba httpx librosa # required for audio tests peft -ray[adag]==2.35 +ray[adag]==2.40.0 sentence-transformers # required for embedding tests soundfile # required for audio tests timm # required for internvl test diff --git a/requirements-test.txt b/requirements-test.txt index 8ceb705cdffd7..c786a1249bddb 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -410,7 +410,7 @@ pyyaml==6.0.2 # ray # timm # transformers -ray[adag]==2.35.0 +ray[adag]==2.40.0 # via -r requirements-test.in redis==5.2.0 # via tensorizer diff --git a/vllm/envs.py b/vllm/envs.py index be5d9985b63a4..bc8c1499e9534 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -45,6 +45,7 @@ VLLM_USE_RAY_SPMD_WORKER: bool = False VLLM_USE_RAY_COMPILED_DAG: bool = False VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL: bool = True + VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM: bool = True VLLM_WORKER_MULTIPROC_METHOD: str = "fork" VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets") VLLM_IMAGE_FETCH_TIMEOUT: int = 5 @@ -337,6 +338,13 @@ def get_default_config_root(): lambda: bool(int(os.getenv("VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL", "1")) ), + # If the env var is set, it enables GPU communication overlap in + # Ray's compiled DAG. This flag is ignored if + # VLLM_USE_RAY_COMPILED_DAG is not set. + "VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM": + lambda: bool(int(os.getenv("VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM", "1")) + ), + # Use dedicated multiprocess context for workers. # Both spawn and fork work "VLLM_WORKER_MULTIPROC_METHOD": diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 4263fb27265f6..4bf5cbbd18ffe 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -414,12 +414,10 @@ def _check_ray_adag_installation(self): import pkg_resources from packaging import version - required_version = version.parse("2.35") + required_version = version.parse("2.40") current_version = version.parse( pkg_resources.get_distribution("ray").version) - # TODO: update the constraint once we adapt to the backward - # incompatible API change from ray 2.36 - if current_version != required_version: + if current_version < required_version: raise ValueError(f"Ray version {required_version} is " f"required, but found {current_version}") @@ -445,6 +443,8 @@ def _compiled_ray_dag(self, enable_asyncio: bool): logger.info("VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL = %s", envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL) + logger.info("VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM = %s", + envs.VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM) with InputNode() as input_data: # Example DAG: PP=2, TP=4 # (ExecuteModelReq, None) -> 0 -> (ExecuteModelReq, IntermediateOutput) -> 4 -> SamplerOutput # noqa: E501 @@ -480,7 +480,10 @@ def _compiled_ray_dag(self, enable_asyncio: bool): forward_dag = MultiOutputNode(outputs) - return forward_dag.experimental_compile(enable_asyncio=enable_asyncio) + return forward_dag.experimental_compile( + enable_asyncio=enable_asyncio, + _overlap_gpu_communication=envs. + VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM) def __del__(self): self.shutdown() @@ -507,8 +510,8 @@ async def execute_model_async( serialized_data = self.input_encoder.encode(execute_model_req) dag_future = await self.forward_dag.execute_async(serialized_data) - outputs = await dag_future - return self.output_decoder.decode(outputs[0]) + output = await dag_future[0] + return self.output_decoder.decode(output) async def _driver_execute_model_async( self, From d1e21a979bba4712f48dac1bbf410e0b57c92e7a Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 12 Dec 2024 06:18:16 +0800 Subject: [PATCH 05/85] [CI/Build] Split up VLM tests (#11083) Signed-off-by: DarkLight1337 --- .buildkite/test-pipeline.yaml | 32 ++++++--- pyproject.toml | 3 +- .../vision_language/test_models.py | 72 ++++++++++++------- tests/utils.py | 37 ++++++---- 4 files changed, 94 insertions(+), 50 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index df4fa7a6ee9ba..aca505178df06 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -321,7 +321,7 @@ steps: ##### models test ##### -- label: Basic Models Test # 30min +- label: Basic Models Test # 24min source_file_dependencies: - vllm/ - tests/models @@ -331,7 +331,7 @@ steps: - pytest -v -s models/test_registry.py - pytest -v -s models/test_initialization.py -- label: Language Models Test (Standard) # 42min +- label: Language Models Test (Standard) # 32min #mirror_hardwares: [amd] source_file_dependencies: - vllm/ @@ -342,7 +342,7 @@ steps: - pytest -v -s models/decoder_only/language -m 'core_model or quant_model' - pytest -v -s models/embedding/language -m core_model -- label: Language Models Test (Extended) # 50min +- label: Language Models Test (Extended) # 1h10min optional: true source_file_dependencies: - vllm/ @@ -353,7 +353,7 @@ steps: - pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model' - pytest -v -s models/embedding/language -m 'not core_model' -- label: Multi-Modal Models Test (Standard) # 26min +- label: Multi-Modal Models Test (Standard) # 28min #mirror_hardwares: [amd] source_file_dependencies: - vllm/ @@ -369,7 +369,7 @@ steps: - pytest -v -s models/encoder_decoder/language -m core_model - pytest -v -s models/encoder_decoder/vision_language -m core_model -- label: Multi-Modal Models Test (Extended) # 1h15m +- label: Multi-Modal Models Test (Extended) 1 # 1h16m optional: true source_file_dependencies: - vllm/ @@ -380,14 +380,24 @@ steps: commands: - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - pytest -v -s models/decoder_only/audio_language -m 'not core_model and not quant_model' + - pytest -v -s models/decoder_only/vision_language/test_models.py -m 'split(group=0) and not core_model and not quant_model' # HACK - run phi3v tests separately to sidestep this transformers bug # https://github.com/huggingface/transformers/issues/34307 - pytest -v -s models/decoder_only/vision_language/test_phi3v.py - - pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model and not quant_model' + - pytest -v -s --ignore models/decoder_only/vision_language/test_models.py --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model and not quant_model' - pytest -v -s models/embedding/vision_language -m 'not core_model' - pytest -v -s models/encoder_decoder/language -m 'not core_model' - pytest -v -s models/encoder_decoder/vision_language -m 'not core_model' +- label: Multi-Modal Models Test (Extended) 2 # 38m + optional: true + source_file_dependencies: + - vllm/ + - tests/models/decoder_only/vision_language + commands: + - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git + - pytest -v -s models/decoder_only/vision_language/test_models.py -m 'split(group=1) and not core_model and not quant_model' + # This test is used only in PR development phase to test individual models and should never run on main - label: Custom Models Test optional: true @@ -446,11 +456,11 @@ steps: - pytest -v -s ./compile/test_basic_correctness.py - pytest -v -s ./compile/test_wrapper.py - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed' - - TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m distributed_2_gpus + - TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)' # Avoid importing model tests that cause CUDA reinitialization error - - pytest models/encoder_decoder/language/test_bart.py -v -s -m distributed_2_gpus - - pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m distributed_2_gpus - - pytest models/decoder_only/vision_language/test_models.py -v -s -m distributed_2_gpus + - pytest models/encoder_decoder/language/test_bart.py -v -s -m 'distributed(num_gpus=2)' + - pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m 'distributed(num_gpus=2)' + - pytest models/decoder_only/vision_language/test_models.py -v -s -m 'distributed(num_gpus=2)' - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py - pip install -e ./plugins/vllm_add_dummy_model - pytest -v -s distributed/test_distributed_oot.py @@ -540,7 +550,7 @@ steps: # see https://github.com/vllm-project/vllm/pull/5689 for details - pytest -v -s distributed/test_custom_all_reduce.py - torchrun --nproc_per_node=2 distributed/test_ca_buffer_sharing.py - - TARGET_TEST_SUITE=A100 pytest basic_correctness/ -v -s -m distributed_2_gpus + - TARGET_TEST_SUITE=A100 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)' - pytest -v -s -x lora/test_mixtral.py - label: LM Eval Large Models # optional diff --git a/pyproject.toml b/pyproject.toml index 253b706a774a7..c5a14ecf5aea9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,7 +96,8 @@ markers = [ "core_model: enable this model test in each PR instead of only nightly", "cpu_model: enable this model test in CPU tests", "quant_model: run this model test under Quantized category", - "distributed_2_gpus: run this test only in distributed tests for 2 GPUs", + "split: run this test as part of a split", + "distributed: run this test only in distributed GPU tests", "skip_v1: do not run this test with v1", "optional: optional tests that are automatically skipped, include --optional to run them", ] diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index ed8f34a677f84..3101d1d2ea831 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -1,7 +1,9 @@ """Common tests for testing .generate() functionality for single / multiple image, embedding, and video support for different VLMs in vLLM. """ +import math import os +from collections import defaultdict from pathlib import PosixPath from typing import Type @@ -10,11 +12,12 @@ from transformers.utils import is_flash_attn_2_available from vllm.platforms import current_platform -from vllm.utils import cuda_device_count_stateless, identity +from vllm.utils import identity from ....conftest import (IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets, _VideoAssets) -from ....utils import fork_new_process_for_each_test, large_gpu_mark +from ....utils import (fork_new_process_for_each_test, large_gpu_mark, + multi_gpu_marks) from ...utils import check_outputs_equal from .vlm_utils import custom_inputs, model_utils, runners from .vlm_utils.case_filtering import get_parametrized_options @@ -382,7 +385,7 @@ prompt_path_encoder=model_utils.qwen_prompt_path_encoder, ), ### Tensor parallel / multi-gpu broadcast tests - "broadcast-chameleon": VLMTestInfo( + "chameleon-broadcast": VLMTestInfo( models=["facebook/chameleon-7b"], prompt_formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:", max_model_len=4096, @@ -393,43 +396,25 @@ vllm_output_post_proc = lambda vllm_output, model: vllm_output[:2], hf_output_post_proc = lambda hf_output, model: hf_output[:2], comparator=check_outputs_equal, - marks=[ - pytest.mark.distributed_2_gpus, - pytest.mark.skipif( - cuda_device_count_stateless() < 2, - reason="Need at least 2 GPUs to run the test.", - ), - ], + marks=multi_gpu_marks(num_gpus=2), **COMMON_BROADCAST_SETTINGS # type: ignore ), - "broadcast-llava": VLMTestInfo( + "llava-broadcast": VLMTestInfo( models=["llava-hf/llava-1.5-7b-hf"], prompt_formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:", max_model_len=4096, auto_cls=AutoModelForVision2Seq, vllm_output_post_proc=model_utils.llava_image_vllm_to_hf_output, - marks=[ - pytest.mark.distributed_2_gpus, - pytest.mark.skipif( - cuda_device_count_stateless() < 2, - reason="Need at least 2 GPUs to run the test.", - ) - ], + marks=multi_gpu_marks(num_gpus=2), **COMMON_BROADCAST_SETTINGS # type: ignore ), - "broadcast-llava_next": VLMTestInfo( + "llava_next-broadcast": VLMTestInfo( models=["llava-hf/llava-v1.6-mistral-7b-hf"], prompt_formatter=lambda img_prompt: f"[INST] {img_prompt} [/INST]", max_model_len=10240, auto_cls=AutoModelForVision2Seq, vllm_output_post_proc=model_utils.llava_image_vllm_to_hf_output, - marks=[ - pytest.mark.distributed_2_gpus, - pytest.mark.skipif( - cuda_device_count_stateless() < 2, - reason="Need at least 2 GPUs to run the test.", - ) - ], + marks=multi_gpu_marks(num_gpus=2), **COMMON_BROADCAST_SETTINGS # type: ignore ), ### Custom input edge-cases for specific models @@ -468,6 +453,41 @@ # yapf: enable +def _mark_splits( + test_settings: dict[str, VLMTestInfo], + *, + num_groups: int, +) -> dict[str, VLMTestInfo]: + name_by_test_info_id = {id(v): k for k, v in test_settings.items()} + test_infos_by_model = defaultdict[str, list[VLMTestInfo]](list) + + for info in test_settings.values(): + for model in info.models: + test_infos_by_model[model].append(info) + + models = sorted(test_infos_by_model.keys()) + split_size = math.ceil(len(models) / num_groups) + + new_test_settings = dict[str, VLMTestInfo]() + + for i in range(num_groups): + models_in_group = models[i * split_size:(i + 1) * split_size] + + for model in models_in_group: + for info in test_infos_by_model[model]: + new_marks = (info.marks or []) + [pytest.mark.split(group=i)] + new_info = info._replace(marks=new_marks) + new_test_settings[name_by_test_info_id[id(info)]] = new_info + + missing_keys = test_settings.keys() - new_test_settings.keys() + assert not missing_keys, f"Missing keys: {missing_keys}" + + return new_test_settings + + +VLM_TEST_SETTINGS = _mark_splits(VLM_TEST_SETTINGS, num_groups=2) + + ### Test wrappers # Wrappers around the core test running func for: # - single image diff --git a/tests/utils.py b/tests/utils.py index a893667e144a6..afeb708f3bcdc 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -682,10 +682,12 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: def large_gpu_mark(min_gb: int) -> pytest.MarkDecorator: - """Gets a pytest skipif mark, which triggers ig the the device doesn't have - meet a minimum memory requirement in gb; can be leveraged via - @large_gpu_test to skip tests in environments without enough resources, or - called when filtering tests to run directly. + """ + Get a pytest mark, which skips the test if the GPU doesn't meet + a minimum memory requirement in GB. + + This can be leveraged via `@large_gpu_test` to skip tests in environments + without enough resources, or called when filtering tests to run directly. """ try: if current_platform.is_cpu(): @@ -712,26 +714,37 @@ def large_gpu_test(*, min_gb: int): Currently, the CI machine uses L4 GPU which has 24 GB VRAM. """ - test_skipif = large_gpu_mark(min_gb) + mark = large_gpu_mark(min_gb) def wrapper(f: Callable[_P, None]) -> Callable[_P, None]: - return test_skipif(f) + return mark(f) return wrapper -def multi_gpu_test(*, num_gpus: int): - """ - Decorate a test to be run only when multiple GPUs are available. - """ - test_selector = getattr(pytest.mark, f"distributed_{num_gpus}_gpus") +def multi_gpu_marks(*, num_gpus: int): + """Get a collection of pytest marks to apply for `@multi_gpu_test`.""" + test_selector = pytest.mark.distributed(num_gpus=num_gpus) test_skipif = pytest.mark.skipif( cuda_device_count_stateless() < num_gpus, reason=f"Need at least {num_gpus} GPUs to run the test.", ) + return [test_selector, test_skipif] + + +def multi_gpu_test(*, num_gpus: int): + """ + Decorate a test to be run only when multiple GPUs are available. + """ + marks = multi_gpu_marks(num_gpus=num_gpus) + def wrapper(f: Callable[_P, None]) -> Callable[_P, None]: - return test_selector(test_skipif(fork_new_process_for_each_test(f))) + func = fork_new_process_for_each_test(f) + for mark in reversed(marks): + func = mark(func) + + return func return wrapper From 452a723bf2e8410ee9b47f82f90c7ea48aa6d14f Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Wed, 11 Dec 2024 18:34:54 -0500 Subject: [PATCH 06/85] [V1][Core] Remove should_shutdown to simplify core process termination (#11113) Signed-off-by: Tyler Michael Smith --- vllm/v1/engine/core.py | 13 ++----------- vllm/v1/engine/core_client.py | 6 ------ 2 files changed, 2 insertions(+), 17 deletions(-) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 55a5c4dff3a5c..a26ffe74a3ae8 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -5,7 +5,6 @@ import threading import time from multiprocessing.process import BaseProcess -from multiprocessing.sharedctypes import Synchronized from typing import List, Tuple, Type, Union import zmq @@ -133,13 +132,9 @@ def __init__( input_path: str, output_path: str, ready_path: str, - should_shutdown: Synchronized, ): super().__init__(vllm_config, executor_class, usage_context) - # Signal from main process to shutdown (multiprocessing.Value). - self.should_shutdown = should_shutdown - # Background Threads and Queues for IO. These enable us to # overlap ZMQ socket IO with GPU since they release the GIL, # and to overlap some serialization/deserialization with the @@ -195,7 +190,6 @@ def make_engine_core_process( input_path: str, output_path: str, ready_path: str, - should_shutdown: Synchronized, ) -> BaseProcess: # The current process might have CUDA context, # so we need to spawn a new process. @@ -210,7 +204,6 @@ def make_engine_core_process( "vllm_config": vllm_config, "executor_class": executor_class, "usage_context": usage_context, - "should_shutdown": should_shutdown } # Run EngineCore busy loop in background process. proc = context.Process(target=EngineCoreProc.run_engine_core, @@ -260,8 +253,8 @@ def signal_handler(signum, frame): def run_busy_loop(self): """Core busy loop of the EngineCore.""" - # Loop until we get a shutdown signal. - while not self.should_shutdown: + # Loop until process is sent a SIGINT or SIGTERM + while True: # 1) Poll the input queue until there is work to do. if not self.scheduler.has_unfinished_requests(): while True: @@ -272,8 +265,6 @@ def run_busy_loop(self): except queue.Empty: self._log_stats() logger.debug("EngineCore busy loop waiting.") - if self.should_shutdown: - return except BaseException: raise diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 4d96b323d1662..1d5ddf4db4d7c 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -1,5 +1,4 @@ import atexit -import multiprocessing from typing import List, Union import msgspec @@ -149,21 +148,16 @@ def __init__( self.input_socket.bind(input_path) # Start EngineCore in background process. - self.should_shutdown = multiprocessing.Value('b', False, lock=False) self.proc = EngineCoreProc.make_engine_core_process( *args, input_path=input_path, output_path=output_path, ready_path=ready_path, - should_shutdown=self.should_shutdown, **kwargs, ) atexit.register(self.shutdown) def shutdown(self): - # Send shutdown signal to background process. - self.should_shutdown = True - # Shut down the zmq context. self.ctx.destroy(linger=0) From 4e116833686f3e0c0a223b05b5859ad76843a017 Mon Sep 17 00:00:00 2001 From: Alexander Matveev <59768536+alexm-neuralmagic@users.noreply.github.com> Date: Wed, 11 Dec 2024 19:55:30 -0500 Subject: [PATCH 07/85] [V1] VLM preprocessor hashing (#11020) Signed-off-by: Roger Wang Signed-off-by: Alexander Matveev Co-authored-by: Michael Goin Co-authored-by: Roger Wang --- examples/offline_inference_vision_language.py | 126 ++++++++++++-- requirements-common.txt | 1 + tests/v1/engine/test_engine_core.py | 1 + tests/v1/engine/test_engine_core_client.py | 1 + vllm/config.py | 10 +- vllm/engine/arg_utils.py | 8 + vllm/v1/engine/__init__.py | 3 +- vllm/v1/engine/core.py | 18 +- vllm/v1/engine/mm_input_mapper.py | 156 ++++++++++++++++-- vllm/v1/engine/processor.py | 35 ++-- vllm/v1/utils.py | 21 +++ 11 files changed, 332 insertions(+), 48 deletions(-) diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index c6a274ee5894b..5e210126dc8fe 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -5,6 +5,8 @@ For most models, the prompt format should follow corresponding examples on HuggingFace model repository. """ +import random + from transformers import AutoTokenizer from vllm import LLM, SamplingParams @@ -23,7 +25,9 @@ def run_llava(question: str, modality: str): prompt = f"USER: \n{question}\nASSISTANT:" - llm = LLM(model="llava-hf/llava-1.5-7b-hf", max_model_len=4096) + llm = LLM(model="llava-hf/llava-1.5-7b-hf", + max_model_len=4096, + mm_cache_preprocessor=args.mm_cache_preprocessor) stop_token_ids = None return llm, prompt, stop_token_ids @@ -33,7 +37,9 @@ def run_llava_next(question: str, modality: str): assert modality == "image" prompt = f"[INST] \n{question} [/INST]" - llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf", max_model_len=8192) + llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf", + max_model_len=8192, + mm_cache_preprocessor=args.mm_cache_preprocessor) stop_token_ids = None return llm, prompt, stop_token_ids @@ -44,7 +50,9 @@ def run_llava_next_video(question: str, modality: str): assert modality == "video" prompt = f"USER: