Skip to content

Commit

Permalink
minor fixes in sync semantics
Browse files Browse the repository at this point in the history
  • Loading branch information
Bogdan Budescu committed Dec 11, 2024
1 parent 984c59f commit 75eb34e
Showing 1 changed file with 27 additions and 8 deletions.
35 changes: 27 additions & 8 deletions smac/model/random_forest/multiproc_util/RFTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def close(self):
self.training_loop_proc = None

if self.model_queue is not None:
_ = self.model # flush the model queue, and return latest model
_ = self.model # try to flush the model queue, and store the latest model
self.model_queue.close()
self.model_queue.join_thread()
del self.model_queue
Expand All @@ -125,15 +125,34 @@ def __del__(self):
self.close()

@property
def model(self) -> Optional[BinaryForest]:
# discard all but the last model in the queue
while True:
try:
self._model = self.model_queue.get(block=False)
except queue.Empty:
break
def model(self) -> BinaryForest:
if self._model is None:
if self.model_queue is None:
raise RuntimeError('rf training loop process has been stopped before being able to train a model')
# wait until the first training is done
self._model = self.model_queue.get()

if self.model_queue is not None:
# discard all but the last model in the queue
while True:
try:
self._model = self.model_queue.get(block=False)
except queue.Empty:
break
return self._model

def submit_for_training(self, X: npt.NDArray[np.float64], y: npt.NDArray[np.float64]):
self.shared_arrs.set_data(X, y)

if self.data_queue is None:
raise RuntimeError('rf training loop process has been stopped, so we cannot submit new training data')

# flush queue before pushing new data onto it
while True:
try:
old_data = self.data_queue.get(block=False)
except queue.Empty:
break
else:
assert old_data != SHUTDOWN
self.data_queue.put((self.shared_arrs.shm_id, len(X)))

0 comments on commit 75eb34e

Please sign in to comment.