Skip to content

Commit

Permalink
warning corrections, constructor reorg
Browse files Browse the repository at this point in the history
  • Loading branch information
zlatsic committed Sep 19, 2024
1 parent c834eb0 commit 8ef3bee
Show file tree
Hide file tree
Showing 26 changed files with 321 additions and 267 deletions.
12 changes: 10 additions & 2 deletions aimm/client/repl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
specified by the REPL control."""

from getpass import getpass

from hat import aio
from hat import juggler
from tenacity import AsyncRetrying, stop_after_attempt, wait_fixed
import base64
import hashlib
import numpy
Expand Down Expand Up @@ -41,7 +43,7 @@ def state(self) -> "JSON":
"""Current state reported from the AIMM server"""
return self._state

async def connect(self, address: str, autoflush_delay: float = 0.2):
async def connect(self, address: str):
"""Connects to the specified remote address. Login data is received
from a user prompt. Passwords are hashed with SHA-256 before sending
login request."""
Expand All @@ -51,7 +53,13 @@ async def connect(self, address: str, autoflush_delay: float = 0.2):
password_hash = hashlib.sha256()
password_hash.update(getpass("Password: ").encode("utf-8"))

connection = await juggler.connect(address)
connection = None
async for attempt in AsyncRetrying(
wait=wait_fixed(1), stop=stop_after_attempt(3)
):
with attempt:
connection = await juggler.connect(address)

await connection.send(
"login",
{"username": username, "password": password_hash.hexdigest()},
Expand Down
1 change: 0 additions & 1 deletion aimm/plugins/common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import Any, ByteString, Callable, Dict, NamedTuple, Optional
from aimm.common import * # NOQA
import abc
import importlib
import logging
Expand Down
28 changes: 19 additions & 9 deletions aimm/plugins/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,6 @@ def serialize(model_types: List[str]) -> Callable:
Args:
model_types: types of models supported by the decorated function
instance_arg_name: if set, indicates under which argument name to pass
the concrete model instance. If not set, it is passed in the first
positional argument
Returns:
Decorated function"""
Expand Down Expand Up @@ -231,13 +228,26 @@ def model(cls: Type) -> Type:
"""

model_type = f"{cls.__module__}.{cls.__name__}"

_declare("instantiate", model_type, common.InstantiatePlugin(cls))
_declare("fit", model_type, common.FitPlugin(cls.fit))
_declare("predict", model_type, common.PredictPlugin(cls.predict))
_declare("serialize", model_type, common.SerializePlugin(cls.serialize))
_declare(
"deserialize", model_type, common.DeserializePlugin(cls.deserialize)
)

fit_fn = getattr(cls, "fit")
if isinstance(fit_fn, Callable):
_declare("fit", model_type, common.FitPlugin(fit_fn))

predict_fn = getattr(cls, "predict")
if isinstance(predict_fn, Callable):
_declare("predict", model_type, common.PredictPlugin(predict_fn))

serialize_fn = getattr(cls, "serialize")
if isinstance(serialize_fn, Callable):
_declare("serialize", model_type, common.SerializePlugin(serialize_fn))

deserialize_fn = getattr(cls, "deserialize")
if isinstance(deserialize_fn, Callable):
_declare(
"deserialize", model_type, common.DeserializePlugin(deserialize_fn)
)

return cls

Expand Down
7 changes: 4 additions & 3 deletions aimm/plugins/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,16 @@ def exec_predict(
state_cb: common.StateCallback = lambda state: None,
*args: Any,
**kwargs: Any
) -> Any:
) -> tuple[Any, Any]:
"""Uses a loaded plugin to perform a prediction with a given model
instance"""
instance. Also returns the instance because it might be altered during the
prediction, e.g. with reinforcement learning models."""
plugin = decorators.get_predict(model_type)
kwargs = _kwargs_add_state_cb(plugin.state_cb_arg_name, state_cb, kwargs)
args, kwargs = _args_add_instance(
plugin.instance_arg_name, instance, args, kwargs
)
return plugin.function(*args, **kwargs)
return instance, plugin.function(*args, **kwargs)


