Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tool parsing] Improve / correct mistral tool parsing #10333

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 82 additions & 11 deletions tests/models/decoder_only/language/test_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@

Run `pytest tests/models/test_mistral.py`.
"""
import copy

import pytest

from vllm import SamplingParams
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( # noqa
MistralToolParser)

from ...utils import check_logprobs_close

Expand Down Expand Up @@ -58,17 +62,69 @@
},
"required": ["city", "state", "unit"]
}
},
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make test much more difficult, complex to show the community to what extent function calling can be used with Mistral models

}, {
"type": "function",
"function": {
"name": "rewrite",
"description": "Rewrites text",
"parameters": {
"type": "object",
"required": [],
"properties": {
"text": {
"type": "string",
"description": "The input text to rewrite."
}
}
}
}
}]
MSGS = [{
"role":
"user",
"content": ("Can you tell me what the temperate"
" will be in Dallas, in fahrenheit?")
}]
EXPECTED_FUNC_CALL = (
'[{"name": "get_current_weather", "arguments": '
'{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}}]')
MSGS = [
{
"role": "system",
"content": "You are an assistant."
},
{
"role":
"user",
"content":
"Could you please rewrite the below article? \n\n My English needs improvving, maybe I make errors." # noqa
},
{
"role":
"assistant",
"content":
"",
"tool_calls": [{
"id": "bbc5b7ede",
"type": "function",
"function": {
"name":
"rewrite",
"arguments":
'{\"text\":\"My English needs improvving, maybe I make errors.\"}' # noqa
}
}]
},
{
"role": "tool",
"content":
"{\"action\":\"rewrite\",\"outcome\":\"My English needs improving, maybe I make errors.\"}", # noqa
"tool_call_id": "bbc5b7ede",
"name": "rewrite"
},
{
"role": "assistant",
"content": "---\n\nMy English needs improving, maybe I make errors"
},
{
"role":
"user",
"content": ("Can you tell me what the temperate"
" will be in Dallas, in fahrenheit?")
}
]


@pytest.mark.parametrize("model", MODELS)
Expand Down Expand Up @@ -175,8 +231,23 @@ def test_mistral_function_calling(
tokenizer_mode="mistral",
config_format="mistral",
load_format="mistral") as vllm_model:
outputs = vllm_model.model.chat(MSGS,

msgs = copy.deepcopy(MSGS)
outputs = vllm_model.model.chat(msgs,
tools=TOOLS,
sampling_params=SAMPLING_PARAMS)

assert outputs[0].outputs[0].text.strip() == EXPECTED_FUNC_CALL
tokenizer = vllm_model.model.get_tokenizer()
tool_parser = MistralToolParser(tokenizer)

model_output = outputs[0].outputs[0].text.strip()
assert model_output.startswith(tool_parser.bot_token), model_output
parsed_message = tool_parser.extract_tool_calls(model_output, None)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cleaner to let the parser take care of correctly extracting the dict


assert parsed_message.tools_called
assert parsed_message.tool_calls[0].id == "0UAqFzWsD"
assert parsed_message.tool_calls[
0].function.name == "get_current_weather"
assert parsed_message.tool_calls[
0].function.arguments == '{"city": "Dallas", "state": "TX", "unit": "fahrenheit"}' # noqa
assert parsed_message.content is None
39 changes: 5 additions & 34 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.transformers_utils.tokenizers import maybe_serialize_tool_calls
from vllm.utils import iterate_with_cancellation

logger = init_logger(__name__)
Expand Down Expand Up @@ -127,41 +128,11 @@ async def create_chat_completion(
return self.create_error_response(
"tool_choice = \"required\" is not supported!")

# NOTE: There is currently a bug in pydantic where attributes
# declared as iterables are replaced in in the instances by
# pydantic-core ValidatorIterator instance. In particular, this
# affects tool_calls defined in ChatCompletionAssistantMessageParam
# model:
# see:
# - https://github.com/pydantic/pydantic/issues/9467
# As a result, tool_calls from assistant messages are never
# deserialized in the request object if the tool_calls iterator is
# not consumed. This affect messages passed to the MistralTokenizer
# since no chat template is applied and therefore the tools_calls
# iterator is not directly consumed.
# Issue is tracked on Pydantic side, with resolution planned for
# v2.11 release. In the meantime, the official workaround is to
# consume the iterator so the tool_calls are correctly deserialized
# in the OpenAI ChatCompletionAssistantMessageParam object
# https://github.com/pydantic/pydantic/issues/9467#issuecomment-2442097291 # noqa: E501
# Official Pydantic Issues:
# - https://github.com/pydantic/pydantic/issues/9541
# TODO: remove when pydantic v2.11 is released
# because of issues with pydantic we need to potentially
# re-serialize the tool_calls field of the request
# for more info: see comment in `maybe_serialize_tool_calls`
if isinstance(tokenizer, MistralTokenizer):
for i, message in enumerate(request.messages):
if message.get("role") == 'assistant':
tool_calls_validator = message.get(
"tool_calls", ().__iter__())
validated_tool_calls = []
while True:
try:
tool_call = next(
tool_calls_validator) # type: ignore
validated_tool_calls.append(tool_call)
except StopIteration:
break
request.messages[i][
"tool_calls"] = validated_tool_calls
maybe_serialize_tool_calls(request)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moving this out of serving_chat.py just to clean the method a bit. This is a very general method and the error correction here is very mistral specific, so probably better placed in tokenizers.mistral.py

Copy link
Contributor

@gcalmettes gcalmettes Nov 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point!

I had originally thought about putting it directly in the Mistral Tokenizer but did not in the end because the same problem would occur for any other futur models having a tokenizer not relying on jinja chat templates (none right now, so this was highly hypothetical).
Factoring the logic in the function like you did is a good solution that would still work with other non-chat-template models 👍


if (request.tool_choice == "auto" and
not (self.enable_auto_tools and tool_parser is not None)
Expand Down
25 changes: 17 additions & 8 deletions vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(self, tokenizer: AnyTokenizer):
] # map what has been streamed for each tool so far to a list
self.bot_token = "[TOOL_CALLS]"
self.bot_token_id = self.vocab.get(self.bot_token)
self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL)
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
if self.bot_token_id is None:
raise RuntimeError(
"Mistral Tool Parser could not locate the tool call token in "
Expand All @@ -84,16 +84,25 @@ def extract_tool_calls(
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)

# first remove the BOT token
tool_content = model_output.replace(self.bot_token, "").strip()

try:

# use a regex to find the tool call. remove the BOT token
# and make sure to replace single quotes with double quotes
raw_tool_call = self.tool_call_regex.findall(
model_output.replace(self.bot_token, ""))[0]
# we first try to directly load the json as parsing very nested
# jsons is difficult
try:
function_call_arr = json.loads(tool_content)
except json.JSONDecodeError:
# use a regex to find the part corresponding to the tool call.
# NOTE: This use case should not happen if the model is trained
# correctly. It's a easy possible fix so it's included, but
# can be brittle for very complex / highly nested tool calls
raw_tool_call = self.tool_call_regex.findall(tool_content)[0]
function_call_arr = json.loads(raw_tool_call)

# load the JSON, and then use it to build the Function and
# Tool Call
function_call_arr = json.loads(raw_tool_call)
tool_calls: List[MistralToolCall] = [
MistralToolCall(
type="function",
Expand All @@ -116,7 +125,7 @@ def extract_tool_calls(
# return information to just treat the tool call as regular JSON
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
content=model_output)
content=tool_content)

def extract_tool_calls_streaming(
self,
Expand Down
4 changes: 2 additions & 2 deletions vllm/transformers_utils/tokenizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .mistral import MistralTokenizer
from .mistral import MistralTokenizer, maybe_serialize_tool_calls

__all__ = ["MistralTokenizer"]
__all__ = ["MistralTokenizer", "maybe_serialize_tool_calls"]
70 changes: 66 additions & 4 deletions vllm/transformers_utils/tokenizers/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import huggingface_hub
from huggingface_hub import HfApi, hf_hub_download
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.base import SpecialTokens
# yapf: disable
from mistral_common.tokens.tokenizers.mistral import (
MistralTokenizer as PublicMistralTokenizer)
Expand All @@ -29,6 +30,43 @@ class Encoding:
input_ids: List[int]


def maybe_serialize_tool_calls(request: ChatCompletionRequest):
# SEE: https://github.com/vllm-project/vllm/pull/9951
# Credits go to: @gcalmettes
# NOTE: There is currently a bug in pydantic where attributes
# declared as iterables are replaced in in the instances by
# pydantic-core ValidatorIterator instance. In particular, this
# affects tool_calls defined in ChatCompletionAssistantMessageParam
# model:
# see:
# - https://github.com/pydantic/pydantic/issues/9467
# As a result, tool_calls from assistant messages are never
# deserialized in the request object if the tool_calls iterator is
# not consumed. This affect messages passed to the MistralTokenizer
# since no chat template is applied and therefore the tools_calls
# iterator is not directly consumed.
# Issue is tracked on Pydantic side, with resolution planned for
# v2.11 release. In the meantime, the official workaround is to
# consume the iterator so the tool_calls are correctly deserialized
# in the OpenAI ChatCompletionAssistantMessageParam object
# https://github.com/pydantic/pydantic/issues/9467#issuecomment-2442097291 # noqa: E501
# Official Pydantic Issues:
# - https://github.com/pydantic/pydantic/issues/9541
# TODO: remove when pydantic v2.11 is released
for i, message in enumerate(request.messages):
if message.get("role") == 'assistant':
tool_calls_validator = message.get("tool_calls", ().__iter__())
validated_tool_calls = []
while True:
try:
tool_call = next(tool_calls_validator) # type: ignore
validated_tool_calls.append(tool_call)
except StopIteration:
break

request.messages[i]["tool_calls"] = validated_tool_calls


def list_local_repo_files(repo_id: str, revision: Optional[str]) -> List[str]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As proposed by @gcalmettes here: #9059 (comment)

We don't parse away the [TOOL_CALLS] token for neither tekken nor spm so that function calls can be correctly parsed.

repo_cache = os.path.join(
huggingface_hub.constants.HF_HUB_CACHE,
Expand Down Expand Up @@ -222,7 +260,8 @@ def convert_tokens_to_string(self, tokens: List[str]) -> str:
if self.is_tekken:
tokens = [
t for t in tokens
if t not in self.tokenizer._all_special_tokens
if (t is SpecialTokens.tool_calls
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that after further testing on my end, I found a edge case where not skipping the [TOOL_CALLS] token here can potentially mess up the intended output:

  • when requiring structured output by specifying response_format=json_object or response_format=json_schema, the [TOOL_CALL] token is still emitted in some cases even though we are not providing any tools to the model, and therefore the generated output is no more compliant with json. I have tested and observed this with all the vllm supported structured output backends (lm-format-enforcer / outlines). Note that this only happens if there is no mention that we expect JSON responses from the model in the system prompt.

If we can find a way to not filter out the SpecialTokens.tool_calls token only when function calling is required (based on the presence of tools in the request for example), that would be best. However I haven't found a clean way yet to pass this information to the convert_tokens_to_string method without having to change the signature of the method ...

I have an easy reproducible example of this problem that I can share to you.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the note! Would be great if you could share an easy repro

Copy link
Contributor

@gcalmettes gcalmettes Nov 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten please find below a scenario were it will break (and further below the small change in prompt that would make the code work, because of added guidance to the model). Note that the code requires lm-format_enforcer version 0.10.9 so it is compatible with the MistralTokenizer.

However, after further investigation, I know now how to fix it (I'm preparing a PR, I'll tag you for your review) ! In fact the problem was present before but "masked" by the fact that the [TOOL_CALL] was skipped in the convert_tokens_to_string method, so your PR made possible to expose the problem 😉 . (the root cause is that all the structured output librairies filter out the special tokens to build their tree of possible tokens, e.g.: this check in lm-format-enforcer but the current vllm MistralTokenizer does not correctly populate the methods that the librairies use for that. The fix is easy, and I have tested it with success.)

"""
vllm server started with the following arguments:
    --guided-decoding-backend=lm-format-enforcer 
    --enable-auto-tool-choice 
    --tool-call-parser=mistral 
    --tokenizer-mode=mistral
