From 80fde42611736bf88f138be39ad3417d722eda42 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 24 Dec 2024 07:20:13 +0000 Subject: [PATCH] Fix fallback Signed-off-by: DarkLight1337 --- vllm/entrypoints/openai/api_server.py | 44 +++++++++++++++++++-------- 1 file changed, 32 insertions(+), 12 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index e8ebb127704f8..3e50613a73dd3 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -45,7 +45,9 @@ DetokenizeRequest, DetokenizeResponse, EmbeddingRequest, - EmbeddingResponse, ErrorResponse, + EmbeddingResponse, + EmbeddingResponseData, + ErrorResponse, LoadLoraAdapterRequest, PoolingRequest, PoolingResponse, ScoreRequest, ScoreResponse, @@ -401,18 +403,36 @@ async def create_completion(request: CompletionRequest, raw_request: Request): async def create_embedding(request: EmbeddingRequest, raw_request: Request): handler = embedding(raw_request) if handler is None: - if base(raw_request).model_config.runner_type == "pooling": - logger.warning( - "Embeddings API will become exclusive to embedding models " - "in a future release. To return the hidden states directly, " - "use the Pooling API (`/pooling`) instead.") - - return await create_pooling(request, raw_request) - - return base(raw_request).create_error_response( - message="The model does not support Embeddings API") + fallback_handler = pooling(raw_request) + if fallback_handler is None: + return base(raw_request).create_error_response( + message="The model does not support Embeddings API") + + logger.warning( + "Embeddings API will become exclusive to embedding models " + "in a future release. To return the hidden states directly, " + "use the Pooling API (`/pooling`) instead.") + + res = await fallback_handler.create_pooling(request, raw_request) + if isinstance(res, PoolingResponse): + generator = EmbeddingResponse( + id=res.id, + object=res.object, + created=res.created, + model=res.model, + data=[ + EmbeddingResponseData( + index=d.index, + embedding=d.data, # type: ignore + ) for d in res.data + ], + usage=res.usage, + ) + else: + generator = res + else: + generator = await handler.create_embedding(request, raw_request) - generator = await handler.create_embedding(request, raw_request) if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code)