diff --git a/pyproject.toml b/pyproject.toml index d7e3631..cf60e16 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,8 @@ dependencies = [ "twine>=4.0.0", "keyring>=24.0.0", "python-dotenv>=1.0.1", + "langfuse>=2.53.9", + "pytest>=8.3.3", ] classifiers = [] diff --git a/synth_sdk/provider_support/openai_lf.py b/synth_sdk/provider_support/openai_lf.py new file mode 100644 index 0000000..8f41d69 --- /dev/null +++ b/synth_sdk/provider_support/openai_lf.py @@ -0,0 +1,978 @@ +import copy +import logging +from inspect import isclass +import types + +from collections import defaultdict +from dataclasses import dataclass +from typing import List, Optional + +import openai.resources +from openai._types import NotGiven +from packaging.version import Version +from wrapt import wrap_function_wrapper + +from langfuse import Langfuse +from langfuse.client import StatefulGenerationClient +from langfuse.decorators import langfuse_context +from langfuse.utils import _get_timestamp +from langfuse.utils.langfuse_singleton import LangfuseSingleton +from synth_sdk.tracing.trackers import synth_tracker_sync, synth_tracker_async +from pydantic import BaseModel +from synth_sdk.tracing.abstractions import MessageInputs, MessageOutputs + +try: + import openai +except ImportError: + raise ModuleNotFoundError( + "Please install OpenAI to use this feature: 'pip install openai'" + ) + +try: + from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI # noqa: F401 +except ImportError: + AsyncAzureOpenAI = None + AsyncOpenAI = None + AzureOpenAI = None + OpenAI = None + +log = logging.getLogger("langfuse") + + +@dataclass +class OpenAiDefinition: + module: str + object: str + method: str + type: str + sync: bool + min_version: Optional[str] = None + + +OPENAI_METHODS_V0 = [ + OpenAiDefinition( + module="openai", + object="ChatCompletion", + method="create", + type="chat", + sync=True, + ), + OpenAiDefinition( + module="openai", + object="Completion", + method="create", + type="completion", + sync=True, + ), +] + + +OPENAI_METHODS_V1 = [ + OpenAiDefinition( + module="openai.resources.chat.completions", + object="Completions", + method="create", + type="chat", + sync=True, + ), + OpenAiDefinition( + module="openai.resources.completions", + object="Completions", + method="create", + type="completion", + sync=True, + ), + OpenAiDefinition( + module="openai.resources.chat.completions", + object="AsyncCompletions", + method="create", + type="chat", + sync=False, + ), + OpenAiDefinition( + module="openai.resources.completions", + object="AsyncCompletions", + method="create", + type="completion", + sync=False, + ), + OpenAiDefinition( + module="openai.resources.beta.chat.completions", + object="Completions", + method="parse", + type="chat", + sync=True, + min_version="1.50.0", + ), + OpenAiDefinition( + module="openai.resources.beta.chat.completions", + object="AsyncCompletions", + method="parse", + type="chat", + sync=False, + min_version="1.50.0", + ), +] + + +class OpenAiArgsExtractor: + def __init__( + self, + name=None, + metadata=None, + trace_id=None, + session_id=None, + user_id=None, + tags=None, + parent_observation_id=None, + langfuse_prompt=None, # we cannot use prompt because it's an argument of the old OpenAI completions API + **kwargs, + ): + self.args = {} + self.args["name"] = name + self.args["metadata"] = ( + metadata + if "response_format" not in kwargs + else { + **(metadata or {}), + "response_format": kwargs["response_format"].model_json_schema() + if isclass(kwargs["response_format"]) + and issubclass(kwargs["response_format"], BaseModel) + else kwargs["response_format"], + } + ) + self.args["trace_id"] = trace_id + self.args["session_id"] = session_id + self.args["user_id"] = user_id + self.args["tags"] = tags + self.args["parent_observation_id"] = parent_observation_id + self.args["langfuse_prompt"] = langfuse_prompt + self.kwargs = kwargs + + def get_langfuse_args(self): + return {**self.args, **self.kwargs} + + def get_openai_args(self): + return self.kwargs + + +def _langfuse_wrapper(func): + def _with_langfuse(open_ai_definitions, initialize): + def wrapper(wrapped, instance, args, kwargs): + return func(open_ai_definitions, initialize, wrapped, args, kwargs) + + return wrapper + + return _with_langfuse + + +def _extract_chat_prompt(kwargs: any): + """Extracts the user input from prompts. Returns an array of messages or dict with messages and functions""" + prompt = {} + + if kwargs.get("functions") is not None: + prompt.update({"functions": kwargs["functions"]}) + + if kwargs.get("function_call") is not None: + prompt.update({"function_call": kwargs["function_call"]}) + + if kwargs.get("tools") is not None: + prompt.update({"tools": kwargs["tools"]}) + + if prompt: + # uf user provided functions, we need to send these together with messages to langfuse + prompt.update( + { + "messages": _filter_image_data(kwargs.get("messages", [])), + } + ) + return prompt + else: + # vanilla case, only send messages in openai format to langfuse + return _filter_image_data(kwargs.get("messages", [])) + + +def _extract_chat_response(kwargs: any): + """Extracts the llm output from the response.""" + response = { + "role": kwargs.get("role", None), + } + + if kwargs.get("function_call") is not None: + response.update({"function_call": kwargs["function_call"]}) + + if kwargs.get("tool_calls") is not None: + response.update({"tool_calls": kwargs["tool_calls"]}) + + response.update( + { + "content": kwargs.get("content", None), + } + ) + return response + + +def _get_langfuse_data_from_kwargs( + resource: OpenAiDefinition, langfuse: Langfuse, start_time, kwargs +): + name = kwargs.get("name", "OpenAI-generation") + + if name is None: + name = "OpenAI-generation" + + if name is not None and not isinstance(name, str): + raise TypeError("name must be a string") + + decorator_context_observation_id = langfuse_context.get_current_observation_id() + decorator_context_trace_id = langfuse_context.get_current_trace_id() + + trace_id = kwargs.get("trace_id", None) or decorator_context_trace_id + if trace_id is not None and not isinstance(trace_id, str): + raise TypeError("trace_id must be a string") + + session_id = kwargs.get("session_id", None) + if session_id is not None and not isinstance(session_id, str): + raise TypeError("session_id must be a string") + + user_id = kwargs.get("user_id", None) + if user_id is not None and not isinstance(user_id, str): + raise TypeError("user_id must be a string") + + tags = kwargs.get("tags", None) + if tags is not None and ( + not isinstance(tags, list) or not all(isinstance(tag, str) for tag in tags) + ): + raise TypeError("tags must be a list of strings") + + # Update trace params in decorator context if specified in openai call + if decorator_context_trace_id: + langfuse_context.update_current_trace( + session_id=session_id, user_id=user_id, tags=tags + ) + + parent_observation_id = kwargs.get("parent_observation_id", None) or ( + decorator_context_observation_id + if decorator_context_observation_id != decorator_context_trace_id + else None + ) + if parent_observation_id is not None and not isinstance(parent_observation_id, str): + raise TypeError("parent_observation_id must be a string") + if parent_observation_id is not None and trace_id is None: + raise ValueError("parent_observation_id requires trace_id to be set") + + metadata = kwargs.get("metadata", {}) + + if metadata is not None and not isinstance(metadata, dict): + raise TypeError("metadata must be a dictionary") + + model = kwargs.get("model", None) or None + + prompt = None + + if resource.type == "completion": + prompt = kwargs.get("prompt", None) + elif resource.type == "chat": + prompt = _extract_chat_prompt(kwargs) + + is_nested_trace = False + if trace_id: + is_nested_trace = True + langfuse.trace(id=trace_id, session_id=session_id, user_id=user_id, tags=tags) + else: + trace_id = ( + decorator_context_trace_id + or langfuse.trace( + session_id=session_id, + user_id=user_id, + tags=tags, + name=name, + input=prompt, + metadata=metadata, + ).id + ) + + parsed_temperature = ( + kwargs.get("temperature", 1) + if not isinstance(kwargs.get("temperature", 1), NotGiven) + else 1 + ) + + parsed_max_tokens = ( + kwargs.get("max_tokens", float("inf")) + if not isinstance(kwargs.get("max_tokens", float("inf")), NotGiven) + else float("inf") + ) + + parsed_top_p = ( + kwargs.get("top_p", 1) + if not isinstance(kwargs.get("top_p", 1), NotGiven) + else 1 + ) + + parsed_frequency_penalty = ( + kwargs.get("frequency_penalty", 0) + if not isinstance(kwargs.get("frequency_penalty", 0), NotGiven) + else 0 + ) + + parsed_presence_penalty = ( + kwargs.get("presence_penalty", 0) + if not isinstance(kwargs.get("presence_penalty", 0), NotGiven) + else 0 + ) + + parsed_seed = ( + kwargs.get("seed", None) + if not isinstance(kwargs.get("seed", None), NotGiven) + else None + ) + + modelParameters = { + "temperature": parsed_temperature, + "max_tokens": parsed_max_tokens, # casing? + "top_p": parsed_top_p, + "frequency_penalty": parsed_frequency_penalty, + "presence_penalty": parsed_presence_penalty, + } + if parsed_seed is not None: + modelParameters["seed"] = parsed_seed + + langfuse_prompt = kwargs.get("langfuse_prompt", None) + + return { + "name": name, + "metadata": metadata, + "trace_id": trace_id, + "parent_observation_id": parent_observation_id, + "user_id": user_id, + "start_time": start_time, + "input": prompt, + "model_parameters": modelParameters, + "model": model or None, + "prompt": langfuse_prompt, + }, is_nested_trace + + +def _create_langfuse_update( + completion, + generation: StatefulGenerationClient, + completion_start_time, + model=None, + usage=None, +): + update = { + "end_time": _get_timestamp(), + "output": completion, + "completion_start_time": completion_start_time, + } + if model is not None: + update["model"] = model + + if usage is not None: + update["usage"] = usage + + generation.update(**update) + + +def _extract_streamed_openai_response(resource, chunks): + completion = defaultdict(str) if resource.type == "chat" else "" + model = None + + for chunk in chunks: + if _is_openai_v1(): + chunk = chunk.__dict__ + + model = model or chunk.get("model", None) or None + usage = chunk.get("usage", None) + + choices = chunk.get("choices", []) + + for choice in choices: + if _is_openai_v1(): + choice = choice.__dict__ + if resource.type == "chat": + delta = choice.get("delta", None) + + if _is_openai_v1(): + delta = delta.__dict__ + + if delta.get("role", None) is not None: + completion["role"] = delta["role"] + + if delta.get("content", None) is not None: + completion["content"] = ( + delta.get("content", None) + if completion["content"] is None + else completion["content"] + delta.get("content", None) + ) + elif delta.get("function_call", None) is not None: + curr = completion["function_call"] + tool_call_chunk = delta.get("function_call", None) + + if not curr: + completion["function_call"] = { + "name": getattr(tool_call_chunk, "name", ""), + "arguments": getattr(tool_call_chunk, "arguments", ""), + } + + else: + curr["name"] = curr["name"] or getattr( + tool_call_chunk, "name", None + ) + curr["arguments"] += getattr(tool_call_chunk, "arguments", "") + + elif delta.get("tool_calls", None) is not None: + curr = completion["tool_calls"] + tool_call_chunk = getattr( + delta.get("tool_calls", None)[0], "function", None + ) + + if not curr: + completion["tool_calls"] = [ + { + "name": getattr(tool_call_chunk, "name", ""), + "arguments": getattr(tool_call_chunk, "arguments", ""), + } + ] + + elif getattr(tool_call_chunk, "name", None) is not None: + curr.append( + { + "name": getattr(tool_call_chunk, "name", None), + "arguments": getattr( + tool_call_chunk, "arguments", None + ), + } + ) + + else: + curr[-1]["name"] = curr[-1]["name"] or getattr( + tool_call_chunk, "name", None + ) + curr[-1]["arguments"] += getattr( + tool_call_chunk, "arguments", None + ) + + if resource.type == "completion": + completion += choice.get("text", None) + + def get_response_for_chat(): + return ( + completion["content"] + or ( + completion["function_call"] + and { + "role": "assistant", + "function_call": completion["function_call"], + } + ) + or ( + completion["tool_calls"] + and { + "role": "assistant", + # "tool_calls": [{"function": completion["tool_calls"]}], + "tool_calls": [ + {"function": data} for data in completion["tool_calls"] + ], + } + ) + or None + ) + + return ( + model, + get_response_for_chat() if resource.type == "chat" else completion, + usage.__dict__ if _is_openai_v1() and usage is not None else usage, + ) + + +def _get_langfuse_data_from_default_response(resource: OpenAiDefinition, response): + if response is None: + return None, "", None + + model = response.get("model", None) or None + + completion = None + if resource.type == "completion": + choices = response.get("choices", []) + if len(choices) > 0: + choice = choices[-1] + + completion = choice.text if _is_openai_v1() else choice.get("text", None) + elif resource.type == "chat": + choices = response.get("choices", []) + if len(choices) > 0: + choice = choices[-1] + completion = ( + _extract_chat_response(choice.message.__dict__) + if _is_openai_v1() + else choice.get("message", None) + ) + + usage = response.get("usage", None) + + return ( + model, + completion, + usage.__dict__ if _is_openai_v1() and usage is not None else usage, + ) + + +def _is_openai_v1(): + return Version(openai.__version__) >= Version("1.0.0") + + +def _is_streaming_response(response): + return ( + isinstance(response, types.GeneratorType) + or isinstance(response, types.AsyncGeneratorType) + or (_is_openai_v1() and isinstance(response, openai.Stream)) + or (_is_openai_v1() and isinstance(response, openai.AsyncStream)) + ) + + +@_langfuse_wrapper +def _wrap(open_ai_resource: OpenAiDefinition, initialize, wrapped, args, kwargs): + new_langfuse: Langfuse = initialize() + + start_time = _get_timestamp() + arg_extractor = OpenAiArgsExtractor(*args, **kwargs) + + generation, is_nested_trace = _get_langfuse_data_from_kwargs( + open_ai_resource, new_langfuse, start_time, arg_extractor.get_langfuse_args() + ) + generation = new_langfuse.generation(**generation) + try: + openai_response = wrapped(**arg_extractor.get_openai_args()) + + if _is_streaming_response(openai_response): + return LangfuseResponseGeneratorSync( + resource=open_ai_resource, + response=openai_response, + generation=generation, + langfuse=new_langfuse, + is_nested_trace=is_nested_trace, + kwargs=arg_extractor.get_openai_args() + ) + + else: + model, completion, usage = _get_langfuse_data_from_default_response( + open_ai_resource, + (openai_response and openai_response.__dict__) + if _is_openai_v1() + else openai_response, + ) + + # Collect messages + if open_ai_resource.type == "completion": + user_prompt = arg_extractor.get_openai_args().get("prompt", "") + messages = [ + {"role": "user", "content": user_prompt}, + {"role": "assistant", "content": completion}, + ] + message_input = MessageInputs(messages=messages) + elif open_ai_resource.type == "chat": + messages = arg_extractor.get_openai_args().get("messages", []) + messages.append({"role": "assistant", "content": completion["content"]}) + message_input = MessageInputs(messages=messages) + else: + message_input = MessageInputs(messages=[]) + + # Use track_lm + synth_tracker_sync.track_lm( + messages=message_input.messages, + model_name=model, + finetune=False + ) + + generation.update( + model=model, output=completion, end_time=_get_timestamp(), usage=usage + ) + + # Avoiding the trace-update if trace-id is provided by user. + if not is_nested_trace: + new_langfuse.trace(id=generation.trace_id, output=completion) + + return openai_response + except Exception as ex: + log.warning(ex) + model = kwargs.get("model", None) or None + generation.update( + end_time=_get_timestamp(), + status_message=str(ex), + level="ERROR", + model=model, + usage={"input_cost": 0, "output_cost": 0, "total_cost": 0}, + ) + raise ex + + +@_langfuse_wrapper +async def _wrap_async( + open_ai_resource: OpenAiDefinition, initialize, wrapped, args, kwargs +): + new_langfuse = initialize() + start_time = _get_timestamp() + arg_extractor = OpenAiArgsExtractor(*args, **kwargs) + + generation, is_nested_trace = _get_langfuse_data_from_kwargs( + open_ai_resource, new_langfuse, start_time, arg_extractor.get_langfuse_args() + ) + generation = new_langfuse.generation(**generation) + try: + openai_response = await wrapped(**arg_extractor.get_openai_args()) + + if _is_streaming_response(openai_response): + return LangfuseResponseGeneratorAsync( + resource=open_ai_resource, + response=openai_response, + generation=generation, + langfuse=new_langfuse, + is_nested_trace=is_nested_trace, + kwargs=arg_extractor.get_openai_args() + ) + + else: + model, completion, usage = _get_langfuse_data_from_default_response( + open_ai_resource, + (openai_response and openai_response.__dict__) + if _is_openai_v1() + else openai_response, + ) + + # Collect messages + if open_ai_resource.type == "completion": + user_prompt = arg_extractor.get_openai_args().get("prompt", "") + messages = [ + {"role": "user", "content": user_prompt}, + {"role": "assistant", "content": completion}, + ] + message_input = MessageInputs(messages=messages) + elif open_ai_resource.type == "chat": + messages = arg_extractor.get_openai_args().get("messages", []) + messages.append({"role": "assistant", "content": completion["content"]}) + message_input = MessageInputs(messages=messages) + else: + message_input = MessageInputs(messages=[]) + + # Use track_lm + synth_tracker_async.track_lm( + messages=message_input.messages, + model_name=model, + finetune=False + ) + + generation.update( + model=model, + output=completion, + end_time=_get_timestamp(), + usage=usage, + ) + # Avoiding the trace-update if trace-id is provided by user. + if not is_nested_trace: + new_langfuse.trace(id=generation.trace_id, output=completion) + + return openai_response + except Exception as ex: + model = kwargs.get("model", None) or None + generation.update( + end_time=_get_timestamp(), + status_message=str(ex), + level="ERROR", + model=model, + usage={"input_cost": 0, "output_cost": 0, "total_cost": 0}, + ) + raise ex + + +class OpenAILangfuse: + _langfuse: Optional[Langfuse] = None + + def initialize(self): + self._langfuse = LangfuseSingleton().get( + public_key=openai.langfuse_public_key, + secret_key=openai.langfuse_secret_key, + host=openai.langfuse_host, + debug=openai.langfuse_debug, + enabled=openai.langfuse_enabled, + sdk_integration="openai", + sample_rate=openai.langfuse_sample_rate, + ) + + return self._langfuse + + def flush(cls): + cls._langfuse.flush() + + def langfuse_auth_check(self): + """Check if the provided Langfuse credentials (public and secret key) are valid. + + Raises: + Exception: If no projects were found for the provided credentials. + + Note: + This method is blocking. It is discouraged to use it in production code. + """ + if self._langfuse is None: + self.initialize() + + return self._langfuse.auth_check() + + def register_tracing(self): + resources = OPENAI_METHODS_V1 if _is_openai_v1() else OPENAI_METHODS_V0 + + for resource in resources: + if resource.min_version is not None and Version( + openai.__version__ + ) < Version(resource.min_version): + continue + + wrap_function_wrapper( + resource.module, + f"{resource.object}.{resource.method}", + _wrap(resource, self.initialize) + if resource.sync + else _wrap_async(resource, self.initialize), + ) + + setattr(openai, "langfuse_public_key", None) + setattr(openai, "langfuse_secret_key", None) + setattr(openai, "langfuse_host", None) + setattr(openai, "langfuse_debug", None) + setattr(openai, "langfuse_enabled", True) + setattr(openai, "langfuse_sample_rate", None) + setattr(openai, "langfuse_mask", None) + setattr(openai, "langfuse_auth_check", self.langfuse_auth_check) + setattr(openai, "flush_langfuse", self.flush) + + +modifier = OpenAILangfuse() +modifier.register_tracing() + + +# DEPRECATED: Use `openai.langfuse_auth_check()` instead +def auth_check(): + if modifier._langfuse is None: + modifier.initialize() + + return modifier._langfuse.auth_check() + + +def _filter_image_data(messages: List[dict]): + """https://platform.openai.com/docs/guides/vision?lang=python + + The messages array remains the same, but the 'image_url' is removed from the 'content' array. + It should only be removed if the value starts with 'data:image/jpeg;base64,' + + """ + output_messages = copy.deepcopy(messages) + + for message in output_messages: + content = ( + message.get("content", None) + if isinstance(message, dict) + else getattr(message, "content", None) + ) + + if content is not None: + for index, item in enumerate(content): + if isinstance(item, dict) and item.get("image_url", None) is not None: + url = item["image_url"]["url"] + if url.startswith("data:image/"): + del content[index]["image_url"] + + return output_messages + + +class LangfuseResponseGeneratorSync: + def __init__( + self, + *, + resource, + response, + generation, + langfuse, + is_nested_trace, + kwargs, + ): + self.items = [] + self.resource = resource + self.response = response + self.generation = generation + self.langfuse = langfuse + self.is_nested_trace = is_nested_trace + self.kwargs = kwargs + self.completion_start_time = None + + def __iter__(self): + try: + for i in self.response: + self.items.append(i) + + if self.completion_start_time is None: + self.completion_start_time = _get_timestamp() + + yield i + finally: + self._finalize() + + def __next__(self): + try: + item = self.response.__next__() + self.items.append(item) + + if self.completion_start_time is None: + self.completion_start_time = _get_timestamp() + + return item + + except StopIteration: + self._finalize() + + raise + + def __enter__(self): + return self.__iter__() + + def __exit__(self, exc_type, exc_value, traceback): + pass + + def _finalize(self): + model, completion, usage = _extract_streamed_openai_response( + self.resource, self.items + ) + + # Collect messages + if self.resource.type == "completion": + user_prompt = self.kwargs.get("prompt", "") + messages = [ + {"role": "user", "content": user_prompt}, + {"role": "assistant", "content": completion}, + ] + message_input = MessageInputs(messages=messages) + elif self.resource.type == "chat": + messages = self.kwargs.get("messages", []) + messages.append({"role": "assistant", "content": completion["content"]}) + message_input = MessageInputs(messages=messages) + else: + message_input = MessageInputs(messages=[]) + + # Use track_lm + synth_tracker_sync.track_lm( + messages=message_input.messages, + model_name=model, + finetune=False + ) + + # Avoiding the trace-update if trace-id is provided by user. + if not self.is_nested_trace: + self.langfuse.trace(id=self.generation.trace_id, output=completion) + + _create_langfuse_update( + completion, + self.generation, + self.completion_start_time, + model=model, + usage=usage, + ) + + +class LangfuseResponseGeneratorAsync: + def __init__( + self, + *, + resource, + response, + generation, + langfuse, + is_nested_trace, + kwargs, + ): + self.items = [] + + self.resource = resource + self.response = response + self.generation = generation + self.langfuse = langfuse + self.is_nested_trace = is_nested_trace + self.kwargs = kwargs + self.completion_start_time = None + + async def __aiter__(self): + try: + async for i in self.response: + self.items.append(i) + + if self.completion_start_time is None: + self.completion_start_time = _get_timestamp() + + yield i + finally: + await self._finalize() + + async def __anext__(self): + try: + item = await self.response.__anext__() + self.items.append(item) + + if self.completion_start_time is None: + self.completion_start_time = _get_timestamp() + + return item + + except StopAsyncIteration: + await self._finalize() + + raise + + async def __aenter__(self): + return self.__aiter__() + + async def __aexit__(self, exc_type, exc_value, traceback): + pass + + async def _finalize(self): + model, completion, usage = _extract_streamed_openai_response( + self.resource, self.items + ) + + # Collect messages + if self.resource.type == "completion": + user_prompt = self.kwargs.get("prompt", "") + messages = [ + {"role": "user", "content": user_prompt}, + {"role": "assistant", "content": completion}, + ] + message_input = MessageInputs(messages=messages) + elif self.resource.type == "chat": + messages = self.kwargs.get("messages", []) + messages.append({"role": "assistant", "content": completion["content"]}) + message_input = MessageInputs(messages=messages) + else: + message_input = MessageInputs(messages=[]) + + # Use track_lm + synth_tracker_async.track_lm( + messages=message_input.messages, + model_name=model, + finetune=False + ) + + # Avoiding the trace-update if trace-id is provided by user. + if not self.is_nested_trace: + self.langfuse.trace(id=self.generation.trace_id, output=completion) + + _create_langfuse_update( + completion, + self.generation, + self.completion_start_time, + model=model, + usage=usage, + ) + + async def close(self) -> None: + """Close the response and release the connection. + + Automatically called if the response body is read to completion. + """ + await self.response.close() diff --git a/synth_sdk/tracing/abstractions.py b/synth_sdk/tracing/abstractions.py index 369f502..b6a20b0 100644 --- a/synth_sdk/tracing/abstractions.py +++ b/synth_sdk/tracing/abstractions.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, List, Dict, Optional, Union +from typing import Any, List, Dict, Optional, Union, Literal from pydantic import BaseModel import logging from synth_sdk.tracing.config import VALID_TYPES @@ -7,33 +7,54 @@ logger = logging.getLogger(__name__) +@dataclass +class MessageInputs: + messages: List[Dict[str, str]] # {"role": "", "content": ""} + + +@dataclass +class ArbitraryInputs: + inputs: Dict[str, Any] + + +@dataclass +class MessageOutputs: + messages: List[Dict[str, str]] + + +@dataclass +class ArbitraryOutputs: + outputs: Dict[str, Any] + + @dataclass class ComputeStep: event_order: int - compute_ended: Any # time step - compute_began: Any # time step - compute_input: Dict[str, Any] # {variable_name: value} - compute_output: Dict[str, Any] # {variable_name: value} + compute_ended: Any # timestamp + compute_began: Any # timestamp + compute_input: List[Any] + compute_output: List[Any] def to_dict(self): - # Define serializable types - #serializable_types = (str, int, float, bool, list, dict, type(None)) - - # Filter compute_input - serializable_input = {} - for name, value in self.compute_input.items(): - if isinstance(value, VALID_TYPES): - serializable_input[name] = value - else: - logger.warning(f"Skipping non-serializable input: {name}={value}") - - # Filter compute_output - serializable_output = {} - for name, value in self.compute_output.items(): - if isinstance(value, VALID_TYPES): - serializable_output[name] = value - else: - logger.warning(f"Skipping non-serializable output: {name}={value}") + # Serialize compute_input + serializable_input = [ + input_item.__dict__ for input_item in self.compute_input + if isinstance(input_item, (MessageInputs, ArbitraryInputs)) + ] + + # Serialize compute_output + serializable_output = [ + output_item.__dict__ for output_item in self.compute_output + if isinstance(output_item, (MessageOutputs, ArbitraryOutputs)) + ] + + # Warn about non-serializable inputs/outputs + for item in self.compute_input: + if not isinstance(item, (MessageInputs, ArbitraryInputs)): + logger.warning(f"Skipping non-serializable input: {item}") + for item in self.compute_output: + if not isinstance(item, (MessageOutputs, ArbitraryOutputs)): + logger.warning(f"Skipping non-serializable output: {item}") return { "event_order": self.event_order, @@ -44,20 +65,26 @@ def to_dict(self): } +@dataclass class AgentComputeStep(ComputeStep): - pass + model_name: Optional[str] = None + compute_input: List[Union[MessageInputs, ArbitraryInputs]] + compute_output: List[Union[MessageOutputs, ArbitraryOutputs]] +@dataclass class EnvironmentComputeStep(ComputeStep): - pass + compute_input: List[ArbitraryInputs] + compute_output: List[ArbitraryOutputs] @dataclass class Event: + system_id: str event_type: str - opened: Any # time stamp - closed: Any # time stamp - partition_index: int # New field + opened: Any # timestamp + closed: Any # timestamp + partition_index: int agent_compute_steps: List[AgentComputeStep] environment_compute_steps: List[EnvironmentComputeStep] diff --git a/synth_sdk/tracing/decorators.py b/synth_sdk/tracing/decorators.py index 58b4896..93f624e 100644 --- a/synth_sdk/tracing/decorators.py +++ b/synth_sdk/tracing/decorators.py @@ -1,5 +1,5 @@ # synth_sdk/tracing/decorators.py -from typing import Callable, Optional, Set, Literal, Any, Dict, Tuple, Union +from typing import Callable, Optional, Set, Literal, Any, Dict, Tuple, Union, List from functools import wraps import threading import time @@ -14,7 +14,7 @@ ) from synth_sdk.tracing.events.store import event_store from synth_sdk.tracing.local import _local, logger -from synth_sdk.tracing.trackers import synth_tracker_sync, synth_tracker_async +from synth_sdk.tracing.trackers import synth_tracker_sync, synth_tracker_async, SynthTracker from synth_sdk.tracing.events.manage import set_current_event from typing import Callable, Optional, Set, Literal, Any, Dict, Tuple, Union @@ -24,6 +24,10 @@ from pydantic import BaseModel from synth_sdk.tracing.abstractions import ( + ArbitraryInputs, + ArbitraryOutputs, + MessageInputs, + MessageOutputs, Event, AgentComputeStep, EnvironmentComputeStep, @@ -45,6 +49,7 @@ def trace_system_sync( manage_event: Literal["create", "end", "lazy_end", None] = None, increment_partition: bool = False, verbose: bool = False, + finetune_step: bool = True, ) -> Callable: def decorator(func: Callable) -> Callable: @wraps(func) @@ -79,6 +84,7 @@ def wrapper(*args, **kwargs): if manage_event == "create": logger.debug("Creating new event") event = Event( + system_id=_local.system_id, event_type=event_type, opened=compute_began, closed=None, @@ -103,30 +109,69 @@ def wrapper(*args, **kwargs): for param, value in bound_args.arguments.items(): if param == "self": continue - synth_tracker_sync.track_input(value, param, origin) + synth_tracker_sync.track_state( + variable_name=param, + variable_value=value, + origin=origin + ) # Execute the function result = func(*args, **kwargs) # Automatically trace function output - if log_result: - synth_tracker_sync.track_output(result, "result", origin) + synth_tracker_sync.track_state( + variable_name="result", + variable_value=result, + origin=origin + ) # Collect traced inputs and outputs traced_inputs, traced_outputs = synth_tracker_sync.get_traced_data() compute_steps_by_origin: Dict[ - Literal["agent", "environment"], Dict[str, Dict[str, Any]] + Literal["agent", "environment"], Dict[str, List[Any]] ] = { - "agent": {"inputs": {}, "outputs": {}}, - "environment": {"inputs": {}, "outputs": {}}, + "agent": {"inputs": [], "outputs": []}, + "environment": {"inputs": [], "outputs": []}, } # Organize traced data by origin - for var_origin, var, var_name, _ in traced_inputs: - compute_steps_by_origin[var_origin]["inputs"][var_name] = var - for var_origin, var, var_name, _ in traced_outputs: - compute_steps_by_origin[var_origin]["outputs"][var_name] = var + for item in traced_inputs: + var_origin = item['origin'] + if 'variable_value' in item and 'variable_name' in item: + # Standard variable input + compute_steps_by_origin[var_origin]["inputs"].append( + ArbitraryInputs(inputs={item['variable_name']: item['variable_value']}) + ) + elif 'messages' in item: + # Message input from track_lm + compute_steps_by_origin[var_origin]["inputs"].append( + MessageInputs(messages=item['messages']) + ) + compute_steps_by_origin[var_origin]["inputs"].append( + ArbitraryInputs(inputs={"model_name": item["model_name"]}) + ) + finetune = item["finetune"] or finetune_step + compute_steps_by_origin[var_origin]["inputs"].append( + ArbitraryInputs(inputs={"finetune": finetune}) + ) + else: + logger.warning(f"Unhandled traced input item: {item}") + + for item in traced_outputs: + var_origin = item['origin'] + if 'variable_value' in item and 'variable_name' in item: + # Standard variable output + compute_steps_by_origin[var_origin]["outputs"].append( + ArbitraryOutputs(outputs={item['variable_name']: item['variable_value']}) + ) + elif 'messages' in item: + # Message output from track_lm + compute_steps_by_origin[var_origin]["outputs"].append( + MessageOutputs(messages=item['messages']) + ) + else: + logger.warning(f"Unhandled traced output item: {item}") # Capture compute end time compute_ended = time.time() @@ -209,6 +254,7 @@ def trace_system_async( manage_event: Literal["create", "end", "lazy_end", None] = None, increment_partition: bool = False, verbose: bool = False, + finetune_step: bool = True, ) -> Callable: """Decorator for tracing asynchronous functions.""" @@ -247,6 +293,7 @@ async def async_wrapper(*args, **kwargs): if manage_event == "create": logger.debug("Creating new event") event = Event( + system_id=self_instance.system_id, event_type=event_type, opened=compute_began, closed=None, @@ -271,31 +318,73 @@ async def async_wrapper(*args, **kwargs): for param, value in bound_args.arguments.items(): if param == "self": continue - synth_tracker_async.track_input(value, param, origin) + synth_tracker_async.track_state( + variable_name=param, + variable_value=value, + origin=origin, + io_type="input" + ) # Execute the coroutine result = await func(*args, **kwargs) # Automatically trace function output - if log_result: - synth_tracker_async.track_output(result, "result", origin) + synth_tracker_async.track_state( + variable_name="result", + variable_value=result, + origin=origin, + io_type="output" + ) # Collect traced inputs and outputs traced_inputs, traced_outputs = synth_tracker_async.get_traced_data() compute_steps_by_origin: Dict[ - Literal["agent", "environment"], Dict[str, Dict[str, Any]] + Literal["agent", "environment"], Dict[str, List[Any]] ] = { - "agent": {"inputs": {}, "outputs": {}}, - "environment": {"inputs": {}, "outputs": {}}, + "agent": {"inputs": [], "outputs": []}, + "environment": {"inputs": [], "outputs": []}, } # Organize traced data by origin - for var_origin, var, var_name, _ in traced_inputs: - compute_steps_by_origin[var_origin]["inputs"][var_name] = var - for var_origin, var, var_name, _ in traced_outputs: - compute_steps_by_origin[var_origin]["outputs"][var_name] = var + for item in traced_inputs: + var_origin = item['origin'] + if 'variable_value' in item and 'variable_name' in item: + # Standard variable input + compute_steps_by_origin[var_origin]["inputs"].append( + ArbitraryInputs(inputs={item['variable_name']: item['variable_value']}) + ) + elif 'messages' in item: + # Message input from track_lm + compute_steps_by_origin[var_origin]["inputs"].append( + MessageInputs(messages=item['messages']) + ) + compute_steps_by_origin[var_origin]["inputs"].append( + ArbitraryInputs(inputs={"model_name": item["model_name"]}) + ) + finetune = finetune_step or item["finetune"] + compute_steps_by_origin[var_origin]["inputs"].append( + ArbitraryInputs(inputs={"finetune": finetune}) + ) + else: + logger.warning(f"Unhandled traced input item: {item}") + + for item in traced_outputs: + var_origin = item['origin'] + if 'variable_value' in item and 'variable_name' in item: + # Standard variable output + compute_steps_by_origin[var_origin]["outputs"].append( + ArbitraryOutputs(outputs={item['variable_name']: item['variable_value']}) + ) + elif 'messages' in item: + # Message output from track_lm + compute_steps_by_origin[var_origin]["outputs"].append( + MessageOutputs(messages=item['messages']) + ) + else: + logger.warning(f"Unhandled traced output item: {item}") + print("COMPUTE STEPS BY ORIGIN", compute_steps_by_origin) # Capture compute end time compute_ended = time.time() @@ -337,6 +426,7 @@ async def async_wrapper(*args, **kwargs): f"Added compute step for {var_origin}: {compute_step.to_dict()}" ) + print("EVENT", event) # Optionally log the function result if log_result: logger.info(f"Function result: {result}") diff --git a/synth_sdk/tracing/events/manage.py b/synth_sdk/tracing/events/manage.py index 7f2f182..2267eff 100644 --- a/synth_sdk/tracing/events/manage.py +++ b/synth_sdk/tracing/events/manage.py @@ -32,35 +32,71 @@ def set_current_event(event: Optional["Event"]): logger.debug(f"Setting current event of type {event.event_type}") - if not hasattr(_local, "active_events"): - _local.active_events = {} - logger.debug("Initialized active_events in thread local storage") - - # If there's an existing event of the same type, end it - if event.event_type in _local.active_events: - logger.debug(f"Found existing event of type {event.event_type}") - existing_event = _local.active_events[event.event_type] - existing_event.closed = time.time() - logger.debug( - f"Closed existing event of type {event.event_type} at {existing_event.closed}" - ) - - # Store the closed event if system_id is present - if hasattr(_local, "system_id"): - logger.debug(f"Storing closed event for system {_local.system_id}") - try: - event_store.add_event(_local.system_id, existing_event) - logger.debug("Successfully stored closed event") - except Exception as e: - logger.error(f"Failed to store closed event: {str(e)}") - raise + # Check if we're in an async context + try: + import asyncio + asyncio.get_running_loop() + is_async = True + except RuntimeError: + is_async = False + + if is_async: + from synth_sdk.tracing.local import active_events_var, system_id_var + # Get current active events from context var + active_events = active_events_var.get() + + # If there's an existing event of the same type, end it + if event.event_type in active_events: + logger.debug(f"Found existing event of type {event.event_type}") + existing_event = active_events[event.event_type] + existing_event.closed = time.time() + logger.debug( + f"Closed existing event of type {event.event_type} at {existing_event.closed}" + ) + + # Store the closed event if system_id is present + system_id = system_id_var.get() + if system_id: + logger.debug(f"Storing closed event for system {system_id}") + try: + event_store.add_event(system_id, existing_event) + logger.debug("Successfully stored closed event") + except Exception as e: + logger.error(f"Failed to store closed event: {str(e)}") + raise + # Set the new event + active_events[event.event_type] = event + active_events_var.set(active_events) + logger.debug("New event set as current in context vars") else: - logger.debug(f"No existing event of type {event.event_type}") + # Original thread-local storage logic + if not hasattr(_local, "active_events"): + _local.active_events = {} + logger.debug("Initialized active_events in thread local storage") + + # If there's an existing event of the same type, end it + if event.event_type in _local.active_events: + logger.debug(f"Found existing event of type {event.event_type}") + existing_event = _local.active_events[event.event_type] + existing_event.closed = time.time() + logger.debug( + f"Closed existing event of type {event.event_type} at {existing_event.closed}" + ) + + # Store the closed event if system_id is present + if hasattr(_local, "system_id"): + logger.debug(f"Storing closed event for system {_local.system_id}") + try: + event_store.add_event(_local.system_id, existing_event) + logger.debug("Successfully stored closed event") + except Exception as e: + logger.error(f"Failed to store closed event: {str(e)}") + raise - # Set the new event - _local.active_events[event.event_type] = event - logger.debug("New event set as current") + # Set the new event + _local.active_events[event.event_type] = event + logger.debug("New event set as current in thread local") def clear_current_event(event_type: str): diff --git a/synth_sdk/tracing/events/scope.py b/synth_sdk/tracing/events/scope.py index 771a2ad..bdefb86 100644 --- a/synth_sdk/tracing/events/scope.py +++ b/synth_sdk/tracing/events/scope.py @@ -1,10 +1,9 @@ from contextlib import contextmanager import time -from synth_sdk.tracing.abstractions import ( - Event, -) +from synth_sdk.tracing.abstractions import Event from synth_sdk.tracing.decorators import set_current_event, clear_current_event, _local from synth_sdk.tracing.events.store import event_store +from synth_sdk.tracing.local import system_id_var @contextmanager @@ -16,7 +15,19 @@ def event_scope(event_type: str): with event_scope("my_event_type"): # do stuff """ + # Check if we're in an async context + try: + import asyncio + asyncio.get_running_loop() + is_async = True + except RuntimeError: + is_async = False + + # Get system_id from appropriate source + system_id = system_id_var.get() if is_async else getattr(_local, "system_id", None) + event = Event( + system_id=system_id, event_type=event_type, opened=time.time(), closed=None, @@ -31,6 +42,6 @@ def event_scope(event_type: str): finally: event.closed = time.time() clear_current_event(event_type) - # Store the event - if hasattr(_local, "system_id"): - event_store.add_event(_local.system_id, event) + # Store the event if system_id is available + if system_id: + event_store.add_event(system_id, event) diff --git a/synth_sdk/tracing/events/store.py b/synth_sdk/tracing/events/store.py index 372b25a..ad29785 100644 --- a/synth_sdk/tracing/events/store.py +++ b/synth_sdk/tracing/events/store.py @@ -1,10 +1,12 @@ import json import threading import logging +import time from typing import Dict, List, Optional from synth_sdk.tracing.abstractions import SystemTrace, EventPartitionElement, Event from synth_sdk.tracing.config import tracer # Update this import line from threading import RLock # Change this import +from synth_sdk.tracing.local import _local, system_id_var, active_events_var # Import context variables logger = logging.getLogger(__name__) @@ -76,50 +78,85 @@ def add_event(self, system_id: str, event: Event): self.logger.debug( f"Event details: opened={event.opened}, closed={event.closed}, partition={event.partition_index}" ) + print("Adding event to partition") + + #try: + if not self._lock.acquire(timeout=5): + self.logger.error("Failed to acquire lock within timeout period") + return try: - if not self._lock.acquire(timeout=5): - self.logger.error("Failed to acquire lock within timeout period") - return - - try: - system_trace = self.get_or_create_system_trace(system_id) - self.logger.debug( - f"Got system trace with {len(system_trace.partition)} partitions" - ) + system_trace = self.get_or_create_system_trace(system_id) + self.logger.debug( + f"Got system trace with {len(system_trace.partition)} partitions" + ) - current_partition = next( - ( - p - for p in system_trace.partition - if p.partition_index == event.partition_index - ), - None, - ) + current_partition = next( + ( + p + for p in system_trace.partition + if p.partition_index == event.partition_index + ), + None, + ) - if current_partition is None: - self.logger.error( - f"No partition found for index {event.partition_index} - existing partitions: {set([p.partition_index for p in system_trace.partition])}" - ) - raise ValueError( - f"No partition found for index {event.partition_index}" - ) - - current_partition.events.append(event) - self.logger.debug( - f"Added event to partition {event.partition_index}. Total events: {len(current_partition.events)}" + if current_partition is None: + self.logger.error( + f"No partition found for index {event.partition_index} - existing partitions: {set([p.partition_index for p in system_trace.partition])}" ) - finally: - self._lock.release() - except Exception as e: - self.logger.error(f"Error in add_event: {str(e)}", exc_info=True) - raise + raise ValueError( + f"No partition found for index {event.partition_index}" + ) + + + current_partition.events.append(event) + self.logger.debug( + f"Added event to partition {event.partition_index}. Total events: {len(current_partition.events)}" + ) + finally: + self._lock.release() + # except Exception as e: + # self.logger.error(f"Error in add_event: {str(e)}", exc_info=True) + # raise def get_system_traces(self) -> List[SystemTrace]: """Get all system traces.""" with self._lock: + self.end_all_active_events() + return list(self._traces.values()) + def end_all_active_events(self): + """End all active events and store them.""" + self.logger.debug("Ending all active events") + + # For synchronous code + if hasattr(_local, "active_events"): + active_events = _local.active_events + system_id = getattr(_local, "system_id", None) + if active_events:# and system_id: + for event_type, event in list(active_events.items()): + if event.closed is None: + event.closed = time.time() + self.add_event(event.system_id, event) + self.logger.debug(f"Stored and closed event {event_type}") + _local.active_events.clear() + + # For asynchronous code + active_events_async = active_events_var.get() + # Use preserved system ID if available, otherwise try to get from context + # system_id_async = preserved_system_id or system_id_var.get(None) + # print("System ID async: ", system_id_async) + # raise ValueError("Test error") + + if active_events_async:# and system_id_async: + for event_type, event in list(active_events_async.items()): + if event.closed is None: + event.closed = time.time() + self.add_event(event.system_id, event) + self.logger.debug(f"Stored and closed event {event_type}") + active_events_var.set({}) + def get_system_traces_json(self) -> str: """Get all system traces as JSON.""" with self._lock: diff --git a/synth_sdk/tracing/trackers.py b/synth_sdk/tracing/trackers.py index 86e3c15..2afaa14 100644 --- a/synth_sdk/tracing/trackers.py +++ b/synth_sdk/tracing/trackers.py @@ -1,68 +1,66 @@ -from typing import Union, Optional, Tuple, Literal +from typing import Union, Optional, Tuple, List, Dict, Literal import asyncio -import threading, contextvars +import threading import contextvars from pydantic import BaseModel from synth_sdk.tracing.local import logger, _local from synth_sdk.tracing.config import VALID_TYPES +from synth_sdk.tracing.abstractions import MessageInputs, MessageOutputs + +# Existing SynthTrackerSync and SynthTrackerAsync classes... -# This tracker ought to be used for synchronous tracing class SynthTrackerSync: _local = _local @classmethod def initialize(cls): cls._local.initialized = True - cls._local.inputs = [] # List of tuples: (origin, var) - cls._local.outputs = [] # List of tuples: (origin, var) + cls._local.inputs = [] + cls._local.outputs = [] @classmethod - def track_input( + def track_lm( cls, - var: Union[BaseModel, str, dict, int, float, bool, list, None], - variable_name: str, - origin: Literal["agent", "environment"], - annotation: Optional[str] = None, + messages: List[Dict[str, str]], + model_name: str, + finetune: bool = False, ): - if not isinstance(var, VALID_TYPES): - raise TypeError( - f"Variable {variable_name} must be one of {VALID_TYPES}, got {type(var)}" - ) - if getattr(cls._local, "initialized", False): - # Convert Pydantic models to dict schema - if isinstance(var, BaseModel): - var = var.model_dump() - cls._local.inputs.append((origin, var, variable_name, annotation)) - logger.debug( - f"Traced input: origin={origin}, var_name={variable_name}, annotation={annotation}" - ) + cls._local.inputs.append({ + "origin": "agent", + "messages": messages, + "model_name": model_name, + "finetune": finetune, + }) + logger.debug("Tracked LM interaction") else: raise RuntimeError( "Trace not initialized. Use within a function decorated with @trace_system_sync." ) @classmethod - def track_output( + def track_state( cls, - var: Union[BaseModel, str, dict, int, float, bool, list, None], variable_name: str, + variable_value: Union[BaseModel, str, dict, int, float, bool, list, None], origin: Literal["agent", "environment"], annotation: Optional[str] = None, ): - if not isinstance(var, VALID_TYPES): + if not isinstance(variable_value, VALID_TYPES): raise TypeError( - f"Variable {variable_name} must be one of {VALID_TYPES}, got {type(var)}" + f"Variable {variable_name} must be one of {VALID_TYPES}, got {type(variable_value)}" ) if getattr(cls._local, "initialized", False): - # Convert Pydantic models to dict schema - if isinstance(var, BaseModel): - var = var.model_dump() - cls._local.outputs.append((origin, var, variable_name, annotation)) - logger.debug( - f"Traced output: origin={origin}, var_name={variable_name}, annotation={annotation}" - ) + if isinstance(variable_value, BaseModel): + variable_value = variable_value.model_dump() + cls._local.outputs.append({ + "origin": origin, + "variable_name": variable_name, + "variable_value": variable_value, + "annotation": annotation, + }) + logger.debug(f"Tracked state: {variable_name}") else: raise RuntimeError( "Trace not initialized. Use within a function decorated with @trace_system_sync." @@ -86,8 +84,6 @@ def finalize(cls): trace_outputs_var = contextvars.ContextVar("trace_outputs", default=None) trace_initialized_var = contextvars.ContextVar("trace_initialized", default=False) - -# This tracker ought to be used for asynchronous tracing class SynthTrackerAsync: @classmethod def initialize(cls): @@ -97,64 +93,73 @@ def initialize(cls): logger.debug("AsyncTrace initialized") @classmethod - def track_input( + def track_lm( cls, - var: Union[BaseModel, str, dict, int, float, bool, list, None], - variable_name: str, - origin: Literal["agent", "environment"], - annotation: Optional[str] = None, + messages: List[Dict[str, str]], + model_name: str, + finetune: bool = False, ): - if not isinstance(var, VALID_TYPES): - raise TypeError( - f"Variable {variable_name} must be one of {VALID_TYPES}, got {type(var)}" - ) - if trace_initialized_var.get(): - # Convert Pydantic models to dict schema - if isinstance(var, BaseModel): - var = var.model_dump() trace_inputs = trace_inputs_var.get() - trace_inputs.append((origin, var, variable_name, annotation)) + trace_inputs.append({ + "origin": "agent", + "messages": messages, + "model_name": model_name, + "finetune": finetune, + }) trace_inputs_var.set(trace_inputs) - logger.debug( - f"Traced input: origin={origin}, var_name={variable_name}, annotation={annotation}" - ) + logger.debug("Tracked LM interaction") else: raise RuntimeError( "Trace not initialized. Use within a function decorated with @trace_system_async." ) @classmethod - def track_output( + def track_state( cls, - var: Union[BaseModel, str, dict, int, float, bool, list, None], variable_name: str, + variable_value: Union[BaseModel, str, dict, int, float, bool, list, None], origin: Literal["agent", "environment"], annotation: Optional[str] = None, + io_type: Literal["input", "output"] = "output", ): - if not isinstance(var, VALID_TYPES): + if not isinstance(variable_value, VALID_TYPES): raise TypeError( - f"Variable {variable_name} must be one of {VALID_TYPES}, got {type(var)}" + f"Variable {variable_name} must be one of {VALID_TYPES}, got {type(variable_value)}" ) if trace_initialized_var.get(): - # Convert Pydantic models to dict schema - if isinstance(var, BaseModel): - var = var.model_dump() + if isinstance(variable_value, BaseModel): + variable_value = variable_value.model_dump() trace_outputs = trace_outputs_var.get() - trace_outputs.append((origin, var, variable_name, annotation)) - trace_outputs_var.set(trace_outputs) - logger.debug( - f"Traced output: origin={origin}, var_name={variable_name}, annotation={annotation}" - ) + if io_type == "input": + trace_inputs = trace_inputs_var.get() + trace_inputs.append({ + "origin": origin, + "variable_name": variable_name, + "variable_value": variable_value, + "annotation": annotation, + }) + trace_inputs_var.set(trace_inputs) + else: + trace_outputs.append({ + "origin": origin, + "variable_name": variable_name, + "variable_value": variable_value, + "annotation": annotation, + }) + trace_outputs_var.set(trace_outputs) + logger.debug(f"Tracked state: {variable_name}") else: raise RuntimeError( "Trace not initialized. Use within a function decorated with @trace_system_async." ) @classmethod - def get_traced_data(cls) -> Tuple[list, list]: - return trace_inputs_var.get(), trace_outputs_var.get() + def get_traced_data(cls): + traced_inputs = trace_inputs_var.get() + traced_outputs = trace_outputs_var.get() + return traced_inputs, traced_outputs @classmethod def finalize(cls): @@ -163,86 +168,67 @@ def finalize(cls): trace_outputs_var.set([]) logger.debug("Finalized async trace data") + # Make traces available globally synth_tracker_sync = SynthTrackerSync synth_tracker_async = SynthTrackerAsync -# Generalized SynthTracker class, depending on if an event loop is running (called from async) -# & if the specified tracker is initalized will determine the appropriate tracker to use class SynthTracker: - def is_called_by_async(): + @classmethod + def is_called_by_async(cls): try: asyncio.get_running_loop() # Attempt to get the running event loop return True # If successful, we are in an async context except RuntimeError: return False # If there's no running event loop, we are in a sync context - - # SynthTracker Async & Sync are initalized by the decorators that wrap the - # respective async & sync functions - @classmethod - def initialize(cls): - pass @classmethod - def track_input( + def track_lm( cls, - var: Union[BaseModel, str, dict, int, float, bool, list, None], - variable_name: str, - origin: Literal["agent", "environment"], - annotation: Optional[str] = None, - async_sync: Literal["async", "sync", ""] = "", # Force the tracker to be async or sync + messages: List[Dict[str, str]], + model_name: str, + finetune: bool = False, ): - - if async_sync == "async" or cls.is_called_by_async() and trace_initialized_var.get(): - logger.debug("Using async tracker to track input") - synth_tracker_async.track_input(var, variable_name, origin, annotation) - - # don't want to add the same event to both trackers - elif async_sync == "sync" or hasattr(synth_tracker_sync._local, "initialized"): - logger.debug("Using sync tracker to track input") - synth_tracker_sync.track_input(var, variable_name, origin, annotation) + if cls.is_called_by_async() and trace_initialized_var.get(): + logger.debug("Using async tracker to track LM") + synth_tracker_async.track_lm(messages, model_name, finetune) + elif getattr(synth_tracker_sync._local, "initialized", False): + logger.debug("Using sync tracker to track LM") + synth_tracker_sync.track_lm(messages, model_name, finetune) else: - raise RuntimeError("track_input() \n Trace not initialized. Use within a function decorated with @trace_system_async or @trace_system_sync.") - + raise RuntimeError("Trace not initialized in track_lm.") + @classmethod - def track_output( + def track_state( cls, - var: Union[BaseModel, str, dict, int, float, bool, list, None], variable_name: str, + variable_value: Union[BaseModel, str, dict, int, float, bool, list, None], origin: Literal["agent", "environment"], annotation: Optional[str] = None, - async_sync: Literal["async", "sync", ""] = "", # Force the tracker to be async or sync ): - if async_sync == "async" or cls.is_called_by_async() and trace_initialized_var.get(): - logger.debug("Using async tracker to track output") - synth_tracker_async.track_output(var, variable_name, origin, annotation) - - # don't want to add the same event to both trackers - elif async_sync == "sync" or hasattr(synth_tracker_sync._local, "initialized"): - logger.debug("Using sync tracker to track output") - synth_tracker_sync.track_output(var, variable_name, origin, annotation) + if cls.is_called_by_async() and trace_initialized_var.get(): + logger.debug("Using async tracker to track state") + synth_tracker_async.track_state(variable_name, variable_value, origin, annotation) + elif getattr(synth_tracker_sync._local, "initialized", False): + logger.debug("Using sync tracker to track state") + synth_tracker_sync.track_state(variable_name, variable_value, origin, annotation) else: - raise RuntimeError("track_output() \n Trace not initialized. Use within a function decorated with @trace_system_async or @trace_system_sync.") + raise RuntimeError("Trace not initialized in track_state.") - - # if both trackers have been used, want to return both sets of data @classmethod def get_traced_data( cls, - async_sync: Literal["async", "sync", ""] = "", # Force only async or sync data to be returned + async_sync: Literal["async", "sync", ""] = "", # Force only async or sync data to be returned ) -> Tuple[list, list]: - traced_inputs, traced_outputs = [], [] - if async_sync == "async" or async_sync == "": - # Attempt to get the traced data from the async tracker + if async_sync in ["async", ""]: logger.debug("Getting traced data from async tracker") traced_inputs_async, traced_outputs_async = synth_tracker_async.get_traced_data() traced_inputs.extend(traced_inputs_async) traced_outputs.extend(traced_outputs_async) - if async_sync == "sync" or async_sync == "": - # Attempt to get the traced data from the sync tracker + if async_sync in ["sync", ""]: logger.debug("Getting traced data from sync tracker") traced_inputs_sync, traced_outputs_sync = synth_tracker_sync.get_traced_data() traced_inputs.extend(traced_inputs_sync) @@ -251,17 +237,3 @@ def get_traced_data( # TODO: Test that the order of the inputs and outputs is correct wrt # the order of events since we are combining the two trackers return traced_inputs, traced_outputs - - # Finalize both trackers - @classmethod - def finalize( - cls, - async_sync: Literal["async", "sync", ""] = "", - ): - if async_sync == "async" or async_sync == "": - logger.debug("Finalizing async tracker") - synth_tracker_async.finalize() - - if async_sync == "sync" or async_sync == "": - logger.debug("Finalizing sync tracker") - synth_tracker_sync.finalize() diff --git a/testing/ai_agent_async.py b/testing/ai_agent_async.py index 3dfa624..7d5215f 100644 --- a/testing/ai_agent_async.py +++ b/testing/ai_agent_async.py @@ -38,14 +38,25 @@ def __init__(self): ) async def make_lm_call(self, user_message: str) -> str: # Only pass the user message, not self - SynthTracker.track_input([user_message], variable_name="user_message", origin="agent") + #SynthTracker.track_input([user_message], variable_name="user_message", origin="agent") logger.debug("Starting LM call with message: %s", user_message) response = await self.lm.respond_async( system_message="You are a helpful assistant.", user_message=user_message ) + SynthTracker.track_lm( + messages = [{"role": "user", "content": user_message}, {"role": "assistant", "content": response}], + model_name = self.lm.model_name, + finetune = False + ) + SynthTracker.track_state( + variable_name = "minecraft_screen_description", + variable_value = None, + origin = "environment", + annotation = "Minecraft screen description" + ) - SynthTracker.track_output(response, variable_name="response", origin="agent") + #SynthTracker.track_output(response, variable_name="response", origin="agent") logger.debug("LM response received: %s", response) time.sleep(0.1) @@ -59,11 +70,20 @@ async def make_lm_call(self, user_message: str) -> str: ) async def process_environment(self, input_data: str) -> dict: # Only pass the input data, not self - SynthTracker.track_input([input_data], variable_name="input_data", origin="environment") + SynthTracker.track_state( + variable_name="input_data", + variable_value=input_data, + origin="environment", + annotation=None # Optional: you can add an annotation if needed + ) result = {"processed": input_data, "timestamp": time.time()} - SynthTracker.track_output(result, variable_name="result", origin="environment") + SynthTracker.track_state( + variable_name="result", + variable_value=result, + origin="environment" + ) return result diff --git a/testing/ai_agent_sync.py b/testing/ai_agent_sync.py index 0f2f082..0a6d488 100644 --- a/testing/ai_agent_sync.py +++ b/testing/ai_agent_sync.py @@ -37,15 +37,26 @@ def __init__(self): verbose=True, ) def make_lm_call(self, user_message: str) -> str: - # Only pass the user message, not self - SynthTracker.track_input([user_message], variable_name="user_message", origin="agent") - logger.debug("Starting LM call with message: %s", user_message) response = self.lm.respond_sync( system_message="You are a helpful assistant.", user_message=user_message ) - - SynthTracker.track_output(response, variable_name="response", origin="agent") + + # Track LM interaction + SynthTracker.track_lm( + messages=[{"role": "user", "content": user_message}, + {"role": "assistant", "content": response}], + model_name=self.lm.model_name, + finetune=False + ) + + # Track additional state if needed + SynthTracker.track_state( + variable_name="minecraft_screen_description", + variable_value=None, + origin="environment", + annotation="Minecraft screen description" + ) logger.debug("LM response received: %s", response) time.sleep(0.1) @@ -58,22 +69,30 @@ def make_lm_call(self, user_message: str) -> str: verbose=True, ) def process_environment(self, input_data: str) -> dict: - # Only pass the input data, not self - SynthTracker.track_input([input_data], variable_name="input_data", origin="environment") + # Track input state + SynthTracker.track_state( + variable_name="input_data", + variable_value=input_data, + origin="environment", + annotation=None + ) result = {"processed": input_data, "timestamp": time.time()} - SynthTracker.track_output(result, variable_name="result", origin="environment") + # Track result state + SynthTracker.track_state( + variable_name="result", + variable_value=result, + origin="environment" + ) return result async def run_test(): logger.info("Starting run_test") - # Create test agent agent = TestAgent() try: - # List of test questions questions = [ "What's the capital of France?", "What's 2+2?", @@ -81,16 +100,13 @@ async def run_test(): ] logger.debug("Test questions initialized: %s", questions) - # Make multiple LM calls with environment processing responses = [] for i, question in enumerate(questions): logger.info("Processing question %d: %s", i, question) try: - # First process in environment env_result = agent.process_environment(question) logger.debug("Environment processing result: %s", env_result) - # Then make LM call response = agent.make_lm_call(question) responses.append(response) logger.debug("Response received and stored: %s", response) @@ -169,6 +185,7 @@ async def run_test(): ) logger.info("Cleanup completed") + # Run a sample agent using the sync decorator and tracker if __name__ == "__main__": logger.info("Starting main execution") diff --git a/testing/openai_autologging.py b/testing/openai_autologging.py new file mode 100644 index 0000000..c1f81ff --- /dev/null +++ b/testing/openai_autologging.py @@ -0,0 +1,67 @@ +import asyncio +import logging +import json +import os +from typing import List +from synth_sdk.provider_support.openai_lf import AsyncOpenAI +from synth_sdk.tracing.decorators import trace_system +from synth_sdk.tracing.events.store import event_store +from synth_sdk.tracing.abstractions import Event, SystemTrace + +from dotenv import load_dotenv + +# Configure logging +logging.basicConfig( + level=logging.DEBUG, # Set to DEBUG to capture all logs + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) +logger = logging.getLogger(__name__) + +class OpenAIAgent: + def __init__(self): + self.system_id = "openai_agent_async_test" + logger.debug("Initializing OpenAIAgent with system_id: %s", self.system_id) + self.openai = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY")) # Replace with your actual API key + load_dotenv() + @trace_system( + origin="agent", + event_type="openai_completion", + manage_event="create", + increment_partition=True, + verbose=True, + ) + async def get_completion(self, prompt: str) -> str: + logger.debug("Sending prompt to OpenAI: %s", prompt) + try: + response = await self.openai.chat.completions.create( + model="gpt-4o-mini-2024-07-18", + messages = [{"role": "user", "content": prompt}], + max_tokens=50, + ) + completion_text = response.choices[0].message.content + logger.debug("Received completion: %s", completion_text) + return completion_text + except Exception as e: + logger.error("Error during OpenAI call: %s", str(e), exc_info=True) + raise + +async def run_test(): + logger.info("Starting OpenAI Agent Async Test") + agent = OpenAIAgent() + prompt = "Explain the theory of relativity in simple terms." + + try: + completion = await agent.get_completion(prompt) + print(f"OpenAI Completion:\n{completion}") + except Exception as e: + print(f"An error occurred: {str(e)}") + + # Retrieve and display traces from the event store + logger.info("Retrieving system traces from event store") + traces: List[SystemTrace] = event_store.get_system_traces() + print("\nRetrieved System Traces:") + for trace in traces: + print(json.dumps(trace.to_dict(), indent=2)) + +if __name__ == "__main__": + asyncio.run(run_test()) diff --git a/testing/traces_test.py b/testing/traces_test.py index e7c6eb9..de2ad5f 100644 --- a/testing/traces_test.py +++ b/testing/traces_test.py @@ -4,7 +4,10 @@ from synth_sdk.tracing.upload import upload from synth_sdk.tracing.upload import validate_json from synth_sdk.tracing.upload import createPayload -from synth_sdk.tracing.abstractions import TrainingQuestion, RewardSignal, Dataset, SystemTrace, EventPartitionElement +from synth_sdk.tracing.abstractions import ( + TrainingQuestion, RewardSignal, Dataset, SystemTrace, EventPartitionElement, + MessageInputs, MessageOutputs, ArbitraryInputs, ArbitraryOutputs +) from synth_sdk.tracing.events.store import event_store from typing import Dict, List import asyncio @@ -104,8 +107,8 @@ class TestAgent: def __init__(self): self.system_id = "test_agent_upload" logger.debug("Initializing TestAgent with system_id: %s", self.system_id) - #self.lm = LM(model_name="gpt-4o-mini-2024-07-18", formatting_model_name="gpt-4o-mini-2024-07-18", temperature=1,) self.lm = MagicMock() + self.lm.model_name = "gpt-4o-mini-2024-07-18" self.lm.respond_sync.return_value = mock_llm_response logger.debug("LM initialized") @@ -116,18 +119,30 @@ def __init__(self): increment_partition=True, verbose=False, ) - def make_lm_call(self, user_message: str) -> str: # Calls an LLM to respond to a user message - # Only pass the user message, not self - SynthTracker.track_input([user_message], variable_name="user_message", origin="agent") + def make_lm_call(self, user_message: str) -> str: + # Create MessageInputs + message_input = MessageInputs(messages=[{"role": "user", "content": user_message}]) + SynthTracker.track_lm( + messages=message_input.messages, + model_name=self.lm.model_name, + finetune=False + ) logger.debug("Starting LM call with message: %s", user_message) response = self.lm.respond_sync( system_message="You are a helpful assistant.", user_message=user_message ) - SynthTracker.track_output(response, variable_name="response", origin="agent") + + # Create MessageOutputs + message_output = MessageOutputs(messages=[{"role": "assistant", "content": response}]) + SynthTracker.track_state( + variable_name="response", + variable_value=message_output.messages, + origin="agent", + annotation="LLM response" + ) logger.debug("LM response received: %s", response) - #time.sleep(0.1) return response @trace_system( @@ -137,10 +152,25 @@ def make_lm_call(self, user_message: str) -> str: # Calls an LLM to respond to a verbose=False, ) def process_environment(self, input_data: str) -> dict: - # Only pass the input data, not self - SynthTracker.track_input([input_data], variable_name="input_data", origin="environment") + # Create ArbitraryInputs + arbitrary_input = ArbitraryInputs(inputs={"input_data": input_data}) + SynthTracker.track_state( + variable_name="input_data", + variable_value=arbitrary_input.inputs, + origin="environment", + annotation="Environment input data" + ) + result = {"processed": input_data, "timestamp": time.time()} - SynthTracker.track_output(result, variable_name="result", origin="environment") + + # Create ArbitraryOutputs + arbitrary_output = ArbitraryOutputs(outputs=result) + SynthTracker.track_state( + variable_name="result", + variable_value=arbitrary_output.outputs, + origin="environment", + annotation="Environment processing result" + ) return result