Skip to content

Commit

Permalink
Merge pull request #16 from hcs-t4sg/theresa/memoryUsage
Browse files Browse the repository at this point in the history
Write outputs to JSON and upload to S3 Bucket
  • Loading branch information
AC-Dap authored Nov 14, 2023
2 parents 48fdd2c + d1c2adb commit ae33842
Showing 1 changed file with 77 additions and 26 deletions.
103 changes: 77 additions & 26 deletions ersilia/core/tracking.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from datetime import datetime
import json
import pandas as pd
import tracemalloc
import tempfile
import logging
import boto3
from botocore.exceptions import ClientError
import os

PERSISTENT_FILE_PATH = os.path.abspath("current_session.txt")
Expand Down Expand Up @@ -38,6 +43,39 @@ def close_persistent_file():
os.rename(PERSISTENT_FILE_PATH, new_file_path)


def upload_to_s3(json_dict, bucket="t4sg-ersilia", object_name=None):
"""Upload a file to an S3 bucket
:param json_dict: JSON object to upload
:param bucket: Bucket to upload to
:param object_name: S3 object name. If not specified then we generate a name based on the timestamp and model id.
:return: True if file was uploaded, else False
"""

# If S3 object_name was not specified, use file_name
if object_name is None:
object_name = (
datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + "-" + json_dict["model_id"]
)

# Dump JSON into a temporary file to upload
json_str = json.dumps(json_dict, indent=4)
tmp = tempfile.NamedTemporaryFile()

with open(tmp.name, "w") as f:
f.write(json_str)
f.flush()

# Upload the file
s3_client = boto3.client("s3")
try:
s3_client.upload_file(tmp.name, bucket, f"{object_name}.json")
except ClientError as e:
logging.error(e)
return False
return True


class RunTracker:
"""
This class will be responsible for tracking model runs. It calculates the desired metadata based on a model's
Expand All @@ -48,10 +86,13 @@ class RunTracker:

def __init__(self):
self.time_start = None
self.memory_usage_start = 0

# function to be called before model is run
def start_tracking(self):
self.time_start = datetime.now()
tracemalloc.start()
self.memory_usage_start = tracemalloc.get_traced_memory()[0]

def sample_df(self, df, num_rows, num_cols):
"""
Expand Down Expand Up @@ -96,6 +137,38 @@ def get_file_sizes(self, input_df, output_df):
"avg_output_size": output_avg_row_size,
}

def check_types(self, resultDf, metadata):
typeDict = {"float64": "Float", "int64": "Int"}
count = 0

# ignore key and input columns
dtypesLst = resultDf.loc[:, ~resultDf.columns.isin(["key", "input"])].dtypes

for i in dtypesLst:
if typeDict[str(i)] != metadata["Output Type"][0]:
count += 1

if len(dtypesLst) > 1 and metadata["Output Shape"] != "List":
print("Not right shape. Expected List but got Single")
correct_shape = False
elif len(dtypesLst) == 1 and metadata["Output Shape"] != "Single":
print("Not right shape. Expected Single but got List")
correct_shape = False
else:
print("Output is correct shape.")
correct_shape = True

print("Output has", count, "mismatched types.\n")

return {"mismatched_types": count, "correct_shape": correct_shape}

def get_peak_memory(self):
# Compare memory between peak and amount when we started
peak_memory = tracemalloc.get_traced_memory()[1] - self.memory_usage_start
tracemalloc.stop()

return peak_memory

def track(self, input, result, meta):
"""
Tracks the results after a model run.
Expand Down Expand Up @@ -125,33 +198,11 @@ def track(self, input, result, meta):

json_dict["file_sizes"] = self.get_file_sizes(input_dataframe, result_dataframe)

json_object = json.dumps(json_dict, indent=4)
print("\nJSON Dictionary:\n", json_object)
json_dict["peak_memory_use"] = self.get_peak_memory()

# log results to persistent tracking file
json_object = json.dumps(json_dict, indent=4)
write_persistent_file(json_object)

def check_types(self, resultDf, metadata):
typeDict = {"float64": "Float", "int64": "Int"}
count = 0

# ignore key and input columns
dtypesLst = resultDf.loc[:, ~resultDf.columns.isin(["key", "input"])].dtypes

for i in dtypesLst:
if typeDict[str(i)] != metadata["Output Type"][0]:
count += 1

if len(dtypesLst) > 1 and metadata["Output Shape"] != "List":
print("Not right shape. Expected List but got Single")
correct_shape = False
elif len(dtypesLst) == 1 and metadata["Output Shape"] != "Single":
print("Not right shape. Expected Single but got List")
correct_shape = False
else:
print("Output is correct shape.")
correct_shape = True

print("Output has", count, "mismatched types.\n")

return {"mismatched_types": count, "correct_shape": correct_shape}
# Upload run stats to s3
upload_to_s3(json_dict)

0 comments on commit ae33842

Please sign in to comment.