Skip to content

Commit

Permalink
Merge pull request #129 from ldeflandre/support-inheritance-in-model-…
Browse files Browse the repository at this point in the history
…definition

Intuitive abstract support by stopping support of no CONFIGURATIONS
  • Loading branch information
ldeflandre authored Jan 6, 2022
2 parents c23e66d + 5c8d720 commit 2287c4a
Show file tree
Hide file tree
Showing 9 changed files with 24 additions and 48 deletions.
8 changes: 4 additions & 4 deletions docs/library/models/organizing.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,16 @@ service = ModelLibrary(

### Abstract models

It is possible to define models that inherits from an abstract model in order to share common behavior.
It is possible to define models that inherits from an abstract model in order to share common behavior. It only requires to not set CONFIGURATIONS dict for those models to be ignored from the configuration steps.

For instance, it can be usefull to implement common prediction algorithm on different data assets
For instance, it can be usefull to implement common prediction algorithm on different data assets.

```python
class BaseModel(AbstractMixin, Model):
class BaseModel(Model):
def _predict(self, item, **kwargs):
...

class DerivedModel(ConcreteMixin, BaseModel):
class DerivedModel(BaseModel):
CONFIGURATIONS = {"derived": {"asset": "something.txt"}}

def _load(self):
Expand Down
3 changes: 1 addition & 2 deletions docs/library/special/distant.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@ Sometimes models will simply need to call another microservice, in this case `Di

```python
from modelkit.core.models.distant_model import DistantHTTPModel
from modelkit.core.model import ConcreteMixin

class SomeDistantHTTPModel(ConcreteMixin, DistantHTTPModel):
class SomeDistantHTTPModel(DistantHTTPModel):
CONFIGURATIONS = {
"some_model": {
"model_settings": {
Expand Down
9 changes: 0 additions & 9 deletions modelkit/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ class Asset:
"""

CONFIGURATIONS: Dict[str, Dict[str, Any]] = {}
_abstract = False

def __init__(
self,
Expand Down Expand Up @@ -786,11 +785,3 @@ def __init__(self, async_model: AsyncModel[ItemType, ReturnType]):
# The following does not currently work, because AsyncToSync does not
# seem to correctly wrap asynchronous generators
# self.predict_gen = AsyncToSync(self.async_model.predict_gen)


class AbstractMixin:
_abstract = True


class ConcreteMixin:
_abstract = False
21 changes: 4 additions & 17 deletions modelkit/core/model_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import inspect
import os
import pkgutil
import re
from collections import ChainMap
from types import ModuleType
from typing import Any, Dict, List, Mapping, Optional, Set, Type, Union
Expand Down Expand Up @@ -53,24 +52,12 @@ def walk_objects(mod):
yield from walk_module_objects(mod, already_seen)


TO_SNAKE_CASE_PATTERN = re.compile(r"(?<!^)(?=[A-Z])")


def to_snake_case(name):
return TO_SNAKE_CASE_PATTERN.sub("_", name).lower()


def _configurations_from_objects(m) -> Dict[str, ModelConfiguration]:
if inspect.isclass(m) and issubclass(m, Asset):
if m._abstract:
return {}
configs = {}
if m.CONFIGURATIONS:
for key, config in m.CONFIGURATIONS.items():
configs[key] = ModelConfiguration(**{**config, "model_type": m})
else:
configs[to_snake_case(m.__name__)] = ModelConfiguration(model_type=m)
return configs
return {
key: ModelConfiguration(**{**config, "model_type": m})
for key, config in m.CONFIGURATIONS.items()
}
elif isinstance(m, (list, tuple)):
return dict(ChainMap(*(_configurations_from_objects(sub_m) for sub_m in m)))
elif isinstance(m, ModuleType):
Expand Down
6 changes: 3 additions & 3 deletions modelkit/core/models/distant_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
wait_random_exponential,
)

from modelkit.core.model import AbstractMixin, AsyncModel, Model
from modelkit.core.model import AsyncModel, Model
from modelkit.core.types import ItemType, ReturnType

logger = get_logger(__name__)
Expand Down Expand Up @@ -46,7 +46,7 @@ def retriable_error(exception):
}


class AsyncDistantHTTPModel(AbstractMixin, AsyncModel[ItemType, ReturnType]):
class AsyncDistantHTTPModel(AsyncModel[ItemType, ReturnType]):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.endpoint = self.model_settings["endpoint"]
Expand Down Expand Up @@ -77,7 +77,7 @@ async def close(self):
return self.aiohttp_session.close()


class DistantHTTPModel(AbstractMixin, Model[ItemType, ReturnType]):
class DistantHTTPModel(Model[ItemType, ReturnType]):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.endpoint = self.model_settings["endpoint"]
Expand Down
8 changes: 4 additions & 4 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
download_assets,
load_model,
)
from modelkit.core.model import AbstractMixin, Asset, AsyncModel, ConcreteMixin, Model
from modelkit.core.model import Asset, AsyncModel, Model
from modelkit.core.model_configuration import (
ModelConfiguration,
_configurations_from_objects,
Expand Down Expand Up @@ -149,7 +149,7 @@ class ModelNoConf(Asset):
assert "les simpsons" in configurations

configurations = _configurations_from_objects(ModelNoConf)
assert "model_no_conf" in configurations
assert {} == configurations

configurations = _configurations_from_objects([SomeModel, SomeModel2, ModelNoConf])
assert "yolo" in configurations
Expand Down Expand Up @@ -727,11 +727,11 @@ def test_model_sub_class(working_dir, monkeypatch):
with open(os.path.join(working_dir, "something.txt"), "w") as f:
f.write("OK")

class BaseAsset(AbstractMixin, Asset):
class BaseAsset(Asset):
def _load(self):
assert self.asset_path

class DerivedAsset(ConcreteMixin, BaseAsset):
class DerivedAsset(BaseAsset):
CONFIGURATIONS = {"derived": {"asset": "something.txt"}}

def _predict(self, item):
Expand Down
5 changes: 2 additions & 3 deletions tests/test_distant_http_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import requests

from modelkit.core.library import ModelLibrary
from modelkit.core.model import ConcreteMixin
from modelkit.core.models.distant_model import AsyncDistantHTTPModel, DistantHTTPModel
from tests import TEST_DIR

Expand Down Expand Up @@ -69,12 +68,12 @@ async def test_distant_http_model(
"async_mode": False,
}

class SomeDistantHTTPModel(ConcreteMixin, DistantHTTPModel):
class SomeDistantHTTPModel(DistantHTTPModel):
CONFIGURATIONS = {
"some_model_sync": {"model_settings": sync_model_settings},
}

class SomeAsyncDistantHTTPModel(ConcreteMixin, AsyncDistantHTTPModel):
class SomeAsyncDistantHTTPModel(AsyncDistantHTTPModel):
CONFIGURATIONS = {"some_model_async": {"model_settings": async_model_settings}}

lib_without_params = ModelLibrary(
Expand Down
6 changes: 3 additions & 3 deletions tests/testmodels/some_assets.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from modelkit.core.model import AbstractMixin, Asset, ConcreteMixin
from modelkit.core.model import Asset


class BaseAsset(AbstractMixin, Asset):
class BaseAsset(Asset):
def _load(self):
assert self.asset_path


class DerivedAsset(ConcreteMixin, BaseAsset):
class DerivedAsset(BaseAsset):
CONFIGURATIONS = {"derived_asset": {"asset": "something.txt"}}
6 changes: 3 additions & 3 deletions tests/testmodels/some_models.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from modelkit.core.model import AbstractMixin, ConcreteMixin, Model
from modelkit.core.model import Model


class BaseModel(AbstractMixin, Model):
class BaseModel(Model):
def _load(self):
assert self.asset_path


class DerivedModel(ConcreteMixin, BaseModel):
class DerivedModel(BaseModel):
CONFIGURATIONS = {"derived_model": {"asset": "something.txt"}}

def _predict(self, item):
Expand Down

0 comments on commit 2287c4a

Please sign in to comment.