Skip to content

Commit

Permalink
Refactor persistent tracking code
Browse files Browse the repository at this point in the history
  • Loading branch information
AC-Dap committed Oct 28, 2023
1 parent b0b79a7 commit 48fdd2c
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 24 deletions.
10 changes: 3 additions & 7 deletions ersilia/cli/commands/close.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .. import echo
from ... import ErsiliaModel
from ...core.session import Session
from ...core.tracking import close_persistent_file


def close_cmd():
Expand All @@ -20,10 +21,5 @@ def close():
mdl.close()
echo(":no_entry: Model {0} closed".format(mdl.model_id), fg="green")

# renames current_session to timestamp
old_file_path = "current_session.txt"
new_file_path = os.path.join(
os.path.dirname(old_file_path),
datetime.datetime.now().strftime("%Y%m%d%H%M%S"),
)
os.rename(old_file_path, new_file_path)
# Close our persistent tracking file
close_persistent_file()
6 changes: 4 additions & 2 deletions ersilia/cli/commands/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .. import echo
from ... import ErsiliaModel
from ..messages import ModelNotFound
from ...core.tracking import open_persistent_file


def serve_cmd():
Expand Down Expand Up @@ -62,6 +63,7 @@ def serve(model, lake, docker, port, track_serve):
echo("")
echo(":person_tipping_hand: Information:", fg="blue")
echo(" - info", fg="blue")

# Setup persistent tracking
if track_serve:
with open("current_session.txt", "w") as f:
f.write("Session started for model: {0}".format(mdl.model_id))
open_persistent_file(mdl.model_id)
51 changes: 36 additions & 15 deletions ersilia/core/tracking.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from datetime import datetime
import json
import pandas as pd
import os

PERSISTENT_FILE_PATH = os.path.abspath("current_session.txt")


def read_csv(file):
Expand All @@ -12,6 +15,29 @@ def read_json(result):
data = json.load(result)
return data


def open_persistent_file(model_id):
with open(PERSISTENT_FILE_PATH, "w") as f:
f.write("Session started for model: {0}\n".format(model_id))


def write_persistent_file(contents):
# Only write to file if it already exists (we're meant to be tracking this run)
if os.path.isfile(PERSISTENT_FILE_PATH):
with open(PERSISTENT_FILE_PATH, "a") as f:
f.write(f"{contents}\n")


def close_persistent_file():
# Make sure the file actually exists before we try renaming
if os.path.isfile(PERSISTENT_FILE_PATH):
new_file_path = os.path.join(
os.path.dirname(PERSISTENT_FILE_PATH),
datetime.now().strftime("%Y-%m-%d%_H-%M-%S.txt"),
)
os.rename(PERSISTENT_FILE_PATH, new_file_path)


class RunTracker:
"""
This class will be responsible for tracking model runs. It calculates the desired metadata based on a model's
Expand Down Expand Up @@ -43,14 +69,14 @@ def stats(self, result):
stats = {}
for column in dat:
column_stats = {}
column_stats['mean'] = dat[column].mean()
column_stats["mean"] = dat[column].mean()
if len(dat[column].mode()) == 1:
column_stats['mode'] = dat[column].mode()
column_stats["mode"] = dat[column].mode().iloc[0]
else:
column_stats['mode'] = None
column_stats['min'] = dat[column].min()
column_stats['max'] = dat[column].max()
column_stats['std'] = dat[column].std()
column_stats["mode"] = None
column_stats["min"] = dat[column].min()
column_stats["max"] = dat[column].max()
column_stats["std"] = dat[column].std()

stats[column] = column_stats

Expand Down Expand Up @@ -97,15 +123,13 @@ def track(self, input, result, meta):

json_dict["stats"] = self.stats(result)

json_dict['file_sizes'] = self.get_file_sizes(input_dataframe, result_dataframe)
json_dict["file_sizes"] = self.get_file_sizes(input_dataframe, result_dataframe)

json_object = json.dumps(json_dict, indent=4)
print("\nJSON Dictionary:\n", json_object)

# log results to console
with open("../cli/commands/current_session.txt", "a") as f:
# write the print statements to a file
f.write(json_object)
# log results to persistent tracking file
write_persistent_file(json_object)

def check_types(self, resultDf, metadata):
typeDict = {"float64": "Float", "int64": "Int"}
Expand All @@ -130,7 +154,4 @@ def check_types(self, resultDf, metadata):

print("Output has", count, "mismatched types.\n")

return {
"mismatched_types": count,
"correct_shape": correct_shape
}
return {"mismatched_types": count, "correct_shape": correct_shape}

0 comments on commit 48fdd2c

Please sign in to comment.