Skip to content

Commit

Permalink
feat: add __contains__ to Container
Browse files Browse the repository at this point in the history
  • Loading branch information
tandemdude committed Aug 22, 2024
1 parent 5e7ac7f commit e8cc4c2
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 55 deletions.
2 changes: 0 additions & 2 deletions docs/source/by-examples/050_dependencies.md
Original file line number Diff line number Diff line change
Expand Up @@ -383,10 +383,8 @@ Lightbulb will enable dependency injection on a specific subset of your methods

These are listed below:
- {meth}`@lightbulb.invoke <lightbulb.commands.execution.invoke>`
- {meth}`@Client.register <lightbulb.client.Client.register>`
- {meth}`@Client.error_handler <lightbulb.client.Client.error_handler>`
- {meth}`@Client.task <lightbulb.client.Client.task>`
- {meth}`@Loader.command <lightbulb.loaders.Loader.command>` (due to it calling `Client.register` internally)
- {meth}`@Loader.listener <lightbulb.loaders.Loader.listener>`
- {meth}`@Loader.task <lightbulb.loaders.Loader.task>`

Expand Down
11 changes: 11 additions & 0 deletions lightbulb/di/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,17 @@ def __init__(self, registry: registry_.Registry, *, parent: Container | None = N

self.add_value(Container, self)

def __contains__(self, item: type[t.Any]) -> bool:
dep_id = di_utils.get_dependency_id(item)
if dep_id not in self._graph:
return False

container = self._graph.nodes[dep_id]["container"]
if dep_id in container._instances:
return True

return container._graph.nodes[dep_id].get("factory") is not None

async def __aenter__(self) -> Container:
return self

Expand Down
136 changes: 84 additions & 52 deletions tests/di/test_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,84 @@ def f(_: A) -> object:
async with di.Container(registry) as container:
await container.get(B)

@pytest.mark.asyncio
async def test_non_direct_circular_dependency_raises_exception(self) -> None:
# fmt: off
def f_a(_: B) -> object: return object()

def f_b(_: A) -> object: return object()

# fmt: on

registry = di.Registry()
registry.register_factory(A, f_a)
registry.register_factory(B, f_b)

with pytest.raises(di.CircularDependencyException):
async with di.Container(registry) as c:
await c.get(A)

@pytest.mark.asyncio
async def test_get_transient_dependency_raises_exception(self) -> None:
def f_a(_: B) -> object:
return object()

registry = di.Registry()
registry.register_factory(A, f_a)

with pytest.raises(di.DependencyNotSatisfiableException):
async with di.Container(registry) as c:
await c.get(B)

@pytest.mark.asyncio
async def test_get_from_closed_container_raises_exception(self) -> None:
registry = di.Registry()
registry.register_factory(object, lambda: object())

with pytest.raises(di.ContainerClosedException):
async with di.Container(registry) as c:
pass
await c.get(object)

@pytest.mark.asyncio
async def test_get_with_default_when_dependency_not_available_returns_default(self) -> None:
registry = di.Registry()

async with di.Container(registry) as c:
assert await c.get(object, default=None) is None

@pytest.mark.asyncio
async def test_get_with_default_when_sub_dependency_not_available_returns_default(self) -> None:
registry = di.Registry()

def f1(_: str) -> object:
return object()

registry.register_factory(object, f1)

async with di.Container(registry) as c:
assert await c.get(object, default=None) is None

@pytest.mark.asyncio
async def test__contains__returns_true_when_dependency_known_by_value(self) -> None:
registry = di.Registry()
async with di.Container(registry) as container:
container.add_value(object, object())
assert object in container

@pytest.mark.asyncio
async def test__contains__returns_true_when_dependency_known_by_factory(self) -> None:
registry = di.Registry()
async with di.Container(registry) as container:
container.add_factory(object, lambda: object())
assert object in container

@pytest.mark.asyncio
async def test__contains__returns_false_when_dependency_not_known(self) -> None:
registry = di.Registry()
async with di.Container(registry) as container:
assert object not in container


class TestContainerWithParent:
@pytest.mark.asyncio
Expand Down Expand Up @@ -313,57 +391,11 @@ def f(_: A) -> object:
await cc.get(B)

@pytest.mark.asyncio
async def test_non_direct_circular_dependency_raises_exception(self) -> None:
# fmt: off
def f_a(_: B) -> object: return object()
def f_b(_: A) -> object: return object()
# fmt: on
async def test__contains__returns_true_when_dependency_known_by_parent(self) -> None:
r1 = di.Registry()
r1.register_value(object, object())

registry = di.Registry()
registry.register_factory(A, f_a)
registry.register_factory(B, f_b)

with pytest.raises(di.CircularDependencyException):
async with di.Container(registry) as c:
await c.get(A)
r2 = di.Registry()

@pytest.mark.asyncio
async def test_get_transient_dependency_raises_exception(self) -> None:
def f_a(_: B) -> object:
return object()

registry = di.Registry()
registry.register_factory(A, f_a)

with pytest.raises(di.DependencyNotSatisfiableException):
async with di.Container(registry) as c:
await c.get(B)

@pytest.mark.asyncio
async def test_get_from_closed_container_raises_exception(self) -> None:
registry = di.Registry()
registry.register_factory(object, lambda: object())

with pytest.raises(di.ContainerClosedException):
async with di.Container(registry) as c:
pass
await c.get(object)

@pytest.mark.asyncio
async def test_get_with_default_when_dependency_not_available_returns_default(self) -> None:
registry = di.Registry()

async with di.Container(registry) as c:
assert await c.get(object, default=None) is None

@pytest.mark.asyncio
async def test_get_with_default_when_sub_dependency_not_available_returns_default(self) -> None:
registry = di.Registry()

def f1(_: str) -> object:
return object()

registry.register_factory(object, f1)

async with di.Container(registry) as c:
assert await c.get(object, default=None) is None
async with di.Container(r1) as r1, di.Container(r2, parent=r1) as r2:
assert object in r2
2 changes: 1 addition & 1 deletion tests/di/test_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def m() -> None: ...

assert m.__lb_foo__ == "bar" # type: ignore[reportFunctionMemberAccess]

def test___get___within_class_does_not_assign_self(self) -> None:
def test__get__within_class_does_not_assign_self(self) -> None:
class Foo:
@lightbulb.di.with_di
def m(self) -> None: ...
Expand Down

0 comments on commit e8cc4c2

Please sign in to comment.