From 840a4dd825748e45e83fe26a4641e9d1a18f5309 Mon Sep 17 00:00:00 2001 From: Josh Purtell Date: Thu, 19 Dec 2024 21:41:55 -0800 Subject: [PATCH] update sdk --- pyproject.toml | 3 +- requirements.txt | 1 - setup.py | 2 +- synth_sdk/tracing/d.py | 561 ++++++++++++++++++ synth_sdk/tracing/decorators.py | 32 +- .../records/episode_classic_0.json | 2 +- .../records/episode_classic_2.json | 2 +- 7 files changed, 573 insertions(+), 30 deletions(-) create mode 100644 synth_sdk/tracing/d.py diff --git a/pyproject.toml b/pyproject.toml index 6b08933..660355d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "synth-sdk" -version = "0.2.103" +version = "0.2.105" description = "" authors = [{name = "Synth AI", email = "josh@usesynth.ai"}] license = {text = "MIT"} @@ -22,7 +22,6 @@ dependencies = [ "pytest>=8.3.3", "pydantic-openapi-schema>=1.5.1", "pytest-asyncio>=0.24.0", - "apropos-ai>=0.4.5", "boto3>=1.35.71", "botocore>=1.35.71", "tqdm>=4.66.4", diff --git a/requirements.txt b/requirements.txt index 3eabcfd..571a290 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ -#apropos-ai==0.4.5 fastapi>=0.115.0 opentelemetry-api>=1.27.0 opentelemetry-instrumentation>=0.48b0 diff --git a/setup.py b/setup.py index 1c6d20c..48feabe 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name="synth-sdk", - version="0.2.103", + version="0.2.104", packages=find_packages(), install_requires=[ "opentelemetry-api", diff --git a/synth_sdk/tracing/d.py b/synth_sdk/tracing/d.py new file mode 100644 index 0000000..9a4d508 --- /dev/null +++ b/synth_sdk/tracing/d.py @@ -0,0 +1,561 @@ +# synth_sdk/tracing/decorators.py +import inspect +import logging +import time +from functools import wraps +from typing import Any, Callable, Dict, List, Literal + +from synth_sdk.tracing.abstractions import ( + AgentComputeStep, + ArbitraryInputs, + ArbitraryOutputs, + EnvironmentComputeStep, + Event, + MessageInputs, + MessageOutputs, +) +from synth_sdk.tracing.events.manage import set_current_event +from synth_sdk.tracing.events.store import event_store +from synth_sdk.tracing.local import ( + _local, + active_events_var, + logger, + system_id_var, + system_instance_id_var, + system_name_var, +) +from synth_sdk.tracing.trackers import ( + synth_tracker_async, + synth_tracker_sync, +) +from synth_sdk.tracing.utils import get_system_id + +logger = logging.getLogger(__name__) + + +# # This decorator is used to trace synchronous functions +def trace_system_sync( + origin: Literal["agent", "environment"], + event_type: str, + log_result: bool = False, + manage_event: Literal["create", "end", None] = None, + increment_partition: bool = False, + verbose: bool = False, + finetune_step: bool = True, +) -> Callable: + """Decorator for tracing synchronous functions. + + Purpose is to keep track of inputs and outputs for compute steps for sync functions. + """ + + def decorator(func: Callable) -> Callable: + @wraps(func) + def wrapper(*args, **kwargs): + # Determine the instance (self) if it's a method + if not hasattr(func, "__self__") or not func.__self__: + if not args: + raise ValueError( + "Instance method expected, but no arguments were passed." + ) + self_instance = args[0] + else: + self_instance = func.__self__ + + # Ensure required attributes are present + required_attrs = ["system_instance_id", "system_name"] + for attr in required_attrs: + if not hasattr(self_instance, attr): + raise ValueError(f"Instance missing required attribute '{attr}'") + + # Set thread-local variables + _local.system_instance_id = self_instance.system_instance_id + _local.system_name = self_instance.system_name + _local.system_id = get_system_id( + self_instance.system_name + ) # self_instance.system_id + + # Initialize Trace + synth_tracker_sync.initialize() + + # Initialize active_events if not present + if not hasattr(_local, "active_events"): + _local.active_events = {} + # logger.debug("Initialized active_events in thread local storage") + + event = None + compute_began = time.time() + try: + if manage_event == "create": + # logger.debug("Creating new event") + event = Event( + system_instance_id=_local.system_instance_id, + event_type=event_type, + opened=compute_began, + closed=None, + partition_index=0, + agent_compute_steps=[], + environment_compute_steps=[], + ) + if increment_partition: + event.partition_index = event_store.increment_partition( + _local.system_name, + _local.system_id, + _local.system_instance_id, + ) + # logger.debug( + # f"Incremented partition to: {event.partition_index}" + # ) + set_current_event(event, decorator_type="sync") + # logger.debug(f"Created and set new event: {event_type}") + + # Automatically trace function inputs + bound_args = inspect.signature(func).bind(*args, **kwargs) + bound_args.apply_defaults() + for param, value in bound_args.arguments.items(): + if param == "self": + continue + synth_tracker_sync.track_state( + variable_name=param, variable_value=value, origin=origin + ) + + # Execute the function + result = func(*args, **kwargs) + + # Automatically trace function output + track_result(result, synth_tracker_sync, origin) + + # Collect traced inputs and outputs + traced_inputs, traced_outputs = synth_tracker_sync.get_traced_data() + + compute_steps_by_origin: Dict[ + Literal["agent", "environment"], Dict[str, List[Any]] + ] = { + "agent": {"inputs": [], "outputs": []}, + "environment": {"inputs": [], "outputs": []}, + } + + # Organize traced data by origin + for item in traced_inputs: + var_origin = item["origin"] + if "variable_value" in item and "variable_name" in item: + # Standard variable input + compute_steps_by_origin[var_origin]["inputs"].append( + ArbitraryInputs( + inputs={item["variable_name"]: item["variable_value"]} + ) + ) + elif "messages" in item: + # Message input from track_lm + compute_steps_by_origin[var_origin]["inputs"].append( + MessageInputs(messages=item["messages"]) + ) + compute_steps_by_origin[var_origin]["inputs"].append( + ArbitraryInputs(inputs={"model_name": item["model_name"]}) + ) + finetune = item["finetune"] or finetune_step + compute_steps_by_origin[var_origin]["inputs"].append( + ArbitraryInputs(inputs={"finetune": finetune}) + ) + else: + logger.warning(f"Unhandled traced input item: {item}") + + for item in traced_outputs: + var_origin = item["origin"] + if "variable_value" in item and "variable_name" in item: + # Standard variable output + compute_steps_by_origin[var_origin]["outputs"].append( + ArbitraryOutputs( + outputs={item["variable_name"]: item["variable_value"]} + ) + ) + elif "messages" in item: + # Message output from track_lm + compute_steps_by_origin[var_origin]["outputs"].append( + MessageOutputs(messages=item["messages"]) + ) + else: + logger.warning(f"Unhandled traced output item: {item}") + + # Capture compute end time + compute_ended = time.time() + + # Create compute steps grouped by origin + for var_origin in ["agent", "environment"]: + inputs = compute_steps_by_origin[var_origin]["inputs"] + outputs = compute_steps_by_origin[var_origin]["outputs"] + if inputs or outputs: + event_order = ( + len(event.agent_compute_steps) + + len(event.environment_compute_steps) + + 1 + if event + else 1 + ) + compute_step = ( + AgentComputeStep( + event_order=event_order, + compute_began=compute_began, + compute_ended=compute_ended, + compute_input=inputs, + compute_output=outputs, + ) + if var_origin == "agent" + else EnvironmentComputeStep( + event_order=event_order, + compute_began=compute_began, + compute_ended=compute_ended, + compute_input=inputs, + compute_output=outputs, + ) + ) + if event: + if var_origin == "agent": + event.agent_compute_steps.append(compute_step) + else: + event.environment_compute_steps.append(compute_step) + # logger.debug( + # f"Added compute step for {var_origin}: {compute_step.to_dict()}" + # ) + + # Optionally log the function result + if log_result: + logger.info(f"Function result: {result}") + + # Handle event management after function execution + if manage_event == "end" and event_type in _local.active_events: + current_event = _local.active_events[event_type] + current_event.closed = compute_ended + # Store the event + if hasattr(_local, "system_instance_id"): + event_store.add_event( + _local.system_name, + _local.system_id, + _local.system_instance_id, + current_event, + ) + del _local.active_events[event_type] + + return result + except Exception as e: + logger.error(f"Exception in traced function '{func.__name__}': {e}") + raise + finally: + # synth_tracker_sync.finalize() + if hasattr(_local, "system_instance_id"): + # logger.debug(f"Cleaning up system_instance_id: {_local.system_instance_id}") + delattr(_local, "system_instance_id") + + return wrapper + + return decorator + + +def trace_system_async( + origin: Literal["agent", "environment"], + event_type: str, + log_result: bool = False, + manage_event: Literal["create", "end", "lazy_end", None] = None, + increment_partition: bool = False, + verbose: bool = False, + finetune_step: bool = True, +) -> Callable: + """Decorator for tracing asynchronous functions. + + Purpose is to keep track of inputs and outputs for compute steps for async functions. + """ + + def decorator(func: Callable) -> Callable: + @wraps(func) + async def async_wrapper(*args, **kwargs): + # Determine the instance (self) if it's a method + if not hasattr(func, "__self__") or not func.__self__: + if not args: + raise ValueError( + "Instance method expected, but no arguments were passed." + ) + self_instance = args[0] + else: + self_instance = func.__self__ + + # Ensure required attributes are present + required_attrs = ["system_instance_id", "system_name"] + for attr in required_attrs: + if not hasattr(self_instance, attr): + raise ValueError(f"Instance missing required attribute '{attr}'") + + # Set context variables + system_instance_id_token = system_instance_id_var.set( + self_instance.system_instance_id + ) + system_name_token = system_name_var.set(self_instance.system_name) + system_id_token = system_id_var.set( + get_system_id(self_instance.system_name) + ) + + # Initialize AsyncTrace + synth_tracker_async.initialize() + + # Initialize active_events if not present + current_active_events = active_events_var.get() + if not current_active_events: + active_events_var.set({}) + # logger.debug("Initialized active_events in context vars") + + event = None + compute_began = time.time() + try: + if manage_event == "create": + # logger.debug("Creating new event") + event = Event( + system_instance_id=self_instance.system_instance_id, + event_type=event_type, + opened=compute_began, + closed=None, + partition_index=0, + agent_compute_steps=[], + environment_compute_steps=[], + ) + if increment_partition: + event.partition_index = event_store.increment_partition( + system_name_var.get(), + system_id_var.get(), + system_instance_id_var.get(), + ) + # logger.debug( + # f"Incremented partition to: {event.partition_index}" + # ) + + set_current_event(event, decorator_type="async") + # logger.debug(f"Created and set new event: {event_type}") + + # Automatically trace function inputs + bound_args = inspect.signature(func).bind(*args, **kwargs) + bound_args.apply_defaults() + for param, value in bound_args.arguments.items(): + if param == "self": + continue + synth_tracker_async.track_state( + variable_name=param, + variable_value=value, + origin=origin, + io_type="input", + ) + + # Execute the coroutine + result = await func(*args, **kwargs) + + # Automatically trace function output + track_result(result, synth_tracker_async, origin) + + # Collect traced inputs and outputs + traced_inputs, traced_outputs = synth_tracker_async.get_traced_data() + + compute_steps_by_origin: Dict[ + Literal["agent", "environment"], Dict[str, List[Any]] + ] = { + "agent": {"inputs": [], "outputs": []}, + "environment": {"inputs": [], "outputs": []}, + } + + # Organize traced data by origin + for item in traced_inputs: + var_origin = item["origin"] + if "variable_value" in item and "variable_name" in item: + # Standard variable input + compute_steps_by_origin[var_origin]["inputs"].append( + ArbitraryInputs( + inputs={item["variable_name"]: item["variable_value"]} + ) + ) + elif "messages" in item: + # Message input from track_lm + compute_steps_by_origin[var_origin]["inputs"].append( + MessageInputs(messages=item["messages"]) + ) + compute_steps_by_origin[var_origin]["inputs"].append( + ArbitraryInputs(inputs={"model_name": item["model_name"]}) + ) + finetune = finetune_step or item["finetune"] + compute_steps_by_origin[var_origin]["inputs"].append( + ArbitraryInputs(inputs={"finetune": finetune}) + ) + else: + logger.warning(f"Unhandled traced input item: {item}") + + for item in traced_outputs: + var_origin = item["origin"] + if "variable_value" in item and "variable_name" in item: + # Standard variable output + compute_steps_by_origin[var_origin]["outputs"].append( + ArbitraryOutputs( + outputs={item["variable_name"]: item["variable_value"]} + ) + ) + elif "messages" in item: + # Message output from track_lm + compute_steps_by_origin[var_origin]["outputs"].append( + MessageOutputs(messages=item["messages"]) + ) + else: + logger.warning(f"Unhandled traced output item: {item}") + + compute_ended = time.time() + + # Create compute steps grouped by origin + for var_origin in ["agent", "environment"]: + inputs = compute_steps_by_origin[var_origin]["inputs"] + outputs = compute_steps_by_origin[var_origin]["outputs"] + if inputs or outputs: + event_order = ( + len(event.agent_compute_steps) + + len(event.environment_compute_steps) + + 1 + if event + else 1 + ) + compute_step = ( + AgentComputeStep( + event_order=event_order, + compute_began=compute_began, + compute_ended=compute_ended, + compute_input=inputs, + compute_output=outputs, + ) + if var_origin == "agent" + else EnvironmentComputeStep( + event_order=event_order, + compute_began=compute_began, + compute_ended=compute_ended, + compute_input=inputs, + compute_output=outputs, + ) + ) + if event: + if var_origin == "agent": + event.agent_compute_steps.append(compute_step) + else: + event.environment_compute_steps.append(compute_step) + # logger.debug( + # f"Added compute step for {var_origin}: {compute_step.to_dict()}" + # ) + # Optionally log the function result + if log_result: + logger.info(f"Function result: {result}") + + # Handle event management after function execution + if manage_event == "end" and event_type in active_events_var.get(): + current_event = active_events_var.get()[event_type] + current_event.closed = compute_ended + # Store the event + if system_instance_id_var.get(): + event_store.add_event( + system_name_var.get(), + system_id_var.get(), + system_instance_id_var.get(), + current_event, + ) + active_events = active_events_var.get() + del active_events[event_type] + active_events_var.set(active_events) + + # Auto-close and store events created with manage_event="create" + if ( + manage_event == "create" + and event is not None + and event.closed is None + ): + event.closed = time.time() + active_events_dict = active_events_var.get() + if event_type in active_events_dict: + # Store the event while context vars are still valid + event_store.add_event( + system_name_var.get(), + system_id_var.get(), + system_instance_id_var.get(), + event, + ) + # Remove from active events + active_events_dict.pop(event_type, None) + active_events_var.set(active_events_dict) + + return result + except Exception as e: + logger.error(f"Exception in traced function '{func.__name__}': {e}") + raise + finally: + # synth_tracker_async.finalize() + # Reset context variables + system_instance_id_var.reset(system_instance_id_token) + system_name_var.reset(system_name_token) + system_id_var.reset(system_id_token) + # logger.debug("Cleaning up system_instance_id from context vars") + + return async_wrapper + + return decorator + + +def trace_system( + origin: Literal["agent", "environment"], + event_type: str, + log_result: bool = False, + manage_event: Literal["create", "end", "lazy_end", None] = None, + increment_partition: bool = False, + verbose: bool = False, +) -> Callable: + """ + Decorator that chooses the correct tracing method (sync or async) based on + whether the wrapped function is synchronous or asynchronous. + + Purpose is to keep track of inputs and outputs for compute steps for both sync and async functions. + """ + + def decorator(func: Callable) -> Callable: + # Check if the function is async or sync + if inspect.iscoroutinefunction(func) or inspect.isasyncgenfunction(func): + # Use async tracing + # logger.debug("Using async tracing") + async_decorator = trace_system_async( + origin, + event_type, + log_result, + manage_event, + increment_partition, + verbose, + ) + return async_decorator(func) + else: + # Use sync tracing + # logger.debug("Using sync tracing") + sync_decorator = trace_system_sync( + origin, + event_type, + log_result, + manage_event, + increment_partition, + verbose, + ) + return sync_decorator(func) + + return decorator + + +def track_result(result, tracker, origin): + # Helper function to track results, including tuple unpacking + if isinstance(result, tuple): + # Track each element of the tuple that matches valid types + for i, item in enumerate(result): + try: + tracker.track_state( + variable_name=f"result_{i}", variable_value=item, origin=origin + ) + except Exception as e: + logger.warning(f"Could not track tuple element {i}: {str(e)}") + else: + # Track single result as before + try: + tracker.track_state( + variable_name="result", variable_value=result, origin=origin + ) + except Exception as e: + logger.warning(f"Could not track result: {str(e)}") diff --git a/synth_sdk/tracing/decorators.py b/synth_sdk/tracing/decorators.py index 96e43c1..c071ece 100644 --- a/synth_sdk/tracing/decorators.py +++ b/synth_sdk/tracing/decorators.py @@ -102,11 +102,11 @@ def wrapper(*args, **kwargs): _local.system_id, _local.system_instance_id, ) - # logger.debug( - # f"Incremented partition to: {event.partition_index}" - # ) + logger.debug( + f"Incremented partition to: {event.partition_index}" + ) set_current_event(event, decorator_type="sync") - # logger.debug(f"Created and set new event: {event_type}") + logger.debug(f"Created and set new event: {event_type}") # Automatically trace function inputs bound_args = inspect.signature(func).bind(*args, **kwargs) @@ -318,12 +318,12 @@ async def async_wrapper(*args, **kwargs): system_id_var.get(), system_instance_id_var.get(), ) - # logger.debug( - # f"Incremented partition to: {event.partition_index}" - # ) + logger.debug( + f"Incremented partition to: {event.partition_index}" + ) set_current_event(event, decorator_type="async") - # logger.debug(f"Created and set new event: {event_type}") + logger.debug(f"Created and set new event: {event_type}") # Automatically trace function inputs bound_args = inspect.signature(func).bind(*args, **kwargs) @@ -455,22 +455,6 @@ async def async_wrapper(*args, **kwargs): del active_events[event_type] active_events_var.set(active_events) - # Auto-close and store events created with manage_event="create" - if manage_event == "create" and event is not None and event.closed is None: - event.closed = time.time() - active_events_dict = active_events_var.get() - if event_type in active_events_dict: - # Store the event while context vars are still valid - event_store.add_event( - system_name_var.get(), - system_id_var.get(), - system_instance_id_var.get(), - event, - ) - # Remove from active events - active_events_dict.pop(event_type, None) - active_events_var.set(active_events_dict) - return result except Exception as e: logger.error(f"Exception in traced function '{func.__name__}': {e}") diff --git a/tests/iteration/craftax/generate_data/records/episode_classic_0.json b/tests/iteration/craftax/generate_data/records/episode_classic_0.json index 1ae3717..b0a8f50 100644 --- a/tests/iteration/craftax/generate_data/records/episode_classic_0.json +++ b/tests/iteration/craftax/generate_data/records/episode_classic_0.json @@ -1,5 +1,5 @@ { - "Collect Wood": true, + "Collect Wood": false, "Place Table": false, "Eat Cow": false, "Collect Sapling": false, diff --git a/tests/iteration/craftax/generate_data/records/episode_classic_2.json b/tests/iteration/craftax/generate_data/records/episode_classic_2.json index 1ae3717..b0a8f50 100644 --- a/tests/iteration/craftax/generate_data/records/episode_classic_2.json +++ b/tests/iteration/craftax/generate_data/records/episode_classic_2.json @@ -1,5 +1,5 @@ { - "Collect Wood": true, + "Collect Wood": false, "Place Table": false, "Eat Cow": false, "Collect Sapling": false,