Skip to content

Commit

Permalink
feat: add RMQ channels options, support for prefix for routing_key, a…
Browse files Browse the repository at this point in the history
…dd public API for middlewares
  • Loading branch information
Lancetnik committed May 15, 2024
1 parent d100d5f commit 93ac180
Show file tree
Hide file tree
Showing 16 changed files with 247 additions and 17 deletions.
2 changes: 1 addition & 1 deletion faststream/__about__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Simple and fast framework to create message brokers based microservices."""

__version__ = "0.5.5"
__version__ = "0.5.6"

SERVICE_NAME = f"faststream-{__version__}"

Expand Down
13 changes: 13 additions & 0 deletions faststream/broker/core/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,19 @@ def __init__(
self._parser = parser
self._decoder = decoder

def add_middleware(self, middleware: "BrokerMiddleware[MsgType]") -> None:
"""Append BrokerMiddleware to the end of middlewares list.
Current middleware will be used as a most inner of already existed ones.
"""
self._middlewares = (*self._middlewares, middleware)

for sub in self._subscribers.values():
sub.add_middleware(middleware)

for pub in self._publishers.values():
pub.add_middleware(middleware)

@abstractmethod
def subscriber(
self,
Expand Down
4 changes: 4 additions & 0 deletions faststream/broker/publisher/proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ class PublisherProto(
_middlewares: Iterable["PublisherMiddleware"]
_producer: Optional["ProducerProto"]

@abstractmethod
def add_middleware(self, middleware: "BrokerMiddleware[MsgType]") -> None:
...

@staticmethod
@abstractmethod
def create() -> "PublisherProto[MsgType]":
Expand Down
10 changes: 9 additions & 1 deletion faststream/broker/publisher/usecase.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@
from faststream.asyncapi.message import get_response_schema
from faststream.asyncapi.utils import to_camelcase
from faststream.broker.publisher.proto import PublisherProto
from faststream.broker.types import MsgType, P_HandlerParams, T_HandlerReturn
from faststream.broker.types import (
BrokerMiddleware,
MsgType,
P_HandlerParams,
T_HandlerReturn,
)
from faststream.broker.wrapper.call import HandlerCallWrapper

if TYPE_CHECKING:
Expand Down Expand Up @@ -87,6 +92,9 @@ def __init__(
self.include_in_schema = include_in_schema
self.schema_ = schema_

def add_middleware(self, middleware: "BrokerMiddleware[MsgType]") -> None:
self._broker_middlewares = (*self._broker_middlewares, middleware)

@override
def setup( # type: ignore[override]
self,
Expand Down
4 changes: 4 additions & 0 deletions faststream/broker/subscriber/proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ class SubscriberProto(
_broker_middlewares: Iterable["BrokerMiddleware[MsgType]"]
_producer: Optional["ProducerProto"]

@abstractmethod
def add_middleware(self, middleware: "BrokerMiddleware[MsgType]") -> None:
...

@staticmethod
@abstractmethod
def create() -> "SubscriberProto[MsgType]":
Expand Down
3 changes: 3 additions & 0 deletions faststream/broker/subscriber/usecase.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ def __init__(
self.description_ = description_
self.include_in_schema = include_in_schema

def add_middleware(self, middleware: "BrokerMiddleware[MsgType]") -> None:
self._broker_middlewares = (*self._broker_middlewares, middleware)

@override
def setup( # type: ignore[override]
self,
Expand Down
4 changes: 4 additions & 0 deletions faststream/nats/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,17 @@ async def parse_batch(
) -> "StreamMessage[List[Msg]]":
if first_msg := next(iter(message), None):
path = self.get_path(first_msg.subject)
headers = first_msg.headers

else:
path = None
headers = None

return NatsBatchMessage(
raw_message=message,
body=[m.data for m in message],
path=path or {},
headers=headers or {},
)

async def decode_batch(
Expand Down
1 change: 1 addition & 0 deletions faststream/rabbit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@
"ReplyConfig",
"RabbitExchange",
"RabbitQueue",
# Annotations
"RabbitMessage",
)
13 changes: 13 additions & 0 deletions faststream/rabbit/annotations.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from aio_pika import RobustChannel, RobustConnection
from typing_extensions import Annotated

from faststream.annotations import ContextRepo, Logger, NoCast
Expand All @@ -13,8 +14,20 @@
"RabbitMessage",
"RabbitBroker",
"RabbitProducer",
"Channel",
"Connection",
)

RabbitMessage = Annotated[RM, Context("message")]
RabbitBroker = Annotated[RB, Context("broker")]
RabbitProducer = Annotated[AioPikaFastProducer, Context("broker._producer")]

Channel = Annotated[RobustChannel, Context("broker._channel")]
Connection = Annotated[RobustConnection, Context("broker._connection")]

# NOTE: transaction is not for the public usage yet
# async def _get_transaction(connection: Connection) -> RabbitTransaction:
# async with connection.channel(publisher_confirms=False) as channel:
# yield channel.transaction()

# Transaction = Annotated[RabbitTransaction, Depends(_get_transaction)]
62 changes: 61 additions & 1 deletion faststream/rabbit/broker/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,26 @@ def __init__(
"TimeoutType",
Doc("Connection establishement timeout."),
] = None,
# channel args
channel_number: Annotated[
Optional[int],
Doc("Specify the channel number explicit."),
] = None,
publisher_confirms: Annotated[
bool,
Doc(
"if `True` the `publish` method will "
"return `bool` type after publish is complete."
"Otherwise it will returns `None`."
),
] = True,
on_return_raises: Annotated[
bool,
Doc(
"raise an :class:`aio_pika.exceptions.DeliveryError`"
"when mandatory message will be returned"
)
] = False,
# broker args
max_consumers: Annotated[
Optional[int],
Expand Down Expand Up @@ -220,6 +240,10 @@ def __init__(
url=str(amqp_url),
ssl_context=security_args.get("ssl_context"),
timeout=timeout,
# channel args
channel_number=channel_number,
publisher_confirms=publisher_confirms,
on_return_raises=on_return_raises,
# Basic args
graceful_timeout=graceful_timeout,
dependencies=dependencies,
Expand Down Expand Up @@ -303,13 +327,42 @@ async def connect( # type: ignore[override]
"TimeoutType",
Doc("Connection establishement timeout."),
] = None,
# channel args
channel_number: Annotated[
Union[int, None, object],
Doc("Specify the channel number explicit."),
] = Parameter.empty,
publisher_confirms: Annotated[
Union[bool, object],
Doc(
"if `True` the `publish` method will "
"return `bool` type after publish is complete."
"Otherwise it will returns `None`."
),
] = Parameter.empty,
on_return_raises: Annotated[
Union[bool, object],
Doc(
"raise an :class:`aio_pika.exceptions.DeliveryError`"
"when mandatory message will be returned"
)
] = Parameter.empty,
) -> "RobustConnection":
"""Connect broker object to RabbitMQ.
To startup subscribers too you should use `broker.start()` after/instead this method.
"""
kwargs: AnyDict = {}

if channel_number is not Parameter.empty:
kwargs["channel_number"] = channel_number

if publisher_confirms is not Parameter.empty:
kwargs["publisher_confirms"] = publisher_confirms

if on_return_raises is not Parameter.empty:
kwargs["on_return_raises"] = on_return_raises

if timeout:
kwargs["timeout"] = timeout

Expand Down Expand Up @@ -346,6 +399,9 @@ async def _connect( # type: ignore[override]
*,
timeout: "TimeoutType",
ssl_context: Optional["SSLContext"],
channel_number: Optional[int],
publisher_confirms: bool,
on_return_raises: bool,
) -> "RobustConnection":
connection = cast(
"RobustConnection",
Expand All @@ -360,7 +416,11 @@ async def _connect( # type: ignore[override]
max_consumers = self._max_consumers
channel = self._channel = cast(
"RobustChannel",
await connection.channel(),
await connection.channel(
channel_number=channel_number,
publisher_confirms=publisher_confirms,
on_return_raises=on_return_raises,
),
)

declarer = self.declarer = RabbitDeclarer(channel)
Expand Down
23 changes: 23 additions & 0 deletions faststream/rabbit/fastapi/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,26 @@ def __init__(
"TimeoutType",
Doc("Connection establishement timeout."),
] = None,
# channel args
channel_number: Annotated[
Optional[int],
Doc("Specify the channel number explicit."),
] = None,
publisher_confirms: Annotated[
bool,
Doc(
"if `True` the `publish` method will "
"return `bool` type after publish is complete."
"Otherwise it will returns `None`."
),
] = True,
on_return_raises: Annotated[
bool,
Doc(
"raise an :class:`aio_pika.exceptions.DeliveryError`"
"when mandatory message will be returned"
)
] = False,
# broker args
max_consumers: Annotated[
Optional[int],
Expand Down Expand Up @@ -408,6 +428,9 @@ def __init__(
graceful_timeout=graceful_timeout,
decoder=decoder,
parser=parser,
channel_number=channel_number,
publisher_confirms=publisher_confirms,
on_return_raises=on_return_raises,
middlewares=middlewares,
security=security,
asyncapi_url=asyncapi_url,
Expand Down
11 changes: 11 additions & 0 deletions faststream/rabbit/schemas/queue.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import deepcopy
from typing import TYPE_CHECKING, Optional

from typing_extensions import Annotated, Doc
Expand Down Expand Up @@ -115,3 +116,13 @@ def __init__(
self.auto_delete = auto_delete
self.arguments = arguments
self.timeout = timeout

def add_prefix(self, prefix: str) -> "RabbitQueue":
new_q: RabbitQueue = deepcopy(self)

new_q.name = "".join((prefix, new_q.name))

if new_q.routing_key:
new_q.routing_key = "".join((prefix, new_q.routing_key))

return new_q
5 changes: 1 addition & 4 deletions faststream/rabbit/subscriber/usecase.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from copy import deepcopy
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -223,6 +222,4 @@ def get_log_context(

def add_prefix(self, prefix: str) -> None:
"""Include Subscriber in router."""
new_q = deepcopy(self.queue)
new_q.name = "".join((prefix, new_q.name))
self.queue = new_q
self.queue = self.queue.add_prefix(prefix)
27 changes: 17 additions & 10 deletions faststream/redis/parser.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import (
TYPE_CHECKING,
Any,
Dict,
Mapping,
Optional,
Sequence,
Expand Down Expand Up @@ -183,9 +184,13 @@ class RedisBatchListParser(SimpleParser):

@staticmethod
def _parse_data(message: Mapping[str, Any]) -> Tuple[bytes, "AnyDict"]:
data = [_decode_batch_body_item(x) for x in message["data"]]
return (
dump_json(_decode_batch_body_item(x) for x in message["data"]),
{"content-type": ContentTypes.json},
dump_json(i[0] for i in data),
{
**data[0][1],
"content-type": ContentTypes.json,
},
)


Expand All @@ -203,17 +208,19 @@ class RedisBatchStreamParser(SimpleParser):

@staticmethod
def _parse_data(message: Mapping[str, Any]) -> Tuple[bytes, "AnyDict"]:
data = [_decode_batch_body_item(x.get(bDATA_KEY, x)) for x in message["data"]]
return (
dump_json(
_decode_batch_body_item(x.get(bDATA_KEY, x)) for x in message["data"]
),
{"content-type": ContentTypes.json},
dump_json(i[0] for i in data),
{
**data[0][1],
"content-type": ContentTypes.json,
},
)


def _decode_batch_body_item(msg_content: bytes) -> Any:
msg_body, _ = RawMessage.parse(msg_content)
def _decode_batch_body_item(msg_content: bytes) -> Tuple[Any, Dict[str, str]]:
msg_body, headers = RawMessage.parse(msg_content)
try:
return json_loads(msg_body)
return json_loads(msg_body), headers
except Exception:
return msg_body
return msg_body, headers
Loading

0 comments on commit 93ac180

Please sign in to comment.