Skip to content

Commit

Permalink
Fix/context get local default (#1752)
Browse files Browse the repository at this point in the history
* fix: correct default value return in context.get_local

* refactor: use correct get_local syntax

* lint: correct docstrings

---------

Co-authored-by: Kumaran Rajendhiran <[email protected]>
  • Loading branch information
Lancetnik and kumaranvpl authored Sep 3, 2024
1 parent 4fe975d commit fabd5ff
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 11 deletions.
4 changes: 2 additions & 2 deletions faststream/broker/middlewares/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async def on_consume(
msg: "StreamMessage[Any]",
) -> "StreamMessage[Any]":
if self.logger is not None:
c = context.get_local("log_context") or {}
c = context.get_local("log_context", {})
self.logger.log(self.log_level, "Received", extra=c)

return await super().on_consume(msg)
Expand All @@ -49,7 +49,7 @@ async def after_processed(
) -> bool:
"""Asynchronously called after processing."""
if self.logger is not None:
c = context.get_local("log_context") or {}
c = context.get_local("log_context", {})

if exc_type:
if issubclass(exc_type, IgnoredException):
Expand Down
4 changes: 2 additions & 2 deletions faststream/log/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def __init__(

def filter(self, record: LogRecord) -> bool:
if is_suitable := super().filter(record):
log_context: Mapping[str, str] = (
context.get_local("log_context") or self.default_context
log_context: Mapping[str, str] = context.get_local(
"log_context", self.default_context
)

for k, v in log_context.items():
Expand Down
2 changes: 1 addition & 1 deletion faststream/nats/broker/registrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@


class NatsRegistrator(ABCBroker["Msg"]):
"""Includable to RabbitBroker router."""
"""Includable to NatsBroker router."""

_subscribers: Dict[int, "AsyncAPISubscriber"]
_publishers: Dict[int, "AsyncAPIPublisher"]
Expand Down
2 changes: 1 addition & 1 deletion faststream/redis/broker/registrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@


class RedisRegistrator(ABCBroker[UnifyRedisDict]):
"""Includable to RabbitBroker router."""
"""Includable to RedisBroker router."""

_subscribers: Dict[int, "SubsciberType"]
_publishers: Dict[int, "PublisherType"]
Expand Down
12 changes: 7 additions & 5 deletions faststream/utils/context/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def set_local(self, key: str, value: Any) -> "Token[Any]":
"""
context_var = self._scope_context.get(key)
if context_var is None:
context_var = ContextVar(key, default=None)
context_var = ContextVar(key, default=EMPTY)
self._scope_context[key] = context_var
return context_var.set(value)

Expand All @@ -92,10 +92,12 @@ def get_local(self, key: str, default: Any = None) -> Any:
Returns:
The value of the local variable.
"""
if (context_var := self._scope_context.get(key)) is not None:
return context_var.get()
else:
return default
value = default
if (context_var := self._scope_context.get(key)) is not None and (
context_value := context_var.get()
) is not EMPTY:
value = context_value
return value

@contextmanager
def scope(self, key: str, value: Any) -> Iterator[None]:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ minversion = "7.0"
addopts = "-q -m 'not slow'"
testpaths = ["tests"]
markers = ["rabbit", "kafka", "confluent", "nats", "redis", "slow", "all"]
asyncio_default_fixture_loop_scope = "function"

[tool.coverage.run]
parallel = true
Expand Down
9 changes: 9 additions & 0 deletions tests/utils/context/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,15 @@ def use(
use()


def test_local_default(context: ContextRepo):
key = "some-key"

tag = context.set_local(key, "useless")
context.reset_local(key, tag)

assert context.get_local(key, 1) == 1


def test_initial():
@apply_types
def use(
Expand Down

0 comments on commit fabd5ff

Please sign in to comment.