Skip to content

Commit

Permalink
feat: shared model manager
Browse files Browse the repository at this point in the history
  • Loading branch information
hrnn committed Jul 17, 2024
1 parent 1c28e60 commit 61377be
Show file tree
Hide file tree
Showing 5 changed files with 222 additions and 2 deletions.
20 changes: 19 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pre-commit = "^3.7.1"
pytest = "^8.2.2"
ruff = "^0.5.0"
mypy = "^1.10.1"
pytest-asyncio = "^0.23.7"

[tool.poetry.group.docs.dependencies]
mkdocstrings = {extras = ["python"], version = "^0.25.1"}
Expand Down
101 changes: 101 additions & 0 deletions tests/tools/test_shared_model_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import pytest
from unittest.mock import AsyncMock, MagicMock

from vision_agent_tools.tools.shared_model_manager import SharedModelManager
from vision_agent_tools.tools.shared_types import BaseTool, Device


@pytest.fixture
def model_pool():
return SharedModelManager()


class MockBaseModel(AsyncMock, BaseTool):
pass


@pytest.mark.asyncio
async def test_add_model(model_pool):
def model_creation_fn():
return MockBaseModel()

model_pool.add(model_creation_fn)

assert len(model_pool.models) == 1
assert model_creation_fn.__name__ in model_pool.models

model_pool.add(model_creation_fn) # Duplicate addition

assert len(model_pool.models) == 1
assert model_creation_fn.__name__ in model_pool.models


@pytest.mark.asyncio
async def test_get_model_cpu(model_pool):
def model_creation_fn():
model = MockBaseModel()
model.to = MagicMock()
return model

model_pool.add(model_creation_fn)

model_to_get = await model_pool.get_model(model_creation_fn.__name__)

assert model_to_get is not None
assert model_to_get.to.call_count == 0 # No device change for CPU


@pytest.mark.asyncio
async def test_get_model_gpu(model_pool):
def model_creation_fn():
model = MockBaseModel()
model.to = MagicMock()
model.device = Device.GPU
return model

model_pool.add(model_creation_fn)

model_to_get = await model_pool.get_model(model_creation_fn.__name__)

assert model_to_get.to.call_count == 1
model_to_get.to.assert_called_once_with(Device.GPU) # Verify to was called with GPU


@pytest.mark.asyncio
async def test_get_model_not_found(model_pool):
with pytest.raises(ValueError):
await model_pool.get_model("NonexistentModel")


@pytest.mark.asyncio
async def test_get_model_multiple_gpu(model_pool):
def model_creation_fn_a():
model = MockBaseModel()
model.to = MagicMock()
model.device = Device.GPU # Set device during creation
return model

def model_creation_fn_b():
model = MockBaseModel()
model.to = MagicMock()
model.device = Device.GPU # Set device during creation
return model

model_pool.add(model_creation_fn_a)
model_pool.add(model_creation_fn_b)

# Get Model1 first, should use GPU
model1_to_get = await model_pool.get_model(model_creation_fn_a.__name__)
assert model1_to_get is not None
# assert model1_to_get is model1
assert (
model_pool._get_current_gpu_model() == model_creation_fn_a.__name__
) # Assert device on retrieved model

# Get Model2, should move Model1 to CPU and use GPU
model2_to_get = await model_pool.get_model(model_creation_fn_b.__name__)
assert model2_to_get is not None
# assert model2_to_get is model2
assert (
model_pool._get_current_gpu_model() == model_creation_fn_b.__name__
) # Assert device change on Model1
92 changes: 92 additions & 0 deletions vision_agent_tools/tools/shared_model_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import asyncio
from vision_agent_tools.tools.shared_types import Device


class SharedModelManager:
def __init__(self):
self.models = {} # store models with class name as key
self.model_locks = {} # store locks for each model
self.device = Device.CPU
self.current_gpu_model = None # Track the model currently using GPU

# Semaphore for exclusive GPU access
self.gpu_semaphore = asyncio.Semaphore(1)

def add(self, model_creation_fn):
"""
Adds a model to the pool with a device preference.
Args:
model_creation_fn (callable): A function that creates the model.
"""

class_name = model_creation_fn.__name__ # Get class name from function
if class_name in self.models:
print(f"Model '{class_name}' already exists in the pool.")
else:
model = model_creation_fn()
self.models[class_name] = model
self.model_locks[class_name] = asyncio.Lock()

async def get_model(self, class_name):
"""
Retrieves a model from the pool for safe execution.
Args:
class_name (str): Name of the model class.
Returns:
BaseTool: The retrieved model instance.
"""

if class_name not in self.models:
raise ValueError(f"Model '{class_name}' not found in the pool.")

model = self.models[class_name]
lock = self.model_locks[class_name]

async def get_model_with_lock():
async with lock:
if model.device == Device.GPU:
# Acquire semaphore if needed
async with self.gpu_semaphore:
# Update current GPU model (for testing)
self.current_gpu_model = class_name
model.to(Device.GPU)
return model

return await get_model_with_lock()

def _get_current_gpu_model(self):
"""
Returns the class name of the model currently using the GPU (if any).
"""
return self.current_gpu_model

async def _move_to_cpu(self, class_name):
"""
Moves a model to CPU and releases the GPU semaphore (if held).
"""
model = self.models[class_name]
model.to(Device.CPU)
if self.current_gpu_model == class_name:
self.current_gpu_model = None
self.gpu_semaphore.release() # Release semaphore if held

def __call__(self, class_name, arguments):
"""
Decorator for safe and efficient model execution.
Args:
class_name (str): Name of the model class.
arguments (tuple): Arguments for the model call.
Returns:
callable: A wrapper function that retrieves the model and executes it.
"""

async def wrapper():
model = await self.get_model(class_name)
return model(arguments)

return wrapper
10 changes: 9 additions & 1 deletion vision_agent_tools/tools/shared_types.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
from enum import Enum
from pydantic import BaseModel


class Device(str, Enum):
GPU = "cuda:0"
CPU = "cpu"


class BaseTool:
pass
def to(self, device: Device):
print(device)
pass


class Point(BaseModel):
Expand Down

0 comments on commit 61377be

Please sign in to comment.