From ba7752795567e3f2bfcc1dca340d107e003d32ad Mon Sep 17 00:00:00 2001 From: William Lin Date: Thu, 12 Sep 2024 21:30:00 -0700 Subject: [PATCH] [bugfix] torch profiler bug for single gpu with GPUExecutor (#8354) --- examples/offline_inference_with_profiler.py | 2 +- vllm/engine/async_llm_engine.py | 15 +++++++++++++-- vllm/engine/llm_engine.py | 15 +++++++++++++-- 3 files changed, 27 insertions(+), 5 deletions(-) diff --git a/examples/offline_inference_with_profiler.py b/examples/offline_inference_with_profiler.py index 906c9502800d8..1f00d26808771 100644 --- a/examples/offline_inference_with_profiler.py +++ b/examples/offline_inference_with_profiler.py @@ -16,7 +16,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. -llm = LLM(model="facebook/opt-125m") +llm = LLM(model="facebook/opt-125m", tensor_parallel_size=1) llm.start_profile() diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 01114e9843ce4..8a07ce1c965e1 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -13,6 +13,7 @@ from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState from vllm.engine.metrics_types import StatLoggerBase from vllm.executor.executor_base import ExecutorAsyncBase +from vllm.executor.gpu_executor import GPUExecutorAsync from vllm.executor.ray_utils import initialize_ray_cluster from vllm.inputs import PromptInputs from vllm.logger import init_logger @@ -1019,7 +1020,17 @@ def remove_logger(self, logger_name: str) -> None: self.engine.remove_logger(logger_name=logger_name) async def start_profile(self) -> None: - self.engine.model_executor._run_workers("start_profile") + # using type instead of isinstance to check to avoid capturing + # inherited classes + if type(self.engine.model_executor) == GPUExecutorAsync: + self.engine.model_executor.start_profile() + else: + self.engine.model_executor._run_workers("start_profile") async def stop_profile(self) -> None: - self.engine.model_executor._run_workers("stop_profile") + # using type instead of isinstance to check to avoid capturing + # inherited classes + if type(self.engine.model_executor) == GPUExecutorAsync: + self.engine.model_executor.stop_profile() + else: + self.engine.model_executor._run_workers("stop_profile") diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 0573921a40fc3..dfdbc22ef00e1 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -26,6 +26,7 @@ from vllm.engine.output_processor.stop_checker import StopChecker from vllm.engine.output_processor.util import create_output_by_sequence_group from vllm.executor.executor_base import ExecutorBase +from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.ray_utils import initialize_ray_cluster from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs, InputRegistry, LLMInputs, PromptInputs) @@ -1597,10 +1598,20 @@ def check_health(self) -> None: self.model_executor.check_health() def start_profile(self) -> None: - self.model_executor.start_profile() + # using type instead of isinstance to check to avoid capturing + # inherited classes (MultiprocessingGPUExecutor) + if type(self.model_executor) == GPUExecutor: + self.model_executor.start_profile() + else: + self.model_executor._run_workers("start_profile") def stop_profile(self) -> None: - self.model_executor.stop_profile() + # using type instead of isinstance to check to avoid capturing + # inherited classes (MultiprocessingGPUExecutor) + if type(self.model_executor) == GPUExecutor: + self.model_executor.stop_profile() + else: + self.model_executor._run_workers("stop_profile") def is_tracing_enabled(self) -> bool: return self.tracer is not None