-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
222 additions
and
2 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
import pytest | ||
from unittest.mock import AsyncMock, MagicMock | ||
|
||
from vision_agent_tools.tools.shared_model_manager import SharedModelManager | ||
from vision_agent_tools.tools.shared_types import BaseTool, Device | ||
|
||
|
||
@pytest.fixture | ||
def model_pool(): | ||
return SharedModelManager() | ||
|
||
|
||
class MockBaseModel(AsyncMock, BaseTool): | ||
pass | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_add_model(model_pool): | ||
def model_creation_fn(): | ||
return MockBaseModel() | ||
|
||
model_pool.add(model_creation_fn) | ||
|
||
assert len(model_pool.models) == 1 | ||
assert model_creation_fn.__name__ in model_pool.models | ||
|
||
model_pool.add(model_creation_fn) # Duplicate addition | ||
|
||
assert len(model_pool.models) == 1 | ||
assert model_creation_fn.__name__ in model_pool.models | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_get_model_cpu(model_pool): | ||
def model_creation_fn(): | ||
model = MockBaseModel() | ||
model.to = MagicMock() | ||
return model | ||
|
||
model_pool.add(model_creation_fn) | ||
|
||
model_to_get = await model_pool.get_model(model_creation_fn.__name__) | ||
|
||
assert model_to_get is not None | ||
assert model_to_get.to.call_count == 0 # No device change for CPU | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_get_model_gpu(model_pool): | ||
def model_creation_fn(): | ||
model = MockBaseModel() | ||
model.to = MagicMock() | ||
model.device = Device.GPU | ||
return model | ||
|
||
model_pool.add(model_creation_fn) | ||
|
||
model_to_get = await model_pool.get_model(model_creation_fn.__name__) | ||
|
||
assert model_to_get.to.call_count == 1 | ||
model_to_get.to.assert_called_once_with(Device.GPU) # Verify to was called with GPU | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_get_model_not_found(model_pool): | ||
with pytest.raises(ValueError): | ||
await model_pool.get_model("NonexistentModel") | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_get_model_multiple_gpu(model_pool): | ||
def model_creation_fn_a(): | ||
model = MockBaseModel() | ||
model.to = MagicMock() | ||
model.device = Device.GPU # Set device during creation | ||
return model | ||
|
||
def model_creation_fn_b(): | ||
model = MockBaseModel() | ||
model.to = MagicMock() | ||
model.device = Device.GPU # Set device during creation | ||
return model | ||
|
||
model_pool.add(model_creation_fn_a) | ||
model_pool.add(model_creation_fn_b) | ||
|
||
# Get Model1 first, should use GPU | ||
model1_to_get = await model_pool.get_model(model_creation_fn_a.__name__) | ||
assert model1_to_get is not None | ||
# assert model1_to_get is model1 | ||
assert ( | ||
model_pool._get_current_gpu_model() == model_creation_fn_a.__name__ | ||
) # Assert device on retrieved model | ||
|
||
# Get Model2, should move Model1 to CPU and use GPU | ||
model2_to_get = await model_pool.get_model(model_creation_fn_b.__name__) | ||
assert model2_to_get is not None | ||
# assert model2_to_get is model2 | ||
assert ( | ||
model_pool._get_current_gpu_model() == model_creation_fn_b.__name__ | ||
) # Assert device change on Model1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import asyncio | ||
from vision_agent_tools.tools.shared_types import Device | ||
|
||
|
||
class SharedModelManager: | ||
def __init__(self): | ||
self.models = {} # store models with class name as key | ||
self.model_locks = {} # store locks for each model | ||
self.device = Device.CPU | ||
self.current_gpu_model = None # Track the model currently using GPU | ||
|
||
# Semaphore for exclusive GPU access | ||
self.gpu_semaphore = asyncio.Semaphore(1) | ||
|
||
def add(self, model_creation_fn): | ||
""" | ||
Adds a model to the pool with a device preference. | ||
Args: | ||
model_creation_fn (callable): A function that creates the model. | ||
""" | ||
|
||
class_name = model_creation_fn.__name__ # Get class name from function | ||
if class_name in self.models: | ||
print(f"Model '{class_name}' already exists in the pool.") | ||
else: | ||
model = model_creation_fn() | ||
self.models[class_name] = model | ||
self.model_locks[class_name] = asyncio.Lock() | ||
|
||
async def get_model(self, class_name): | ||
""" | ||
Retrieves a model from the pool for safe execution. | ||
Args: | ||
class_name (str): Name of the model class. | ||
Returns: | ||
BaseTool: The retrieved model instance. | ||
""" | ||
|
||
if class_name not in self.models: | ||
raise ValueError(f"Model '{class_name}' not found in the pool.") | ||
|
||
model = self.models[class_name] | ||
lock = self.model_locks[class_name] | ||
|
||
async def get_model_with_lock(): | ||
async with lock: | ||
if model.device == Device.GPU: | ||
# Acquire semaphore if needed | ||
async with self.gpu_semaphore: | ||
# Update current GPU model (for testing) | ||
self.current_gpu_model = class_name | ||
model.to(Device.GPU) | ||
return model | ||
|
||
return await get_model_with_lock() | ||
|
||
def _get_current_gpu_model(self): | ||
""" | ||
Returns the class name of the model currently using the GPU (if any). | ||
""" | ||
return self.current_gpu_model | ||
|
||
async def _move_to_cpu(self, class_name): | ||
""" | ||
Moves a model to CPU and releases the GPU semaphore (if held). | ||
""" | ||
model = self.models[class_name] | ||
model.to(Device.CPU) | ||
if self.current_gpu_model == class_name: | ||
self.current_gpu_model = None | ||
self.gpu_semaphore.release() # Release semaphore if held | ||
|
||
def __call__(self, class_name, arguments): | ||
""" | ||
Decorator for safe and efficient model execution. | ||
Args: | ||
class_name (str): Name of the model class. | ||
arguments (tuple): Arguments for the model call. | ||
Returns: | ||
callable: A wrapper function that retrieves the model and executes it. | ||
""" | ||
|
||
async def wrapper(): | ||
model = await self.get_model(class_name) | ||
return model(arguments) | ||
|
||
return wrapper |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters