diff --git a/ersilia/core/tracking.py b/ersilia/core/tracking.py index 9f77f2192..64f8c8063 100644 --- a/ersilia/core/tracking.py +++ b/ersilia/core/tracking.py @@ -77,6 +77,12 @@ def track(self, input, result, meta): time = datetime.now() - self.time_start print("Time taken:", time) + # checking for mismatched types + nan_count = result_dataframe.isna().sum() + print("\nNAN Count:\n", nan_count) + + self.check_types(result_dataframe, meta["metadata"]) + self.stats(result) self.get_file_sizes(input_dataframe, result_dataframe) @@ -88,3 +94,23 @@ def read_json(self, result): data = json.load(result) self.log_to_console(result) return data + + def check_types(self, resultDf, metadata): + typeDict = {"float64": "Float", "int64": "Int"} + count = 0 + + # ignore key and input columns + dtypesLst = resultDf.loc[:, ~resultDf.columns.isin(["key", "input"])].dtypes + + for i in dtypesLst: + if typeDict[str(i)] != metadata["Output Type"][0]: + count += 1 + + if len(dtypesLst) > 1 and metadata["Output Shape"] != "List": + print("Not right shape. Expected List but got Single") + elif len(dtypesLst) == 1 and metadata["Output Shape"] != "Single": + print("Not right shape. Expected Single but got List") + else: + print("Output is correct shape.") + + print("Output has", count, "mismatched types.\n")