Skip to content

Commit

Permalink
Merge pull request #14 from synth-laboratories/wasabi_upload
Browse files Browse the repository at this point in the history
Wasabi upload
  • Loading branch information
DoKu88 authored Dec 16, 2024
2 parents f734745 + 2f05ef2 commit 6d328a1
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 93 deletions.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ dependencies = [
"pytest-asyncio>=0.24.0",
"apropos-ai>=0.4.5",
"craftaxlm>=0.0.5",
"boto3>=1.35.71",
"botocore>=1.35.71",
"tqdm>=4.66.4"
]
classifiers = []

Expand Down
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,6 @@ alembic>=1.13.3
zyk>=0.2.10
#synth_sdk>=0.2.61
#smallbench>=0.2.11
boto3>=1.35.71
botocore>=1.35.71
tqdm>=4.66.4
188 changes: 95 additions & 93 deletions synth_sdk/tracing/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
import asyncio
import sys
from pympler import asizeof
from tqdm import tqdm
import boto3
from datetime import datetime


# NOTE: This may cause memory issues in the future
def validate_json(data: dict) -> None:
#Validate that a dictionary contains only JSON-serializable values.

Expand All @@ -37,70 +40,100 @@ def createPayload(dataset: Dataset, traces: List[SystemTrace]) -> Dict[str, Any]
}
return payload

async def send_system_traces(
dataset: Dataset, traces: List[SystemTrace], base_url: str, api_key: str, upload_id: str
):
# Send all system traces and dataset metadata to the server.
# Get the token using the API key
token_url = f"{base_url}/v1/auth/token"
token_response = requests.get(token_url, headers={"customer_specific_api_key": api_key})
token_response.raise_for_status()
access_token = token_response.json()["access_token"]

# Send the traces with the token
api_url = f"{base_url}/v1/uploads/{upload_id}"

payload = createPayload(dataset, traces) # Create the payload
async def send_system_traces_s3(dataset: Dataset, traces: List[SystemTrace]):
# 1. Create S3 client
s3_client = boto3.client(
"s3",
endpoint_url="https://s3.wasabisys.com",
aws_access_key_id=os.getenv("WASABI_ACCESS_KEY"),
aws_secret_access_key=os.getenv("WASABI_SECRET_KEY"),
)

# 2. Create and validate payload
payload = createPayload(dataset, traces)
validate_json(payload)

# 3. Create bucket path with datetime
bucket_name = os.getenv("WASABI_BUCKET_NAME")
current_time = datetime.now().strftime("%Y_%m_%d_%H%M%S")
bucket_path = f"uploads/upload_{current_time}.json"

# 4. Upload payload to Wasabi
s3_client.put_object(
Bucket=bucket_name,
Key=bucket_path,
Body=json.dumps(payload),
)

# 5. Generate a signed URL
signed_url = s3_client.generate_presigned_url(
'get_object',
Params={
'Bucket': bucket_name,
'Key': bucket_path
},
ExpiresIn=14400 # URL expires in 4 hours
)

return {
'bucket_path': bucket_path,
'signed_url': signed_url
}

validate_json(payload) # Validate the entire payload
def send_system_traces_s3_wrapper(dataset: Dataset, traces: List[SystemTrace], base_url: str, api_key: str):
# Create async function that contains all async operations
async def _async_operations():

memory_size = asizeof.asizeof(payload) / 1024 # Memory size in KB
logging.info(f"Payload size (in memory): {memory_size:.2f} KB")
result = await send_system_traces_s3(dataset, traces)
bucket_path, signed_url = result['bucket_path'], result['signed_url']

headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {access_token}"
}

try:
response = requests.post(api_url, json=payload, headers=headers)
response.raise_for_status()
logging.info(f"Response status code: {response.status_code}")
logging.info(f"Upload ID: {response.json().get('upload_id')}")
return response, payload
except requests.exceptions.HTTPError as http_err:
logging.error(
f"HTTP error occurred: {http_err} - Response Content: {response.text}"
)
raise
except Exception as err:
logging.error(f"An error occurred: {err}")
raise
upload_id = await get_upload_id(base_url, api_key)

