Skip to content

Commit

Permalink
Allow factories to receive the current container (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
hynek authored Aug 1, 2023
1 parent 77cf867 commit d2697cc
Show file tree
Hide file tree
Showing 6 changed files with 194 additions and 15 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@ You can find our backwards-compatibility policy [here](https://github.com/hynek/

## [Unreleased](https://github.com/hynek/svcs/compare/23.6.0...HEAD)

### Added

- Factories now may take a parameter called `svcs_container` or that is annotated to be `svcs.Container`.
In this case the factory will receive the current container as a first positional argument.
This allows for recursive factories without global state.
[#10](https://github.com/hynek/svcs/pull/10)


## [23.6.0](https://github.com/hynek/svcs/compare/23.5.0...23.6.0) - 2023-07-31

Expand Down
23 changes: 19 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -196,11 +196,26 @@ True

A container lives as long as you want the instances to live – e.g., as long as a request lives.

Importantly: It is possible to overwrite registered service factories later – e.g., for testing – **without monkey-patching**.
You have to remove possibly cached instances from the container though (`Container.forget_about()`).
The Flask integration takes care of this for you.
If a factory takes a first argument called `svcs_container` or the first argument of any name that is annotated as being `svcs.Container`, the current container instance is passed into the factory as the first *positional* argument:

How to achieve this in other frameworks elegantly is TBD.
```python
>>> def factory(svcs_container) -> str:
... return svcs_container.get(uuid.UUID).hex

>>> reg.register_factory(str, factory)

% skip: next

>>> container.get(str)
'86d342d6652d4d7faa912769dff0793e'
```

> [!NOTE]
> It is possible to overwrite registered service factories later – e.g., for testing – **without monkey-patching**.
> You have to remove possibly cached instances from the container though (`Container.forget_about()`).
> The Flask integration takes care of this for you.
>
> How to achieve this in other frameworks elegantly is TBD.

#### Cleanup
Expand Down
10 changes: 9 additions & 1 deletion src/svcs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
"Container",
"RegisteredService",
"Registry",
"ServiceNotFoundError",
"ServicePing",
"exceptions",
]
Expand All @@ -26,3 +25,12 @@
from . import flask # noqa: F401
except ImportError:
__all__.append("flask")


# Make nicer public names.
__locals = locals()
for __name in __all__:
if not __name.startswith("__") and not __name.islower():
__locals[__name].__module__ = "svcs"
del __locals
del __name
37 changes: 33 additions & 4 deletions src/svcs/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from __future__ import annotations

import inspect
import logging
import sys
import warnings
Expand Down Expand Up @@ -57,7 +58,7 @@ def get(self, svc_type: type) -> Any:
return svc

rs = self.registry.get_registered_service_for(svc_type)
svc = rs.factory()
svc = rs.factory(self) if rs.takes_container else rs.factory()

if isinstance(svc, Generator):
self._on_close.append((rs.name, svc))
Expand Down Expand Up @@ -179,6 +180,7 @@ def get_pings(self) -> list[ServicePing]:
class RegisteredService:
svc_type: type
factory: Callable = attrs.field(hash=False)
takes_container: bool
ping: Callable | None = attrs.field(hash=False)

@property
Expand All @@ -188,8 +190,11 @@ def name(self) -> str:
def __repr__(self) -> str:
return (
f"<RegisteredService(svc_type="
f"{ self.name}, "
f"has_ping={ self.ping is not None})>"
f"{self.name}, "
f"{self.factory}, "
f"takes_container={self.takes_container}, "
f"has_ping={ self.ping is not None}"
")>"
)

@property
Expand Down Expand Up @@ -240,7 +245,9 @@ def register_factory(
ping: Callable | None = None,
on_registry_close: Callable | None = None,
) -> None:
rs = RegisteredService(svc_type, factory, ping)
rs = RegisteredService(
svc_type, factory, _takes_container(factory), ping
)
self._services[svc_type] = rs

if on_registry_close is not None:
Expand Down Expand Up @@ -321,3 +328,25 @@ async def aclose(self) -> None:

self._services.clear()
self._on_close.clear()


def _takes_container(factory: Callable) -> bool:
"""
Return True if *factory* takes a svcs.Container as its first argument.
"""
sig = inspect.signature(factory)
if not sig.parameters:
return False

if len(sig.parameters) != 1:
msg = "Factories must take 0 or 1 parameters."
raise TypeError(msg)

((name, p),) = tuple(sig.parameters.items())
if name == "svcs_container":
return True

if (annot := p.annotation) is Container or annot == "svcs.Container":
return True

return False
128 changes: 122 additions & 6 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,52 @@ class YetAnotherService:

@pytest.fixture(name="rs")
def _rs(svc):
return svcs.RegisteredService(Service, Service, None)
return svcs.RegisteredService(Service, Service, False, None)


@pytest.fixture(name="svc")
def _svc():
return Service()


class TestIntegration:
def test_passes_container_bc_name(self, registry, container):
"""
If the factory takes an argument called `svcs_container`, it is passed
on instantiation.
"""
called = False

def factory(svcs_container):
assert container is svcs_container
nonlocal called
called = True

registry.register_factory(Service, factory)

container.get(Service)

assert called

def test_passes_container_bc_annotation(self, registry, container):
"""
If the factory takes an argument annotated with svcs.Container, it is
passed on instantiation.
"""
called = False

def factory(foo: svcs.Container):
assert container is foo
nonlocal called
called = True

registry.register_factory(Service, factory)

container.get(Service)

assert called


class TestContainer:
def test_register_factory_get(self, registry, container):
"""
Expand Down Expand Up @@ -213,7 +251,10 @@ def test_repr(self, rs):
"""

assert (
"<RegisteredService(svc_type=tests.test_core.Service, has_ping=False)>"
"<RegisteredService(svc_type=tests.test_core.Service, "
"<class 'tests.test_core.Service'>, takes_container=False, "
"has_ping=False"
")>"
) == repr(rs)

def test_name(self, rs):
Expand All @@ -235,8 +276,10 @@ async def factory_cleanup():
await asyncio.sleep(0)
yield 42

assert svcs.RegisteredService(object, factory, None).is_async
assert svcs.RegisteredService(object, factory_cleanup, None).is_async
assert svcs.RegisteredService(object, factory, False, None).is_async
assert svcs.RegisteredService(
object, factory_cleanup, False, None
).is_async

def test_is_async_nope(self):
"""
Expand All @@ -249,9 +292,11 @@ def factory():
def factory_cleanup():
yield 42

assert not svcs.RegisteredService(object, factory, None).is_async
assert not svcs.RegisteredService(
object, factory_cleanup, None
object, factory, False, None
).is_async
assert not svcs.RegisteredService(
object, factory_cleanup, False, None
).is_async


Expand Down Expand Up @@ -423,3 +468,74 @@ async def test_aclose_logs_failures(self, registry, caplog):

close_mock.assert_awaited_once()
assert "tests.test_core.Service" == caplog.records[0].svcs_service_name


def factory_wrong_annotation(foo: svcs.Registry) -> int:
return 42


class TestTakesContainer:
@pytest.mark.parametrize(
"factory",
[lambda: None, lambda container: None, factory_wrong_annotation],
)
def test_nope(self, factory):
"""
Functions with different names and annotations are ignored.
"""
assert not svcs._core._takes_container(factory)

def test_name(self):
"""
Return True if the name is `svcs_container`.
"""

def factory(svcs_container):
return 42

assert svcs._core._takes_container(factory)

def test_annotation(self):
"""
Return true if the first argument is annotated as `svcs.Container`.
"""

def factory(foo: svcs.Container):
return 42

assert svcs._core._takes_container(factory)

def test_annotation_str(self):
"""
Return true if the first argument is annotated as `svcs.Container`
using a string.
"""

def factory(bar: "svcs.Container"):
return 42

assert svcs._core._takes_container(factory)

def test_catches_invalid_sigs(self):
"""
If the factory takes more than one parameter, raise an TypeError.
"""

def factory(foo, bar):
return 42

with pytest.raises(
TypeError, match="Factories must take 0 or 1 parameters."
):
svcs._core._takes_container(factory)

def test_call_works(self):
"""
Does not raise if the factory is a class with __call__.
"""

class Factory:
def __call__(self, svcs_container):
return 42

assert svcs._core._takes_container(Factory())
4 changes: 4 additions & 0 deletions tests/typing/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ def factory_with_cleanup() -> Generator[int, None, None]:
yield 1


def factory_that_takes_container_by_annotation(foo: svcs.Container) -> int:
return 1


async def async_ping() -> None:
pass

Expand Down

0 comments on commit d2697cc

Please sign in to comment.