Skip to content

Commit

Permalink
Determine async-ness of factories only once
Browse files Browse the repository at this point in the history
  • Loading branch information
hynek committed Aug 1, 2023
1 parent e0e2c5b commit 71b3156
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 43 deletions.
13 changes: 6 additions & 7 deletions src/svcs/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ class RegisteredService:
svc_type: type
factory: Callable = attrs.field(hash=False)
takes_container: bool
is_async: bool
ping: Callable | None = attrs.field(hash=False)

@property
Expand All @@ -197,12 +198,6 @@ def __repr__(self) -> str:
")>"
)

@property
def is_async(self) -> bool:
return iscoroutinefunction(self.factory) or isasyncgenfunction(
self.factory
)


@attrs.frozen
class ServicePing:
Expand Down Expand Up @@ -246,7 +241,11 @@ def register_factory(
on_registry_close: Callable | None = None,
) -> None:
rs = RegisteredService(
svc_type, factory, _takes_container(factory), ping
svc_type,
factory,
_takes_container(factory),
iscoroutinefunction(factory) or isasyncgenfunction(factory),
ping,
)
self._services[svc_type] = rs

Expand Down
85 changes: 49 additions & 36 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class YetAnotherService:

@pytest.fixture(name="rs")
def _rs(svc):
return svcs.RegisteredService(Service, Service, False, None)
return svcs.RegisteredService(Service, Service, False, False, None)


@pytest.fixture(name="svc")
Expand Down Expand Up @@ -256,41 +256,6 @@ def test_name(self, rs):

assert "tests.test_core.Service" == rs.name

def test_is_async_yep(self):
"""
The is_async property returns True if the factory needs to be awaited.
"""

async def factory():
return 42

async def factory_cleanup():
await asyncio.sleep(0)
yield 42

assert svcs.RegisteredService(object, factory, False, None).is_async
assert svcs.RegisteredService(
object, factory_cleanup, False, None
).is_async

def test_is_async_nope(self):
"""
is_async is False for sync factories.
"""

def factory():
return 42

def factory_cleanup():
yield 42

assert not svcs.RegisteredService(
object, factory, False, None
).is_async
assert not svcs.RegisteredService(
object, factory_cleanup, False, None
).is_async


class TestServicePing:
def test_name(self, rs):
Expand Down Expand Up @@ -406,6 +371,54 @@ def test_close_logs_failures(self, registry, caplog):

assert "tests.test_core.Service" == caplog.records[0].svcs_service_name

def test_detects_async_factories(self, registry):
"""
The is_async property of the RegisteredService is True if the factory
needs to be awaited.
"""

async def factory():
return 42

async def factory_cleanup():
await asyncio.sleep(0)
yield str(42)

registry.register_factory(int, factory)
registry.register_factory(str, factory_cleanup)

assert (
svcs.RegisteredService(int, factory, False, True, None)
== registry._services[int]
)
assert (
svcs.RegisteredService(str, factory_cleanup, False, True, None)
== registry._services[str]
)

def test_no_false_positive_async(self, registry):
"""
is_async is False for sync factories.
"""

def factory():
return 42

def factory_cleanup():
yield "42"

registry.register_factory(int, factory)
registry.register_factory(str, factory_cleanup)

assert (
svcs.RegisteredService(int, factory, False, False, None)
== registry._services[int]
)
assert (
svcs.RegisteredService(str, factory_cleanup, False, False, None)
== registry._services[str]
)

@pytest.mark.skipif(
not hasattr(contextlib, "aclosing"),
reason="Hasn't contextlib.aclosing()",
Expand Down

0 comments on commit 71b3156

Please sign in to comment.