From c68f7ede6a4aef0cd31f531b5d7ec22ab224de95 Mon Sep 17 00:00:00 2001 From: Guillaume Calmettes Date: Wed, 20 Nov 2024 22:42:21 +0100 Subject: [PATCH] [Bugfix]: allow extra fields in requests to openai compatible server (#10463) Signed-off-by: Guillaume Calmettes --- tests/entrypoints/openai/test_chat.py | 26 +++++++++++++------------- vllm/entrypoints/openai/protocol.py | 18 ++++++++++++++++-- 2 files changed, 29 insertions(+), 15 deletions(-) diff --git a/tests/entrypoints/openai/test_chat.py b/tests/entrypoints/openai/test_chat.py index 8d13f64dce01c..843d15e768093 100644 --- a/tests/entrypoints/openai/test_chat.py +++ b/tests/entrypoints/openai/test_chat.py @@ -899,19 +899,19 @@ async def test_response_format_json_schema(client: openai.AsyncOpenAI): @pytest.mark.asyncio -async def test_extra_fields(client: openai.AsyncOpenAI): - with pytest.raises(BadRequestError) as exc_info: - await client.chat.completions.create( - model=MODEL_NAME, - messages=[{ - "role": "system", - "content": "You are a helpful assistant.", - "extra_field": "0", - }], # type: ignore - temperature=0, - seed=0) - - assert "extra_forbidden" in exc_info.value.message +async def test_extra_fields_allowed(client: openai.AsyncOpenAI): + resp = await client.chat.completions.create( + model=MODEL_NAME, + messages=[{ + "role": "user", + "content": "what is 1+1?", + "extra_field": "0", + }], # type: ignore + temperature=0, + seed=0) + + content = resp.choices[0].message.content + assert content is not None @pytest.mark.asyncio diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index b7b064ae01f05..a82212677f63a 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -9,12 +9,15 @@ from typing_extensions import Annotated from vllm.entrypoints.chat_utils import ChatCompletionMessageParam +from vllm.logger import init_logger from vllm.pooling_params import PoolingParams from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams, RequestOutputKind, SamplingParams) from vllm.sequence import Logprob from vllm.utils import random_uuid +logger = init_logger(__name__) + # torch is mocked during docs generation, # so we have to provide the values as literals _MOCK_LONG_INFO = Namespace(min=-9223372036854775808, max=9223372036854775807) @@ -35,8 +38,19 @@ class OpenAIBaseModel(BaseModel): - # OpenAI API does not allow extra fields - model_config = ConfigDict(extra="forbid") + # OpenAI API does allow extra fields + model_config = ConfigDict(extra="allow") + + @model_validator(mode="before") + @classmethod + def __log_extra_fields__(cls, data): + if isinstance(data, dict): + extra_fields = data.keys() - cls.model_fields.keys() + if extra_fields: + logger.warning( + "The following fields were present in the request " + "but ignored: %s", extra_fields) + return data class ErrorResponse(OpenAIBaseModel):