Skip to content

Commit

Permalink
works now
Browse files Browse the repository at this point in the history
  • Loading branch information
jikunshang authored and xuechendi committed Nov 26, 2024
1 parent d850872 commit bc8acd2
Show file tree
Hide file tree
Showing 8 changed files with 34 additions and 21 deletions.
1 change: 0 additions & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1443,7 +1443,6 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:

outputs = self.model_executor.execute_model(
execute_model_req=execute_model_req)

# We need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
if self.scheduler_config.is_multi_step:
Expand Down
2 changes: 1 addition & 1 deletion vllm/engine/mm_arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,7 +911,7 @@ def create_model_config(self, model:str = None) -> ModelConfig:
model=model if model is not None else self.model,
task=self.task,
# We know this is not None because we set it in __post_init__
tokenizer=cast(str, self.tokenizer),
tokenizer=cast(str, model),
tokenizer_mode=self.tokenizer_mode,
trust_remote_code=self.trust_remote_code,
allowed_local_media_path=self.allowed_local_media_path,
Expand Down
5 changes: 5 additions & 0 deletions vllm/engine/multiprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class RPCProcessRequest:
prompt: PromptType
params: Union[SamplingParams, PoolingParams]
request_id: str
model: str
lora_request: Optional[LoRARequest] = None
trace_headers: Optional[Mapping[str, str]] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None
Expand All @@ -39,6 +40,7 @@ def __init__(
inputs: PromptType,
params: Union[SamplingParams, PoolingParams],
request_id: str,
model: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
Expand All @@ -52,6 +54,7 @@ def __init__(
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
request_id: str,
model: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
Expand All @@ -68,6 +71,7 @@ def __init__(
prompt: Optional[PromptType] = None,
params: Optional[Union[SamplingParams, PoolingParams]] = None,
request_id: Optional[str] = None,
model: Optional[str] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
Expand All @@ -85,6 +89,7 @@ def __init__(
self.prompt = prompt
self.params = params
self.request_id = request_id
self.model = model
self.lora_request = lora_request
self.trace_headers = trace_headers
self.prompt_adapter_request = prompt_adapter_request
Expand Down
11 changes: 10 additions & 1 deletion vllm/engine/multiprocessing/mm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,12 @@ async def get_input_preprocessor(self) -> InputPreprocessor:
async def get_tokenizer(self, lora_request: Optional[LoRARequest] = None):
return await self.tokenizer.get_lora_tokenizer_async(lora_request)

async def get_tokenizer_mm(self, model, lora_request: Optional[LoRARequest] = None):

Check failure on line 364 in vllm/engine/multiprocessing/mm_client.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/engine/multiprocessing/mm_client.py:364:81: E501 Line too long (88 > 80)
for tokenizer in self.tokenizers:
if tokenizer.tokenizer_id == model:
return await tokenizer.get_lora_tokenizer_async(lora_request)
raise ValueError(f"Tokenizer for model {model} not found.")

async def get_decoding_config(self) -> DecodingConfig:
return self.decoding_config

Expand Down Expand Up @@ -458,6 +464,7 @@ def generate(
prompt: Optional[PromptType] = None,
sampling_params: Optional[SamplingParams] = None,
request_id: Optional[str] = None,
model: Optional[str] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
Expand Down Expand Up @@ -489,7 +496,7 @@ def generate(
assert (prompt is not None and sampling_params is not None
and request_id is not None)

return self._process_request(prompt, sampling_params, request_id,
return self._process_request(prompt, sampling_params, request_id, model,
lora_request, trace_headers,
prompt_adapter_request, priority)

Expand Down Expand Up @@ -570,6 +577,7 @@ async def _process_request(
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
request_id: str,
model: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
Expand Down Expand Up @@ -618,6 +626,7 @@ async def _process_request(
prompt=prompt,
params=params,
request_id=request_id,
model=model,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
Expand Down
30 changes: 15 additions & 15 deletions vllm/engine/multiprocessing/mm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,20 +71,19 @@ def __init__(self,

# get configs from args and kwargs, determine how many models to load
vllm_config = kwargs.get('vllm_config')
print(f"aaaa {vllm_config}")
models_load = [model_config.model for model_config in vllm_config.model_configs ]

Check failure on line 74 in vllm/engine/multiprocessing/mm_engine.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/engine/multiprocessing/mm_engine.py:74:81: E501 Line too long (89 > 80)
self.engines = []

for i, model in enumerate(models_load):
print(f"create engine for model: {model}")
vllm_config.model_config = vllm_config.model_configs[i]
self.engines.append(LLMEngine(model=model, *args, **kwargs))

Check failure on line 79 in vllm/engine/multiprocessing/mm_engine.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (B026)

vllm/engine/multiprocessing/mm_engine.py:79:56: B026 Star-arg unpacking after a keyword argument is strongly discouraged
self.log_requests = log_requests

self.use_async_sockets = use_async_sockets
# if self.use_async_sockets:
# self.engine.process_request_outputs_callback = \
# self._async_socket_engine_callback
if self.use_async_sockets:
for engine in self.engines:
engine.process_request_outputs_callback = \
self._async_socket_engine_callback

self.ctx = zmq.Context() # type: ignore[attr-defined]

Expand Down Expand Up @@ -215,7 +214,7 @@ def engine_step(self) -> List[RequestOutput]:
try:
res = []
for engine in self.engines:
res.append(engine.step())
res += engine.step()
return res
except SystemExit:
raise
Expand Down Expand Up @@ -269,14 +268,16 @@ def _handle_process_request(self, request: RPCProcessRequest):
self._send_outputs(rpc_err)

try:
self.engine.add_request(
request_id=request_id,
prompt=request.prompt,
params=request.params,
lora_request=request.lora_request,
trace_headers=request.trace_headers,
prompt_adapter_request=request.prompt_adapter_request,
priority=request.priority)
for engine in self.engines:
if engine.model_config.model == request.model:
engine.add_request(
request_id=request_id,
prompt=request.prompt,
params=request.params,
lora_request=request.lora_request,
trace_headers=request.trace_headers,
prompt_adapter_request=request.prompt_adapter_request,
priority=request.priority)

if self.log_requests:
logger.info("Added request %s.", request.request_id)
Expand Down Expand Up @@ -372,7 +373,6 @@ def signal_handler(*_) -> None:
def run_mm_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext,
ipc_path: str, engine_alive):
try:
print(f"bbbb {engine_args}")
engine = MMLLMEngine.from_engine_args(engine_args=engine_args,
usage_context=usage_context,
ipc_path=ipc_path)
Expand Down
1 change: 0 additions & 1 deletion vllm/entrypoints/openai/mm_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ async def build_async_engine_client_from_engine_args(
Returns the Client or None if the creation failed.
"""

# Fall back
# TODO: fill out feature matrix.
if (MMLLMEngineClient.is_unsupported_config(engine_args)

Check failure on line 133 in vllm/entrypoints/openai/mm_api_server.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Argument 1 to "is_unsupported_config" of "MMLLMEngineClient" has incompatible type "vllm.engine.mm_arg_utils.AsyncEngineArgs"; expected "vllm.engine.arg_utils.AsyncEngineArgs" [arg-type]

Check failure on line 133 in vllm/entrypoints/openai/mm_api_server.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Argument 1 to "is_unsupported_config" of "MMLLMEngineClient" has incompatible type "vllm.engine.mm_arg_utils.AsyncEngineArgs"; expected "vllm.engine.arg_utils.AsyncEngineArgs" [arg-type]

Check failure on line 133 in vllm/entrypoints/openai/mm_api_server.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Argument 1 to "is_unsupported_config" of "MMLLMEngineClient" has incompatible type "vllm.engine.mm_arg_utils.AsyncEngineArgs"; expected "vllm.engine.arg_utils.AsyncEngineArgs" [arg-type]

Check failure on line 133 in vllm/entrypoints/openai/mm_api_server.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Argument 1 to "is_unsupported_config" of "MMLLMEngineClient" has incompatible type "vllm.engine.mm_arg_utils.AsyncEngineArgs"; expected "vllm.engine.arg_utils.AsyncEngineArgs" [arg-type]
Expand Down
2 changes: 1 addition & 1 deletion vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ async def create_chat_completion(
prompt_adapter_request,
) = self._maybe_get_adapters(request)

tokenizer = await self.engine_client.get_tokenizer(lora_request)
tokenizer = await self.engine_client.get_tokenizer_mm(request.model, lora_request)

Check failure on line 126 in vllm/entrypoints/openai/serving_chat.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/entrypoints/openai/serving_chat.py:126:81: E501 Line too long (94 > 80)

Check failure on line 126 in vllm/entrypoints/openai/serving_chat.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

"EngineClient" has no attribute "get_tokenizer_mm"; maybe "get_tokenizer"? [attr-defined]

Check failure on line 126 in vllm/entrypoints/openai/serving_chat.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

"EngineClient" has no attribute "get_tokenizer_mm"; maybe "get_tokenizer"? [attr-defined]

Check failure on line 126 in vllm/entrypoints/openai/serving_chat.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

"EngineClient" has no attribute "get_tokenizer_mm"; maybe "get_tokenizer"? [attr-defined]

Check failure on line 126 in vllm/entrypoints/openai/serving_chat.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

"EngineClient" has no attribute "get_tokenizer_mm"; maybe "get_tokenizer"? [attr-defined]

tool_parser = self.tool_parser

Expand Down
3 changes: 2 additions & 1 deletion vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ async def create_completion(
prompt_adapter_request,
) = self._maybe_get_adapters(request)

tokenizer = await self.engine_client.get_tokenizer(lora_request)
tokenizer = await self.engine_client.get_tokenizer_mm(request.model, lora_request)

Check failure on line 102 in vllm/entrypoints/openai/serving_completion.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/entrypoints/openai/serving_completion.py:102:81: E501 Line too long (94 > 80)

Check failure on line 102 in vllm/entrypoints/openai/serving_completion.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

"EngineClient" has no attribute "get_tokenizer_mm"; maybe "get_tokenizer"? [attr-defined]

Check failure on line 102 in vllm/entrypoints/openai/serving_completion.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

"EngineClient" has no attribute "get_tokenizer_mm"; maybe "get_tokenizer"? [attr-defined]

Check failure on line 102 in vllm/entrypoints/openai/serving_completion.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

"EngineClient" has no attribute "get_tokenizer_mm"; maybe "get_tokenizer"? [attr-defined]

Check failure on line 102 in vllm/entrypoints/openai/serving_completion.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

"EngineClient" has no attribute "get_tokenizer_mm"; maybe "get_tokenizer"? [attr-defined]

request_prompts, engine_prompts = self._preprocess_completion(
request,
Expand Down Expand Up @@ -148,6 +148,7 @@ async def create_completion(
engine_prompt,
sampling_params,
request_id_item,
request.model,

Check failure on line 151 in vllm/entrypoints/openai/serving_completion.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Argument 4 to "generate" of "EngineClient" has incompatible type "str"; expected "Optional[LoRARequest]" [arg-type]

Check failure on line 151 in vllm/entrypoints/openai/serving_completion.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Argument 4 to "generate" of "EngineClient" has incompatible type "str"; expected "LoRARequest | None" [arg-type]

Check failure on line 151 in vllm/entrypoints/openai/serving_completion.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Argument 4 to "generate" of "EngineClient" has incompatible type "str"; expected "LoRARequest | None" [arg-type]

Check failure on line 151 in vllm/entrypoints/openai/serving_completion.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Argument 4 to "generate" of "EngineClient" has incompatible type "str"; expected "LoRARequest | None" [arg-type]
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers,
Expand Down

0 comments on commit bc8acd2

Please sign in to comment.