Skip to content

Commit

Permalink
minor aesthetic refactors
Browse files Browse the repository at this point in the history
  • Loading branch information
Bogdan Budescu committed Dec 12, 2024
1 parent 880e0cc commit 8458d41
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 8 deletions.
18 changes: 11 additions & 7 deletions smac/model/random_forest/multiproc_util/GrowingSharedArray.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@

import math
from multiprocessing import Lock
from .SharedMemory import SharedMemory

# from multiprocessing.shared_memory import SharedMemory
from .SharedMemory import SharedMemory as TrackableSharedMemory
def SharedMemory(*args, **kwargs) -> TrackableSharedMemory:
return TrackableSharedMemory(*args, track=False, **kwargs)


import numpy as np
from numpy import typing as npt
Expand All @@ -29,8 +34,8 @@ def __init__(self, lock: Lock):
def open(self, shm_id: int, size: int):
if shm_id != self.shm_id:
self.close()
self.shm_X = SharedMemory(f'{self.basename_X}_{shm_id}', track=False)
self.shm_y = SharedMemory(f'{self.basename_y}_{shm_id}', track=False)
self.shm_X = SharedMemory(f'{self.basename_X}_{shm_id}')
self.shm_y = SharedMemory(f'{self.basename_y}_{shm_id}')
self.shm_id = shm_id
self.size = size

Expand Down Expand Up @@ -121,15 +126,14 @@ def set_data(self, X: npt.NDArray[np.float64], y: npt.NDArray[np.float64]) -> No
assert self.shm_y is None
capacity = size

shm_id = uuid.uuid4().int
shm_id = uuid.uuid4().int # self.shm_id + 1 if self.shm_id else 0

row_size = X.shape[1]
if self.row_size is not None:
assert row_size == self.row_size
shm_X = SharedMemory(f'{self.basename_X}_{shm_id}', create=True,
size=capacity * row_size * X.dtype.itemsize, track=False)
shm_y = SharedMemory(f'{self.basename_y}_{shm_id}', create=True, size=capacity * y.dtype.itemsize,
track=False)
size=capacity * row_size * X.dtype.itemsize)
shm_y = SharedMemory(f'{self.basename_y}_{shm_id}', create=True, size=capacity * y.dtype.itemsize)

with self.lock:
if grow:
Expand Down
3 changes: 2 additions & 1 deletion smac/model/random_forest/multiproc_util/SharedMemory.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def __init__(

# if tracking, normal init will suffice
if track:
return super().__init__(name=name, create=create, size=size)
super().__init__(name=name, create=create, size=size)
return

# lock so that other threads don't attempt to use the
# register function during this time
Expand Down

0 comments on commit 8458d41

Please sign in to comment.