From 63bfbebbd288c8669d6bce7f44f8c9a3a82facd5 Mon Sep 17 00:00:00 2001 From: Aron Bahram Date: Wed, 7 Dec 2022 09:59:21 +0100 Subject: [PATCH] refactor: track model_ids in cv_results (#1628) --- autosklearn/automl.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/autosklearn/automl.py b/autosklearn/automl.py index ffcf1fb033..1b2b08f74f 100644 --- a/autosklearn/automl.py +++ b/autosklearn/automl.py @@ -1921,15 +1921,17 @@ def cv_results_(self): metric_dict[metric.name] = [] metric_mask[metric.name] = [] + model_ids = [] mean_fit_time = [] params = [] status = [] budgets = [] - for run_key in self.runhistory_.data: - run_value = self.runhistory_.data[run_key] + for run_key, run_value in self.runhistory_.data.items(): config_id = run_key.config_id config = self.runhistory_.ids_config[config_id] + if run_value.additional_info and "num_run" in run_value.additional_info: + model_ids.append(run_value.additional_info["num_run"]) s = run_value.status if s == StatusType.SUCCESS: @@ -1990,6 +1992,8 @@ def cv_results_(self): metric_dict[metric.name].append(metric_value) metric_mask[metric.name].append(mask_value) + results["model_ids"] = model_ids + if len(self._metrics) == 1: results["mean_test_score"] = np.array(metric_dict[self._metrics[0].name]) rank_order = -1 * self._metrics[0]._sign * results["mean_test_score"] @@ -2165,14 +2169,11 @@ def show_models(self) -> dict[int, Any]: warnings.warn("No ensemble found. Returning empty dictionary.") return ensemble_dict - def has_key(rv, key): - return rv.additional_info and key in rv.additional_info - table_dict = {} - for run_key, run_val in self.runhistory_.data.items(): - if has_key(run_val, "num_run"): - model_id = run_val.additional_info["num_run"] - table_dict[model_id] = {"model_id": model_id, "cost": run_val.cost} + for run_key, run_value in self.runhistory_.data.items(): + if run_value.additional_info and "num_run" in run_value.additional_info: + model_id = run_value.additional_info["num_run"] + table_dict[model_id] = {"model_id": model_id, "cost": run_value.cost} # Checking if the dictionary is empty if not table_dict: