Skip to content

Commit

Permalink
refactor: track model_ids in cv_results (#1628)
Browse files Browse the repository at this point in the history
  • Loading branch information
aron-bram authored Dec 7, 2022
1 parent a978478 commit 63bfbeb
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions autosklearn/automl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1921,15 +1921,17 @@ def cv_results_(self):
metric_dict[metric.name] = []
metric_mask[metric.name] = []

model_ids = []
mean_fit_time = []
params = []
status = []
budgets = []

for run_key in self.runhistory_.data:
run_value = self.runhistory_.data[run_key]
for run_key, run_value in self.runhistory_.data.items():
config_id = run_key.config_id
config = self.runhistory_.ids_config[config_id]
if run_value.additional_info and "num_run" in run_value.additional_info:
model_ids.append(run_value.additional_info["num_run"])

s = run_value.status
if s == StatusType.SUCCESS:
Expand Down Expand Up @@ -1990,6 +1992,8 @@ def cv_results_(self):
metric_dict[metric.name].append(metric_value)
metric_mask[metric.name].append(mask_value)

results["model_ids"] = model_ids

if len(self._metrics) == 1:
results["mean_test_score"] = np.array(metric_dict[self._metrics[0].name])
rank_order = -1 * self._metrics[0]._sign * results["mean_test_score"]
Expand Down Expand Up @@ -2165,14 +2169,11 @@ def show_models(self) -> dict[int, Any]:
warnings.warn("No ensemble found. Returning empty dictionary.")
return ensemble_dict

def has_key(rv, key):
return rv.additional_info and key in rv.additional_info

table_dict = {}
for run_key, run_val in self.runhistory_.data.items():
if has_key(run_val, "num_run"):
model_id = run_val.additional_info["num_run"]
table_dict[model_id] = {"model_id": model_id, "cost": run_val.cost}
for run_key, run_value in self.runhistory_.data.items():
if run_value.additional_info and "num_run" in run_value.additional_info:
model_id = run_value.additional_info["num_run"]
table_dict[model_id] = {"model_id": model_id, "cost": run_value.cost}

# Checking if the dictionary is empty
if not table_dict:
Expand Down

0 comments on commit 63bfbeb

Please sign in to comment.