Skip to content

Commit

Permalink
[Bugfix]: allow extra fields in requests to openai compatible server (#…
Browse files Browse the repository at this point in the history
…10463)

Signed-off-by: Guillaume Calmettes <[email protected]>
  • Loading branch information
gcalmettes authored Nov 20, 2024
1 parent 0cd3d97 commit c68f7ed
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 15 deletions.
26 changes: 13 additions & 13 deletions tests/entrypoints/openai/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,19 +899,19 @@ async def test_response_format_json_schema(client: openai.AsyncOpenAI):


@pytest.mark.asyncio
async def test_extra_fields(client: openai.AsyncOpenAI):
with pytest.raises(BadRequestError) as exc_info:
await client.chat.completions.create(
model=MODEL_NAME,
messages=[{
"role": "system",
"content": "You are a helpful assistant.",
"extra_field": "0",
}], # type: ignore
temperature=0,
seed=0)

assert "extra_forbidden" in exc_info.value.message
async def test_extra_fields_allowed(client: openai.AsyncOpenAI):
resp = await client.chat.completions.create(
model=MODEL_NAME,
messages=[{
"role": "user",
"content": "what is 1+1?",
"extra_field": "0",
}], # type: ignore
temperature=0,
seed=0)

content = resp.choices[0].message.content
assert content is not None


@pytest.mark.asyncio
Expand Down
18 changes: 16 additions & 2 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@
from typing_extensions import Annotated

from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.logger import init_logger
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
RequestOutputKind, SamplingParams)
from vllm.sequence import Logprob
from vllm.utils import random_uuid

logger = init_logger(__name__)

# torch is mocked during docs generation,
# so we have to provide the values as literals
_MOCK_LONG_INFO = Namespace(min=-9223372036854775808, max=9223372036854775807)
Expand All @@ -35,8 +38,19 @@


class OpenAIBaseModel(BaseModel):
# OpenAI API does not allow extra fields
model_config = ConfigDict(extra="forbid")
# OpenAI API does allow extra fields
model_config = ConfigDict(extra="allow")

@model_validator(mode="before")
@classmethod
def __log_extra_fields__(cls, data):
if isinstance(data, dict):
extra_fields = data.keys() - cls.model_fields.keys()
if extra_fields:
logger.warning(
"The following fields were present in the request "
"but ignored: %s", extra_fields)
return data


class ErrorResponse(OpenAIBaseModel):
Expand Down

0 comments on commit c68f7ed

Please sign in to comment.