Skip to content

Commit

Permalink
add previous runs to *.json when warm starting with result_logger
Browse files Browse the repository at this point in the history
  • Loading branch information
rosea-tf committed Aug 27, 2020
1 parent 841db4b commit 397850c
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
10 changes: 8 additions & 2 deletions hpbandster/core/base_iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ class WarmStartIteration(BaseIteration):
iteration that imports a privious Result for warm starting
"""

def __init__(self, Result, config_generator):
def __init__(self, Result, config_generator, result_logger=None):

self.is_finished=False
self.stage = 0
Expand All @@ -263,11 +263,14 @@ def __init__(self, Result, config_generator):
id2conf = Result.get_id2config_mapping()
delta_t = - max(map(lambda r: r.time_stamps['finished'], Result.get_all_runs()))

super().__init__(-1, [len(id2conf)] , [None], None)
super().__init__(-1, [len(id2conf)], [None],
None,
result_logger=result_logger)


for i, id in enumerate(id2conf):
new_id = self.add_configuration(config=id2conf[id]['config'], config_info=id2conf[id]['config_info'])
# if result_logger exists, add this config to configs.json

for r in Result.get_runs_by_id(id):

Expand All @@ -281,6 +284,9 @@ def __init__(self, Result, config_generator):
j.timestamps[k] = v + delta_t

self.register_result(j , skip_sanity_checks=True)

if self.result_logger:
self.result_logger(j) # add prev jobs to results.json

config_generator.new_result(j, update_model=(i==len(id2conf)-1))

Expand Down
6 changes: 5 additions & 1 deletion hpbandster/core/master.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,11 @@ def __init__(self,
self.warmstart_iteration = []

else:
self.warmstart_iteration = [WarmStartIteration(previous_result, self.config_generator)]
self.warmstart_iteration = [
WarmStartIteration(previous_result,
self.config_generator,
result_logger=self.result_logger)
]

# condition to synchronize the job_callback and the queue
self.thread_cond = threading.Condition()
Expand Down

0 comments on commit 397850c

Please sign in to comment.