From 48fdd2ca9d347a6cd5abcc923c6d2cd46816f312 Mon Sep 17 00:00:00 2001 From: Anthony Cui Date: Sat, 28 Oct 2023 15:40:52 -0400 Subject: [PATCH] Refactor persistent tracking code --- ersilia/cli/commands/close.py | 10 +++---- ersilia/cli/commands/serve.py | 6 +++-- ersilia/core/tracking.py | 51 ++++++++++++++++++++++++----------- 3 files changed, 43 insertions(+), 24 deletions(-) diff --git a/ersilia/cli/commands/close.py b/ersilia/cli/commands/close.py index d053ceda4..a87f144ac 100644 --- a/ersilia/cli/commands/close.py +++ b/ersilia/cli/commands/close.py @@ -4,6 +4,7 @@ from .. import echo from ... import ErsiliaModel from ...core.session import Session +from ...core.tracking import close_persistent_file def close_cmd(): @@ -20,10 +21,5 @@ def close(): mdl.close() echo(":no_entry: Model {0} closed".format(mdl.model_id), fg="green") - # renames current_session to timestamp - old_file_path = "current_session.txt" - new_file_path = os.path.join( - os.path.dirname(old_file_path), - datetime.datetime.now().strftime("%Y%m%d%H%M%S"), - ) - os.rename(old_file_path, new_file_path) + # Close our persistent tracking file + close_persistent_file() diff --git a/ersilia/cli/commands/serve.py b/ersilia/cli/commands/serve.py index d4018357a..d91b60b82 100644 --- a/ersilia/cli/commands/serve.py +++ b/ersilia/cli/commands/serve.py @@ -3,6 +3,7 @@ from .. import echo from ... import ErsiliaModel from ..messages import ModelNotFound +from ...core.tracking import open_persistent_file def serve_cmd(): @@ -62,6 +63,7 @@ def serve(model, lake, docker, port, track_serve): echo("") echo(":person_tipping_hand: Information:", fg="blue") echo(" - info", fg="blue") + + # Setup persistent tracking if track_serve: - with open("current_session.txt", "w") as f: - f.write("Session started for model: {0}".format(mdl.model_id)) + open_persistent_file(mdl.model_id) diff --git a/ersilia/core/tracking.py b/ersilia/core/tracking.py index ad0f800cd..58758aa26 100644 --- a/ersilia/core/tracking.py +++ b/ersilia/core/tracking.py @@ -1,6 +1,9 @@ from datetime import datetime import json import pandas as pd +import os + +PERSISTENT_FILE_PATH = os.path.abspath("current_session.txt") def read_csv(file): @@ -12,6 +15,29 @@ def read_json(result): data = json.load(result) return data + +def open_persistent_file(model_id): + with open(PERSISTENT_FILE_PATH, "w") as f: + f.write("Session started for model: {0}\n".format(model_id)) + + +def write_persistent_file(contents): + # Only write to file if it already exists (we're meant to be tracking this run) + if os.path.isfile(PERSISTENT_FILE_PATH): + with open(PERSISTENT_FILE_PATH, "a") as f: + f.write(f"{contents}\n") + + +def close_persistent_file(): + # Make sure the file actually exists before we try renaming + if os.path.isfile(PERSISTENT_FILE_PATH): + new_file_path = os.path.join( + os.path.dirname(PERSISTENT_FILE_PATH), + datetime.now().strftime("%Y-%m-%d%_H-%M-%S.txt"), + ) + os.rename(PERSISTENT_FILE_PATH, new_file_path) + + class RunTracker: """ This class will be responsible for tracking model runs. It calculates the desired metadata based on a model's @@ -43,14 +69,14 @@ def stats(self, result): stats = {} for column in dat: column_stats = {} - column_stats['mean'] = dat[column].mean() + column_stats["mean"] = dat[column].mean() if len(dat[column].mode()) == 1: - column_stats['mode'] = dat[column].mode() + column_stats["mode"] = dat[column].mode().iloc[0] else: - column_stats['mode'] = None - column_stats['min'] = dat[column].min() - column_stats['max'] = dat[column].max() - column_stats['std'] = dat[column].std() + column_stats["mode"] = None + column_stats["min"] = dat[column].min() + column_stats["max"] = dat[column].max() + column_stats["std"] = dat[column].std() stats[column] = column_stats @@ -97,15 +123,13 @@ def track(self, input, result, meta): json_dict["stats"] = self.stats(result) - json_dict['file_sizes'] = self.get_file_sizes(input_dataframe, result_dataframe) + json_dict["file_sizes"] = self.get_file_sizes(input_dataframe, result_dataframe) json_object = json.dumps(json_dict, indent=4) print("\nJSON Dictionary:\n", json_object) - # log results to console - with open("../cli/commands/current_session.txt", "a") as f: - # write the print statements to a file - f.write(json_object) + # log results to persistent tracking file + write_persistent_file(json_object) def check_types(self, resultDf, metadata): typeDict = {"float64": "Float", "int64": "Int"} @@ -130,7 +154,4 @@ def check_types(self, resultDf, metadata): print("Output has", count, "mismatched types.\n") - return { - "mismatched_types": count, - "correct_shape": correct_shape - } + return {"mismatched_types": count, "correct_shape": correct_shape}