Skip to content

Commit

Permalink
add flag to allow user switch to emulating the old behavior, i.e., to…
Browse files Browse the repository at this point in the history
… wait until training is done before being able to query the model to suggest a new config to try
  • Loading branch information
Bogdan Budescu committed Dec 12, 2024
1 parent ed50624 commit 65edb8d
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions smac/model/random_forest/multiproc_util/RFTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ..util import init_data_container


SHUTDOWN = -1
SHUTDOWN = None


def rf_training_loop(
Expand All @@ -25,7 +25,7 @@ def rf_training_loop(
# rf opts
n_trees: int, bootstrapping: bool, max_features: int, min_samples_split: int, min_samples_leaf: int,
max_depth: int, eps_purity: float, max_nodes: int, n_points_per_tree: int
):
) -> None:
rf_opts = regression.forest_opts()
rf_opts.num_trees = n_trees
rf_opts.do_bootstrapping = bootstrapping
Expand All @@ -51,31 +51,27 @@ def rf_training_loop(
msg = data_queue.get(block=False)
except queue.Empty:
break

if msg == SHUTDOWN:
break
else:
if msg == SHUTDOWN:
return

shm_id, size = msg

X, y = shared_arrs.get_data(shm_id, size)

data = init_data_container(X, y, bounds)

if n_points_per_tree <= 0:
rf_opts.num_data_points_per_tree = len(X)

rf = BinaryForest()
rf.options = rf_opts

rf.fit(data, rng)

# remove previous models from queue, if any, and replace them with the latest
# remove previous models from queue, if any, before pushing the latest model
while True:
try:
old_rf = model_queue.get(block=False)
_ = model_queue.get(block=False)
except queue.Empty:
break

model_queue.put(rf)


Expand All @@ -85,10 +81,13 @@ def __init__(
self, bounds: Iterable[tuple[float, float]], seed: int,
# rf opts
n_trees: int, bootstrapping: bool, max_features: int, min_samples_split: int, min_samples_leaf: int,
max_depth: int, eps_purity: float, max_nodes: int, n_points_per_tree: int
max_depth: int, eps_purity: float, max_nodes: int, n_points_per_tree: int,
# process synchronization
sync: bool = False
) -> None:
self._model: Optional[BinaryForest] = None
self.shared_arrs = GrowingSharedArray()
self.sync = sync

self.model_queue = Queue(maxsize=1)
self.data_queue = Queue(maxsize=1)
Expand Down Expand Up @@ -159,3 +158,6 @@ def submit_for_training(self, X: npt.NDArray[np.float64], y: npt.NDArray[np.floa
else:
assert old_data != SHUTDOWN
self.data_queue.put((self.shared_arrs.shm_id, len(X)))

if self.sync:
self._model = self.model_queue.get()

0 comments on commit 65edb8d

Please sign in to comment.