Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/eric/create-csv'
Browse files Browse the repository at this point in the history
  • Loading branch information
AC-Dap committed Dec 21, 2023
2 parents 6060d9f + 731dc20 commit fc6dc94
Showing 1 changed file with 42 additions and 1 deletion.
43 changes: 42 additions & 1 deletion ersilia/core/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,36 @@
TEMP_FILE_LOGS = os.path.abspath("")


def create_csv(output_df):
"""
This function takes in the output dataframe from the model run and returns
a new temporary csv file that will later be passed to CDD vault. The CSV
file has two columns: the first column is the input molecules and the
second column is the ISO-formatted time of the run.
:param file: The output dataframe from the model run
:return: A new temporary csv file
"""

new_df = output_df[['input']].copy()
current_time = datetime.now().isoformat()

new_df['time'] = current_time
csv_file = tempfile.NamedTemporaryFile(mode="w", suffix=".csv")
new_df.to_csv(csv_file.name, index=False)

return csv_file


def log_files_metrics(file):
"""
This function will log the number of errors and warnings in the log files.
:param file: The log file to be read
:return: None (writes to file)
"""


error_count = 0
warning_count = 0

Expand Down Expand Up @@ -218,6 +247,15 @@ def get_file_sizes(self, input_df, output_df):
}

def check_types(self, resultDf, metadata):
"""
This class is responsible for checking the types of the output dataframe against the expected types.
This includes checking the shape of the output dataframe (list vs single) and the types of each column.
:param resultDf: The output dataframe
:param metadata: The metadata dictionary
:return: A dictionary containing the number of mismatched types and a boolean for whether the shape is correct
"""

typeDict = {"float64": "Float", "int64": "Int"}
count = 0

Expand All @@ -237,7 +275,8 @@ def check_types(self, resultDf, metadata):
else:
print("Output is correct shape.")
correct_shape = True


print(resultDf)
print("Output has", count, "mismatched types.\n")

return {"mismatched_types": count, "correct_shape": correct_shape}
Expand Down Expand Up @@ -284,5 +323,7 @@ def track(self, input, result, meta):
json_object = json.dumps(json_dict, indent=4)
write_persistent_file(json_object)

create_csv(result_dataframe)

# Upload run stats to s3
upload_to_s3(json_dict)

0 comments on commit fc6dc94

Please sign in to comment.