Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into graphrag_demo
Browse files Browse the repository at this point in the history
  • Loading branch information
marklysze committed Dec 4, 2024
2 parents 5ae2fcc + 0a9e847 commit 4d17b4a
Showing 1 changed file with 35 additions and 6 deletions.
41 changes: 35 additions & 6 deletions autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import logging
import sys
import uuid
from pickle import PickleError, PicklingError
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union

from pydantic import BaseModel, schema_json_of
Expand Down Expand Up @@ -42,6 +41,7 @@
ChoiceDeltaToolCall,
ChoiceDeltaToolCallFunction,
)
from openai.types.chat.parsed_chat_completion import ParsedChatCompletion, ParsedChatCompletionMessage
from openai.types.completion import Completion
from openai.types.completion_usage import CompletionUsage

Expand Down Expand Up @@ -238,6 +238,36 @@ def __init__(self, config):
class OpenAIClient:
"""Follows the Client protocol and wraps the OpenAI client."""

@staticmethod
def _convert_to_chat_completion(parsed: ParsedChatCompletion) -> ChatCompletion:
# Helper function to convert ParsedChatCompletionMessage to ChatCompletionMessage
def convert_message(parsed_message: ParsedChatCompletionMessage) -> ChatCompletionMessage:
return ChatCompletionMessage(
role=parsed_message.role,
content=parsed_message.content,
function_call=parsed_message.function_call,
)

# Convert ParsedChatCompletion to ChatCompletion
return ChatCompletion(
id=parsed.id,
choices=[
Choice(
finish_reason=choice.finish_reason,
index=choice.index,
logprobs=choice.logprobs,
message=convert_message(choice.message), # Parse the message
)
for choice in parsed.choices
],
created=parsed.created,
model=parsed.model,
object=parsed.object,
service_tier=parsed.service_tier,
system_fingerprint=parsed.system_fingerprint,
usage=parsed.usage,
)

def __init__(self, client: Union[OpenAI, AzureOpenAI]):
self._oai_client = client
if (
Expand Down Expand Up @@ -290,7 +320,9 @@ def _create_or_parse(*args, **kwargs):
if "stream" in kwargs:
kwargs.pop("stream")
kwargs["response_format"] = response_format
return self._oai_client.beta.chat.completions.parse(*args, **kwargs)
return OpenAIClient._convert_to_chat_completion(
self._oai_client.beta.chat.completions.parse(*args, **kwargs)
)

create_or_parse = _create_or_parse
else:
Expand Down Expand Up @@ -915,10 +947,7 @@ def yes_or_no_filter(context, response):
if cache_client is not None:
# Cache the response
with cache_client as cache:
try:
cache.set(key, response)
except (PicklingError, AttributeError) as e:
logger.info(f"Failed to cache response: {e}")
cache.set(key, response)

if logging_enabled():
# TODO: log the config_id and pass_filter etc.
Expand Down

0 comments on commit 4d17b4a

Please sign in to comment.