Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

documentation for check_types and log_files_metrics #19

Merged
merged 2 commits into from
Dec 21, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 42 additions & 1 deletion ersilia/core/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,36 @@
TEMP_FILE_LOGS = os.path.abspath("")


def create_csv(output_df):
"""
This function takes in the output dataframe from the model run and returns
a new temporary csv file that will later be passed to CDD vault. The CSV
file has two columns: the first column is the input molecules and the
second column is the ISO-formatted time of the run.

:param file: The output dataframe from the model run
:return: A new temporary csv file
"""

new_df = output_df[['input']].copy()
current_time = datetime.now().isoformat()

new_df['time'] = current_time
csv_file = tempfile.NamedTemporaryFile(mode="w", suffix=".csv")
new_df.to_csv(csv_file.name, index=False)

return csv_file


def log_files_metrics(file):
"""
This function will log the number of errors and warnings in the log files.

:param file: The log file to be read
:return: None (writes to file)
"""


error_count = 0
warning_count = 0

Expand Down Expand Up @@ -157,6 +186,15 @@ def get_file_sizes(self, input_df, output_df):
}

def check_types(self, resultDf, metadata):
"""
This class is responsible for checking the types of the output dataframe against the expected types.
This includes checking the shape of the output dataframe (list vs single) and the types of each column.

:param resultDf: The output dataframe
:param metadata: The metadata dictionary
:return: A dictionary containing the number of mismatched types and a boolean for whether the shape is correct
"""

typeDict = {"float64": "Float", "int64": "Int"}
count = 0

Expand All @@ -176,7 +214,8 @@ def check_types(self, resultDf, metadata):
else:
print("Output is correct shape.")
correct_shape = True


print(resultDf)
print("Output has", count, "mismatched types.\n")

return {"mismatched_types": count, "correct_shape": correct_shape}
Expand Down Expand Up @@ -223,5 +262,7 @@ def track(self, input, result, meta):
json_object = json.dumps(json_dict, indent=4)
write_persistent_file(json_object)

create_csv(result_dataframe)

# Upload run stats to s3
upload_to_s3(json_dict)
Loading