diff --git a/aimm/client/repl.py b/aimm/client/repl.py index 6e1d478..18c014f 100644 --- a/aimm/client/repl.py +++ b/aimm/client/repl.py @@ -2,8 +2,10 @@ specified by the REPL control.""" from getpass import getpass + from hat import aio from hat import juggler +from tenacity import AsyncRetrying, stop_after_attempt, wait_fixed import base64 import hashlib import numpy @@ -41,7 +43,7 @@ def state(self) -> "JSON": """Current state reported from the AIMM server""" return self._state - async def connect(self, address: str, autoflush_delay: float = 0.2): + async def connect(self, address: str): """Connects to the specified remote address. Login data is received from a user prompt. Passwords are hashed with SHA-256 before sending login request.""" @@ -51,7 +53,13 @@ async def connect(self, address: str, autoflush_delay: float = 0.2): password_hash = hashlib.sha256() password_hash.update(getpass("Password: ").encode("utf-8")) - connection = await juggler.connect(address) + connection = None + async for attempt in AsyncRetrying( + wait=wait_fixed(1), stop=stop_after_attempt(3) + ): + with attempt: + connection = await juggler.connect(address) + await connection.send( "login", {"username": username, "password": password_hash.hexdigest()}, diff --git a/aimm/plugins/common.py b/aimm/plugins/common.py index 35cddd5..e3f061b 100644 --- a/aimm/plugins/common.py +++ b/aimm/plugins/common.py @@ -1,5 +1,4 @@ from typing import Any, ByteString, Callable, Dict, NamedTuple, Optional -from aimm.common import * # NOQA import abc import importlib import logging diff --git a/aimm/plugins/decorators.py b/aimm/plugins/decorators.py index a81c6b5..280b308 100644 --- a/aimm/plugins/decorators.py +++ b/aimm/plugins/decorators.py @@ -169,9 +169,6 @@ def serialize(model_types: List[str]) -> Callable: Args: model_types: types of models supported by the decorated function - instance_arg_name: if set, indicates under which argument name to pass - the concrete model instance. If not set, it is passed in the first - positional argument Returns: Decorated function""" @@ -231,13 +228,26 @@ def model(cls: Type) -> Type: """ model_type = f"{cls.__module__}.{cls.__name__}" + _declare("instantiate", model_type, common.InstantiatePlugin(cls)) - _declare("fit", model_type, common.FitPlugin(cls.fit)) - _declare("predict", model_type, common.PredictPlugin(cls.predict)) - _declare("serialize", model_type, common.SerializePlugin(cls.serialize)) - _declare( - "deserialize", model_type, common.DeserializePlugin(cls.deserialize) - ) + + fit_fn = getattr(cls, "fit") + if isinstance(fit_fn, Callable): + _declare("fit", model_type, common.FitPlugin(fit_fn)) + + predict_fn = getattr(cls, "predict") + if isinstance(predict_fn, Callable): + _declare("predict", model_type, common.PredictPlugin(predict_fn)) + + serialize_fn = getattr(cls, "serialize") + if isinstance(serialize_fn, Callable): + _declare("serialize", model_type, common.SerializePlugin(serialize_fn)) + + deserialize_fn = getattr(cls, "deserialize") + if isinstance(deserialize_fn, Callable): + _declare( + "deserialize", model_type, common.DeserializePlugin(deserialize_fn) + ) return cls diff --git a/aimm/plugins/execute.py b/aimm/plugins/execute.py index 10cb89e..34a4452 100644 --- a/aimm/plugins/execute.py +++ b/aimm/plugins/execute.py @@ -50,15 +50,16 @@ def exec_predict( state_cb: common.StateCallback = lambda state: None, *args: Any, **kwargs: Any -) -> Any: +) -> tuple[Any, Any]: """Uses a loaded plugin to perform a prediction with a given model - instance""" + instance. Also returns the instance because it might be altered during the + prediction, e.g. with reinforcement learning models.""" plugin = decorators.get_predict(model_type) kwargs = _kwargs_add_state_cb(plugin.state_cb_arg_name, state_cb, kwargs) args, kwargs = _args_add_instance( plugin.instance_arg_name, instance, args, kwargs ) - return plugin.function(*args, **kwargs) + return instance, plugin.function(*args, **kwargs) def exec_serialize(model_type: str, instance: Any) -> ByteString: diff --git a/aimm/server/backend/dummy.py b/aimm/server/backend/dummy.py index 2a6dfb3..06060ef 100644 --- a/aimm/server/backend/dummy.py +++ b/aimm/server/backend/dummy.py @@ -4,14 +4,15 @@ from aimm.server import common -def create(conf, _): - backend = DummyBackend() - backend._group = aio.Group() - backend._id_counter = itertools.count(1) - return backend +def create(_, __): + return DummyBackend() class DummyBackend(common.Backend): + def __init__(self): + self._group = aio.Group() + self._id_counter = itertools.count(1) + @property def async_group(self) -> aio.Group: """Async group""" diff --git a/aimm/server/backend/event.py b/aimm/server/backend/event.py index 4d4789f..eb7d06f 100644 --- a/aimm/server/backend/event.py +++ b/aimm/server/backend/event.py @@ -6,6 +6,7 @@ from aimm.server import common from aimm import plugins +from aimm.server.common import Model def create_subscription(conf): @@ -14,29 +15,33 @@ def create_subscription(conf): async def create(conf, event_client): common.json_schema_repo.validate("aimm://server/backend/event.yaml#", conf) - backend = EventBackend() - - backend._model_prefix = conf["model_prefix"] - backend._executor = aio.create_executor() - backend._cbs = util.CallbackRegistry() - backend._async_group = aio.Group() - backend._client = event_client - - models = await backend.get_models() - backend._id_counter = itertools.count( - max((model.instance_id for model in models), default=1) - ) + backend = EventBackend(conf, event_client) + await backend.start() return backend class EventBackend(common.Backend): + def __init__(self, conf, event_client): + self._model_prefix = conf["model_prefix"] + self._executor = aio.create_executor() + self._cbs = util.CallbackRegistry() + self._group = aio.Group() + self._client = event_client + self._id_counter = None + @property def async_group(self) -> aio.Group: """Async group""" - return self._async_group + return self._group + + async def start(self): + models = await self.get_models() + self._id_counter = itertools.count( + max((model.instance_id for model in models), default=1) + ) - async def get_models(self): + async def get_models(self) -> list[Model]: query_result = await self._client.query( hat.event.common.QueryLatestParams( event_types=[(*self._model_prefix, "*")] diff --git a/aimm/server/backend/sqlite.py b/aimm/server/backend/sqlite.py index d97ad6f..347ea08 100644 --- a/aimm/server/backend/sqlite.py +++ b/aimm/server/backend/sqlite.py @@ -11,28 +11,32 @@ async def create(conf, _): common.json_schema_repo.validate( "aimm://server/backend/sqlite.yaml#", conf ) - backend = SQLiteBackend() - - executor = aio.create_executor(1) - connection = await executor(_ext_db_connect, Path(conf["path"])) - connection.row_factory = sqlite3.Row - - group = aio.Group() - group.spawn(aio.call_on_cancel, executor, _ext_db_close, connection) - - backend._executor = executor - backend._connection = connection - backend._group = group - + backend = SQLiteBackend(conf) + await backend.start() return backend class SQLiteBackend(common.Backend): + def __init__(self, conf): + self._conf = conf + self._executor = aio.create_executor(1) + self._connection = None + self._group = aio.Group() + @property def async_group(self) -> aio.Group: """Async group""" return self._group + async def start(self): + self._connection = await self._executor( + _ext_db_connect, Path(self._conf["path"]) + ) + self._connection.row_factory = sqlite3.Row + self._group.spawn( + aio.call_on_cancel, self._executor, _ext_db_close, self._connection + ) + async def get_models(self): query = """SELECT * FROM models""" cursor = await self._execute(query) diff --git a/aimm/server/common.py b/aimm/server/common.py index abe2000..59bef3b 100644 --- a/aimm/server/common.py +++ b/aimm/server/common.py @@ -11,11 +11,12 @@ Collection, ) import abc -import aimm.common import hat.event.eventer.client import hat.event.common import logging +import aimm.common + mlog = logging.getLogger(__name__) json_schema_repo = aimm.common.json_schema_repo @@ -31,7 +32,7 @@ class Model(NamedTuple): """Server's representation of objects returned by - :func:`plugins.exec_instantiate`. Contains all metadata neccessary to + :func:`plugins.exec_instantiate`. Contains all metadata necessary to identify and perform other actions with it.""" instance: Any @@ -49,7 +50,7 @@ class DataAccess(NamedTuple): docstrings.""" name: str - """name of the data acces type, used to identify which plugin to use""" + """name of the data access type, used to identify which plugin to use""" args: Iterable """positional arguments to be passed to the plugin call""" kwargs: Dict[str, Any] @@ -93,12 +94,10 @@ async def update_instance(self, model: Model): """Update existing instance in the state""" @abc.abstractmethod - async def fit( - self, instance_id: int, *args: Any, **kwargs: Any - ) -> "Action": + def fit(self, instance_id: int, *args: Any, **kwargs: Any) -> "Action": """Starts an action that fits an existing model instance. The used fitting function is the one assigned to the model type. The instance, - while it is being fitted, is not accessable by any of the other + while it is being fitted, is not accessible by any of the other functions that would use it (other calls to fit, predictions, etc.). Args: @@ -111,12 +110,10 @@ async def fit( arguments""" @abc.abstractmethod - async def predict( - self, instance_id: int, *args: Any, **kwargs: Any - ) -> "Action": + def predict(self, instance_id: int, *args: Any, **kwargs: Any) -> "Action": """Starts an action that uses an existing model instance to perform a prediction. The used prediction function is the one assigned to model's - type. The instance, while prediction is called, is not accessable by + type. The instance, while prediction is called, is not accessible by any of the other functions that would use it (other calls to predict, fittings, etc.). If instance has changed while predicting, it is updated in the state and database. @@ -131,7 +128,7 @@ async def predict( arguments Returns: - Reference to task of the managable predict call, result of it is + Reference to task of the manageable predict call, result of it is the model's prediction""" @@ -157,7 +154,7 @@ def create_backend( signature""" -class Backend(aio.Resource): +class Backend(aio.Resource, abc.ABC): """Backend interface. In order to integrate in the aimm server, create a module with the implementation and function ``create`` that creates a backend instance. The function should have a signature as the @@ -172,7 +169,7 @@ class Backend(aio.Resource): @abc.abstractmethod async def get_models(self) -> List[Model]: - """Get all persisted models, requries that a deserialization function + """Get all persisted models, requires that a deserialization function is defined for all persisted types Returns: @@ -213,7 +210,7 @@ def create_control( signature""" -class Control(aio.Resource): +class Control(aio.Resource, abc.ABC): """Control interface. In order to integrate in the aimm server, create a module with the implementation and function ``create`` that creates a control instance and should have a signature as the :func:`create_control` diff --git a/aimm/server/control/event.py b/aimm/server/control/event.py index c075263..6a52220 100644 --- a/aimm/server/control/event.py +++ b/aimm/server/control/event.py @@ -21,26 +21,24 @@ async def create(conf, engine, event_client): raise ValueError( "attempting to create event control without hat compatibility" ) - - control = EventControl() - - control._client = event_client - control._engine = engine - control._async_group = aio.Group() - control._event_prefixes = conf["event_prefixes"] - control._state_event_type = conf["state_event_type"] - control._action_state_event_type = conf["action_state_event_type"] - control._executor = aio.create_executor() - control._notified_state = {} - control._in_progress = {} - - control._notify_state() - control._engine.subscribe_to_state_change(control._notify_state) - - return control + return EventControl(conf, engine, event_client) class EventControl(common.Control): + def __init__(self, conf, engine, event_client): + self._client = event_client + self._engine = engine + self._async_group = aio.Group() + self._event_prefixes = conf["event_prefixes"] + self._state_event_type = conf["state_event_type"] + self._action_state_event_type = conf["action_state_event_type"] + self._executor = aio.create_executor() + self._notified_state = {} + self._in_progress = {} + + self._notify_state() + self._engine.subscribe_to_state_change(self._notify_state) + @property def async_group(self) -> aio.Group: """Async group""" diff --git a/aimm/server/control/repl.py b/aimm/server/control/repl.py index 64017b0..2e9ef8a 100644 --- a/aimm/server/control/repl.py +++ b/aimm/server/control/repl.py @@ -15,41 +15,41 @@ async def create(conf, engine, _): common.json_schema_repo.validate("aimm://server/control/repl.yaml#", conf) - control = REPLControl() - - srv_conf = conf["server"] - server = await juggler.listen( - srv_conf["host"], - srv_conf["port"], - connection_cb=control._connection_cb, - request_cb=control._request_cb, - index_path=None, - ws_path="/", - pem_file=srv_conf.get("pem_file"), - autoflush_delay=srv_conf.get("autoflush_delay", 0.2), - shutdown_timeout=srv_conf.get("shutdown_timeout", 0.1), - ) - - async_group = aio.Group() - _bind_resource(async_group, server) - - control._conf = conf - control._engine = engine - control._async_group = async_group - control._server = server - control._connection_session_mapping = {} - + control = REPLControl(conf, engine) return control class REPLControl(common.Control): + def __init__(self, conf, engine): + self._conf = conf + self._engine = engine + self._group = aio.Group() + self._connection_session_mapping = {} + + self._group.spawn(self._run, conf["server"]) + @property def async_group(self) -> aio.Group: """Async group""" - return self._async_group + return self._group + + async def _run(self, conf): + server = await juggler.listen( + conf["host"], + conf["port"], + connection_cb=self._connection_cb, + request_cb=self._request_cb, + index_path=None, + ws_path="/", + pem_file=conf.get("pem_file"), + autoflush_delay=conf.get("autoflush_delay", 0.2), + shutdown_timeout=conf.get("shutdown_timeout", 0.1), + ) + _bind_resource(self._group, server) + await server.wait_closing() def _connection_cb(self, connection): - subgroup = self._async_group.create_subgroup() + subgroup = self._group.create_subgroup() session = Session(connection, self._engine, self._conf, subgroup) self._connection_session_mapping[connection] = session diff --git a/aimm/server/engine.py b/aimm/server/engine.py index 06ea43d..fdb0ecd 100644 --- a/aimm/server/engine.py +++ b/aimm/server/engine.py @@ -14,51 +14,45 @@ mlog = logging.getLogger(__name__) -async def create( - conf: typing.Dict, backend: common.Backend, group: aio.Group -) -> common.Engine: +async def create(conf: typing.Dict, backend: common.Backend) -> common.Engine: """Create engine Args: conf: configuration that follows schema with id ``aimm://server/schema.yaml#/definitions/engine`` backend: backend - group: async group Returns: engine """ - engine = _Engine() - - models = await backend.get_models() - - engine._group = group - engine._backend = backend - engine._conf = conf - engine._state = { - "actions": {}, - "models": {model.instance_id: model for model in models}, - } - engine._locks = { - instance_id: asyncio.Lock() for instance_id in engine._state["models"] - } - - engine._action_id_gen = itertools.count(1) - - engine._pool = mprocess.ProcessManager( - conf["max_children"], - group.create_subgroup(), - conf["check_children_period"], - conf["sigterm_timeout"], - ) - engine._callback_registry = util.CallbackRegistry() - + engine = _Engine(conf, backend) + await engine.start() return engine class _Engine(common.Engine): """Engine implementation, use :func:`create` to instantiate""" + def __init__(self, conf, backend): + self._group = aio.Group() + self._backend = backend + self._conf = conf + self._state = {"actions": {}, "models": {}} + self._locks = { + instance_id: asyncio.Lock() + for instance_id in self._state["models"] + } + + self._action_id_gen = itertools.count(1) + + self._pool = mprocess.ProcessManager( + conf["max_children"], + self._group.create_subgroup(), + conf["check_children_period"], + conf["sigterm_timeout"], + ) + self._callback_registry = util.CallbackRegistry() + @property def async_group(self): return self._group @@ -67,13 +61,20 @@ def async_group(self): def state(self): return self._state + async def start(self): + models = await self._backend.get_models() + self._state = { + "actions": {}, + "models": {model.instance_id: model for model in models}, + } + def subscribe_to_state_change(self, cb): return self._callback_registry.register(cb) def create_instance(self, model_type, *args, **kwargs): action_id = next(self._action_id_gen) state_cb = partial(self._update_action, action_id) - return _Action( + return create_action( self._group.create_subgroup(), self._act_create_instance, model_type, @@ -95,7 +96,7 @@ async def update_instance(self, model: common.Model): def fit(self, instance_id, *args, **kwargs): action_id = next(self._action_id_gen) state_cb = partial(self._update_action, action_id) - return _Action( + return create_action( self._group.create_subgroup(), self._act_fit, instance_id, @@ -107,7 +108,7 @@ def fit(self, instance_id, *args, **kwargs): def predict(self, instance_id, *args, **kwargs): action_id = next(self._action_id_gen) state_cb = partial(self._update_action, action_id) - return _Action( + return create_action( self._group.create_subgroup(), self._act_predict, instance_id, @@ -203,13 +204,8 @@ async def _act_fit(self, instance_id, args, kwargs, state_cb): *args, **kwargs ) - new_model = model._replace(instance=instance) - - reactive.update(dict(reactive.state, progress="storing")) - await self._backend.update_model(new_model) - + new_model = await self._update_model(instance, model, reactive) reactive.update(dict(reactive.state, progress="complete")) - self._set_model(new_model) return new_model async def _act_predict(self, instance_id, args, kwargs, state_cb): @@ -235,7 +231,8 @@ async def _act_predict(self, instance_id, args, kwargs, state_cb): ) async with self._locks[instance_id]: model = self.state["models"][instance_id] - prediction = await handler.run( + reactive.update(dict(reactive.state, progress="executing")) + instance, prediction = await handler.run( plugins.exec_predict, model.model_type, model.instance, @@ -243,9 +240,28 @@ async def _act_predict(self, instance_id, args, kwargs, state_cb): *args, **kwargs ) + await self._update_model(instance, model, reactive) reactive.update(dict(reactive.state, progress="complete")) return prediction + async def _update_model(self, instance, model, reactive): + new_model = common.Model( + instance=instance, + model_type=model.model_type, + instance_id=model.instance_id, + ) + reactive.update(dict(reactive.state, progress="storing")) + await self._backend.update_model(new_model) + + self._set_model(new_model) + return new_model + + +def create_action( + async_group: aio.Group, fn: typing.Callable, *args, **kwargs +) -> common.Action: + return _Action(async_group, fn, *args, **kwargs) + class _Action(common.Action): def __init__(self, async_group, fn, *args, **kwargs): diff --git a/aimm/server/mprocess.py b/aimm/server/mprocess.py index dfe5cfa..a2b040f 100644 --- a/aimm/server/mprocess.py +++ b/aimm/server/mprocess.py @@ -152,16 +152,13 @@ async def run(self, fn: Callable, *args: Any, **kwargs: Any): self._process.start() async def wait_result(): - try: - result = await self._executor( - _ext_closeable_recv, self._result_pipe - ) - if result.success: - return result.result - else: - raise result.exception - except _ProcessTerminatedException: - raise Exception("process terminated") + result = await self._executor( + _ext_closeable_recv, self._result_pipe + ) + if result.success: + return result.result + else: + raise result.exception return await aio.uncancellable(wait_result()) finally: @@ -186,32 +183,32 @@ async def _cleanup(self): await self._executor(_ext_close_pipe, self._state_pipe) +@contextlib.contextmanager +def sigterm_override(): + try: + signal.signal(signal.SIGTERM, _plugin_sigterm_handler) + yield + finally: + signal.signal(signal.SIGTERM, signal.SIG_DFL) + + class _Result(NamedTuple): success: bool result: Optional[Any] = None exception: Optional[Exception] = None -class _ProcessTerminatedException(Exception): +class ProcessTerminatedException(Exception): pass -def _plugin_sigterm_handler(frame, signum): - raise Exception("process terminated") - - -@contextlib.contextmanager -def _sigterm_override(): - try: - signal.signal(signal.SIGTERM, _plugin_sigterm_handler) - yield - finally: - signal.signal(signal.SIGTERM, signal.SIG_DFL) +def _plugin_sigterm_handler(_, __): + raise ProcessTerminatedException("process sigterm") def _proc_run_fn(pipe, fn, *args, **kwargs): try: - with _sigterm_override(): + with sigterm_override(): result = _Result(success=True, result=fn(*args, **kwargs)) except Exception as e: result = _Result(success=False, exception=e) @@ -242,5 +239,5 @@ def _ext_closeable_recv(pipe): recv_conn, _ = pipe value = recv_conn.recv() if value == _PipeSentinel.CLOSE: - raise _ProcessTerminatedException("pipe closed") + raise ProcessTerminatedException("pipe closed") return value diff --git a/aimm/server/runners.py b/aimm/server/runners.py index 2f0623b..2a0fb30 100644 --- a/aimm/server/runners.py +++ b/aimm/server/runners.py @@ -197,7 +197,7 @@ async def _create_resources(self): yield self._backend self._engine = await aimm.server.engine.create( - self._conf["engine"], self._backend, self._group.create_subgroup() + self._conf["engine"], self._backend ) yield self._engine diff --git a/examples/0002/.gitignore b/examples/0002/.gitignore index af0da52..7b65fef 100644 --- a/examples/0002/.gitignore +++ b/examples/0002/.gitignore @@ -1,5 +1,4 @@ /venv /data -/data /src_js /view/login/index.js diff --git a/poetry.lock b/poetry.lock index f4eaa32..78e1ed6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2367,6 +2367,21 @@ lint = ["mypy", "ruff (==0.5.5)", "types-docutils"] standalone = ["Sphinx (>=5)"] test = ["pytest"] +[[package]] +name = "tenacity" +version = "9.0.0" +description = "Retry code until it succeeds" +optional = false +python-versions = ">=3.8" +files = [ + {file = "tenacity-9.0.0-py3-none-any.whl", hash = "sha256:93de0c98785b27fcf659856aa9f54bfbd399e29969b0621bc7f762bd441b4539"}, + {file = "tenacity-9.0.0.tar.gz", hash = "sha256:807f37ca97d62aa361264d497b0e31e92b8027044942bfa756160d908320d73b"}, +] + +[package.extras] +doc = ["reno", "sphinx"] +test = ["pytest", "tornado (>=4.5)", "typeguard"] + [[package]] name = "tomli" version = "2.0.1" @@ -2566,4 +2581,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.12" -content-hash = "9e0310872782ddf836eeb44dc6f912ce2a7d075f8e50a128feae4b470841371e" +content-hash = "7d2b9c4a920aa6a72ea8b75fce392bd6dc0cb24317bdb1dfbaed0cf522f11e37" diff --git a/pyproject.toml b/pyproject.toml index 321b899..f5216e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ hat-json = "^0.5.28" hat-monitor = "^0.8.11" hat-event = "^0.9.20" psutil = "^6.0.0" +tenacity = "^9.0.0" [tool.poetry.group.dev.dependencies] @@ -60,3 +61,4 @@ ignore = "E203" [tool.pytest.ini_options] asyncio_mode = "auto" +asyncio_default_fixture_loop_scope="function" diff --git a/test/conftest.py b/test/conftest.py index 1cec5f6..568f964 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -4,7 +4,7 @@ from aimm import plugins -def pytest_configure(config): +def pytest_configure(): aio.init_asyncio() diff --git a/test/test_sys/plugins/basic.py b/test/test_sys/plugins/basic.py index b22896e..2ad965a 100644 --- a/test/test_sys/plugins/basic.py +++ b/test/test_sys/plugins/basic.py @@ -7,7 +7,7 @@ @plugins.model class Model1(plugins.Model): def __init__(self, *args, **kwargs): - self._state = {"init": {"args": args, "kwargs": kwargs}, "fit": None} + self._state = {"init": {"args": args, "kwargs": kwargs}} def fit(self, *args, **kwargs): self._state["fit"] = {"args": args, "kwargs": kwargs} diff --git a/test/test_sys/test_event.py b/test/test_sys/test_event.py index ca40525..96a4123 100644 --- a/test/test_sys/test_event.py +++ b/test/test_sys/test_event.py @@ -227,7 +227,6 @@ async def _create_instance(client, model_type, events_queue): events = await events_queue.get() event = events[0] - payload = event.payload.data assert event.type == ("aimm", "model", "1") events = await events_queue.get() @@ -318,6 +317,14 @@ async def test_predict(aimm_server_proc, event_client_factory): assert payload["status"] == "IN_PROGRESS" assert payload["result"] is None + events = await events_queue.get() + assert len(events) == 1 + event = events[0] + assert event.type == ("aimm", "model", str(model_id)) + assert event.source_timestamp is None + assert event.payload.data["type"] == model_type + assert event.payload.data["instance"] is not None + events = await events_queue.get() assert len(events) == 1 event = events[0] @@ -336,7 +343,7 @@ async def test_cancel(aimm_server_proc, event_client_factory): [("aimm", "action_state"), ("aimm", "model", "*")] ) as (client, events_queue): model_id = await _create_instance(client, model_type, events_queue) - request = await client.register( + await client.register( [ _register_event( ("predict", str(model_id)), @@ -349,7 +356,6 @@ async def test_cancel(aimm_server_proc, event_client_factory): ], with_response=True, ) - request = request[0] events = await events_queue.get() assert len(events) == 1 diff --git a/test/test_unit/test_plugins.py b/test/test_unit/test_plugins.py index ef4ac01..81877a4 100644 --- a/test/test_unit/test_plugins.py +++ b/test/test_unit/test_plugins.py @@ -1,12 +1,16 @@ from aimm import plugins +def dummy_state_cb(state): + print("state:", state) + + def test_instantiate(plugin_teardown): @plugins.instantiate("test", state_cb_arg_name="state_cb") def instantiate(state_cb): return state_cb - assert plugins.exec_instantiate("test", "state_cb") == "state_cb" + assert plugins.exec_instantiate("test", dummy_state_cb) == dummy_state_cb def test_data_access(plugin_teardown): @@ -14,7 +18,7 @@ def test_data_access(plugin_teardown): def data_access(state_cb): return state_cb - assert plugins.exec_data_access("test", "state_cb") == "state_cb" + assert plugins.exec_data_access("test", dummy_state_cb) == dummy_state_cb def test_fit(plugin_teardown): @@ -22,10 +26,10 @@ def test_fit(plugin_teardown): ["test"], state_cb_arg_name="state_cb", instance_arg_name="instance" ) def fit(state_cb, instance): - return (state_cb, instance) + return state_cb, instance - result = plugins.exec_fit("test", "instance", "state_cb") - assert result == ("state_cb", "instance") + result = plugins.exec_fit("test", "instance", dummy_state_cb) + assert result == (dummy_state_cb, "instance") def test_predict(plugin_teardown): @@ -33,11 +37,11 @@ def test_predict(plugin_teardown): ["test"], state_cb_arg_name="state_cb", instance_arg_name="instance" ) def predict(state_cb, instance): - return (state_cb, instance) + return state_cb, instance - assert plugins.exec_predict("test", "instance", "state_cb") == ( - "state_cb", + assert plugins.exec_predict("test", "instance", dummy_state_cb) == ( "instance", + (dummy_state_cb, "instance"), ) @@ -51,12 +55,12 @@ def serialize(instance): def test_deserialize(plugin_teardown): @plugins.deserialize(["test"]) - def deserialize(instance_bytes): - return instance_bytes + def deserialize(i_bytes): + return i_bytes - assert ( - plugins.exec_deserialize("test", "instance_bytes") == "instance_bytes" - ) + instance_bytes = "instance bytes".encode("utf-8") + + assert plugins.exec_deserialize("test", instance_bytes) == instance_bytes def test_model(plugin_teardown): @@ -65,6 +69,8 @@ class Model1(plugins.Model): def __init__(self, *args, **kwargs): self.args = args self.kwargs = kwargs + self.fit_args = [] + self.fit_kwargs = {} def fit(self, *args, **kwargs): self.fit_args = args @@ -78,17 +84,17 @@ def serialize(self): return bytes() @classmethod - def deserialize(cls): + def deserialize(cls, _): return Model1() model_type = "test_plugins.Model1" model = plugins.exec_instantiate( - model_type, None, "a1", "a2", k1="1", k2="2" + model_type, dummy_state_cb, "a1", "a2", k1="1", k2="2" ) assert model.args == ("a1", "a2") assert model.kwargs == {"k1": "1", "k2": "2"} - plugins.exec_fit(model_type, model, None, "fit_a1", fit_k1="1") + plugins.exec_fit(model_type, model, dummy_state_cb, "fit_a1", fit_k1="1") assert model.fit_args == ("fit_a1",) assert model.fit_kwargs == {"fit_k1": "1"} diff --git a/test/test_unit/test_server/test_backend/test_event.py b/test/test_unit/test_server/test_backend/test_event.py index d0aa6dd..8bc03b2 100644 --- a/test/test_unit/test_server/test_backend/test_event.py +++ b/test/test_unit/test_server/test_backend/test_event.py @@ -23,7 +23,7 @@ async def query(self, query_data): self._query_queue.put_nowait(query_data) return self._query_result - async def register(self, events, with_response=False): + async def register(self, events, _=False): self._register_queue.put_nowait(events) return events diff --git a/test/test_unit/test_server/test_backend/test_sqlite.py b/test/test_unit/test_server/test_backend/test_sqlite.py index 53a6733..f20c665 100644 --- a/test/test_unit/test_server/test_backend/test_sqlite.py +++ b/test/test_unit/test_server/test_backend/test_sqlite.py @@ -13,9 +13,11 @@ async def backend(tmp_path): async def test_create(tmp_path): - backend = await sqlite.create({"path": str(tmp_path / "backend.db")}, None) - assert backend - await backend.async_close() + backend_object = await sqlite.create( + {"path": str(tmp_path / "backend.db")}, None + ) + assert backend_object + await backend_object.async_close() async def test_models(backend, plugin_teardown): diff --git a/test/test_unit/test_server/test_control/test_event.py b/test/test_unit/test_server/test_control/test_event.py index 424296e..09c4d4b 100644 --- a/test/test_unit/test_server/test_control/test_event.py +++ b/test/test_unit/test_server/test_control/test_event.py @@ -15,20 +15,22 @@ def __init__(self): self._register_queue = aio.Queue() self._receive_queue = aio.Queue() - async def register(self, events, with_response=False): + async def register(self, events, _=False): self._register_queue.put_nowait(events) class MockEngine(common.Engine): def __init__( self, - state={"models": {}, "actions": {}}, + state=None, create_instance_cb=None, add_instance_cb=None, update_instance_cb=None, fit_cb=None, predict_cb=None, ): + if state is None: + state = {"models": {}, "actions": {}} self._state = state self._cb = None self._create_instance_cb = create_instance_cb @@ -57,7 +59,7 @@ def subscribe_to_state_change(self, cb): def create_instance(self, *args, **kwargs): if self._create_instance_cb: - return aimm.server.engine._Action( + return aimm.server.engine.create_action( self._group.create_subgroup(), aio.call, self._create_instance_cb, @@ -78,7 +80,7 @@ async def update_instance(self, *args, **kwargs): def fit(self, *args, **kwargs): if self._fit_cb: - return aimm.server.engine._Action( + return aimm.server.engine.create_action( self._group.create_subgroup(), aio.call, self._fit_cb, @@ -89,7 +91,7 @@ def fit(self, *args, **kwargs): def predict(self, *args, **kwargs): if self._predict_cb: - return aimm.server.engine._Action( + return aimm.server.engine.create_action( self._group.create_subgroup(), aio.call, self._predict_cb, @@ -136,13 +138,13 @@ async def test_state(): async def test_create_instance(): create_queue = aio.Queue() - async def create_instance_cb(model_type, *args, **kwargs): + async def create_instance_cb(model_type, *c_args, **c_kwargs): complete_future = asyncio.Future() create_queue.put_nowait( { "model_type": model_type, - "args": args, - "kwargs": kwargs, + "args": c_args, + "kwargs": c_kwargs, "complete_future": complete_future, } ) @@ -218,7 +220,7 @@ async def add_instance_cb(model_type, instance): client = MockClient() engine = MockEngine(add_instance_cb=add_instance_cb) control = await aimm.server.control.event.create(conf(), engine, client) - events = await client._register_queue.get() # state + await client._register_queue.get() # state req_event = _event( ("add_instance",), @@ -270,7 +272,7 @@ async def update_instance_cb(model): client = MockClient() engine = MockEngine(update_instance_cb=update_instance_cb) control = await aimm.server.control.event.create(conf(), engine, client) - events = await client._register_queue.get() # state + await client._register_queue.get() # state req_event = _event( ("update_instance", "10"), @@ -320,12 +322,12 @@ async def fit_cb(model_id, *args, **kwargs): client = MockClient() engine = MockEngine( - {"models": {11: common.Model("M", None, 11)}, "actions": {}}, + {"models": {11: common.Model("M", "test", 11)}, "actions": {}}, fit_cb=fit_cb, ) control = await aimm.server.control.event.create(conf(), engine, client) - events = await client._register_queue.get() # state + await client._register_queue.get() # state req_event = _event( ("fit", "11"), @@ -384,12 +386,12 @@ async def predict_cb(model_id, *args, **kwargs): client = MockClient() engine = MockEngine( - {"models": {12: common.Model("M", None, 12)}, "actions": {}}, + {"models": {12: common.Model("M", "test", 12)}, "actions": {}}, predict_cb=predict_cb, ) control = await aimm.server.control.event.create(conf(), engine, client) - events = await client._register_queue.get() # state + await client._register_queue.get() # state req_event = _event( ("predict", "12"), @@ -434,19 +436,19 @@ async def predict_cb(model_id, *args, **kwargs): async def test_cancel(): future_queue = aio.Queue() - async def predict_cb(model_id, *args, **kwargs): + async def predict_cb(_, *__, **___): done_future = asyncio.Future() future_queue.put_nowait(done_future) return await done_future client = MockClient() engine = MockEngine( - {"models": {12: common.Model("M", None, 12)}, "actions": {}}, + {"models": {12: common.Model("M", "test", 12)}, "actions": {}}, predict_cb=predict_cb, ) control = await aimm.server.control.event.create(conf(), engine, client) - events = await client._register_queue.get() # state + await client._register_queue.get() # state req_event = _event( ("predict", "12"), {"args": [], "kwargs": {}, "request_id": "1"} diff --git a/test/test_unit/test_server/test_control/test_repl.py b/test/test_unit/test_server/test_control/test_repl.py index 01cbce0..0d98e1c 100644 --- a/test/test_unit/test_server/test_control/test_repl.py +++ b/test/test_unit/test_server/test_control/test_repl.py @@ -14,13 +14,15 @@ class MockEngine(common.Engine): def __init__( self, - state={"models": {}, "actions": {}}, + state=None, create_instance_cb=None, add_instance_cb=None, update_instance_cb=None, fit_cb=None, predict_cb=None, ): + if state is None: + state = {"models": {}, "actions": {}} self._state = state self._cb = lambda: None self._create_instance_cb = create_instance_cb diff --git a/test/test_unit/test_server/test_engine.py b/test/test_unit/test_server/test_engine.py index aaa727f..06c9ba3 100644 --- a/test/test_unit/test_server/test_engine.py +++ b/test/test_unit/test_server/test_engine.py @@ -43,9 +43,8 @@ async def update_model(self, model): self._queue.put_nowait(("update", model)) -async def create_engine(backend=None, group=None): +async def create_engine(backend=None): backend = backend or MockBackend() - group = group or aio.Group() return await engine.create( { "sigterm_timeout": 1, @@ -53,7 +52,6 @@ async def create_engine(backend=None, group=None): "check_children_period": 0.2, }, backend, - group, ) @@ -82,8 +80,8 @@ async def test_create_instance(plugin_teardown): eng.subscribe_to_state_change(lambda: state_queue.put_nowait(eng.state)) @plugins.instantiate("test") - def create(*args, **kwargs): - return "test", args, kwargs + def create(*c_args, **c_kwargs): + return "test", c_args, c_kwargs args = (1, 2, 3) kwargs = {"p1": 4, "p2": 5} @@ -101,7 +99,7 @@ def create(*args, **kwargs): break assert state["models"][model_id] == expected_model assert await action.wait_result() == expected_model - await backend.queue.get() == ("create", expected_model) + assert await backend.queue.get() == ("create", expected_model) await eng.async_close() @@ -133,7 +131,7 @@ def state_change_cb(): }, } ] - await backend.queue.get() == ( + assert await backend.queue.get() == ( "create", common.Model( instance=None, model_type="test", instance_id=instance_id @@ -159,39 +157,29 @@ def state_change_cb(): await backend.queue.get() @plugins.fit(["test"]) - def fit(*args, **kwargs): - return ("instance_fitted", args, kwargs) + def fit(*f_args, **f_kwargs): + return "instance_fitted", f_args, f_kwargs args = (1, 2) kwargs = {"p1": 3, "p2": 4} action = eng.fit(1, *args, **kwargs) - expected_instance = common.Model( + expected_model = common.Model( instance=("instance_fitted", ("instance", *args), kwargs), model_type="test", instance_id=1, ) - await queue.get() == { - "models": { - 1: common.Model( - instance=expected_instance, model_type="test", instance_id=1 - ) - } - } model = await action.wait_result() - assert model == expected_instance - await backend.queue.get() == ( - "update", - common.Model( - instance=expected_instance, model_type="test", instance_id=1 - ), - ) + assert model == expected_model + + assert queue.get_nowait_until_empty()["models"] == {1: expected_model} + assert await backend.queue.get() == ("update", expected_model) # allow model lock to release await eng.async_close() -@pytest.mark.timeout(2) +# @pytest.mark.timeout(2) async def test_predict(plugin_teardown): backend = MockBackend() eng = await create_engine(backend) @@ -208,23 +196,25 @@ def state_change_cb(): await backend.queue.get() @plugins.predict(["test"]) - def predict(instance, *args, **kwargs): + def predict(instance, *p_args, **p_kwargs): instance.append(1) - return (instance, args, kwargs) + return instance, p_args, p_kwargs args = (1, 2) kwargs = {"p1": 3, "p2": 4} action = eng.predict(1, *args, **kwargs) - expected_instance = (["instance", 1], args, kwargs) - await queue.get() == { - "models": { - 1: common.Model( - instance=expected_instance, model_type="test", instance_id=1 - ) - } - } + expected_result = (["instance", 1], args, kwargs) result = await action.wait_result() - assert result == expected_instance + assert result == expected_result + + expected_model = common.Model( + instance=["instance", 1], model_type="test", instance_id=1 + ) + + state = queue.get_nowait_until_empty() + assert state["models"] == {1: expected_model} + predict_backend_update = await backend.queue.get() + assert predict_backend_update == ("update", expected_model) await eng.async_close() diff --git a/test/test_unit/test_server/test_mprocess.py b/test/test_unit/test_server/test_mprocess.py index b965b2a..3bef8d4 100644 --- a/test/test_unit/test_server/test_mprocess.py +++ b/test/test_unit/test_server/test_mprocess.py @@ -7,11 +7,12 @@ import time from aimm.server import mprocess +from aimm.server.mprocess import ProcessTerminatedException @pytest.fixture def disable_sigterm_handler(monkeypatch): - default = mprocess._sigterm_override + default = mprocess.sigterm_override @contextlib.contextmanager def handler_patch(): @@ -20,15 +21,15 @@ def handler_patch(): yield with monkeypatch.context() as ctx: - ctx.setattr(mprocess, "_sigterm_override", handler_patch) + ctx.setattr(mprocess, "sigterm_override", handler_patch) yield @pytest.mark.timeout(2) @pytest.mark.parametrize("action_count", [1, 2, 10]) async def test_process_regular(action_count, disable_sigterm_handler): - def fn(*args, **kwargs): - return args, kwargs + def fn(*f_args, **f_kwargs): + return f_args, f_kwargs args = ("arg1", "arg2") kwargs = {"k1": "v1", "k2": "v2"} @@ -69,17 +70,14 @@ async def test_process_sigterm(disable_sigterm_handler): def fn(): time.sleep(10) - queue = aio.Queue() - - def state_cb(state): - queue.put_nowait(state) - pa_pool = mprocess.ProcessManager(1, aio.Group(), 0.1, 5) process_action = pa_pool.create_handler(lambda _: None) async with aio.Group() as group: async def _run(): - with pytest.raises(Exception, match="process terminated"): + with pytest.raises( + ProcessTerminatedException, match="process sigterm" + ): await process_action.run(fn) task = group.spawn(_run) @@ -96,20 +94,16 @@ async def test_process_sigkill(): def fn(): try: time.sleep(10) - except Exception: + except ProcessTerminatedException: time.sleep(10) - - queue = aio.Queue() - - def state_cb(state): - queue.put_nowait(state) + raise Exception("unexpected exception") pa_pool = mprocess.ProcessManager(1, aio.Group(), 0.1, 0.2) process_action = pa_pool.create_handler(lambda _: None) async with aio.Group() as group: async def _run(): - with pytest.raises(Exception, match="process terminated"): + with pytest.raises(ProcessTerminatedException): await process_action.run(fn) task = group.spawn(_run)