Skip to content

Commit

Permalink
changes
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshuaPurtell committed Dec 18, 2024
1 parent 4cb419e commit 8235068
Show file tree
Hide file tree
Showing 13 changed files with 98 additions and 66 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "synth-sdk"
version = "0.2.102"
version = "0.2.103"
description = ""
authors = [{name = "Synth AI", email = "[email protected]"}]
license = {text = "MIT"}
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="synth-sdk",
version="0.2.102",
version="0.2.103",
packages=find_packages(),
install_requires=[
"opentelemetry-api",
Expand Down
2 changes: 2 additions & 0 deletions synth_sdk/tracing/abstractions.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def to_dict(self):

@dataclass
class SystemTrace:
system_name: str
system_id: str
system_instance_id: str
metadata: Optional[Dict[str, Any]]
Expand All @@ -141,6 +142,7 @@ class SystemTrace:

def to_dict(self):
return {
"system_name": self.system_name,
"system_id": self.system_id,
"system_instance_id": self.system_instance_id,
"partition": [element.to_dict() for element in self.partition],
Expand Down
47 changes: 32 additions & 15 deletions synth_sdk/tracing/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,15 @@ def wrapper(*args, **kwargs):
)
if increment_partition:
event.partition_index = event_store.increment_partition(
_local.system_instance_id,
_local.system_name,
_local.system_id,
_local.system_instance_id,
)
logger.debug(
f"Incremented partition to: {event.partition_index}"
)
# logger.debug(
# f"Incremented partition to: {event.partition_index}"
# )
set_current_event(event, decorator_type="sync")
logger.debug(f"Created and set new event: {event_type}")
# logger.debug(f"Created and set new event: {event_type}")

# Automatically trace function inputs
bound_args = inspect.signature(func).bind(*args, **kwargs)
Expand Down Expand Up @@ -227,11 +228,8 @@ def wrapper(*args, **kwargs):
# Store the event
if hasattr(_local, "system_instance_id"):
event_store.add_event(
_local.system_instance_id, _local.system_id, current_event
_local.system_name, _local.system_id, _local.system_instance_id, current_event
)
# logger.debug(
# f"Stored and closed event {event_type} for system {_local.system_instance_id}"
# )
del _local.active_events[event_type]

return result
Expand Down Expand Up @@ -316,14 +314,16 @@ async def async_wrapper(*args, **kwargs):
)
if increment_partition:
event.partition_index = event_store.increment_partition(
system_instance_id_var.get(), system_id_var.get()
)
logger.debug(
f"Incremented partition to: {event.partition_index}"
system_name_var.get(),
system_id_var.get(),
system_instance_id_var.get(),
)
# logger.debug(
# f"Incremented partition to: {event.partition_index}"
# )

set_current_event(event, decorator_type="async")
logger.debug(f"Created and set new event: {event_type}")
# logger.debug(f"Created and set new event: {event_type}")

# Automatically trace function inputs
bound_args = inspect.signature(func).bind(*args, **kwargs)
Expand Down Expand Up @@ -446,14 +446,31 @@ async def async_wrapper(*args, **kwargs):
# Store the event
if system_instance_id_var.get():
event_store.add_event(
system_instance_id_var.get(),
system_name_var.get(),
system_id_var.get(),
system_instance_id_var.get(),
current_event,
)
active_events = active_events_var.get()
del active_events[event_type]
active_events_var.set(active_events)

# Auto-close and store events created with manage_event="create"
if manage_event == "create" and event is not None and event.closed is None:
event.closed = time.time()
active_events_dict = active_events_var.get()
if event_type in active_events_dict:
# Store the event while context vars are still valid
event_store.add_event(
system_name_var.get(),
system_id_var.get(),
system_instance_id_var.get(),
event,
)
# Remove from active events
active_events_dict.pop(event_type, None)
active_events_var.set(active_events_dict)

