Skip to content

Commit

Permalink
refactor: make BaseMiddleware generic
Browse files Browse the repository at this point in the history
  • Loading branch information
Lancetnik committed Dec 7, 2024
1 parent 9e57990 commit 31d913f
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 15 deletions.
2 changes: 1 addition & 1 deletion faststream/_internal/state/logger/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def log(
message: str,
log_level: Optional[int] = None,
extra: Optional["AnyDict"] = None,
exc_info: Optional[Exception] = None,
exc_info: Optional[BaseException] = None,
) -> None:
self.logger.log(
(log_level or self.log_level),
Expand Down
33 changes: 21 additions & 12 deletions faststream/middlewares/base.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,27 @@
from collections.abc import Awaitable
from typing import TYPE_CHECKING, Any, Callable, Optional
from typing import TYPE_CHECKING, Any, Callable, Generic, Optional

from typing_extensions import Self
# We should use typing_extensions.TypeVar until python3.13 due default
from typing_extensions import Self, TypeVar

from faststream.response.response import PublishCommand

if TYPE_CHECKING:
from types import TracebackType

from faststream._internal.basic_types import AsyncFuncAny
from faststream._internal.context.repository import ContextRepo
from faststream.message import StreamMessage
from faststream.response.response import PublishCommand


class BaseMiddleware:
_PublishCommand_T = TypeVar(
"_PublishCommand_T",
bound=PublishCommand,
default=PublishCommand,
)


class BaseMiddleware(Generic[_PublishCommand_T]):
"""A base middleware class."""

def __init__(
Expand Down Expand Up @@ -54,11 +63,11 @@ async def on_consume(
self,
msg: "StreamMessage[Any]",
) -> "StreamMessage[Any]":
"""Asynchronously consumes a message."""
"""This option was deprecated and will be removed in 0.7.0. Please, use `consume_scope` instead."""
return msg

async def after_consume(self, err: Optional[Exception]) -> None:
"""A function to handle the result of consuming a resource asynchronously."""
"""This option was deprecated and will be removed in 0.7.0. Please, use `consume_scope` instead."""
if err is not None:
raise err

Expand All @@ -83,23 +92,23 @@ async def consume_scope(

async def on_publish(
self,
msg: "PublishCommand",
) -> "PublishCommand":
"""Asynchronously handle a publish event."""
msg: _PublishCommand_T,
) -> _PublishCommand_T:
"""This option was deprecated and will be removed in 0.7.0. Please, use `publish_scope` instead."""
return msg

async def after_publish(
self,
err: Optional[Exception],
) -> None:
"""Asynchronous function to handle the after publish event."""
"""This option was deprecated and will be removed in 0.7.0. Please, use `publish_scope` instead."""
if err is not None:
raise err

async def publish_scope(
self,
call_next: Callable[["PublishCommand"], Awaitable[Any]],
cmd: "PublishCommand",
call_next: Callable[[_PublishCommand_T], Awaitable[Any]],
cmd: _PublishCommand_T,
) -> Any:
"""Publish a message and return an async iterator."""
err: Optional[Exception] = None
Expand Down
4 changes: 2 additions & 2 deletions faststream/middlewares/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ async def consume_scope(
self,
call_next: "AsyncFuncAny",
msg: "StreamMessage[Any]",
) -> "StreamMessage[Any]":
) -> Any:
source_type = self._source_type = msg._source_type

if source_type is not SourceType.RESPONSE:
Expand All @@ -78,7 +78,7 @@ async def __aexit__(
if issubclass(exc_type, IgnoredException):
self.logger.log(
log_level=logging.INFO,
message=exc_val,
message=str(exc_val),
extra=c,
)

Expand Down

0 comments on commit 31d913f

Please sign in to comment.