Skip to content

Commit

Permalink
Fix fallback
Browse files Browse the repository at this point in the history
Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 committed Dec 24, 2024
1 parent 6ff8b70 commit 80fde42
Showing 1 changed file with 32 additions and 12 deletions.
44 changes: 32 additions & 12 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@
DetokenizeRequest,
DetokenizeResponse,
EmbeddingRequest,
EmbeddingResponse, ErrorResponse,
EmbeddingResponse,
EmbeddingResponseData,
ErrorResponse,
LoadLoraAdapterRequest,
PoolingRequest, PoolingResponse,
ScoreRequest, ScoreResponse,
Expand Down Expand Up @@ -401,18 +403,36 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
handler = embedding(raw_request)
if handler is None:
if base(raw_request).model_config.runner_type == "pooling":
logger.warning(
"Embeddings API will become exclusive to embedding models "
"in a future release. To return the hidden states directly, "
"use the Pooling API (`/pooling`) instead.")

return await create_pooling(request, raw_request)

return base(raw_request).create_error_response(
message="The model does not support Embeddings API")
fallback_handler = pooling(raw_request)
if fallback_handler is None:
return base(raw_request).create_error_response(
message="The model does not support Embeddings API")

logger.warning(
"Embeddings API will become exclusive to embedding models "
"in a future release. To return the hidden states directly, "
"use the Pooling API (`/pooling`) instead.")

res = await fallback_handler.create_pooling(request, raw_request)
if isinstance(res, PoolingResponse):
generator = EmbeddingResponse(
id=res.id,
object=res.object,
created=res.created,
model=res.model,
data=[
EmbeddingResponseData(
index=d.index,
embedding=d.data, # type: ignore
) for d in res.data
],
usage=res.usage,
)
else:
generator = res
else:
generator = await handler.create_embedding(request, raw_request)

generator = await handler.create_embedding(request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
Expand Down

0 comments on commit 80fde42

Please sign in to comment.