Skip to content

Commit

Permalink
[V1] Fix torch profiling for offline inference (vllm-project#11125)
Browse files Browse the repository at this point in the history
Signed-off-by: Roger Wang <[email protected]>
  • Loading branch information
ywang96 authored Dec 12, 2024
1 parent 85362f0 commit 4816d20
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 14 deletions.
31 changes: 19 additions & 12 deletions examples/offline_inference_with_profiler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import time

from vllm import LLM, SamplingParams

Expand All @@ -15,19 +16,25 @@
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

# Create an LLM.
llm = LLM(model="facebook/opt-125m", tensor_parallel_size=1)
if __name__ == "__main__":

llm.start_profile()
# Create an LLM.
llm = LLM(model="facebook/opt-125m", tensor_parallel_size=1)

# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
llm.start_profile()

llm.stop_profile()
# Generate texts from the prompts. The output is a list of RequestOutput
# objects that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)

# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
llm.stop_profile()

# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

# Add a buffer to wait for profiler in the background process
# (in case MP is on) to finish writing profiling output.
time.sleep(10)
4 changes: 2 additions & 2 deletions vllm/v1/engine/core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def shutdown(self):
def __del__(self):
self.shutdown()

async def profile(self, is_start=True) -> None:
def profile(self, is_start=True) -> None:
self.engine_core.profile(is_start)


Expand Down Expand Up @@ -212,7 +212,7 @@ def add_request(self, request: EngineCoreRequest) -> None:
def abort_requests(self, request_ids: List[str]) -> None:
self._send_input(EngineCoreRequestType.ABORT, request_ids)

async def profile(self, is_start=True) -> None:
def profile(self, is_start=True) -> None:
self._send_input(EngineCoreRequestType.PROFILE,
EngineCoreProfile(is_start))

Expand Down

0 comments on commit 4816d20

Please sign in to comment.