Skip to content

Commit

Permalink
[mypy] Fix mypy warnings in api_server.py (#11941)
Browse files Browse the repository at this point in the history
Signed-off-by: Fred Reiss <[email protected]>
  • Loading branch information
frreiss authored Jan 11, 2025
1 parent d45cbe7 commit c9f09a4
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from contextlib import asynccontextmanager
from functools import partial
from http import HTTPStatus
from typing import AsyncIterator, Optional, Set, Tuple
from typing import AsyncIterator, Dict, Optional, Set, Tuple, Union

import uvloop
from fastapi import APIRouter, FastAPI, HTTPException, Request
Expand Down Expand Up @@ -420,6 +420,8 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
"use the Pooling API (`/pooling`) instead.")

res = await fallback_handler.create_pooling(request, raw_request)

generator: Union[ErrorResponse, EmbeddingResponse]
if isinstance(res, PoolingResponse):
generator = EmbeddingResponse(
id=res.id,
Expand Down Expand Up @@ -494,7 +496,7 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
return await create_score(request, raw_request)


TASK_HANDLERS = {
TASK_HANDLERS: Dict[str, Dict[str, tuple]] = {
"generate": {
"messages": (ChatCompletionRequest, create_chat_completion),
"default": (CompletionRequest, create_completion),
Expand Down Expand Up @@ -652,7 +654,7 @@ async def add_request_id(request: Request, call_next):
module_path, object_name = middleware.rsplit(".", 1)
imported = getattr(importlib.import_module(module_path), object_name)
if inspect.isclass(imported):
app.add_middleware(imported)
app.add_middleware(imported) # type: ignore[arg-type]
elif inspect.iscoroutinefunction(imported):
app.middleware("http")(imported)
else:
Expand Down

0 comments on commit c9f09a4

Please sign in to comment.