Skip to content

Commit

Permalink
[Feature] - Allow importing the object from string in the factory (#179)
Browse files Browse the repository at this point in the history
  • Loading branch information
tarsil authored Oct 4, 2023
2 parents dc25969 + ab26f59 commit 1606fd7
Show file tree
Hide file tree
Showing 9 changed files with 247 additions and 9 deletions.
4 changes: 2 additions & 2 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ services:
MONGO_INITDB_ROOT_PASSWORD: mongoadmin
MONGO_INITDB_DATABASE: mongodb
volumes:
- "mongo_db_data:/data/db"
- "mongo_esmerald_db_data:/data/db"
expose:
- 27017
ports:
Expand All @@ -56,5 +56,5 @@ services:
volumes:
esmerald:
external: true
mongo_db_data:
mongo_esmerald_db_data:
external: true
42 changes: 39 additions & 3 deletions docs/dependencies.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,17 @@ and checks if the value is bigger or equal than 5 and that result `is_valid` is

The same is applied also to [exception handlers](./exception-handlers.md).

## More Real World example
## More real world examples

Now let us imagine that we have a web application with one of the views. Something like this:

```python hl_lines="17"
{!> ../docs_src/dependencies/views.py !}
```

As you can notice, the user_dao is injected automatically using the appropriate level of dependency injection..
As you can notice, the `user_dao`` is injected automatically using the appropriate level of dependency injection.

Let's see the `urls.py` and understand from where we got the `user_dao`:
Let us see the `urls.py` and understand from where we got the `user_dao`:

```python hl_lines="14-16 32-34"
{!> ../docs_src/dependencies/urls.py !}
Expand All @@ -85,5 +85,41 @@ The Factory is a clean wrapper around any callable (classes usually are callable
No need to explicitly instantiate the class, just pass the class definition to the `Factory`
and Esmerald takes care of the rest for you.

### Importing using strings

Like everything is Esmerald, there are different ways of achieving the same results and the `Factory`
is no exception.

In the previous examples we were passing the `UserDAO`, `ArticleDAO` and `PostDAO` classes directly
into the `Factory` object and that also means that **you will need to import the objects to then be passed**.

What can happen with this process? Majority of the times nothing **but** you can also have the classic
`partially imported ...` annoying error, right?

Well, the good news is that Esmerald got you covered, as usual.

The `Factory` **also allows import via string** without the need of importing directly the object
to the place where it is needed.

Let us then see how it would look like and let us then assume:

1. The `UserDAO` is located somewhere in the codebase like `myapp.accounts.daos`.
2. The `ArticleDAO` is located somewhere in the codebase like `myapp.articles.daos`.
3. The `PostDAO` is located somewhere in the codebase like `myapp.posts.daos`.

Ok, now that we know this, let us see how it would look like in the codebase importing it inside the
`Factory`.

```python hl_lines="13-15"
{!> ../docs_src/dependencies/urls_factory_import.py !}
```

Now, this is a beauty is it not? This way, the codebase is cleaner and without all of those imported
objects from the top.

!!! Tip
Both cases work well within Esmerald, this is simply an alternative in case the complexity of
the codebase increases and you would like to tidy it up a bit more.

In conclusion, if your views/routes expect dependencies, you can define them in the upper level as described
and Esmerald will make sure that they will be automatically injected.
18 changes: 18 additions & 0 deletions docs_src/dependencies/urls_factory_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from esmerald import Factory, Include, Inject

route_patterns = [
Include(
"/api/v1",
routes=[
Include("/accounts", namespace="accounts.v1.urls"),
Include("/articles", namespace="articles.v1.urls"),
Include("/posts", namespace="posts.v1.urls"),
],
interceptors=[LoggingInterceptor], # Custom interceptor
dependencies={
"user_dao": Inject(Factory("myapp.accounts.daos.UserDAO")),
"article_dao": Inject(Factory("myapp.articles.daos.ArticleDAO")),
"post_dao": Inject(Factory("myapp.posts.daos.PostDAO")),
},
)
]
Empty file added esmerald/core/di/__init__.py
Empty file.
89 changes: 89 additions & 0 deletions esmerald/core/di/provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""
Functions to use with the Factory dependency injection.
"""
from typing import Any, Callable, Tuple, cast

from esmerald.exceptions import ImproperlyConfigured
from esmerald.utils.module_loading import import_string


def _lookup(klass: Any, comp: Any, import_path: Any) -> Any:
"""
Runs a lookup via __import__ and returns the component.
"""
try:
return getattr(klass, comp)
except AttributeError:
__import__(import_path)
return getattr(klass, comp)


def _importer(target: Any, attribute: Any) -> Any:
"""
Gets the attribute from the target.
"""
components = target.split(".")
import_path = components.pop(0)
klass = __import__(import_path)

for comp in components:
import_path += ".%s" % comp
klass = _lookup(klass, comp, import_path)
return getattr(klass, attribute)


def _get_provider_callable(target: str) -> Any:
try:
target, attribute = target.rsplit(".", 1)
except (TypeError, ValueError, AttributeError):
raise TypeError(f"Need a valid target to lookup. You supplied: {target!r}") from None

def getter() -> Any:
return _importer(target, attribute)

return getter


def load_provider(provider: str) -> Tuple[Callable, bool]:
"""
Loads any callable by string import. This will make
sure that there is no need to have all the imports in one
file to use the `esmerald.injector.Factory`.
Example:
# myapp.daos.py
from esmerald import AsyncDAOProtocol
class MyDAO(AsyncDAOProtocol):
...
# myapp.urls.py
from esmerald import Inject, Factory, Gateway
route_patterns = [
Gateway(
...,
dependencies={"my_dao": Inject(Factory("myapp.daos.MyDAO"))}
)
]
"""
if not isinstance(provider, str):
raise ImproperlyConfigured(
"The `provider` should be a string with the format <module>.<file>"
)

is_nested: bool = False
try:
provider_callable = import_string(provider)
except ModuleNotFoundError:
target = _get_provider_callable(provider)
provider_callable = target
is_nested = True

if not callable(provider_callable):
raise ImproperlyConfigured(
f"The `provider` specified must be a callable, got {type(provider_callable)} instead."
)

return cast(Callable, provider_callable), is_nested
27 changes: 24 additions & 3 deletions esmerald/injector.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type, Union

from esmerald.core.di.provider import load_provider
from esmerald.parsers import ArbitraryHashableBaseModel
from esmerald.transformers.datastructures import Signature
from esmerald.typing import Void
Expand All @@ -10,10 +11,19 @@


class Factory:
def __init__(self, provides: "AnyCallable", *args: Any) -> None:
self.provides = provides
def __init__(self, provides: Union["AnyCallable", str], *args: Any) -> None:
"""
The provider can be passed in separate ways. Via direct callable
or via string value where it will be automatically imported by the application.
"""
self.__args: Tuple[Any, ...] = ()
self.set_args(*args)
self.is_nested: bool = False

if isinstance(provides, str):
self.provides, self.is_nested = load_provider(provides)
else:
self.provides = provides

def set_args(self, *args: Any) -> None:
self.__args = args
Expand All @@ -24,6 +34,17 @@ def cls(self) -> "AnyCallable":
return self.provides

async def __call__(self) -> Any:
"""
This handles with normal and nested imports.
Example:
1. MyClass.func
2. MyClass.AnotherClass.func
"""
if self.is_nested:
self.provides = self.provides()

if is_async_callable(self.provides):
value = await self.provides(*self.__args)
else:
Expand Down
2 changes: 1 addition & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ nav:
- Interceptors: "interceptors.md"
- Permissions: "permissions.md"
- Middleware: "middleware/middleware.md"
- Dependencies: "dependencies.md"
- Exceptions: "exceptions.md"
- Exception Handlers: "exception-handlers.md"
- Dependencies: "dependencies.md"
- Pluggables: "pluggables.md"
- Datastructures: "datastructures.md"
- Password Hashers: "password-hashers.md"
Expand Down
23 changes: 23 additions & 0 deletions tests/test_inject.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,20 @@ def __init__(self) -> None:
async def async_class(cls) -> int:
return cls.val

class InsideTest:
val = 56

@classmethod
async def async_class(cls) -> int:
return cls.val

class NestedInsideTest:
val = 92

@classmethod
async def async_class(cls) -> int:
return cls.val

@classmethod
def sync_class(cls) -> int:
return cls.val
Expand Down Expand Up @@ -59,6 +73,7 @@ def sync_fn(val: str = "three-one") -> str:
[
(async_fn, "three-one"),
(Factory(async_fn), "three-one"),
(Factory("tests.test_inject.async_fn"), "three-one"),
],
)
@pytest.mark.asyncio()
Expand All @@ -73,6 +88,7 @@ async def test_Inject_default(_callable, exp) -> None:
[
(async_fn, "three-one"),
(Factory(async_fn), "three-one"),
(Factory("tests.test_inject.async_fn"), "three-one"),
],
)
@pytest.mark.asyncio()
Expand Down Expand Up @@ -130,6 +146,13 @@ def test_Injectr_equality_check_Factory() -> None:
(Factory(Test.sync_static), "one-three"),
(Factory(Test().async_instance), 13),
(Factory(Test().sync_instance), 13),
(Factory("tests.test_inject.async_fn"), "three-one"),
(Factory("tests.test_inject.Test.async_class"), 31),
(Factory("tests.test_inject.Test.sync_class"), 31),
(Factory("tests.test_inject.Test.async_static"), "one-three"),
(Factory("tests.test_inject.Test.sync_static"), "one-three"),
(Factory("tests.test_inject.Test.InsideTest.async_class"), 56),
(Factory("tests.test_inject.Test.InsideTest.NestedInsideTest.async_class"), 92),
],
)
@pytest.mark.asyncio()
Expand Down
51 changes: 51 additions & 0 deletions tests/test_injects.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,17 @@ async def test(fake_dao: FakeDAO = Injects()) -> Dict[str, int]:
assert resp.json() == {"value": ["awesome_data"]}


def test_no_default_dependency_Injected_with_Factory_from_string() -> None:
@get(dependencies={"fake_dao": Inject(Factory("tests.conftest.FakeDAO"))})
async def test(fake_dao: FakeDAO = Injects()) -> Dict[str, int]:
result = await fake_dao.get_all()
return {"value": result}

with create_client(routes=[Gateway(handler=test)]) as client:
resp = client.get("/")
assert resp.json() == {"value": ["awesome_data"]}


def test_dependency_not_Injected_and_no_default() -> None:
@get()
def test(value: int = Injects()) -> Dict[str, int]:
Expand Down Expand Up @@ -177,6 +188,21 @@ async def test(self, fake_dao: FakeDAO = Injects()) -> Dict[str, List[str]]:
assert resp.json() == {"value": ["awesome_data"]}


def test_dependency_Injected_on_APIView_with_Factory_from_string() -> None:
class C(APIView):
path = ""
dependencies = {"fake_dao": Inject(Factory("tests.conftest.FakeDAO"))}

@get()
async def test(self, fake_dao: FakeDAO = Injects()) -> Dict[str, List[str]]:
result = await fake_dao.get_all()
return {"value": result}

with create_client(routes=[Gateway(handler=C)]) as client:
resp = client.get("/")
assert resp.json() == {"value": ["awesome_data"]}


def test_dependency_skip_validation() -> None:
@get("/validated")
def validated(value: int = Injects()) -> Dict[str, int]:
Expand Down Expand Up @@ -224,3 +250,28 @@ async def skipped(fake_dao: FakeDAO = Injects(skip_validation=True)) -> Dict[str
skipped_resp = client.get("/skipped")
assert skipped_resp.status_code == HTTP_200_OK
assert skipped_resp.json() == {"value": ["awesome_data"]}


def test_dependency_skip_validation_with_Factory_from_string() -> None:
@get("/validated")
def validated(fake_dao: int = Injects()) -> Dict[str, List[str]]:
""" """

@get("/skipped")
async def skipped(fake_dao: FakeDAO = Injects(skip_validation=True)) -> Dict[str, List[str]]:
result = await fake_dao.get_all()
return {"value": result}

with create_client(
routes=[
Gateway(handler=validated),
Gateway(handler=skipped),
],
dependencies={"fake_dao": Inject(Factory("tests.conftest.FakeDAO"))},
) as client:
validated_resp = client.get("/validated")
assert validated_resp.status_code == HTTP_500_INTERNAL_SERVER_ERROR

skipped_resp = client.get("/skipped")
assert skipped_resp.status_code == HTTP_200_OK
assert skipped_resp.json() == {"value": ["awesome_data"]}

0 comments on commit 1606fd7

Please sign in to comment.