def chunk_traces(traces: List[SystemTrace], chunk_size_kb: int = 1024):
"""Split traces into chunks that won't exceed approximately chunk_size_kb when serialized"""
chunks = []
current_chunk = []
current_size = 0

for trace in traces:
trace_dict = trace.to_dict()
trace_size = asizeof.asizeof(trace_dict) / 1024 # Memory size in KB
logging.info(f"Trace size (in memory): {trace_size:.2f} KB")

if current_size + trace_size > chunk_size_kb:
# Current chunk would exceed size limit, start new chunk
chunks.append(current_chunk)
current_chunk = [trace]
current_size = trace_size
else:
current_chunk.append(trace)
current_size += trace_size

if current_chunk:
chunks.append(current_chunk)

return chunks
token_url = f"{base_url}/v1/auth/token"
token_response = requests.get(token_url, headers={"customer_specific_api_key": api_key})
token_response.raise_for_status()
access_token = token_response.json()["access_token"]

api_url = f"{base_url}/v1/uploads/process-upload/{upload_id}"
data = {"signed_url": signed_url}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {access_token}",
}

try:
response = requests.post(api_url, headers=headers, json=data)
response.raise_for_status()

upload_id = response.json()["upload_id"]
signed_url = response.json()["signed_url"]
status = response.json()["status"]

print(f"Status: {status}")
print(f"Upload ID retrieved: {upload_id}")
print(f"Signed URL: {signed_url}")

return upload_id, signed_url
except requests.exceptions.HTTPError as e:
logging.error(f"HTTP error occurred: {e}")
raise
except Exception as e:
logging.error(f"An error occurred: {e}")
raise

# Run the async operations in an event loop
if not is_event_loop_running():
# If no event loop is running, create a new one
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(_async_operations())
finally:
loop.close()
else:
# If an event loop is already running, use it
loop = asyncio.get_event_loop()
return loop.run_until_complete(_async_operations())

async def get_upload_id(base_url: str, api_key: str):
token_url = f"{base_url}/v1/auth/token"
Expand All @@ -127,36 +160,6 @@ async def get_upload_id(base_url: str, api_key: str):
logging.error(f"An error occurred: {e}")
raise

def send_system_traces_chunked(dataset: Dataset, traces: List[SystemTrace],
base_url: str, api_key: str, chunk_size_kb: int = 1024):
"""Upload traces in chunks to avoid memory issues"""

async def _async_upload():
trace_chunks = chunk_traces(traces, chunk_size_kb)
upload_id = await get_upload_id(base_url, api_key)

tasks = []
for chunk in trace_chunks:
task = send_system_traces(dataset, chunk, base_url, api_key, upload_id)
tasks.append(task)

results = await asyncio.gather(*tasks)
return results[0] if results else (None, None) # Return first result or None tuple

# Handle the event loop
try:
if not is_event_loop_running():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop.run_until_complete(_async_upload())
else:
loop = asyncio.get_event_loop()
return loop.run_until_complete(_async_upload())
finally:
# Only close the loop if we created it
if 'loop' in locals() and not is_event_loop_running():
loop.close()

class UploadValidator(BaseModel):
traces: List[Dict[str, Any]]
dataset: Dict[str, Any]
Expand Down Expand Up @@ -354,7 +357,7 @@ def upload_helper(dataset: Dataset, traces: List[SystemTrace]=[], verbose: bool
print("Upload format validation successful")

# Send to server
response, payload = send_system_traces_chunked(
response, payload = send_system_traces_s3_wrapper(
dataset=dataset,
traces=traces,
base_url="https://agent-learning.onrender.com",
Expand All @@ -373,7 +376,6 @@ def upload_helper(dataset: Dataset, traces: List[SystemTrace]=[], verbose: bool
print("Payload sent to server: ")
pprint(payload)

#return response, payload, dataset, traces
questions_json, reward_signals_json, traces_json = format_upload_output(dataset, traces)
return response, questions_json, reward_signals_json, traces_json

Expand Down

0 comments on commit 6d328a1

Please sign in to comment.