diff --git a/poetry.lock b/poetry.lock index a4716e86..bc379e93 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2375,6 +2375,24 @@ tomli = {version = ">=1", markers = "python_version < \"3.11\""} [package.extras] dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-asyncio" +version = "0.23.7" +description = "Pytest support for asyncio" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest_asyncio-0.23.7-py3-none-any.whl", hash = "sha256:009b48127fbe44518a547bddd25611551b0e43ccdbf1e67d12479f569832c20b"}, + {file = "pytest_asyncio-0.23.7.tar.gz", hash = "sha256:5f5c72948f4c49e7db4f29f2521d4031f1c27f86e57b046126654083d4770268"}, +] + +[package.dependencies] +pytest = ">=7.0.0,<9" + +[package.extras] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] +testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] + [[package]] name = "python-dateutil" version = "2.9.0.post0" @@ -4070,4 +4088,4 @@ qr-reader = ["qreader"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0" -content-hash = "cae864510b9c2e399a01ab2ae429c18873077075ce6402ab903408dd8aab812e" +content-hash = "39b434a9c2e57663cf1ca3c1f26be49fd2d708e5ce0adcceb664b3448bda8ebe" diff --git a/pyproject.toml b/pyproject.toml index 1b9a708c..ac710d75 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ pre-commit = "^3.7.1" pytest = "^8.2.2" ruff = "^0.5.0" mypy = "^1.10.1" +pytest-asyncio = "^0.23.7" [tool.poetry.group.docs.dependencies] mkdocstrings = {extras = ["python"], version = "^0.25.1"} diff --git a/tests/tools/test_shared_model_manager.py b/tests/tools/test_shared_model_manager.py new file mode 100644 index 00000000..edac7c3a --- /dev/null +++ b/tests/tools/test_shared_model_manager.py @@ -0,0 +1,101 @@ +import pytest +from unittest.mock import AsyncMock, MagicMock + +from vision_agent_tools.tools.shared_model_manager import SharedModelManager +from vision_agent_tools.tools.shared_types import BaseTool, Device + + +@pytest.fixture +def model_pool(): + return SharedModelManager() + + +class MockBaseModel(AsyncMock, BaseTool): + pass + + +@pytest.mark.asyncio +async def test_add_model(model_pool): + def model_creation_fn(): + return MockBaseModel() + + model_pool.add(model_creation_fn) + + assert len(model_pool.models) == 1 + assert model_creation_fn.__name__ in model_pool.models + + model_pool.add(model_creation_fn) # Duplicate addition + + assert len(model_pool.models) == 1 + assert model_creation_fn.__name__ in model_pool.models + + +@pytest.mark.asyncio +async def test_get_model_cpu(model_pool): + def model_creation_fn(): + model = MockBaseModel() + model.to = MagicMock() + return model + + model_pool.add(model_creation_fn) + + model_to_get = await model_pool.get_model(model_creation_fn.__name__) + + assert model_to_get is not None + assert model_to_get.to.call_count == 0 # No device change for CPU + + +@pytest.mark.asyncio +async def test_get_model_gpu(model_pool): + def model_creation_fn(): + model = MockBaseModel() + model.to = MagicMock() + model.device = Device.GPU + return model + + model_pool.add(model_creation_fn) + + model_to_get = await model_pool.get_model(model_creation_fn.__name__) + + assert model_to_get.to.call_count == 1 + model_to_get.to.assert_called_once_with(Device.GPU) # Verify to was called with GPU + + +@pytest.mark.asyncio +async def test_get_model_not_found(model_pool): + with pytest.raises(ValueError): + await model_pool.get_model("NonexistentModel") + + +@pytest.mark.asyncio +async def test_get_model_multiple_gpu(model_pool): + def model_creation_fn_a(): + model = MockBaseModel() + model.to = MagicMock() + model.device = Device.GPU # Set device during creation + return model + + def model_creation_fn_b(): + model = MockBaseModel() + model.to = MagicMock() + model.device = Device.GPU # Set device during creation + return model + + model_pool.add(model_creation_fn_a) + model_pool.add(model_creation_fn_b) + + # Get Model1 first, should use GPU + model1_to_get = await model_pool.get_model(model_creation_fn_a.__name__) + assert model1_to_get is not None + # assert model1_to_get is model1 + assert ( + model_pool._get_current_gpu_model() == model_creation_fn_a.__name__ + ) # Assert device on retrieved model + + # Get Model2, should move Model1 to CPU and use GPU + model2_to_get = await model_pool.get_model(model_creation_fn_b.__name__) + assert model2_to_get is not None + # assert model2_to_get is model2 + assert ( + model_pool._get_current_gpu_model() == model_creation_fn_b.__name__ + ) # Assert device change on Model1 diff --git a/vision_agent_tools/tools/shared_model_manager.py b/vision_agent_tools/tools/shared_model_manager.py new file mode 100644 index 00000000..ed041acb --- /dev/null +++ b/vision_agent_tools/tools/shared_model_manager.py @@ -0,0 +1,92 @@ +import asyncio +from vision_agent_tools.tools.shared_types import Device + + +class SharedModelManager: + def __init__(self): + self.models = {} # store models with class name as key + self.model_locks = {} # store locks for each model + self.device = Device.CPU + self.current_gpu_model = None # Track the model currently using GPU + + # Semaphore for exclusive GPU access + self.gpu_semaphore = asyncio.Semaphore(1) + + def add(self, model_creation_fn): + """ + Adds a model to the pool with a device preference. + + Args: + model_creation_fn (callable): A function that creates the model. + """ + + class_name = model_creation_fn.__name__ # Get class name from function + if class_name in self.models: + print(f"Model '{class_name}' already exists in the pool.") + else: + model = model_creation_fn() + self.models[class_name] = model + self.model_locks[class_name] = asyncio.Lock() + + async def get_model(self, class_name): + """ + Retrieves a model from the pool for safe execution. + + Args: + class_name (str): Name of the model class. + + Returns: + BaseTool: The retrieved model instance. + """ + + if class_name not in self.models: + raise ValueError(f"Model '{class_name}' not found in the pool.") + + model = self.models[class_name] + lock = self.model_locks[class_name] + + async def get_model_with_lock(): + async with lock: + if model.device == Device.GPU: + # Acquire semaphore if needed + async with self.gpu_semaphore: + # Update current GPU model (for testing) + self.current_gpu_model = class_name + model.to(Device.GPU) + return model + + return await get_model_with_lock() + + def _get_current_gpu_model(self): + """ + Returns the class name of the model currently using the GPU (if any). + """ + return self.current_gpu_model + + async def _move_to_cpu(self, class_name): + """ + Moves a model to CPU and releases the GPU semaphore (if held). + """ + model = self.models[class_name] + model.to(Device.CPU) + if self.current_gpu_model == class_name: + self.current_gpu_model = None + self.gpu_semaphore.release() # Release semaphore if held + + def __call__(self, class_name, arguments): + """ + Decorator for safe and efficient model execution. + + Args: + class_name (str): Name of the model class. + arguments (tuple): Arguments for the model call. + + Returns: + callable: A wrapper function that retrieves the model and executes it. + """ + + async def wrapper(): + model = await self.get_model(class_name) + return model(arguments) + + return wrapper diff --git a/vision_agent_tools/tools/shared_types.py b/vision_agent_tools/tools/shared_types.py index 1f6e7328..d7d57ba5 100644 --- a/vision_agent_tools/tools/shared_types.py +++ b/vision_agent_tools/tools/shared_types.py @@ -1,8 +1,16 @@ +from enum import Enum from pydantic import BaseModel +class Device(str, Enum): + GPU = "cuda:0" + CPU = "cpu" + + class BaseTool: - pass + def to(self, device: Device): + print(device) + pass class Point(BaseModel):