From 4eba13cd0f6a6c10d2b1a01cd7da7c6d7a859b05 Mon Sep 17 00:00:00 2001 From: Josh Purtell Date: Sun, 15 Dec 2024 20:06:17 -0800 Subject: [PATCH] save --- pyproject.toml | 7 +-- setup.py | 2 +- synth_sdk/tracing/abstractions.py | 6 ++ synth_sdk/tracing/events/store.py | 1 + synth_sdk/tracing/upload.py | 101 +++++++++++++++++++----------- 5 files changed, 75 insertions(+), 42 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6edd26f..ea05531 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "synth-sdk" -version = "0.2.84" +version = "0.2.93" description = "" authors = [{name = "Synth AI", email = "josh@usesynth.ai"}] license = {text = "MIT"} @@ -12,7 +12,7 @@ dependencies = [ "pydantic", "requests", "asyncio", - "zyk==0.2.21", + "zyk>=0.2.24", "build>=1.2.2.post1", "pypi", "twine>=4.0.0", @@ -22,8 +22,7 @@ dependencies = [ "pytest>=8.3.3", "pydantic-openapi-schema>=1.5.1", "pytest-asyncio>=0.24.0", - "apropos-ai>=0.4.5", - "craftaxlm>=0.0.5", + "craftaxlm>=0.0.7", ] classifiers = [] diff --git a/setup.py b/setup.py index 891e383..5c47b73 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name="synth-sdk", - version="0.2.84", + version="0.2.93", packages=find_packages(), install_requires=[ "opentelemetry-api", diff --git a/synth_sdk/tracing/abstractions.py b/synth_sdk/tracing/abstractions.py index 249ff91..c794cfa 100644 --- a/synth_sdk/tracing/abstractions.py +++ b/synth_sdk/tracing/abstractions.py @@ -72,6 +72,11 @@ class AgentComputeStep(ComputeStep): compute_input: List[Union[MessageInputs, ArbitraryInputs]] compute_output: List[Union[MessageOutputs, ArbitraryOutputs]] + def to_dict(self): + base_dict = super().to_dict() # Get the parent class serialization + base_dict["model_name"] = self.model_name # Add model_name + return base_dict + @dataclass class EnvironmentComputeStep(ComputeStep): @@ -128,6 +133,7 @@ def to_dict(self): "system_id": self.system_id, "partition": [element.to_dict() for element in self.partition], "current_partition_index": self.current_partition_index, + "metadata": self.metadata if self.metadata else None } diff --git a/synth_sdk/tracing/events/store.py b/synth_sdk/tracing/events/store.py index 2893410..7972b33 100644 --- a/synth_sdk/tracing/events/store.py +++ b/synth_sdk/tracing/events/store.py @@ -31,6 +31,7 @@ def _get_or_create(): #logger.debug(f"Creating new system trace for {system_id}") self._traces[system_id] = SystemTrace( system_id=system_id, + metadata={}, partition=[EventPartitionElement(partition_index=0, events=[])], current_partition_index=0, ) diff --git a/synth_sdk/tracing/upload.py b/synth_sdk/tracing/upload.py index 0b9dc2a..1dd9ce7 100644 --- a/synth_sdk/tracing/upload.py +++ b/synth_sdk/tracing/upload.py @@ -1,31 +1,33 @@ -from typing import List, Dict, Any, Union, Tuple, Coroutine -from pydantic import BaseModel, validator -import synth_sdk.config.settings -import requests +import asyncio +import json import logging import os import time -from synth_sdk.tracing.events.store import event_store -from synth_sdk.tracing.abstractions import Dataset, SystemTrace -import json from pprint import pprint -import asyncio +from typing import Any, Dict, List + +import requests +from pydantic import BaseModel, validator + +from synth_sdk.tracing.abstractions import Dataset, SystemTrace +from synth_sdk.tracing.events.store import event_store def validate_json(data: dict) -> None: - #Validate that a dictionary contains only JSON-serializable values. + # Validate that a dictionary contains only JSON-serializable values. - #Args: + # Args: # data: Dictionary to validate for JSON serialization - #Raises: + # Raises: # ValueError: If the dictionary contains non-serializable values - + try: json.dumps(data) except (TypeError, OverflowError) as e: raise ValueError(f"Contains non-JSON-serializable values: {e}. {data}") + def createPayload(dataset: Dataset, traces: List[SystemTrace]) -> Dict[str, Any]: payload = { "traces": [ @@ -35,8 +37,12 @@ def createPayload(dataset: Dataset, traces: List[SystemTrace]) -> Dict[str, Any] } return payload + def send_system_traces( - dataset: Dataset, traces: List[SystemTrace], base_url: str, api_key: str, + dataset: Dataset, + traces: List[SystemTrace], + base_url: str, + api_key: str, ): # Send all system traces and dataset metadata to the server. # Get the token using the API key @@ -50,7 +56,7 @@ def send_system_traces( # Send the traces with the token api_url = f"{base_url}/v1/uploads/" - payload = createPayload(dataset, traces) # Create the payload + payload = createPayload(dataset, traces) # Create the payload validate_json(payload) # Validate the entire payload @@ -91,6 +97,11 @@ def validate_traces(cls, traces): if "partition" not in trace: raise ValueError("Each trace must have a partition") + # Validate metadata if present + if "metadata" in trace and trace["metadata"] is not None: + if not isinstance(trace["metadata"], dict): + raise ValueError("Metadata must be a dictionary") + # Validate partition structure partition = trace["partition"] if not isinstance(partition, list): @@ -157,8 +168,8 @@ def validate_dataset(cls, dataset): def validate_upload(traces: List[Dict[str, Any]], dataset: Dict[str, Any]): - #Validate the upload format before sending to server. - #Raises ValueError if validation fails. + # Validate the upload format before sending to server. + # Raises ValueError if validation fails. try: UploadValidator(traces=traces, dataset=dataset) return True @@ -174,47 +185,55 @@ def is_event_loop_running(): # This exception is raised if no event loop is running return False + def format_upload_output(dataset, traces): # Format questions array questions_data = [ - { - "intent": q.intent, - "criteria": q.criteria, - "question_id": q.question_id - } for q in dataset.questions + {"intent": q.intent, "criteria": q.criteria, "question_id": q.question_id} + for q in dataset.questions ] - + # Format reward signals array with error handling reward_signals_data = [ { "system_id": rs.system_id, "reward": rs.reward, "question_id": rs.question_id, - "annotation": rs.annotation if hasattr(rs, 'annotation') else None - } for rs in dataset.reward_signals + "annotation": rs.annotation if hasattr(rs, "annotation") else None, + } + for rs in dataset.reward_signals ] - + # Format traces array traces_data = [ { "system_id": t.system_id, + "metadata": t.metadata if t.metadata else None, "partition": [ { "partition_index": p.partition_index, - "events": [e.to_dict() for e in p.events] - } for p in t.partition - ] - } for t in traces + "events": [e.to_dict() for e in p.events], + } + for p in t.partition + ], + } + for t in traces ] return questions_data, reward_signals_data, traces_data + # Supports calls from both async and sync contexts -def upload(dataset: Dataset, traces: List[SystemTrace]=[], verbose: bool = False, show_payload: bool = False): +def upload( + dataset: Dataset, + traces: List[SystemTrace] = [], + verbose: bool = False, + show_payload: bool = False, +): """Upload all system traces and dataset to the server. Returns a tuple of (response, questions_json, reward_signals_json, traces_json) Note that you can directly upload questions, reward_signals, and traces to the server using the Website - + response is the response from the server. questions_json is the formatted questions array reward_signals_json is the formatted reward signals array @@ -222,7 +241,13 @@ def upload(dataset: Dataset, traces: List[SystemTrace]=[], verbose: bool = False return upload_helper(dataset, traces, verbose, show_payload) -def upload_helper(dataset: Dataset, traces: List[SystemTrace]=[], verbose: bool = False, show_payload: bool = False): + +def upload_helper( + dataset: Dataset, + traces: List[SystemTrace] = [], + verbose: bool = False, + show_payload: bool = False, +): api_key = os.getenv("SYNTH_API_KEY") if not api_key: raise ValueError("SYNTH_API_KEY environment variable not set") @@ -245,8 +270,8 @@ def upload_helper(dataset: Dataset, traces: List[SystemTrace]=[], verbose: bool # Also close any unclosed events in existing traces logged_traces = event_store.get_system_traces() - traces = logged_traces+ traces - #traces = event_store.get_system_traces() if len(traces) == 0 else traces + traces = logged_traces + traces + # traces = event_store.get_system_traces() if len(traces) == 0 else traces current_time = time.time() for trace in traces: for partition in trace.partition: @@ -291,10 +316,12 @@ def upload_helper(dataset: Dataset, traces: List[SystemTrace]=[], verbose: bool print("Payload sent to server: ") pprint(payload) - #return response, payload, dataset, traces - questions_json, reward_signals_json, traces_json = format_upload_output(dataset, traces) + # return response, payload, dataset, traces + questions_json, reward_signals_json, traces_json = format_upload_output( + dataset, traces + ) return response, questions_json, reward_signals_json, traces_json - + except ValueError as e: if verbose: print("Validation error:", str(e))