From a3b4117c69b3a3cbad8521ba0fa2c4ddce0483e3 Mon Sep 17 00:00:00 2001 From: cedonley Date: Thu, 5 Dec 2024 18:28:39 +0000 Subject: [PATCH] [Bugfix] Added additional failure checks and fixed mistral tool_id generator to be consistent with non-streaming Signed-off-by: cedonley --- .../entrypoints/openai/tool_parsers/hermes_tool_parser.py | 8 ++++++++ .../openai/tool_parsers/mistral_tool_parser.py | 3 +-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py index f4d8c654b0655..869d15ac359ea 100644 --- a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -199,6 +199,11 @@ def extract_tool_calls_streaming( # case -- the current tool call is being closed. elif (cur_tool_start_count == cur_tool_end_count and cur_tool_end_count >= prev_tool_end_count): + if (self.prev_tool_call_arr is None + or len(self.prev_tool_call_arr) == 0): + logger.debug( + "attempting to close tool call, but no tool call") + return None diff = self.prev_tool_call_arr[self.current_tool_id].get( "arguments") if diff: @@ -236,6 +241,9 @@ def extract_tool_calls_streaming( except partial_json_parser.core.exceptions.MalformedJSON: logger.debug('not enough tokens to parse into JSON yet') return None + except json.decoder.JSONDecodeError: + logger.debug("unable to parse JSON") + return None # case - we haven't sent the tool name yet. If it's available, send # it. otherwise, wait until it's available. diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index d2b609dc4dde3..1738edcaf7f15 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -19,7 +19,6 @@ extract_intermediate_diff) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer -from vllm.utils import random_uuid logger = init_logger(__name__) @@ -233,7 +232,7 @@ def extract_tool_calls_streaming( delta = DeltaMessage(tool_calls=[ DeltaToolCall(index=self.current_tool_id, type="function", - id=f"chatcmpl-tool-{random_uuid()}", + id="".join(choices(ALPHANUMERIC, k=9)), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True))