def exec_serialize(model_type: str, instance: Any) -> ByteString:
Expand Down
11 changes: 6 additions & 5 deletions aimm/server/backend/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
from aimm.server import common


def create(conf, _):
backend = DummyBackend()
backend._group = aio.Group()
backend._id_counter = itertools.count(1)
return backend
def create(_, __):
return DummyBackend()


class DummyBackend(common.Backend):
def __init__(self):
self._group = aio.Group()
self._id_counter = itertools.count(1)

@property
def async_group(self) -> aio.Group:
"""Async group"""
Expand Down
33 changes: 19 additions & 14 deletions aimm/server/backend/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from aimm.server import common
from aimm import plugins
from aimm.server.common import Model


def create_subscription(conf):
Expand All @@ -14,29 +15,33 @@ def create_subscription(conf):

async def create(conf, event_client):
common.json_schema_repo.validate("aimm://server/backend/event.yaml#", conf)
backend = EventBackend()

backend._model_prefix = conf["model_prefix"]
backend._executor = aio.create_executor()
backend._cbs = util.CallbackRegistry()
backend._async_group = aio.Group()
backend._client = event_client

models = await backend.get_models()
backend._id_counter = itertools.count(
max((model.instance_id for model in models), default=1)
)
backend = EventBackend(conf, event_client)
await backend.start()

return backend


class EventBackend(common.Backend):
def __init__(self, conf, event_client):
self._model_prefix = conf["model_prefix"]
self._executor = aio.create_executor()
self._cbs = util.CallbackRegistry()
self._group = aio.Group()
self._client = event_client
self._id_counter = None

@property
def async_group(self) -> aio.Group:
"""Async group"""
return self._async_group
return self._group

async def start(self):
models = await self.get_models()
self._id_counter = itertools.count(
max((model.instance_id for model in models), default=1)
)

