Skip to content

Commit

Permalink
test: add Injects tests with Factory
Browse files Browse the repository at this point in the history
  • Loading branch information
shakonord committed Sep 22, 2023
1 parent 3f4cf47 commit 38c1e06
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 17 deletions.
22 changes: 22 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,23 @@
import functools
import pathlib
from typing import Any, List

import pytest

from esmerald import AsyncDAOProtocol
from esmerald.testclient import EsmeraldTestClient


class FakeDAO(AsyncDAOProtocol):
model = "Awesome"

def __init__(self, conn: str = "awesome_conn"):
self.conn = conn

async def get_all(self, **kwargs: Any) -> List[Any]:
return ["awesome_data"]


@pytest.fixture
def no_trio_support(anyio_backend_name): # pragma: no cover
if anyio_backend_name == "trio":
Expand Down Expand Up @@ -37,3 +49,13 @@ def test_app_client_factory(anyio_backend_name, anyio_backend_options):
@pytest.fixture()
def template_dir(tmp_path: pathlib.Path) -> pathlib.Path:
return tmp_path


@pytest.fixture(scope="module")
def get_fake_dao():
return FakeDAO


@pytest.fixture(scope="module")
def get_fake_dao_instance():
return FakeDAO()
19 changes: 4 additions & 15 deletions tests/test_inject.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,12 @@
from functools import partial
from typing import Any, List
from typing import Any

import pytest

from esmerald import AsyncDAOProtocol
from esmerald.injector import Factory, Inject
from esmerald.typing import Void


class FakeDAO(AsyncDAOProtocol):
model = "Awesome"

def __init__(self, conn: str = "awesome_conn"):
self.conn = conn

async def get_all(self, **kwargs: Any) -> List[Any]:
return ["awesome_data"]


class Test:
__test__ = False

Expand Down Expand Up @@ -149,7 +138,7 @@ async def test_Inject_for_callable(fn: Any, exp: Any) -> None:


@pytest.mark.asyncio()
async def test_if_DAO_is_injectable() -> None:
async def test_if_DAO_is_injectable(get_fake_dao) -> None:
"""
Current:
dependencies={
Expand All @@ -168,13 +157,13 @@ async def test_if_DAO_is_injectable() -> None:
"fake_dao": Inject(Factory(FakeDAO, "nice_conn")),
},
"""
injectable1 = Inject(Factory(FakeDAO))
injectable1 = Inject(Factory(get_fake_dao))
obj = await injectable1()
assert await obj.get_all() == ["awesome_data"]
assert obj.model == "Awesome"
assert obj.conn == "awesome_conn"

injectable2 = Inject(Factory(FakeDAO, "nice_conn"))
injectable2 = Inject(Factory(get_fake_dao, "nice_conn"))
obj = await injectable2()
assert await obj.get_all() == ["awesome_data"]
assert obj.model == "Awesome"
Expand Down
56 changes: 54 additions & 2 deletions tests/test_injects.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional

import pytest
from starlette.status import HTTP_200_OK, HTTP_500_INTERNAL_SERVER_ERROR

from esmerald.applications import Esmerald
from esmerald.exceptions import ImproperlyConfigured
from esmerald.injector import Inject
from esmerald.injector import Factory, Inject
from esmerald.params import Injects
from esmerald.routing.gateways import Gateway
from esmerald.routing.handlers import get
from esmerald.routing.router import Include
from esmerald.routing.views import APIView
from esmerald.testclient import create_client
from esmerald.utils.constants import IS_DEPENDENCY
from tests.conftest import FakeDAO


def test_is_dependency_inserted_into_field_extra() -> None:
Expand Down Expand Up @@ -127,6 +128,17 @@ def test(value: int = Injects()) -> Dict[str, int]:
assert resp.json() == {"value": 13}


def test_no_default_dependency_Injected_with_Factory(get_fake_dao) -> None:
@get(dependencies={"fake_dao": Inject(Factory(get_fake_dao))})
async def test(fake_dao: FakeDAO = Injects()) -> Dict[str, int]:
result = await fake_dao.get_all()
return {"value": result}

with create_client(routes=[Gateway(handler=test)]) as client:
resp = client.get("/")
assert resp.json() == {"value": ["awesome_data"]}


def test_dependency_not_Injected_and_no_default() -> None:
@get()
def test(value: int = Injects()) -> Dict[str, int]:
Expand All @@ -150,6 +162,21 @@ def test(self, value: int = Injects()) -> Dict[str, int]:
assert resp.json() == {"value": 13}


def test_dependency_Injected_on_APIView_with_Factory(get_fake_dao) -> None:
class C(APIView):
path = ""
dependencies = {"fake_dao": Inject(Factory(get_fake_dao))}

@get()
async def test(self, fake_dao: FakeDAO = Injects()) -> Dict[str, List[str]]:
result = await fake_dao.get_all()
return {"value": result}

with create_client(routes=[Gateway(handler=C)]) as client:
resp = client.get("/")
assert resp.json() == {"value": ["awesome_data"]}


def test_dependency_skip_validation() -> None:
@get("/validated")
def validated(value: int = Injects()) -> Dict[str, int]:
Expand All @@ -172,3 +199,28 @@ def skipped(value: int = Injects(skip_validation=True)) -> Dict[str, int]:
skipped_resp = client.get("/skipped")
assert skipped_resp.status_code == HTTP_200_OK
assert skipped_resp.json() == {"value": "str"}


def test_dependency_skip_validation_with_Factory(get_fake_dao) -> None:
@get("/validated")
def validated(fake_dao: int = Injects()) -> Dict[str, List[str]]:
""" """

@get("/skipped")
async def skipped(fake_dao: FakeDAO = Injects(skip_validation=True)) -> Dict[str, List[str]]:
result = await fake_dao.get_all()
return {"value": result}

with create_client(
routes=[
Gateway(handler=validated),
Gateway(handler=skipped),
],
dependencies={"fake_dao": Inject(Factory(get_fake_dao))},
) as client:
validated_resp = client.get("/validated")
assert validated_resp.status_code == HTTP_500_INTERNAL_SERVER_ERROR

skipped_resp = client.get("/skipped")
assert skipped_resp.status_code == HTTP_200_OK
assert skipped_resp.json() == {"value": ["awesome_data"]}

0 comments on commit 38c1e06

Please sign in to comment.