return result
except Exception as e:
logger.error(f"Exception in traced function '{func.__name__}': {e}")
Expand Down
7 changes: 4 additions & 3 deletions synth_sdk/tracing/events/manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def set_current_event(
)
try:
event_store.add_event(
_local.system_instance_id, _local.system_id, existing_event
_local.system_name, _local.system_id, _local.system_instance_id, existing_event
)
logger.debug("Successfully stored closed event")
except Exception as e:
Expand All @@ -78,6 +78,7 @@ def set_current_event(
else:
from synth_sdk.tracing.local import (
active_events_var,
system_name_var,
system_id_var,
system_instance_id_var,
)
Expand Down Expand Up @@ -105,7 +106,7 @@ def set_current_event(
)
try:
event_store.add_event(
system_instance_id, system_id_var.get(), existing_event
system_name_var.get(), system_id_var.get(), system_instance_id, existing_event
)
logger.debug("Successfully stored closed event")
except Exception as e:
Expand All @@ -132,7 +133,7 @@ def end_event(event_type: str) -> Optional[Event]:
# Store the event
if hasattr(_local, "system_instance_id"):
event_store.add_event(
_local.system_instance_id, _local.system_id, current_event
_local.system_name, _local.system_id, _local.system_instance_id, current_event
)
clear_current_event(event_type)
return current_event
5 changes: 3 additions & 2 deletions synth_sdk/tracing/events/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from synth_sdk.tracing.abstractions import Event
from synth_sdk.tracing.decorators import _local, clear_current_event, set_current_event
from synth_sdk.tracing.events.store import event_store
from synth_sdk.tracing.local import system_id_var, system_instance_id_var
from synth_sdk.tracing.local import system_name_var, system_id_var, system_instance_id_var


@contextmanager
Expand Down Expand Up @@ -32,6 +32,7 @@ def event_scope(event_type: str):
else getattr(_local, "system_instance_id", None)
)
system_id = system_id_var.get() if is_async else getattr(_local, "system_id", None)
system_name = system_name_var.get() if is_async else getattr(_local, "system_name", None)

event = Event(
system_instance_id=system_instance_id,
Expand All @@ -51,4 +52,4 @@ def event_scope(event_type: str):
clear_current_event(event_type)
# Store the event if system_instance_id is available
if system_instance_id:
event_store.add_event(system_instance_id, system_id, event)
event_store.add_event(system_name, system_id, system_instance_id, event)
22 changes: 14 additions & 8 deletions synth_sdk/tracing/events/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@ def __init__(self):
self.logger = logging.getLogger(__name__)

def get_or_create_system_trace(
self, system_instance_id: str, system_id: str, _already_locked: bool = False
self,
system_name: str,
system_id: str,
system_instance_id: str,
_already_locked: bool = False
) -> SystemTrace:
"""Get or create a SystemTrace for the given system_instance_id."""
logger = logging.getLogger(__name__)
Expand All @@ -32,6 +36,7 @@ def _get_or_create():
if system_instance_id not in self._traces:
# logger.debug(f"Creating new system trace for {system_instance_id}")
self._traces[system_instance_id] = SystemTrace(
system_name=system_name,
system_id=system_id,
system_instance_id=system_instance_id,
metadata={},
Expand All @@ -48,15 +53,15 @@ def _get_or_create():
# logger.debug("Lock acquired in get_or_create_system_trace")
return _get_or_create()

def increment_partition(self, system_instance_id: str, system_id: str) -> int:
def increment_partition(self, system_name: str, system_id: str, system_instance_id: str) -> int:
"""Increment the partition index for a system and create new partition element."""
logger = logging.getLogger(__name__)
# logger.debug(f"Starting increment_partition for system {system_instance_id}")

with self._lock:
# logger.debug("Lock acquired in increment_partition")
system_trace = self.get_or_create_system_trace(
system_instance_id, system_id, _already_locked=True
system_name, system_id, system_instance_id, _already_locked=True
)
# logger.debug(
# f"Got system trace, current index: {system_trace.current_partition_index}"
Expand All @@ -76,7 +81,7 @@ def increment_partition(self, system_instance_id: str, system_id: str) -> int:

return system_trace.current_partition_index

def add_event(self, system_instance_id: str, system_id: str, event: Event):
def add_event(self, system_name: str, system_id: str, system_instance_id: str, event: Event):
"""Add an event to the appropriate partition of the system trace."""
# self.#logger.debug(f"Adding event type {event.event_type} to system {system_instance_id}")
# self.#logger.debug(
Expand All @@ -91,7 +96,7 @@ def add_event(self, system_instance_id: str, system_id: str, event: Event):

try:
system_trace = self.get_or_create_system_trace(
system_instance_id, system_id
system_name, system_id, system_instance_id, _already_locked=True
)
# self.#logger.debug(
# f"Got system trace with {len(system_trace.partition)} partitions"
Expand Down Expand Up @@ -139,25 +144,26 @@ def end_all_active_events(self):
if hasattr(_local, "active_events"):
active_events = _local.active_events
system_instance_id = getattr(_local, "system_instance_id", None)
system_name = getattr(_local, "system_name", None)
system_id = getattr(_local, "system_id", None)
if active_events: # and system_instance_id:
for event_type, event in list(active_events.items()):
if event.closed is None:
event.closed = time.time()
self.add_event(event.system_instance_id, system_id, event)
self.add_event(system_name, system_id, event.system_instance_id, event)
# self.#logger.debug(f"Stored and closed event {event_type}")
_local.active_events.clear()

else:
# For asynchronous code
active_events_async = active_events_var.get()

print("Active events async:", active_events_async.items())
if active_events_async: # and system_instance_id_async:
for event_type, event in list(active_events_async.items()):
system_id = system_id_var.get()
if event.closed is None:
event.closed = time.time()
self.add_event(event.system_instance_id, system_id, event)
self.add_event(system_name, system_id, event.system_instance_id, event)
# self.#logger.debug(f"Stored and closed event {event_type}")
active_events_var.set({})

Expand Down
26 changes: 13 additions & 13 deletions synth_sdk/tracing/trackers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def track_lm(
"model_name": model_name,
"finetune": finetune,
})
#logger.debug("Tracked LM interaction")
# logger.debug("Tracked LM interaction")
else:
pass
# raise RuntimeError(
Expand Down Expand Up @@ -67,7 +67,7 @@ def track_state(
"variable_value": variable_value,
"annotation": annotation,
})
#logger.debug(f"Tracked state: {variable_name}")
# logger.debug(f"Tracked state: {variable_name}")
else:
pass
# raise RuntimeError(
Expand All @@ -84,7 +84,7 @@ def finalize(cls):
cls._local.initialized = False
cls._local.inputs = []
cls._local.outputs = []
#logger.debug("Finalized trace data")
# logger.debug("Finalized trace data")


