From 5514b7cd68c946b6b62fe7c69caa353516d3dec3 Mon Sep 17 00:00:00 2001 From: Anthony Cui Date: Mon, 16 Oct 2023 23:56:11 -0400 Subject: [PATCH] Refactor code to be cleaner --- ersilia/core/tracking.py | 43 +++++++++++++++++++++------------------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/ersilia/core/tracking.py b/ersilia/core/tracking.py index 9013d0eb..9f77f219 100644 --- a/ersilia/core/tracking.py +++ b/ersilia/core/tracking.py @@ -16,12 +16,6 @@ class RunTracker: NOTE: Currently, the Splunk connection is not set up. For now, we will print tracking results to the console. """ - def sample_df(self, df, num_rows, num_cols): - """ - Returns a sample of the dataframe, with the specified number of rows and columns. - """ - return df.sample(num_rows, axis=0).sample(num_cols, axis=1) - def __init__(self): self.time_start = None @@ -29,8 +23,14 @@ def __init__(self): def start_tracking(self): self.time_start = datetime.now() + def sample_df(self, df, num_rows, num_cols): + """ + Returns a sample of the dataframe, with the specified number of rows and columns. + """ + return df.sample(num_rows, axis=0).sample(num_cols, axis=1) + def stats(self, result): - dat = self.read_csv(result) + dat = read_csv(result) # drop first two columns (key, input) dat = dat.drop(["key", "input"], axis=1) @@ -46,15 +46,28 @@ def stats(self, result): print("Max %s: %s" % (column, dat[column].max())) print("Standard deviation %s: %s" % (column, dat[column].std())) + def get_file_sizes(self, input_df, output_df): + input_size = input_df.memory_usage(deep=True).sum() / 1024 + output_size = output_df.memory_usage(deep=True).sum() / 1024 + + input_avg_row_size = input_size / len(input_df) + output_avg_row_size = output_size / len(output_df) + + print("Average Input Row Size (KB):", input_avg_row_size) + print("Average Output Row Size (KB):", output_avg_row_size) + def track(self, input, result, meta): """ Tracks the results after a model run. """ + input_dataframe = read_csv(input) + result_dataframe = read_csv(result) + print("Run input file:", input) - print(read_csv(input)) + print(input_dataframe) print("Run output file:", result) - print(read_csv(result)) + print(result_dataframe) print("Model metadata:", meta) @@ -66,17 +79,7 @@ def track(self, input, result, meta): self.stats(result) - input_dataframe = self.read_csv(input) - result_dataframe = self.read_csv(result) - - input_size = input_dataframe.memory_usage(deep=True).sum() / 1024 - output_size = result_dataframe.memory_usage(deep=True).sum() / 1024 - - input_avg_row_size = input_size / len(input_dataframe) - output_avg_row_size = output_size / len(result_dataframe) - - print("Average Input Row Size (KB):", input_avg_row_size) - print("Average Output Row Size (KB):", output_avg_row_size) + self.get_file_sizes(input_dataframe, result_dataframe) def log_to_console(self, data): print(f"\n{json.dumps(data)}\n")