"""

from openai import OpenAI
from pydantic import BaseModel

client = OpenAI(
    base_url="http://localhost:8000/v1",
    api_key="none",
)

class CalendarEvent(BaseModel):
    name: str
    date: str
    participants: list[str]

completion = client.beta.chat.completions.parse(
    model="mistralai/Pixtral-12B-2409",
    messages=[
        {"role": "system", "content": "Extract the event information."},
        {"role": "user", "content": "Alice and Bob are going to a science fair on Friday."},
    ],
    response_format=CalendarEvent,
)

# the response will break as `[TOOL_CALLS]` is present at the beginning of the response
event = completion.choices[0].message.parsed
print(event.__dict__)

Guiding the model to output JSON by changing the system prompt as below is enough so that the model actually does not produce a tool_call token :

{"role": "system", "content": "Extract the event information. Respond as JSON."},

or t not in self.tokenizer._all_special_tokens)
]

if any(isinstance(t, bytes) for t in tokens):
Expand All @@ -246,7 +285,27 @@ def _token_to_id(t: str):
else:
decoded = "".join(tokens)
else:
decoded = self.tokenizer.decode(tokens) # type: ignore[arg-type]
# make sure certain special tokens like Tool calls are
# not decoded
special_tokens = {SpecialTokens.tool_calls}
regular_tokens: List[str] = []
decoded_list = []

for token in tokens:
if token in special_tokens:
if regular_tokens:
decoded_list.append(
self.tokenizer.decode(regular_tokens))
regular_tokens = []
decoded_list.append(token)
else:
regular_tokens.append(token)

if regular_tokens:
decoded_list.append(
self.decode(regular_tokens)) # type: ignore

decoded = ''.join(decoded_list)

return decoded

Expand Down Expand Up @@ -274,8 +333,11 @@ def convert_ids_to_tokens(
assert self.is_tekken or self.is_spm, type(self.tokenizer)

if self.is_tekken:
# skip special tokens
ids = [i for i in ids if i > self.tokenizer.num_special_tokens]
# skip special tokens except tool call
ids = [
i for i in ids if i > self.tokenizer.num_special_tokens or i ==
self.tokenizer.get_control_token(SpecialTokens.tool_calls)
]

tokens = [self.tokenizer.id_to_piece(id) for id in ids]

Expand Down