From e78b4ab3e3923433944499197253740ec424bb30 Mon Sep 17 00:00:00 2001 From: Sternakt Date: Tue, 3 Dec 2024 13:41:45 +0100 Subject: [PATCH] Add conversion from ParsedChatCompletion to ChatCompletion --- autogen/oai/client.py | 41 +++++++++++++++++++++++++++++++++++------ 1 file changed, 35 insertions(+), 6 deletions(-) diff --git a/autogen/oai/client.py b/autogen/oai/client.py index 02f2f888c2..60ead265fc 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -10,7 +10,6 @@ import logging import sys import uuid -from pickle import PickleError, PicklingError from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union from pydantic import BaseModel, schema_json_of @@ -42,6 +41,7 @@ ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction, ) + from openai.types.chat.parsed_chat_completion import ParsedChatCompletion, ParsedChatCompletionMessage from openai.types.completion import Completion from openai.types.completion_usage import CompletionUsage @@ -238,6 +238,36 @@ def __init__(self, config): class OpenAIClient: """Follows the Client protocol and wraps the OpenAI client.""" + @staticmethod + def _convert_to_chat_completion(parsed: ParsedChatCompletion) -> ChatCompletion: + # Helper function to convert ParsedChatCompletionMessage to ChatCompletionMessage + def convert_message(parsed_message: ParsedChatCompletionMessage) -> ChatCompletionMessage: + return ChatCompletionMessage( + role=parsed_message.role, + content=parsed_message.content, + function_call=parsed_message.function_call, + ) + + # Convert ParsedChatCompletion to ChatCompletion + return ChatCompletion( + id=parsed.id, + choices=[ + Choice( + finish_reason=choice.finish_reason, + index=choice.index, + logprobs=choice.logprobs, + message=convert_message(choice.message), # Parse the message + ) + for choice in parsed.choices + ], + created=parsed.created, + model=parsed.model, + object=parsed.object, + service_tier=parsed.service_tier, + system_fingerprint=parsed.system_fingerprint, + usage=parsed.usage, + ) + def __init__(self, client: Union[OpenAI, AzureOpenAI]): self._oai_client = client if ( @@ -290,7 +320,9 @@ def _create_or_parse(*args, **kwargs): if "stream" in kwargs: kwargs.pop("stream") kwargs["response_format"] = response_format - return self._oai_client.beta.chat.completions.parse(*args, **kwargs) + return OpenAIClient._convert_to_chat_completion( + self._oai_client.beta.chat.completions.parse(*args, **kwargs) + ) create_or_parse = _create_or_parse else: @@ -915,10 +947,7 @@ def yes_or_no_filter(context, response): if cache_client is not None: # Cache the response with cache_client as cache: - try: - cache.set(key, response) - except (PicklingError, AttributeError) as e: - logger.info(f"Failed to cache response: {e}") + cache.set(key, response) if logging_enabled(): # TODO: log the config_id and pass_filter etc.