Skip to content

Commit

Permalink
Merge pull request #10 from hcs-t4sg/eric/verify-output
Browse files Browse the repository at this point in the history
Eric/verify output
  • Loading branch information
AC-Dap authored Oct 17, 2023
2 parents b390f64 + 1c3591e commit 552e615
Showing 1 changed file with 28 additions and 1 deletion.
29 changes: 28 additions & 1 deletion ersilia/core/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def track(self, input, result, meta):
print(read_csv(input))

print("Run output file:", result)
print(read_csv(result))
resultDf = read_csv(result)
print(resultDf)

print("Model metadata:", meta)

Expand All @@ -64,6 +65,12 @@ def track(self, input, result, meta):
time = datetime.now() - self.time_start
print("Time taken:", time)

# checking for mismatched types
nan_count = resultDf.isna().sum()
print("\nNAN Count:\n", nan_count)

self.check_types(resultDf, meta["metadata"])

self.stats(result)

input_dataframe = self.read_csv(input)
Expand All @@ -85,3 +92,23 @@ def read_json(self, result):
data = json.load(result)
self.log_to_console(result)
return data

def check_types(self, resultDf, metadata):
typeDict = {"float64": "Float", "int64": "Int"}
count = 0

# ignore key and input columns
dtypesLst = resultDf.loc[:, ~resultDf.columns.isin(["key", "input"])].dtypes

for i in dtypesLst:
if typeDict[str(i)] != metadata["Output Type"][0]:
count += 1

if len(dtypesLst) > 1 and metadata["Output Shape"] != "List":
print("Not right shape. Expected List but got Single")
elif len(dtypesLst) == 1 and metadata["Output Shape"] != "Single":
print("Not right shape. Expected Single but got List")
else:
print("Output is correct shape.")

print("Output has", count, "mismatched types.\n")

0 comments on commit 552e615

Please sign in to comment.