Skip to content

Commit

Permalink
made upload unit test much faster
Browse files Browse the repository at this point in the history
  • Loading branch information
DoKu88 committed Nov 6, 2024
1 parent db10856 commit a71f947
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 33 deletions.
15 changes: 9 additions & 6 deletions synth_sdk/tracing/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ def validate_json(data: dict) -> None:
except (TypeError, OverflowError) as e:
raise ValueError(f"Contains non-JSON-serializable values: {e}. {data}")

def create_payload(dataset: Dataset, traces: str) -> Dict[str, Any]:
payload = {
"traces": [
trace.to_dict() for trace in traces
], # Convert SystemTrace objects to dicts
"dataset": dataset.to_dict(),
}
return payload

def send_system_traces(
dataset: Dataset, base_url: str, api_key: str
Expand All @@ -43,12 +51,7 @@ def send_system_traces(
# Send the traces with the token
api_url = f"{base_url}/upload/"

payload = {
"traces": [
trace.to_dict() for trace in traces
], # Convert SystemTrace objects to dicts
"dataset": dataset.to_dict(),
}
payload = create_payload(dataset, traces) # Create the payload

validate_json(payload) # Validate the entire payload

Expand Down
62 changes: 35 additions & 27 deletions testing/upload_sync_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from zyk import LM
from synth_sdk.tracing.decorators import trace_system_sync, _local
from synth_sdk.tracing.trackers import SynthTrackerSync
from synth_sdk.tracing.upload import upload
from synth_sdk.tracing.upload import upload, create_payload
from synth_sdk.tracing.abstractions import TrainingQuestion, RewardSignal, Dataset
from synth_sdk.tracing.events.store import event_store
from typing import Dict
Expand All @@ -11,6 +11,8 @@
#import json
import logging
import pytest
from unittest.mock import MagicMock, Mock, patch
import requests

# Configure logging
logging.basicConfig(
Expand All @@ -19,19 +21,40 @@
)
logger = logging.getLogger(__name__)

questions = ["What's the capital of France?",
"What's 2+2?",
"Who wrote Romeo and Juliet?",]
# Unit Test Configuration:
# ===============================
questions = ["What's the capital of France?"]
mock_llm_response = "The capital of France is Paris."

# This function generates a payload from the data in the dataset to compare the sent payload against
def generate_payload_from_data(dataset: Dataset) -> Dict:
traces = event_store.get_system_traces()

payload = {
"traces": [
trace.to_dict() for trace in traces
], # Convert SystemTrace objects to dicts
"dataset": dataset.to_dict(),
}
return payload

def create_payload_wrapper(dataset: Dataset, base_url: str, api_key: str) -> Dict:
payload = create_payload(dataset, event_store.get_system_traces())

response = requests.Response()
response.status_code = 200

return response, payload

# ===============================

class TestAgent:
def __init__(self):
self.system_id = "test_agent_upload"
logger.debug("Initializing TestAgent with system_id: %s", self.system_id)
self.lm = LM(
model_name="gpt-4o-mini-2024-07-18",
formatting_model_name="gpt-4o-mini-2024-07-18",
temperature=1,
)
#self.lm = LM(model_name="gpt-4o-mini-2024-07-18", formatting_model_name="gpt-4o-mini-2024-07-18", temperature=1,)
self.lm = MagicMock()
self.lm.respond_sync.return_value = mock_llm_response
logger.debug("LM initialized")

@trace_system_sync(
Expand All @@ -49,7 +72,6 @@ def make_lm_call(self, user_message: str) -> str: # Calls an LLM to respond to a
response = self.lm.respond_sync(
system_message="You are a helpful assistant.", user_message=user_message
)

SynthTrackerSync.track_output(response, variable_name="response", origin="agent")

logger.debug("LM response received: %s", response)
Expand All @@ -65,26 +87,13 @@ def make_lm_call(self, user_message: str) -> str: # Calls an LLM to respond to a
def process_environment(self, input_data: str) -> dict:
# Only pass the input data, not self
SynthTrackerSync.track_input([input_data], variable_name="input_data", origin="environment")

result = {"processed": input_data, "timestamp": time.time()}

SynthTrackerSync.track_output(result, variable_name="result", origin="environment")
return result

# This function generates a payload from the data in the dataset to compare the sent payload against
def generate_payload_from_data(self, dataset: Dataset) -> Dict:
traces = event_store.get_system_traces()

payload = {
"traces": [
trace.to_dict() for trace in traces
], # Convert SystemTrace objects to dicts
"dataset": dataset.to_dict(),
}
return payload

@pytest.mark.asyncio
async def test_upload():
@patch("synth_sdk.tracing.upload.send_system_traces", side_effect=create_payload_wrapper)
async def test_upload(mock_send_system_traces):
logger.info("Starting run_test")
agent = TestAgent() # Create test agent

Expand Down Expand Up @@ -132,7 +141,6 @@ async def test_upload():
logger.info("Attempting to upload traces")
response, payload = await upload(dataset=dataset, verbose=True)
logger.info("Upload successful!")
logger.info("Payload sent to server:")

# Pytest assertion
assert payload == agent.generate_payload_from_data(dataset)
assert payload == generate_payload_from_data(dataset)

0 comments on commit a71f947

Please sign in to comment.