From e656b06882a9180cb5159ad3bf3fd15b08db4e33 Mon Sep 17 00:00:00 2001 From: Varun Vinayak Shenoy Date: Fri, 22 Nov 2024 21:13:29 -0800 Subject: [PATCH] [Bugfix] Internal Server Error when tool_choice is incorrect. (#10567) Signed-off-by: Varun Shenoy Signed-off-by: Maxime Fournioux <55544262+mfournioux@users.noreply.github.com> --- tests/entrypoints/openai/test_chat.py | 14 ++++++++++++++ vllm/entrypoints/openai/protocol.py | 12 ++++++------ 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/tests/entrypoints/openai/test_chat.py b/tests/entrypoints/openai/test_chat.py index 843d15e768093..8d23a2be6f9bb 100644 --- a/tests/entrypoints/openai/test_chat.py +++ b/tests/entrypoints/openai/test_chat.py @@ -829,6 +829,20 @@ async def test_inconsistent_tool_choice_and_tools(client: openai.AsyncOpenAI, "name": "nondefined_function_name" } }) + with pytest.raises(openai.BadRequestError): + await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_completion_tokens=1000, + tools=[{ + "type": "function", + "function": { + "name": "dummy_function_name", + "description": "This is a dummy function", + "parameters": sample_json_schema + } + }], + tool_choice={}) @pytest.mark.asyncio diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 9db5951e5fe5b..f343732174014 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -478,17 +478,17 @@ def check_tool_usage(cls, data): # it matches a valid tool if isinstance(data["tool_choice"], dict): valid_tool = False - specified_function = data["tool_choice"]["function"] + specified_function = data["tool_choice"].get("function") if not specified_function: raise ValueError( - "Incorrectly formatted `tool_choice`. Should be like " - "`{\"type\": \"function\"," + "Expected field `function` in `tool_choice`." + " Correct usage: `{\"type\": \"function\"," " \"function\": {\"name\": \"my_function\"}}`") - specified_function_name = specified_function["name"] + specified_function_name = specified_function.get("name") if not specified_function_name: raise ValueError( - "Incorrectly formatted `tool_choice`. Should be like " - "`{\"type\": \"function\", " + "Expected field `name` in `function` in `tool_choice`." + "Correct usage: `{\"type\": \"function\", " "\"function\": {\"name\": \"my_function\"}}`") for tool in data["tools"]: if tool["function"]["name"] == specified_function_name: