From a54ed8024953dc6b59906072a7a89cd4791ec4f0 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 17 Sep 2024 19:50:37 +0200 Subject: [PATCH] [Model] Add mistral function calling format to all models loaded with "mistral" format (#8515) Co-authored-by: Cyrus Leung --- examples/offline_chat_with_tools.py | 138 ++++++++++++++++++ .../decoder_only/language/test_mistral.py | 67 +++++++++ vllm/entrypoints/llm.py | 6 +- vllm/entrypoints/openai/serving_chat.py | 9 +- vllm/transformers_utils/tokenizers/mistral.py | 8 +- 5 files changed, 219 insertions(+), 9 deletions(-) create mode 100644 examples/offline_chat_with_tools.py diff --git a/examples/offline_chat_with_tools.py b/examples/offline_chat_with_tools.py new file mode 100644 index 0000000000000..e69a6c067e4da --- /dev/null +++ b/examples/offline_chat_with_tools.py @@ -0,0 +1,138 @@ +# ruff: noqa +import json +import random +import string + +from vllm import LLM +from vllm.sampling_params import SamplingParams + +# This script is an offline demo for function calling +# +# If you want to run a server/client setup, please follow this code: +# +# - Server: +# +# ```bash +# vllm serve mistralai/Mistral-7B-Instruct-v0.3 --tokenizer-mode mistral --load-format mistral --config-format mistral +# ``` +# +# - Client: +# +# ```bash +# curl --location 'http://:8000/v1/chat/completions' \ +# --header 'Content-Type: application/json' \ +# --header 'Authorization: Bearer token' \ +# --data '{ +# "model": "mistralai/Mistral-7B-Instruct-v0.3" +# "messages": [ +# { +# "role": "user", +# "content": [ +# {"type" : "text", "text": "Describe this image in detail please."}, +# {"type": "image_url", "image_url": {"url": "https://s3.amazonaws.com/cms.ipressroom.com/338/files/201808/5b894ee1a138352221103195_A680%7Ejogging-edit/A680%7Ejogging-edit_hero.jpg"}}, +# {"type" : "text", "text": "and this one as well. Answer in French."}, +# {"type": "image_url", "image_url": {"url": "https://www.wolframcloud.com/obj/resourcesystem/images/a0e/a0ee3983-46c6-4c92-b85d-059044639928/6af8cfb971db031b.png"}} +# ] +# } +# ] +# }' +# ``` +# +# Usage: +# python demo.py simple +# python demo.py advanced + +model_name = "mistralai/Mistral-7B-Instruct-v0.3" +# or switch to "mistralai/Mistral-Nemo-Instruct-2407" +# or "mistralai/Mistral-Large-Instruct-2407" +# or any other mistral model with function calling ability + +sampling_params = SamplingParams(max_tokens=8192, temperature=0.0) +llm = LLM(model=model_name, + tokenizer_mode="mistral", + config_format="mistral", + load_format="mistral") + + +def generate_random_id(length=9): + characters = string.ascii_letters + string.digits + random_id = ''.join(random.choice(characters) for _ in range(length)) + return random_id + + +# simulate an API that can be called +def get_current_weather(city: str, state: str, unit: 'str'): + return (f"The weather in {city}, {state} is 85 degrees {unit}. It is " + "partly cloudly, with highs in the 90's.") + + +tool_funtions = {"get_current_weather": get_current_weather} + +tools = [{ + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": + "string", + "description": + "The city to find the weather for, e.g. 'San Francisco'" + }, + "state": { + "type": + "string", + "description": + "the two-letter abbreviation for the state that the city is" + " in, e.g. 'CA' which would mean 'California'" + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"] + } + }, + "required": ["city", "state", "unit"] + } + } +}] + +messages = [{ + "role": + "user", + "content": + "Can you tell me what the temperate will be in Dallas, in fahrenheit?" +}] + +outputs = llm.chat(messages, sampling_params=sampling_params, tools=tools) +output = outputs[0].outputs[0].text.strip() + +# append the assistant message +messages.append({ + "role": "assistant", + "content": output, +}) + +# let's now actually parse and execute the model's output simulating an API call by using the +# above defined function +tool_calls = json.loads(output) +tool_answers = [ + tool_funtions[call['name']](**call['arguments']) for call in tool_calls +] + +# append the answer as a tool message and let the LLM give you an answer +messages.append({ + "role": "tool", + "content": "\n\n".join(tool_answers), + "tool_call_id": generate_random_id(), +}) + +outputs = llm.chat(messages, sampling_params, tools=tools) + +print(outputs[0].outputs[0].text.strip()) +# yields +# 'The weather in Dallas, TX is 85 degrees fahrenheit. ' +# 'It is partly cloudly, with highs in the 90's.' diff --git a/tests/models/decoder_only/language/test_mistral.py b/tests/models/decoder_only/language/test_mistral.py index 687ba6a03a691..26f90456849f1 100644 --- a/tests/models/decoder_only/language/test_mistral.py +++ b/tests/models/decoder_only/language/test_mistral.py @@ -4,13 +4,61 @@ """ import pytest +from vllm import SamplingParams + from ...utils import check_logprobs_close MODELS = [ "mistralai/Mistral-7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.3", + # Mistral-Nemo is to big for CI, but passes locally + # "mistralai/Mistral-Nemo-Instruct-2407" ] +SAMPLING_PARAMS = SamplingParams(max_tokens=512, temperature=0.0, logprobs=5) + +# for function calling +TOOLS = [{ + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": + "string", + "description": + "The city to find the weather for, e.g. 'San Francisco'" + }, + "state": { + "type": + "string", + "description": + "the two-letter abbreviation for the state that the city is" + " in, e.g. 'CA' which would mean 'California'" + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"] + } + }, + "required": ["city", "state", "unit"] + } + } +}] +MSGS = [{ + "role": + "user", + "content": ("Can you tell me what the temperate" + " will be in Dallas, in fahrenheit?") +}] +EXPECTED_FUNC_CALL = ( + '[{"name": "get_current_weather", "arguments": ' + '{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]') + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) @@ -81,3 +129,22 @@ def test_mistral_format( name_0="hf", name_1="mistral", ) + + +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("model", MODELS[1:]) # v1 can't do func calling +def test_mistral_function_calling( + vllm_runner, + model: str, + dtype: str, +) -> None: + with vllm_runner(model, + dtype=dtype, + tokenizer_mode="mistral", + config_format="mistral", + load_format="mistral") as vllm_model: + outputs = vllm_model.model.chat(MSGS, + tools=TOOLS, + sampling_params=SAMPLING_PARAMS) + + assert outputs[0].outputs[0].text.strip() == EXPECTED_FUNC_CALL diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index a26b721093521..248b070611cd2 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,5 +1,6 @@ from contextlib import contextmanager -from typing import ClassVar, List, Optional, Sequence, Union, cast, overload +from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Union, cast, + overload) from tqdm import tqdm @@ -357,6 +358,7 @@ def chat( lora_request: Optional[LoRARequest] = None, chat_template: Optional[str] = None, add_generation_prompt: bool = True, + tools: Optional[List[Dict[str, Any]]] = None, ) -> List[RequestOutput]: """ Generate responses for a chat conversation. @@ -401,6 +403,7 @@ def chat( messages=messages, chat_template=chat_template, add_generation_prompt=add_generation_prompt, + tools=tools, ) else: prompt = apply_hf_chat_template( @@ -408,6 +411,7 @@ def chat( conversation=conversation, chat_template=chat_template, add_generation_prompt=add_generation_prompt, + tools=tools, ) inputs: PromptInputs diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 58e42fb5363fb..d28362a12abdb 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -123,7 +123,8 @@ async def create_chat_completion( ] prompt: Union[str, List[int]] - if isinstance(tokenizer, MistralTokenizer): + is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer) + if is_mistral_tokenizer: prompt = apply_mistral_chat_template( tokenizer, messages=request.messages, @@ -159,10 +160,10 @@ async def create_chat_completion( return self.create_error_response( "tool_choice = \"required\" is not supported!") - # "auto" tools requires --enable-auto-tool-choice - # and --tool-call-parser - if request.tool_choice == "auto" and not ( + if not is_mistral_tokenizer and request.tool_choice == "auto" and not ( self.enable_auto_tools and self.tool_parser is not None): + # for hf tokenizers, "auto" tools requires + # --enable-auto-tool-choice and --tool-call-parser return self.create_error_response( "\"auto\" tool choice requires " "--enable-auto-tool-choice and --tool-call-parser to be set") diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index ea1910ed20ec3..7a228a3efa6e8 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -165,10 +165,9 @@ def apply_chat_template(self, messages: List["ChatCompletionMessageParam"], tools: Optional[Dict[str, Any]] = None, **kwargs) -> List[int]: - assert tools is None, "`tools` are not yet supported." - request = ChatCompletionRequest( - messages=messages) # type: ignore[type-var] + request = ChatCompletionRequest(messages=messages, + tools=tools) # type: ignore[type-var] encoded = self.mistral.encode_chat_completion(request) # encode-decode to get clean prompt @@ -176,7 +175,8 @@ def apply_chat_template(self, def convert_tokens_to_string(self, tokens: List[str]) -> str: if isinstance(self.tokenizer, Tekkenizer): - return "".join(tokens) + return "".join(t for t in tokens + if t not in self.tokenizer._all_special_tokens) else: return self.tokenizer.decode(tokens) # type: ignore[arg-type]