# Context variables for asynchronous tracing
Expand All @@ -104,7 +104,7 @@ def initialize(cls):
trace_initialized_var.set(True)
trace_inputs_var.set([]) # List of tuples: (origin, var)
trace_outputs_var.set([]) # List of tuples: (origin, var)
#logger.debug("AsyncTrace initialized")
# logger.debug("AsyncTrace initialized")

@classmethod
def track_lm(
Expand All @@ -122,7 +122,7 @@ def track_lm(
"finetune": finetune,
})
trace_inputs_var.set(trace_inputs)
#logger.debug("Tracked LM interaction")
# logger.debug("Tracked LM interaction")
else:
pass
# raise RuntimeError(
Expand Down Expand Up @@ -164,7 +164,7 @@ def track_state(
"annotation": annotation,
})
trace_outputs_var.set(trace_outputs)
#logger.debug(f"Tracked state: {variable_name}")
# logger.debug(f"Tracked state: {variable_name}")
else:
pass
# raise RuntimeError(
Expand All @@ -182,7 +182,7 @@ def finalize(cls):
trace_initialized_var.set(False)
trace_inputs_var.set([])
trace_outputs_var.set([])
logger.debug("Finalized async trace data")
# logger.debug("Finalized async trace data")


# Make traces available globally
Expand Down Expand Up @@ -246,10 +246,10 @@ def process_chat(self, user_input: str):
```
"""
if cls.is_called_by_async() and trace_initialized_var.get():
logger.debug("Using async tracker to track LM")
# logger.debug("Using async tracker to track LM")
synth_tracker_async.track_lm(messages, model_name, finetune)
elif getattr(synth_tracker_sync._local, "initialized", False):
logger.debug("Using sync tracker to track LM")
# logger.debug("Using sync tracker to track LM")
synth_tracker_sync.track_lm(messages, model_name, finetune)
else:
# raise RuntimeError("Trace not initialized in track_lm.")
Expand Down Expand Up @@ -302,10 +302,10 @@ def update_state(self, new_value: dict):
```
"""
if cls.is_called_by_async() and trace_initialized_var.get():
logger.debug("Using async tracker to track state")
# logger.debug("Using async tracker to track state")
synth_tracker_async.track_state(variable_name, variable_value, origin, annotation)
elif getattr(synth_tracker_sync._local, "initialized", False):
logger.debug("Using sync tracker to track state")
# logger.debug("Using sync tracker to track state")
synth_tracker_sync.track_state(variable_name, variable_value, origin, annotation)
else:
#raise RuntimeError("Trace not initialized in track_state.")
Expand All @@ -319,13 +319,13 @@ def get_traced_data(
traced_inputs, traced_outputs = [], []

if async_sync in ["async", ""]:
#logger.debug("Getting traced data from async tracker")
# logger.debug("Getting traced data from async tracker")
traced_inputs_async, traced_outputs_async = synth_tracker_async.get_traced_data()
traced_inputs.extend(traced_inputs_async)
traced_outputs.extend(traced_outputs_async)

if async_sync in ["sync", ""]:
#logger.debug("Getting traced data from sync tracker")
# logger.debug("Getting traced data from sync tracker")
traced_inputs_sync, traced_outputs_sync = synth_tracker_sync.get_traced_data()
traced_inputs.extend(traced_inputs_sync)
traced_outputs.extend(traced_outputs_sync)
Expand Down
Loading

0 comments on commit 8235068

Please sign in to comment.