Skip to content

Commit

Permalink
Update documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
AC-Dap committed Dec 21, 2023
1 parent a58f7e8 commit 26d9bd4
Showing 1 changed file with 54 additions and 36 deletions.
90 changes: 54 additions & 36 deletions ersilia/core/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import re

PERSISTENT_FILE_PATH = os.path.abspath("current_session.txt")
# Temporary path to log files
# Temporary path to log files until log files are fixed
TEMP_FILE_LOGS = os.path.abspath("")


Expand All @@ -21,17 +21,17 @@ def create_csv(output_df):
file has two columns: the first column is the input molecules and the
second column is the ISO-formatted time of the run.
:param file: The output dataframe from the model run
:param output_df: The output dataframe from the model run
:return: A new temporary csv file
"""

new_df = output_df[['input']].copy()
new_df = output_df[["input"]].copy()
current_time = datetime.now().isoformat()

new_df['time'] = current_time
new_df["time"] = current_time
csv_file = tempfile.NamedTemporaryFile(mode="w", suffix=".csv")
new_df.to_csv(csv_file.name, index=False)

return csv_file


Expand All @@ -43,7 +43,6 @@ def log_files_metrics(file):
:return: None (writes to file)
"""


error_count = 0
warning_count = 0

Expand Down Expand Up @@ -88,7 +87,7 @@ def log_files_metrics(file):
else:
# other errors are pretty self-descriptive and short. Will cap by character
misc_error_flag = True
error_name = line.split('| ERROR | ')[1].rstrip()
error_name = line.split("| ERROR | ")[1].rstrip()
elif "| WARNING" in line:
warning_count += 1
if line is not None:
Expand All @@ -109,29 +108,32 @@ def log_files_metrics(file):
logging.warning("Log file not found")


def read_csv(file):
# reads csv file and returns Pandas dataframe
return pd.read_csv(file)


def read_json(result):
data = json.load(result)
return data


def open_persistent_file(model_id):
"""
Opens a new persistent file, specifically for a run of model_id
:param model_id: The currently running model
"""
with open(PERSISTENT_FILE_PATH, "w") as f:
f.write("Session started for model: {0}\n".format(model_id))


def write_persistent_file(contents):
"""
Writes contents to the current persistent file. Only writes if the file actually exists.
:param contents: The contents to write to the file.
"""

# Only write to file if it already exists (we're meant to be tracking this run)
if os.path.isfile(PERSISTENT_FILE_PATH):
with open(PERSISTENT_FILE_PATH, "a") as f:
f.write(f"{contents}\n")


def close_persistent_file():
"""
Closes the persistent file, renaming it to a unique name.
"""

# Make sure the file actually exists before we try renaming
if os.path.isfile(PERSISTENT_FILE_PATH):
log_files_metrics(TEMP_FILE_LOGS)
Expand Down Expand Up @@ -171,7 +173,9 @@ def upload_to_s3(json_dict, bucket="t4sg-ersilia", object_name=None):
try:
s3_client.upload_file(tmp.name, bucket, f"{object_name}.json")
except NoCredentialsError:
logging.error("Unable to upload tracking data to AWS: Credentials not found")
logging.error(
"Unable to upload tracking data to AWS: Credentials not found"
)
except ClientError as e:
logging.error(e)
return False
Expand All @@ -181,9 +185,8 @@ def upload_to_s3(json_dict, bucket="t4sg-ersilia", object_name=None):
class RunTracker:
"""
This class will be responsible for tracking model runs. It calculates the desired metadata based on a model's
inputs, outputs, and other run-specific features, before uploading them to Ersilia's Splunk dashboard.
NOTE: Currently, the Splunk connection is not set up. For now, we will print tracking results to the console.
inputs, outputs, and other run-specific features, before uploading them to AWS to be ingested
to Ersilia's Splunk dashboard.
"""

def __init__(self):
Expand All @@ -192,6 +195,10 @@ def __init__(self):

# function to be called before model is run
def start_tracking(self):
"""
Runs any code necessary for the beginning of the run.
Currently necessary for tracking the runtime and memory usage of a run.
"""
self.time_start = datetime.now()
tracemalloc.start()
self.memory_usage_start = tracemalloc.get_traced_memory()[0]
Expand All @@ -202,10 +209,16 @@ def sample_df(self, df, num_rows, num_cols):
"""
return df.sample(num_rows, axis=0).sample(num_cols, axis=1)

# Stats function: calculates the basic statistics of the output file from a model. This includes the
# mode (if applicable), minimum, maximum, and standard deviation.
def stats(self, result):
dat = read_csv(result)
"""
Stats function: calculates the basic statistics of the output file from a model. This includes the
mode (if applicable), minimum, maximum, and standard deviation.
:param result: The path to the model's output file.
:return: A dictionary containing the stats for each column of the result.
"""

dat = pd.read_csv(result)

# drop first two columns (key, input)
dat = dat.drop(["key", "input"], axis=1)
Expand Down Expand Up @@ -248,42 +261,47 @@ def get_file_sizes(self, input_df, output_df):
"avg_output_size": output_avg_row_size,
}

def check_types(self, resultDf, metadata):
def check_types(self, result_df, metadata):
"""
This class is responsible for checking the types of the output dataframe against the expected types.
This includes checking the shape of the output dataframe (list vs single) and the types of each column.
:param resultDf: The output dataframe
:param result_df: The output dataframe
:param metadata: The metadata dictionary
:return: A dictionary containing the number of mismatched types and a boolean for whether the shape is correct
"""

typeDict = {"float64": "Float", "int64": "Int"}
type_dict = {"float64": "Float", "int64": "Int"}
count = 0

# ignore key and input columns
dtypesLst = resultDf.loc[:, ~resultDf.columns.isin(["key", "input"])].dtypes
dtypes_list = result_df.loc[:, ~result_df.columns.isin(["key", "input"])].dtypes

for i in dtypesLst:
if typeDict[str(i)] != metadata["Output Type"][0]:
for i in dtypes_list:
if type_dict[str(i)] != metadata["Output Type"][0]:
count += 1

if len(dtypesLst) > 1 and metadata["Output Shape"] != "List":
if len(dtypes_list) > 1 and metadata["Output Shape"] != "List":
print("Not right shape. Expected List but got Single")
correct_shape = False
elif len(dtypesLst) == 1 and metadata["Output Shape"] != "Single":
elif len(dtypes_list) == 1 and metadata["Output Shape"] != "Single":
print("Not right shape. Expected Single but got List")
correct_shape = False
else:
print("Output is correct shape.")
correct_shape = True
print(resultDf)

print(result_df)
print("Output has", count, "mismatched types.\n")

return {"mismatched_types": count, "correct_shape": correct_shape}

def get_peak_memory(self):
"""
Calculates the peak memory usage of ersilia's Python instance during the run.
:return: The peak memory usage in bytes.
"""

# Compare memory between peak and amount when we started
peak_memory = tracemalloc.get_traced_memory()[1] - self.memory_usage_start
tracemalloc.stop()
Expand All @@ -295,8 +313,8 @@ def track(self, input, result, meta):
Tracks the results after a model run.
"""
json_dict = {}
input_dataframe = read_csv(input)
result_dataframe = read_csv(result)
input_dataframe = pd.read_csv(input)
result_dataframe = pd.read_csv(result)

json_dict["input_dataframe"] = input_dataframe.to_dict()
json_dict["result_dataframe"] = result_dataframe.to_dict()
Expand Down

0 comments on commit 26d9bd4

Please sign in to comment.