Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 handle missing function calls for openai (#35) #38

Merged
merged 2 commits into from
Jan 9, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 62 additions & 17 deletions aide/backend/backend_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,28 +19,65 @@
openai.InternalServerError,
)

# (docs) https://platform.openai.com/docs/guides/function-calling/supported-models
SUPPORTED_FUNCTION_CALL_MODELS = {
"gpt-4o",
"gpt-4o-2024-08-06",
"gpt-4o-2024-05-13",
"gpt-4o-mini",
"gpt-4o-mini-2024-07-18",
"gpt-4-turbo",
"gpt-4-turbo-2024-04-09",
"gpt-4-turbo-preview",
"gpt-4-0125-preview",
"gpt-4-1106-preview",
"gpt-3.5-turbo",
"gpt-3.5-turbo-0125",
"gpt-3.5-turbo-1106",
}


@once
def _setup_openai_client():
global _client
_client = openai.OpenAI(max_retries=0)


def is_function_call_supported(model_name: str) -> bool:
"""Return True if the model supports function calling."""
return model_name in SUPPORTED_FUNCTION_CALL_MODELS


def query(
system_message: str | None,
user_message: str | None,
func_spec: FunctionSpec | None = None,
**model_kwargs,
) -> tuple[OutputType, float, int, int, dict]:
"""
Query the OpenAI API, optionally with function calling.
Function calling support is only checked for feedback/review operations.
"""
_setup_openai_client()
filtered_kwargs: dict = select_values(notnone, model_kwargs) # type: ignore
filtered_kwargs: dict = select_values(notnone, model_kwargs)
model_name = filtered_kwargs.get("model", "")
logger.debug(f"OpenAI query called with model='{model_name}'")

messages = opt_messages_to_list(system_message, user_message)

if func_spec is not None:
filtered_kwargs["tools"] = [func_spec.as_openai_tool_dict]
# force the model the use the function
filtered_kwargs["tool_choice"] = func_spec.openai_tool_choice_dict
# Only check function call support for feedback/search operations
if func_spec.name == "submit_review":
if not is_function_call_supported(model_name):
logger.warning(
f"Review function calling was requested, but model '{model_name}' "
"does not support function calling. Falling back to plain text generation."
)
filtered_kwargs.pop("tools", None)
filtered_kwargs.pop("tool_choice", None)
else:
filtered_kwargs["tools"] = [func_spec.as_openai_tool_dict]
filtered_kwargs["tool_choice"] = func_spec.openai_tool_choice_dict

t0 = time.time()
completion = backoff_create(
Expand All @@ -53,22 +90,30 @@ def query(

choice = completion.choices[0]

if func_spec is None:
if func_spec is None or "tools" not in filtered_kwargs:
output = choice.message.content
else:
assert (
choice.message.tool_calls
), f"function_call is empty, it is not a function call: {choice.message}"
assert (
choice.message.tool_calls[0].function.name == func_spec.name
), "Function name mismatch"
try:
output = json.loads(choice.message.tool_calls[0].function.arguments)
except json.JSONDecodeError as e:
logger.error(
f"Error decoding the function arguments: {choice.message.tool_calls[0].function.arguments}"
tool_calls = getattr(choice.message, "tool_calls", None)

if not tool_calls:
logger.warning(
f"No function call used despite function spec. Fallback to text. "
f"Message content: {choice.message.content}"
)
output = choice.message.content
else:
first_call = tool_calls[0]
assert first_call.function.name == func_spec.name, (
f"Function name mismatch: expected {func_spec.name}, "
f"got {first_call.function.name}"
)
raise e
try:
output = json.loads(first_call.function.arguments)
except json.JSONDecodeError as e:
logger.error(
f"Error decoding function arguments:\n{first_call.function.arguments}"
)
raise e

in_tokens = completion.usage.prompt_tokens
out_tokens = completion.usage.completion_tokens
Expand Down
Loading