Skip to content

Commit

Permalink
Fixed ContextVar issue; added tutorials; upload sync for simplicity
Browse files Browse the repository at this point in the history
  • Loading branch information
DoKu88 committed Nov 22, 2024
1 parent 6ec5831 commit cbd2891
Show file tree
Hide file tree
Showing 12 changed files with 645 additions and 47 deletions.
7 changes: 3 additions & 4 deletions synth_sdk/tracing/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,7 @@ def wrapper(*args, **kwargs):
logger.debug(
f"Incremented partition to: {event.partition_index}"
)

set_current_event(event)
set_current_event(event, decorator_type="sync")
logger.debug(f"Created and set new event: {event_type}")

# Automatically trace function inputs
Expand Down Expand Up @@ -309,7 +308,7 @@ async def async_wrapper(*args, **kwargs):
f"Incremented partition to: {event.partition_index}"
)

set_current_event(event)
set_current_event(event, decorator_type="async")
logger.debug(f"Created and set new event: {event_type}")

# Automatically trace function inputs
Expand Down Expand Up @@ -467,7 +466,7 @@ def trace_system(
"""
def decorator(func: Callable) -> Callable:
# Check if the function is async or sync
if inspect.iscoroutinefunction(func):
if inspect.iscoroutinefunction(func) or inspect.isasyncgenfunction(func):
# Use async tracing
logger.debug("Using async tracing")
async_decorator = trace_system_async(
Expand Down
56 changes: 28 additions & 28 deletions synth_sdk/tracing/events/manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def get_current_event(event_type: str) -> "Event":
return events[event_type]


def set_current_event(event: Optional["Event"]):
def set_current_event(event: Optional["Event"], decorator_type: Literal["sync", "async"]=None):
"""
Set the current event, ending any existing events of the same type.
If event is None, it clears the current event of that type.
Expand All @@ -40,64 +40,64 @@ def set_current_event(event: Optional["Event"]):
except RuntimeError:
is_async = False

if is_async:
from synth_sdk.tracing.local import active_events_var, system_id_var
# Get current active events from context var
active_events = active_events_var.get()

if decorator_type == "sync" or not is_async:
# Original thread-local storage logic
if not hasattr(_local, "active_events"):
_local.active_events = {}
logger.debug("Initialized active_events in thread local storage")

# If there's an existing event of the same type, end it
if event.event_type in active_events:
if event.event_type in _local.active_events:
logger.debug(f"Found existing event of type {event.event_type}")
existing_event = active_events[event.event_type]
existing_event = _local.active_events[event.event_type]
existing_event.closed = time.time()
logger.debug(
f"Closed existing event of type {event.event_type} at {existing_event.closed}"
)

# Store the closed event if system_id is present
system_id = system_id_var.get()
if system_id:
logger.debug(f"Storing closed event for system {system_id}")
if hasattr(_local, "system_id"):
logger.debug(f"Storing closed event for system {_local.system_id}")
try:
event_store.add_event(system_id, existing_event)
event_store.add_event(_local.system_id, existing_event)
logger.debug("Successfully stored closed event")
except Exception as e:
logger.error(f"Failed to store closed event: {str(e)}")
raise

# Set the new event
active_events[event.event_type] = event
active_events_var.set(active_events)
logger.debug("New event set as current in context vars")
else:
# Original thread-local storage logic
if not hasattr(_local, "active_events"):
_local.active_events = {}
logger.debug("Initialized active_events in thread local storage")
_local.active_events[event.event_type] = event
logger.debug("New event set as current in thread local")

else:
from synth_sdk.tracing.local import active_events_var, system_id_var
# Get current active events from context var
active_events = active_events_var.get()

# If there's an existing event of the same type, end it
if event.event_type in _local.active_events:
if event.event_type in active_events:
logger.debug(f"Found existing event of type {event.event_type}")
existing_event = _local.active_events[event.event_type]
existing_event = active_events[event.event_type]
existing_event.closed = time.time()
logger.debug(
f"Closed existing event of type {event.event_type} at {existing_event.closed}"
)

# Store the closed event if system_id is present
if hasattr(_local, "system_id"):
logger.debug(f"Storing closed event for system {_local.system_id}")
system_id = system_id_var.get()
if system_id:
logger.debug(f"Storing closed event for system {system_id}")
try:
event_store.add_event(_local.system_id, existing_event)
event_store.add_event(system_id, existing_event)
logger.debug("Successfully stored closed event")
except Exception as e:
logger.error(f"Failed to store closed event: {str(e)}")
raise

# Set the new event
_local.active_events[event.event_type] = event
logger.debug("New event set as current in thread local")

active_events[event.event_type] = event
active_events_var.set(active_events)
logger.debug("New event set as current in context vars")

def clear_current_event(event_type: str):
if hasattr(_local, "active_events"):
Expand Down
19 changes: 4 additions & 15 deletions synth_sdk/tracing/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,21 +219,10 @@ def upload(dataset: Dataset, traces: List[SystemTrace]=[], verbose: bool = False
questions_json is the formatted questions array
reward_signals_json is the formatted reward signals array
traces_json is the formatted traces array"""
async def upload_wrapper(dataset, traces, verbose, show_payload):
response, payload, dataset, traces = await upload_helper(dataset, traces, verbose, show_payload)

# If we're in an async context (event loop is running)
if is_event_loop_running():
logging.info("Event loop is already running")
# Return the coroutine directly for async contexts
return upload_helper(dataset, traces, verbose, show_payload)
else:
# In sync context, run the coroutine and return the result
logging.info("Event loop is not running")
return asyncio.run(upload_helper(dataset, traces, verbose, show_payload))


async def upload_helper(dataset: Dataset, traces: List[SystemTrace]=[], verbose: bool = False, show_payload: bool = False):

return upload_helper(dataset, traces, verbose, show_payload)

def upload_helper(dataset: Dataset, traces: List[SystemTrace]=[], verbose: bool = False, show_payload: bool = False):
api_key = os.getenv("SYNTH_API_KEY")
if not api_key:
raise ValueError("SYNTH_API_KEY environment variable not set")
Expand Down
197 changes: 197 additions & 0 deletions tutorials/AsyncAgentExample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
from synth_sdk.tracing.decorators import trace_system, _local
from synth_sdk.tracing.trackers import SynthTracker
from synth_sdk.tracing.upload import upload
from synth_sdk.tracing.abstractions import TrainingQuestion, RewardSignal, Dataset
from synth_sdk.tracing.events.store import event_store
import asyncio
import time
import json
import logging
from openai import AsyncOpenAI
from dotenv import load_dotenv
import os

# Load environment variables
load_dotenv()
openai_api_key = os.getenv("OPENAI_API_KEY")
# Load SYNTH_API_KEY environment variables here!

# Configure logging
logging.basicConfig(
level=logging.CRITICAL,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)


class TestAgent:
def __init__(self):
self.system_id = "test_agent_async"
logger.debug("Initializing TestAgent with system_id: %s", self.system_id)
self.client = AsyncOpenAI()
logger.debug("OpenAI client initialized")

@trace_system(
origin="agent",
event_type="lm_call",
manage_event="create",
increment_partition=True,
verbose=True,
)
async def make_lm_call(self, user_message: str) -> str:
# Only pass the user message, not self
SynthTracker.track_state(variable_name="user_message", variable_value=user_message, origin="agent")

logger.debug("Starting LM call with message: %s", user_message)
response = await self.client.chat.completions.create(
model="gpt-4",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": user_message}
]
)

response_text = response.choices[0].message.content
SynthTracker.track_state(variable_name="response", variable_value=response_text, origin="agent")

logger.debug("LM response received: %s", response_text)
time.sleep(0.1)
return response_text

@trace_system(
origin="environment",
event_type="environment_processing",
manage_event="create",
verbose=True,
)
async def process_environment(self, input_data: str) -> dict:
# Only pass the input data, not self
SynthTracker.track_state(variable_name="input_data", variable_value=input_data, origin="environment")

result = {"processed": input_data, "timestamp": time.time()}

SynthTracker.track_state(variable_name="result", variable_value=result, origin="environment")
return result


async def run_test():
logger.info("Starting run_test")
# Create test agent
agent = TestAgent()

try:
# List of test questions
questions = [
"What's the capital of France?",
"What's 2+2?",
"Who wrote Romeo and Juliet?",
]
logger.debug("Test questions initialized: %s", questions)

# Make multiple LM calls with environment processing
responses = []
for i, question in enumerate(questions):
logger.info("Processing question %d: %s", i, question)
try:
# First process in environment
env_result = await agent.process_environment(question)
logger.debug("Environment processing result: %s", env_result)

# Then make LM call
response = await agent.make_lm_call(question)
responses.append(response)
logger.debug("Response received and stored: %s", response)
except Exception as e:
logger.error("Error during processing: %s", str(e), exc_info=True)
continue

logger.info("Creating dataset for upload")
# Create dataset for upload
dataset = Dataset(
questions=[
TrainingQuestion(
intent="Test question",
criteria="Testing tracing functionality",
question_id=f"q{i}",
)
for i in range(len(questions))
],
reward_signals=[
RewardSignal(
question_id=f"q{i}",
system_id=agent.system_id,
reward=1.0,
annotation="Test reward",
)
for i in range(len(questions))
],
)
logger.debug(
"Dataset created with %d questions and %d reward signals",
len(dataset.questions),
len(dataset.reward_signals),
)

# Upload traces
try:
logger.info("Attempting to upload traces")
response, questions_json, reward_signals_json, traces_json = upload(dataset=dataset, verbose=True)
logger.info("Upload successful!")
print("Upload successful!")

# Save JSON files with error handling
try:
with open("tutorials/questions_async.json", "w") as f:
json.dump(questions_json, f)
with open("tutorials/reward_signals_async.json", "w") as f:
json.dump(reward_signals_json, f)
with open("tutorials/traces_async.json", "w") as f:
json.dump(traces_json, f)
except Exception as e:
logger.error(f"Error saving JSON files: {str(e)}")
print(f"Error saving JSON files: {str(e)}")

except Exception as e:
logger.error("Upload failed: %s", str(e), exc_info=True)
print(f"Upload failed: {str(e)}")

# Print debug information
traces = event_store.get_system_traces()
logger.debug("Retrieved %d system traces", len(traces))
print("\nTraces:")
print(json.dumps([trace.to_dict() for trace in traces], indent=2))

print("\nDataset:")
print(json.dumps(dataset.to_dict(), indent=2))
finally:
logger.info("Starting cleanup")
# Cleanup
if hasattr(_local, "active_events"):
for event_type, event in _local.active_events.items():
logger.debug("Cleaning up event: %s", event_type)
if event.closed is None:
event.closed = time.time()
if hasattr(_local, "system_id"):
try:
event_store.add_event(_local.system_id, event)
logger.debug(
"Successfully cleaned up event: %s", event_type
)
except Exception as e:
logger.error(
"Error during cleanup of event %s: %s",
event_type,
str(e),
exc_info=True,
)
print(
f"Error during cleanup of event {event_type}: {str(e)}"
)
logger.info("Cleanup completed")

# Run a sample agent using the async decorator and tracker
if __name__ == "__main__":
logger.info("Starting main execution")
asyncio.run(run_test())
logger.info("Main execution completed")

Loading

0 comments on commit cbd2891

Please sign in to comment.