From 8458d416c6323eff9668527d5d21b9084ca7d084 Mon Sep 17 00:00:00 2001 From: Bogdan Budescu Date: Thu, 12 Dec 2024 16:49:37 +0200 Subject: [PATCH] minor aesthetic refactors --- .../multiproc_util/GrowingSharedArray.py | 18 +++++++++++------- .../multiproc_util/SharedMemory.py | 3 ++- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/smac/model/random_forest/multiproc_util/GrowingSharedArray.py b/smac/model/random_forest/multiproc_util/GrowingSharedArray.py index 92cf1b51e..1f378829d 100644 --- a/smac/model/random_forest/multiproc_util/GrowingSharedArray.py +++ b/smac/model/random_forest/multiproc_util/GrowingSharedArray.py @@ -5,7 +5,12 @@ import math from multiprocessing import Lock -from .SharedMemory import SharedMemory + +# from multiprocessing.shared_memory import SharedMemory +from .SharedMemory import SharedMemory as TrackableSharedMemory +def SharedMemory(*args, **kwargs) -> TrackableSharedMemory: + return TrackableSharedMemory(*args, track=False, **kwargs) + import numpy as np from numpy import typing as npt @@ -29,8 +34,8 @@ def __init__(self, lock: Lock): def open(self, shm_id: int, size: int): if shm_id != self.shm_id: self.close() - self.shm_X = SharedMemory(f'{self.basename_X}_{shm_id}', track=False) - self.shm_y = SharedMemory(f'{self.basename_y}_{shm_id}', track=False) + self.shm_X = SharedMemory(f'{self.basename_X}_{shm_id}') + self.shm_y = SharedMemory(f'{self.basename_y}_{shm_id}') self.shm_id = shm_id self.size = size @@ -121,15 +126,14 @@ def set_data(self, X: npt.NDArray[np.float64], y: npt.NDArray[np.float64]) -> No assert self.shm_y is None capacity = size - shm_id = uuid.uuid4().int + shm_id = uuid.uuid4().int # self.shm_id + 1 if self.shm_id else 0 row_size = X.shape[1] if self.row_size is not None: assert row_size == self.row_size shm_X = SharedMemory(f'{self.basename_X}_{shm_id}', create=True, - size=capacity * row_size * X.dtype.itemsize, track=False) - shm_y = SharedMemory(f'{self.basename_y}_{shm_id}', create=True, size=capacity * y.dtype.itemsize, - track=False) + size=capacity * row_size * X.dtype.itemsize) + shm_y = SharedMemory(f'{self.basename_y}_{shm_id}', create=True, size=capacity * y.dtype.itemsize) with self.lock: if grow: diff --git a/smac/model/random_forest/multiproc_util/SharedMemory.py b/smac/model/random_forest/multiproc_util/SharedMemory.py index 6420a83ac..31dc5b194 100644 --- a/smac/model/random_forest/multiproc_util/SharedMemory.py +++ b/smac/model/random_forest/multiproc_util/SharedMemory.py @@ -25,7 +25,8 @@ def __init__( # if tracking, normal init will suffice if track: - return super().__init__(name=name, create=create, size=size) + super().__init__(name=name, create=create, size=size) + return # lock so that other threads don't attempt to use the # register function during this time