-
-
Notifications
You must be signed in to change notification settings - Fork 229
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor: encapsulation: move multiprocessing implementation to dedic…
…ated subpackage
- Loading branch information
Bogdan Budescu
committed
Dec 11, 2024
1 parent
7492c48
commit d9c3867
Showing
4 changed files
with
224 additions
and
194 deletions.
There are no files selected for viewing
125 changes: 125 additions & 0 deletions
125
smac/model/random_forest/multiproc_util/GrowingSharedArray.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
#!/usr/bin/env python3 | ||
# -*- coding: utf-8 -*- | ||
|
||
from typing import Optional | ||
|
||
import math | ||
from multiprocessing import Lock, shared_memory | ||
|
||
import numpy as np | ||
from numpy import typing as npt | ||
|
||
|
||
def dtypes_are_equal(dtype1: np.dtype, dtype2: np.dtype) -> bool: | ||
return np.issubdtype(dtype2, dtype1) and np.issubdtype(dtype1, dtype2) | ||
|
||
|
||
class GrowingSharedArrayReaderView: | ||
basename_X: str = 'X' | ||
basename_y: str = 'y' | ||
|
||
def __init__(self, lock: Lock): | ||
self.lock = lock | ||
self.shm_id: Optional[int] = None | ||
self.shm_X: Optional[shared_memory.SharedMemory] = None | ||
self.shm_y: Optional[shared_memory.SharedMemory] = None | ||
|
||
def __del__(self): | ||
if self.shm_X is not None: | ||
self.shm_X.close() | ||
if self.shm_y is not None: | ||
self.shm_y.close() | ||
|
||
@property | ||
def capacity(self) -> Optional[int]: | ||
if self.shm_y is None: | ||
return None | ||
assert self.shm_y.size % np.float64.itemsize == 0 | ||
return self.shm_y.size / np.float64.itemsize | ||
|
||
@property | ||
def row_size(self) -> Optional[int]: | ||
if self.shm_X is None: | ||
return None | ||
if self.shm_X.size == 0: | ||
assert self.shm_y.size == 0 | ||
return 0 | ||
assert self.shm_X.size % self.shm_y.size == 0 | ||
return self.shm_X.size // self.shm_y.size | ||
|
||
def np_view(self, size: int) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]: | ||
X = np.ndarray(shape=(self.capacity, self.row_size), dtype=np.float64, buffer=self.shm_X.buf) | ||
y = np.ndarray(shape=(self.capacity,), dtype=np.float64, buffer=self.shm_y.buf) | ||
return X[:size], y[:size] | ||
|
||
def get_data(self, shm_id: int, size: int) -> tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]: | ||
with self.lock: | ||
# single_read_shared_mem() as shm_X, single_read_shared_mem(f'{self.basename_y}_{shm_id}') as shm_y: | ||
if shm_id != self.shm_id: | ||
self.shm_X.close() | ||
del self.shm_X | ||
self.shm_X = None | ||
|
||
self.shm_y.close() | ||
del self.shm_y | ||
self.shm_y = None | ||
|
||
self.shm_X = shared_memory.SharedMemory(f'{self.basename_X}_{shm_id}') | ||
self.shm_y = shared_memory.SharedMemory(f'{self.basename_y}_{shm_id}') | ||
|
||
shared_X, shared_y = self.np_view(size) | ||
X, y = np.array(shared_X), np.array(shared_y) # make copies | ||
|
||
return X, y | ||
|
||
|
||
class GrowingSharedArray(GrowingSharedArrayReaderView): | ||
def __init__(self): | ||
self.growth_rate = 1.5 | ||
super().__init__(lock=Lock()) | ||
|
||
def set_data(self, X: npt.NDArray[np.float64], y: npt.NDArray[np.float64]) -> None: | ||
assert len(X) == len(y) | ||
assert X.ndim == 2 | ||
assert y.ndim == 1 | ||
assert dtypes_are_equal(X.dtype, np.float64) | ||
assert dtypes_are_equal(y.dtype, np.float64) | ||
assert X.dtype.itemsize == 8 | ||
assert y.dtype.itemsize == 8 | ||
|
||
size = len(y) | ||
grow = size > self.capacity | ||
if grow: | ||
if self.capacity: | ||
n_growth = math.ceil(math.log(size / self.capacity, self.growth_rate)) | ||
capacity = int(math.ceil(self.capacity * self.growth_rate ** n_growth)) | ||
self.shm_id += 1 | ||
else: | ||
assert self.shm_X is None | ||
assert self.shm_y is None | ||
capacity = size | ||
self.shm_id = 0 | ||
|
||
if self.row_size is not None: | ||
assert X.shape[1] == self.row_size | ||
|
||
shm_X = shared_memory.SharedMemory(f'{self.basename_X}_{self.shm_id}', create=True, | ||
size=capacity * self.row_size * X.dtype.itemsize) | ||
shm_y = shared_memory.SharedMemory(f'{self.basename_y}_{self.shm_id}', create=True, | ||
size=capacity * y.dtype.itemsize) | ||
|
||
with self.lock: | ||
if grow: | ||
if self.capacity: | ||
assert self.shm_X is not None | ||
self.shm_X.close() | ||
self.shm_X.unlink() | ||
assert self.shm_y is not None | ||
self.shm_y.close() | ||
self.shm_y.unlink() | ||
self.shm_X = shm_X | ||
self.shm_y = shm_y | ||
X_buf, y_buf = self.np_view(size) | ||
X_buf[...] = X | ||
y_buf[...] = y | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
#!/usr/bin/env python3 | ||
# -*- coding: utf-8 -*- | ||
|
||
from typing import TYPE_CHECKING, Optional | ||
|
||
from multiprocessing import Process, Lock, Queue | ||
|
||
from numpy import typing as npt | ||
import numpy as np | ||
|
||
if TYPE_CHECKING: | ||
from pyrfr.regression import binary_rss_forest as BinaryForest, forest_opts as ForestOpts | ||
|
||
|
||
class RFTrainer(Process): | ||
def __init__(self): | ||
self._model: Optional[BinaryForest] = None | ||
self.model_lock = Lock() | ||
self.model_queue = Queue(maxsize=1) | ||
|
||
self.opts = None | ||
self.data_queue = Queue(maxsize=1) | ||
|
||
super().__init__(daemon=True) | ||
self.start() | ||
|
||
@property | ||
def model(self): | ||
model = None | ||
while True: | ||
m = self.model_queue.get(block=False) | ||
if m is None: | ||
break | ||
else: | ||
model = m | ||
|
||
with self.model_lock: | ||
if model is not None: | ||
self._model = model | ||
return self._model | ||
|
||
def submit_for_training(self, X: npt.NDArray[np.float64], y: npt.NDArray[np.float64], opts: ForestOpts): | ||
# use condition variable to wake up the trainer thread if it's sleeping | ||
with self.data_cv: | ||
assert data is not None | ||
# overwrite with latest training data | ||
self.data = data | ||
self.opts = opts | ||
self.data_cv.notify() | ||
|
||
def run(self) -> None: | ||
while True: | ||
# sleep until new data is submitted for training | ||
with self.data_cv: | ||
while self.data is None: | ||
self.data_cv.wait() | ||
data = self.data | ||
self.data = None | ||
|
||
# here we could (conditionally) call self.model_available.clear() in order to make _some_ worker threads | ||
# wait for training to finish before receiving a new configuration to try, depending on CPU load; we might | ||
# have to replace the Event by a Condition | ||
|
||
data = self._init_data_container(X, y) | ||
|
||
_rf = regression.binary_rss_forest() | ||
_rf.options = self.opts | ||
|
||
_rf.fit(data, rng=self._rng) | ||
|
||
with self.model_lock: | ||
self._model = _rf | ||
|
||
if not self.model_available.is_set(): | ||
self.model_available.set() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
#!/usr/bin/env python3 | ||
# -*- coding: utf-8 -*- | ||
|
||
""" | ||
doc of what is in the file | ||
""" | ||
|
||
__author__ = "Iddo Software <[email protected]>" | ||
__copyright__ = "Copyright © 2022, Client & Iddo Software. All Rights Reserved." | ||
__license__ = "Proprietary" | ||
__version__ = "0.1" | ||
__maintainer__ = "Iddo Software <[email protected]>" | ||
__email__ = "[email protected]" | ||
__status__ = "Development" # can also be "Prototype" or "Production" | ||
|
||
|
||
def main(): | ||
pass | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Oops, something went wrong.