async def get_models(self):
async def get_models(self) -> list[Model]:
query_result = await self._client.query(
hat.event.common.QueryLatestParams(
event_types=[(*self._model_prefix, "*")]
Expand Down
30 changes: 17 additions & 13 deletions aimm/server/backend/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,32 @@ async def create(conf, _):
common.json_schema_repo.validate(
"aimm://server/backend/sqlite.yaml#", conf
)
backend = SQLiteBackend()

executor = aio.create_executor(1)
connection = await executor(_ext_db_connect, Path(conf["path"]))
connection.row_factory = sqlite3.Row

group = aio.Group()
group.spawn(aio.call_on_cancel, executor, _ext_db_close, connection)

backend._executor = executor
backend._connection = connection
backend._group = group

backend = SQLiteBackend(conf)
await backend.start()
return backend


class SQLiteBackend(common.Backend):
def __init__(self, conf):
self._conf = conf
self._executor = aio.create_executor(1)
self._connection = None
self._group = aio.Group()

@property
def async_group(self) -> aio.Group:
"""Async group"""
return self._group

async def start(self):
self._connection = await self._executor(
_ext_db_connect, Path(self._conf["path"])
)
self._connection.row_factory = sqlite3.Row
self._group.spawn(
aio.call_on_cancel, self._executor, _ext_db_close, self._connection
)

async def get_models(self):
query = """SELECT * FROM models"""
cursor = await self._execute(query)
Expand Down
27 changes: 12 additions & 15 deletions aimm/server/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
Collection,
)
import abc
import aimm.common
import hat.event.eventer.client
import hat.event.common
import logging

import aimm.common

mlog = logging.getLogger(__name__)

json_schema_repo = aimm.common.json_schema_repo
Expand All @@ -31,7 +32,7 @@

class Model(NamedTuple):
"""Server's representation of objects returned by
:func:`plugins.exec_instantiate`. Contains all metadata neccessary to
:func:`plugins.exec_instantiate`. Contains all metadata necessary to
identify and perform other actions with it."""

instance: Any
Expand All @@ -49,7 +50,7 @@ class DataAccess(NamedTuple):
docstrings."""

name: str
"""name of the data acces type, used to identify which plugin to use"""
"""name of the data access type, used to identify which plugin to use"""
args: Iterable
"""positional arguments to be passed to the plugin call"""
kwargs: Dict[str, Any]
Expand Down Expand Up @@ -93,12 +94,10 @@ async def update_instance(self, model: Model):
"""Update existing instance in the state"""

@abc.abstractmethod
async def fit(
self, instance_id: int, *args: Any, **kwargs: Any
) -> "Action":
def fit(self, instance_id: int, *args: Any, **kwargs: Any) -> "Action":
"""Starts an action that fits an existing model instance. The used
fitting function is the one assigned to the model type. The instance,
while it is being fitted, is not accessable by any of the other
while it is being fitted, is not accessible by any of the other
functions that would use it (other calls to fit, predictions, etc.).
Args:
Expand All @@ -111,12 +110,10 @@ async def fit(
arguments"""

@abc.abstractmethod
async def predict(
self, instance_id: int, *args: Any, **kwargs: Any
) -> "Action":
def predict(self, instance_id: int, *args: Any, **kwargs: Any) -> "Action":
"""Starts an action that uses an existing model instance to perform a
prediction. The used prediction function is the one assigned to model's
type. The instance, while prediction is called, is not accessable by
type. The instance, while prediction is called, is not accessible by
any of the other functions that would use it (other calls to predict,
fittings, etc.). If instance has changed while predicting, it is
updated in the state and database.
Expand All @@ -131,7 +128,7 @@ async def predict(
arguments
Returns:
Reference to task of the managable predict call, result of it is
Reference to task of the manageable predict call, result of it is
the model's prediction"""


Expand All @@ -157,7 +154,7 @@ def create_backend(
signature"""


class Backend(aio.Resource):
class Backend(aio.Resource, abc.ABC):
"""Backend interface. In order to integrate in the aimm server, create a
module with the implementation and function ``create`` that creates a
backend instance. The function should have a signature as the
Expand All @@ -172,7 +169,7 @@ class Backend(aio.Resource):

@abc.abstractmethod
async def get_models(self) -> List[Model]:
"""Get all persisted models, requries that a deserialization function
"""Get all persisted models, requires that a deserialization function
is defined for all persisted types
Returns:
Expand Down Expand Up @@ -213,7 +210,7 @@ def create_control(
signature"""


class Control(aio.Resource):
class Control(aio.Resource, abc.ABC):
"""Control interface. In order to integrate in the aimm server, create a
module with the implementation and function ``create`` that creates a
control instance and should have a signature as the :func:`create_control`
Expand Down
32 changes: 15 additions & 17 deletions aimm/server/control/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,24 @@ async def create(conf, engine, event_client):
raise ValueError(
"attempting to create event control without hat compatibility"
)

control = EventControl()

control._client = event_client
control._engine = engine
control._async_group = aio.Group()
control._event_prefixes = conf["event_prefixes"]
control._state_event_type = conf["state_event_type"]
control._action_state_event_type = conf["action_state_event_type"]
control._executor = aio.create_executor()
control._notified_state = {}
control._in_progress = {}

control._notify_state()
control._engine.subscribe_to_state_change(control._notify_state)

return control
return EventControl(conf, engine, event_client)


class EventControl(common.Control):
def __init__(self, conf, engine, event_client):
self._client = event_client
self._engine = engine
self._async_group = aio.Group()
self._event_prefixes = conf["event_prefixes"]
self._state_event_type = conf["state_event_type"]
self._action_state_event_type = conf["action_state_event_type"]
self._executor = aio.create_executor()
self._notified_state = {}
self._in_progress = {}

self._notify_state()
self._engine.subscribe_to_state_change(self._notify_state)

@property
def async_group(self) -> aio.Group:
"""Async group"""
Expand Down
Loading

0 comments on commit 8ef3bee

Please sign in to comment.