diff --git a/ersilia/core/tracking.py b/ersilia/core/tracking.py index c811e59d4..6df065bd8 100644 --- a/ersilia/core/tracking.py +++ b/ersilia/core/tracking.py @@ -1,6 +1,11 @@ from datetime import datetime import json import pandas as pd +import tracemalloc +import tempfile +import logging +import boto3 +from botocore.exceptions import ClientError import os PERSISTENT_FILE_PATH = os.path.abspath("current_session.txt") @@ -57,6 +62,39 @@ def close_persistent_file(): log_files_metrics(TEMP_FILE_LOGS) +def upload_to_s3(json_dict, bucket="t4sg-ersilia", object_name=None): + """Upload a file to an S3 bucket + + :param json_dict: JSON object to upload + :param bucket: Bucket to upload to + :param object_name: S3 object name. If not specified then we generate a name based on the timestamp and model id. + :return: True if file was uploaded, else False + """ + + # If S3 object_name was not specified, use file_name + if object_name is None: + object_name = ( + datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + "-" + json_dict["model_id"] + ) + + # Dump JSON into a temporary file to upload + json_str = json.dumps(json_dict, indent=4) + tmp = tempfile.NamedTemporaryFile() + + with open(tmp.name, "w") as f: + f.write(json_str) + f.flush() + + # Upload the file + s3_client = boto3.client("s3") + try: + s3_client.upload_file(tmp.name, bucket, f"{object_name}.json") + except ClientError as e: + logging.error(e) + return False + return True + + class RunTracker: """ This class will be responsible for tracking model runs. It calculates the desired metadata based on a model's @@ -67,10 +105,13 @@ class RunTracker: def __init__(self): self.time_start = None + self.memory_usage_start = 0 # function to be called before model is run def start_tracking(self): self.time_start = datetime.now() + tracemalloc.start() + self.memory_usage_start = tracemalloc.get_traced_memory()[0] def sample_df(self, df, num_rows, num_cols): """ @@ -115,6 +156,38 @@ def get_file_sizes(self, input_df, output_df): "avg_output_size": output_avg_row_size, } + def check_types(self, resultDf, metadata): + typeDict = {"float64": "Float", "int64": "Int"} + count = 0 + + # ignore key and input columns + dtypesLst = resultDf.loc[:, ~resultDf.columns.isin(["key", "input"])].dtypes + + for i in dtypesLst: + if typeDict[str(i)] != metadata["Output Type"][0]: + count += 1 + + if len(dtypesLst) > 1 and metadata["Output Shape"] != "List": + print("Not right shape. Expected List but got Single") + correct_shape = False + elif len(dtypesLst) == 1 and metadata["Output Shape"] != "Single": + print("Not right shape. Expected Single but got List") + correct_shape = False + else: + print("Output is correct shape.") + correct_shape = True + + print("Output has", count, "mismatched types.\n") + + return {"mismatched_types": count, "correct_shape": correct_shape} + + def get_peak_memory(self): + # Compare memory between peak and amount when we started + peak_memory = tracemalloc.get_traced_memory()[1] - self.memory_usage_start + tracemalloc.stop() + + return peak_memory + def track(self, input, result, meta): """ Tracks the results after a model run. @@ -144,33 +217,11 @@ def track(self, input, result, meta): json_dict["file_sizes"] = self.get_file_sizes(input_dataframe, result_dataframe) - json_object = json.dumps(json_dict, indent=4) - print("\nJSON Dictionary:\n", json_object) + json_dict["peak_memory_use"] = self.get_peak_memory() # log results to persistent tracking file + json_object = json.dumps(json_dict, indent=4) write_persistent_file(json_object) - def check_types(self, resultDf, metadata): - typeDict = {"float64": "Float", "int64": "Int"} - count = 0 - - # ignore key and input columns - dtypesLst = resultDf.loc[:, ~resultDf.columns.isin(["key", "input"])].dtypes - - for i in dtypesLst: - if typeDict[str(i)] != metadata["Output Type"][0]: - count += 1 - - if len(dtypesLst) > 1 and metadata["Output Shape"] != "List": - print("Not right shape. Expected List but got Single") - correct_shape = False - elif len(dtypesLst) == 1 and metadata["Output Shape"] != "Single": - print("Not right shape. Expected Single but got List") - correct_shape = False - else: - print("Output is correct shape.") - correct_shape = True - - print("Output has", count, "mismatched types.\n") - - return {"mismatched_types": count, "correct_shape": correct_shape} + # Upload run stats to s3 + upload_to_s3(json_dict)