Skip to content

Commit

Permalink
Update OpenAI integration to require SDK v1 and enhance failure hook …
Browse files Browse the repository at this point in the history
…with automatic retries and improved message formatting
  • Loading branch information
safoinme committed Oct 31, 2024
1 parent d15672f commit 46a9ae0
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 16 deletions.
2 changes: 1 addition & 1 deletion src/zenml/integrations/openai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class OpenAIIntegration(Integration):
"""Definition of OpenAI integration for ZenML."""

NAME = OPEN_AI
REQUIREMENTS = ["openai>=0.27.0,<1.0.0"]
REQUIREMENTS = ["openai>=1.0.0"]


OpenAIIntegration.check_installation()
52 changes: 37 additions & 15 deletions src/zenml/integrations/openai/hooks/open_ai_failure_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@

import io
import sys
from typing import NoReturn, Optional

import openai
from openai import OpenAI
from rich.console import Console

from zenml import get_step_context
Expand All @@ -38,6 +39,9 @@ def openai_alerter_failure_hook_helper(
Args:
exception: The exception that was raised.
model_name: The OpenAI model to use for the chatbot.
Note:
This implementation uses the OpenAI v1 SDK with automatic retries and backoff.
"""
client = Client()
context = get_step_context()
Expand All @@ -47,12 +51,13 @@ def openai_alerter_failure_hook_helper(
openai_secret = client.get_secret(
"openai", allow_partial_name_match=False
)
openai_api_key = openai_secret.secret_values.get("api_key")
openai_api_key: Optional[str] = openai_secret.secret_values.get("api_key")
except (KeyError, NotImplementedError):
openai_api_key = None

alerter = client.active_stack.alerter
if alerter and openai_api_key:
# Capture rich traceback
output_captured = io.StringIO()
original_stdout = sys.stdout
sys.stdout = output_captured
Expand All @@ -62,25 +67,42 @@ def openai_alerter_failure_hook_helper(
sys.stdout = original_stdout
rich_traceback = output_captured.getvalue()

response = openai.ChatCompletion.create( # type: ignore
# Initialize OpenAI client with timeout and retry settings
openai_client = OpenAI(
api_key=openai_api_key,
max_retries=3, # Will retry 3 times with exponential backoff
timeout=60.0, # 60 second timeout
)

# Create chat completion using the new client pattern
response = openai_client.chat.completions.create(
model=model_name,
messages=[
{
"role": "user",
"content": f"This is an error message (following an exception of type '{type(exception)}') I encountered while executing a ZenML step. Please suggest ways I might fix the problem. Feel free to give code snippets as examples, and note that your response will be piped to a Slack bot so make sure the formatting is appropriate: {exception} -- {rich_traceback}. Thank you!",
"content": f"This is an error message (following an exception of type '{type(exception)}') "
f"I encountered while executing a ZenML step. Please suggest ways I might fix the problem. "
f"Feel free to give code snippets as examples, and note that your response will be piped "
f"to a Slack bot so make sure the formatting is appropriate: {exception} -- {rich_traceback}. "
f"Thank you!",
}
],
)
suggestion = response["choices"][0]["message"]["content"]
message = "*Failure Hook Notification! Step failed!*" + "\n\n"
message += f"Run name: `{context.pipeline_run.name}`" + "\n"
message += f"Step name: `{context.step_run.name}`" + "\n"
message += f"Parameters: `{context.step_run.config.parameters}`" + "\n"
message += f"Exception: `({type(exception)}) {exception}`" + "\n\n"
message += (
f"*OpenAI ChatGPT's suggestion (model = `{model_name}`) on how to fix it:*\n `{suggestion}`"
+ "\n"
)

suggestion = response.choices[0].message.content

# Format the alert message
message = "\n".join([
"*Failure Hook Notification! Step failed!*",
"",
f"Run name: `{context.pipeline_run.name}`",
f"Step name: `{context.step_run.name}`",
f"Parameters: `{context.step_run.config.parameters}`",
f"Exception: `({type(exception)}) {exception}`",
"",
f"*OpenAI ChatGPT's suggestion (model = `{model_name}`) on how to fix it:*\n `{suggestion}`",
])

alerter.post(message)
elif not openai_api_key:
logger.warning(
Expand Down Expand Up @@ -111,4 +133,4 @@ def openai_gpt4_alerter_failure_hook(
Args:
exception: The exception that was raised.
"""
openai_alerter_failure_hook_helper(exception, "gpt-4")
openai_alerter_failure_hook_helper(exception, "gpt-4")

0 comments on commit 46a9ae0

Please sign in to comment.