diff --git a/docs/docs/en/api/faststream/kafka/subscriber/asyncapi/AsyncAPIConcurrentDefaultSubscriber.md b/docs/docs/en/api/faststream/kafka/subscriber/asyncapi/AsyncAPIConcurrentDefaultSubscriber.md new file mode 100644 index 0000000000..8ce5838961 --- /dev/null +++ b/docs/docs/en/api/faststream/kafka/subscriber/asyncapi/AsyncAPIConcurrentDefaultSubscriber.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: faststream.kafka.subscriber.asyncapi.AsyncAPIConcurrentDefaultSubscriber diff --git a/docs/docs/en/api/faststream/kafka/subscriber/usecase/ConcurrentDefaultSubscriber.md b/docs/docs/en/api/faststream/kafka/subscriber/usecase/ConcurrentDefaultSubscriber.md new file mode 100644 index 0000000000..16f09d9334 --- /dev/null +++ b/docs/docs/en/api/faststream/kafka/subscriber/usecase/ConcurrentDefaultSubscriber.md @@ -0,0 +1,11 @@ +--- +# 0.5 - API +# 2 - Release +# 3 - Contributing +# 5 - Template Page +# 10 - Default +search: + boost: 0.5 +--- + +::: faststream.kafka.subscriber.usecase.ConcurrentDefaultSubscriber diff --git a/docs/docs/en/getting-started/acknowlegment.md b/docs/docs/en/getting-started/acknowlegment.md new file mode 100644 index 0000000000..e26179e30f --- /dev/null +++ b/docs/docs/en/getting-started/acknowlegment.md @@ -0,0 +1,24 @@ +# Acknowledgment + +Since unexpected errors may occur during message processing, **FastStream** has an `ack_policy` parameter. + +`AckPolicy` have 4 variants: + +- `ACK` means that the message will be acked anyway. + +- `NACK_ON_ERROR` means that the message will be nacked if an error occurs during processing and consumer will receive this message one more time. + +- `REJECT_ON_ERROR` means that the message will be rejected if an error occurs during processing and consumer will not receive this message again. + +- `DO_NOTHING` in this case *FastStream* will do nothing with the message. You must ack/nack/reject the message manually. + + +You must provide this parameter when initializing the subscriber. + +```python linenums="1" hl_lines="5" title="main.py" +from faststream import AckPolicy +from faststream.nats import NatsBroker + +broker = NatsBroker() +@broker.subscriber(ack_policy=AckPolicy.REJECT_ON_ERROR) +``` diff --git a/docs/docs/en/getting-started/logging.md b/docs/docs/en/getting-started/logging.md index 4c4b9e77b2..c6d56c478a 100644 --- a/docs/docs/en/getting-started/logging.md +++ b/docs/docs/en/getting-started/logging.md @@ -213,9 +213,9 @@ app = FastStream(broker, logger=logger) And the job is done! Now you have a perfectly structured logs using **Structlog**. ```{.shell .no-copy} -TIMESPAMP [info ] FastStream app starting... extra={} -TIMESPAMP [debug ] `Handler` waiting for messages extra={'topic': 'topic', 'group_id': 'group', 'message_id': ''} -TIMESPAMP [debug ] `Handler` waiting for messages extra={'topic': 'topic', 'group_id': 'group2', 'message_id': ''} -TIMESPAMP [info ] FastStream app started successfully! To exit, press CTRL+C extra={'topic': '', 'group_id': '', 'message_id': ''} +TIMESTAMP [info ] FastStream app starting... extra={} +TIMESTAMP [debug ] `Handler` waiting for messages extra={'topic': 'topic', 'group_id': 'group', 'message_id': ''} +TIMESTAMP [debug ] `Handler` waiting for messages extra={'topic': 'topic', 'group_id': 'group2', 'message_id': ''} +TIMESTAMP [info ] FastStream app started successfully! To exit, press CTRL+C extra={'topic': '', 'group_id': '', 'message_id': ''} ``` { data-search-exclude } diff --git a/docs/docs/en/nats/jetstream/ack.md b/docs/docs/en/nats/jetstream/ack.md index f003966493..a1dbc41168 100644 --- a/docs/docs/en/nats/jetstream/ack.md +++ b/docs/docs/en/nats/jetstream/ack.md @@ -16,29 +16,6 @@ In most cases, **FastStream** automatically acknowledges (*acks*) messages on yo However, there are situations where you might want to use different acknowledgement logic. -## Retries - -If you prefer to use a *nack* instead of a *reject* when there's an error in message processing, you can specify the `retry` flag in the `#!python @broker.subscriber(...)` method, which is responsible for error handling logic. - -By default, this flag is set to `False`, indicating that if an error occurs during message processing, the message can still be retrieved from the queue: - -```python -@broker.subscriber("test", retry=False) # don't handle exceptions -async def base_handler(body: str): - ... -``` - -If this flag is set to `True`, the message will be *nack*ed and placed back in the queue each time an error occurs. In this scenario, the message can be processed by another consumer (if there are several of them) or by the same one: - -```python -@broker.subscriber("test", retry=True) # try again indefinitely -async def base_handler(body: str): - ... -``` - -!!! tip - For more complex error handling cases, you can use [tenacity](https://tenacity.readthedocs.io/en/latest/){.external-link target="_blank"} - ## Manual Acknowledgement If you want to acknowledge a message manually, you can get access directly to the message object via the [Context](../../getting-started/context/existed.md){.internal-link} and call the method. diff --git a/docs/docs/en/rabbit/ack.md b/docs/docs/en/rabbit/ack.md index d68b66a0cd..e7632742dc 100644 --- a/docs/docs/en/rabbit/ack.md +++ b/docs/docs/en/rabbit/ack.md @@ -16,44 +16,6 @@ In most cases, **FastStream** automatically acknowledges (*acks*) messages on yo However, there are situations where you might want to use a different acknowledgement logic. -## Retries - -If you prefer to use a *nack* instead of a *reject* when there's an error in message processing, you can specify the `retry` flag in the `#!python @broker.subscriber(...)` method, which is responsible for error handling logic. - -By default, this flag is set to `False`, indicating that if an error occurs during message processing, the message can still be retrieved from the queue: - -```python -@broker.subscriber("test", retry=False) # don't handle exceptions -async def base_handler(body: str): - ... -``` - -If this flag is set to `True`, the message will be *nack*ed and placed back in the queue each time an error occurs. In this scenario, the message can be processed by another consumer (if there are several of them) or by the same one: - -```python -@broker.subscriber("test", retry=True) # try again indefinitely -async def base_handler(body: str): - ... -``` - -If the `retry` flag is set to an `int`, the message will be placed back in the queue, and the number of retries will be limited to this number: - -```python -@broker.subscriber("test", retry=3) # make up to 3 attempts -async def base_handler(body: str): - ... -``` - -!!! tip - **FastStream** identifies the message by its `message_id`. To make this option work, you should manually set this field on the producer side (if your library doesn't set it automatically). - -!!! bug - At the moment, attempts are counted only by the current consumer. If the message goes to another consumer, it will have its own counter. - Subsequently, this logic will be reworked. - -!!! tip - For more complex error handling cases, you can use [tenacity](https://tenacity.readthedocs.io/en/latest/){.external-link target="_blank"} - ## Manual acknowledgement If you want to acknowledge a message manually, you can get access directly to the message object via the [Context](../getting-started/context/existed.md){.internal-link} and call the method. diff --git a/docs/docs_src/confluent/ack/errors.py b/docs/docs_src/confluent/ack/errors.py index 36ceb61424..72bc9e1aba 100644 --- a/docs/docs_src/confluent/ack/errors.py +++ b/docs/docs_src/confluent/ack/errors.py @@ -1,4 +1,4 @@ -from faststream import FastStream +from faststream import FastStream, AckPolicy from faststream.exceptions import AckMessage from faststream.confluent import KafkaBroker @@ -7,7 +7,7 @@ @broker.subscriber( - "test-error-topic", group_id="test-error-group", auto_commit=False, auto_offset_reset="earliest" + "test-error-topic", group_id="test-error-group", ack_policy=AckPolicy.REJECT_ON_ERROR, auto_offset_reset="earliest" ) async def handle(body): smth_processing(body) diff --git a/docs/docs_src/getting_started/asyncapi/asyncapi_customization/custom_info.py b/docs/docs_src/getting_started/asyncapi/asyncapi_customization/custom_info.py index 4121bebe29..d177e86909 100644 --- a/docs/docs_src/getting_started/asyncapi/asyncapi_customization/custom_info.py +++ b/docs/docs_src/getting_started/asyncapi/asyncapi_customization/custom_info.py @@ -1,7 +1,6 @@ from faststream import FastStream from faststream.specification.asyncapi import AsyncAPI -from faststream.specification.schema.license import License -from faststream.specification.schema.contact import Contact +from faststream.specification import License, Contact from faststream.kafka import KafkaBroker broker = KafkaBroker("localhost:9092") diff --git a/docs/docs_src/kafka/ack/errors.py b/docs/docs_src/kafka/ack/errors.py index 19d333976d..6f293ab681 100644 --- a/docs/docs_src/kafka/ack/errors.py +++ b/docs/docs_src/kafka/ack/errors.py @@ -1,4 +1,4 @@ -from faststream import FastStream +from faststream import FastStream, AckPolicy from faststream.exceptions import AckMessage from faststream.kafka import KafkaBroker @@ -7,7 +7,7 @@ @broker.subscriber( - "test-topic", group_id="test-group", auto_commit=False + "test-topic", group_id="test-group", ack_policy=AckPolicy.REJECT_ON_ERROR, ) async def handle(body): smth_processing(body) diff --git a/examples/kafka/ack_after_process.py b/examples/kafka/ack_after_process.py index 7a00b7fac7..97550fdb87 100644 --- a/examples/kafka/ack_after_process.py +++ b/examples/kafka/ack_after_process.py @@ -1,14 +1,13 @@ -from faststream import FastStream, Logger +from faststream import FastStream, Logger, AckPolicy from faststream.kafka import KafkaBroker broker = KafkaBroker() app = FastStream(broker) - @broker.subscriber( "test", group_id="group", - auto_commit=False, + ack_policy=AckPolicy.REJECT_ON_ERROR, ) async def handler(msg: str, logger: Logger): logger.info(msg) diff --git a/faststream/_internal/_compat.py b/faststream/_internal/_compat.py index 8965fc951b..ba38326ac9 100644 --- a/faststream/_internal/_compat.py +++ b/faststream/_internal/_compat.py @@ -1,6 +1,7 @@ import json import sys import warnings +from collections import UserString from collections.abc import Iterable, Mapping from importlib.metadata import version as get_version from importlib.util import find_spec @@ -17,8 +18,6 @@ from faststream._internal.basic_types import AnyDict -PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") - IS_WINDOWS = ( sys.platform == "win32" or sys.platform == "cygwin" or sys.platform == "msys" ) @@ -76,9 +75,13 @@ def json_dumps(*a: Any, **kw: Any) -> bytes: JsonSchemaValue = Mapping[str, Any] +major, minor, *_ = PYDANTIC_VERSION.split(".") +_PYDANTCI_MAJOR, _PYDANTIC_MINOR = int(major), int(minor) + +PYDANTIC_V2 = _PYDANTCI_MAJOR >= 2 if PYDANTIC_V2: - if PYDANTIC_VERSION >= "2.4.0": + if _PYDANTIC_MINOR >= 4: from pydantic.annotated_handlers import ( GetJsonSchemaHandler, ) @@ -86,14 +89,9 @@ def json_dumps(*a: Any, **kw: Any) -> bytes: with_info_plain_validator_function, ) else: - if PYDANTIC_VERSION >= "2.10": - from pydantic.annotated_handlers import ( - GetJsonSchemaHandler, - ) - else: - from pydantic._internal._annotated_handlers import ( # type: ignore[no-redef] - GetJsonSchemaHandler, - ) + from pydantic._internal._annotated_handlers import ( # type: ignore[no-redef] + GetJsonSchemaHandler, + ) from pydantic_core.core_schema import ( general_plain_validator_function as with_info_plain_validator_function, ) @@ -112,7 +110,7 @@ def model_to_jsonable( def dump_json(data: Any) -> bytes: return json_dumps(model_to_jsonable(data)) - def get_model_fields(model: type[BaseModel]) -> dict[str, Any]: + def get_model_fields(model: type[BaseModel]) -> AnyDict: return model.model_fields def model_to_json(model: BaseModel, **kwargs: Any) -> str: @@ -142,7 +140,7 @@ def model_schema(model: type[BaseModel], **kwargs: Any) -> AnyDict: def dump_json(data: Any) -> bytes: return json_dumps(data, default=pydantic_encoder) - def get_model_fields(model: type[BaseModel]) -> dict[str, Any]: + def get_model_fields(model: type[BaseModel]) -> AnyDict: return model.__fields__ # type: ignore[return-value] def model_to_json(model: BaseModel, **kwargs: Any) -> str: @@ -175,19 +173,19 @@ def with_info_plain_validator_function( # type: ignore[misc] return {} -anyio_major = int(get_version("anyio").split(".")[0]) -ANYIO_V3 = anyio_major == 3 +major, *_ = get_version("anyio").split(".") +_ANYIO_MAJOR = int(major) +ANYIO_V3 = _ANYIO_MAJOR == 3 if ANYIO_V3: from anyio import ExceptionGroup # type: ignore[attr-defined] -elif sys.version_info < (3, 11): +elif sys.version_info >= (3, 11): + ExceptionGroup = ExceptionGroup # noqa: PLW0127 +else: from exceptiongroup import ( ExceptionGroup, ) -else: - ExceptionGroup = ExceptionGroup # noqa: PLW0127 - try: import email_validator @@ -198,7 +196,7 @@ def with_info_plain_validator_function( # type: ignore[misc] except ImportError: # pragma: no cover # NOTE: EmailStr mock was copied from the FastAPI # https://github.com/tiangolo/fastapi/blob/master/fastapi/openapi/models.py#24 - class EmailStr(str): # type: ignore[no-redef] + class EmailStr(UserString): # type: ignore[no-redef] """EmailStr is a string that should be an email. Note: EmailStr mock was copied from the FastAPI: diff --git a/faststream/_internal/basic_types.py b/faststream/_internal/basic_types.py index f781df146e..e844171150 100644 --- a/faststream/_internal/basic_types.py +++ b/faststream/_internal/basic_types.py @@ -58,7 +58,7 @@ class StandardDataclass(Protocol): """Protocol to check type is dataclass.""" - __dataclass_fields__: ClassVar[dict[str, Any]] + __dataclass_fields__: ClassVar[AnyDict] BaseSendableMessage: TypeAlias = Union[ diff --git a/faststream/_internal/broker/abc_broker.py b/faststream/_internal/broker/abc_broker.py index f92b8c2358..34678d5f5e 100644 --- a/faststream/_internal/broker/abc_broker.py +++ b/faststream/_internal/broker/abc_broker.py @@ -76,10 +76,8 @@ def publisher( is_running: bool = False, ) -> "PublisherProto[MsgType]": publisher.add_prefix(self.prefix) - if not is_running: self._publishers.append(publisher) - return publisher def setup_publisher( diff --git a/faststream/_internal/broker/broker.py b/faststream/_internal/broker/broker.py index b3621d34ee..831295ae76 100644 --- a/faststream/_internal/broker/broker.py +++ b/faststream/_internal/broker/broker.py @@ -35,6 +35,7 @@ MsgType, ) from faststream._internal.utils.functions import to_async +from faststream.specification.proto import ServerSpecification from .abc_broker import ABCBroker from .pub_base import BrokerPublishMixin @@ -51,12 +52,13 @@ PublisherProto, ) from faststream.security import BaseSecurity - from faststream.specification.schema.tag import Tag, TagDict + from faststream.specification.schema.extra import Tag, TagDict class BrokerUsecase( ABCBroker[MsgType], SetupAble, + ServerSpecification, BrokerPublishMixin[MsgType], Generic[MsgType, ConnectionType], ): @@ -121,7 +123,7 @@ def __init__( Doc("AsyncAPI server description."), ], tags: Annotated[ - Optional[Iterable[Union["Tag", "TagDict"]]], + Iterable[Union["Tag", "TagDict"]], Doc("AsyncAPI server tags."), ], specification_url: Annotated[ diff --git a/faststream/_internal/broker/pub_base.py b/faststream/_internal/broker/pub_base.py index 31cfb476fd..c45c36cced 100644 --- a/faststream/_internal/broker/pub_base.py +++ b/faststream/_internal/broker/pub_base.py @@ -37,7 +37,7 @@ async def _basic_publish( publish = producer.publish context = self.context # caches property - for m in self.middlewares: + for m in self.middlewares[::-1]: publish = partial(m(None, context=context).publish_scope, publish) return await publish(cmd) @@ -58,7 +58,7 @@ async def _basic_publish_batch( publish = producer.publish_batch context = self.context # caches property - for m in self.middlewares: + for m in self.middlewares[::-1]: publish = partial(m(None, context=context).publish_scope, publish) return await publish(cmd) @@ -82,7 +82,7 @@ async def _basic_request( request = producer.request context = self.context # caches property - for m in self.middlewares: + for m in self.middlewares[::-1]: request = partial(m(None, context=context).publish_scope, request) published_msg = await request(cmd) diff --git a/faststream/_internal/cli/docs/app.py b/faststream/_internal/cli/docs/app.py index d85f53de9c..d7c7b5951d 100644 --- a/faststream/_internal/cli/docs/app.py +++ b/faststream/_internal/cli/docs/app.py @@ -12,8 +12,12 @@ from faststream._internal.cli.utils.imports import import_from_string from faststream.exceptions import INSTALL_WATCHFILES, INSTALL_YAML, SCHEMA_NOT_SUPPORTED from faststream.specification.asyncapi.site import serve_app -from faststream.specification.asyncapi.v2_6_0.schema import Schema as SchemaV2_6 -from faststream.specification.asyncapi.v3_0_0.schema import Schema as SchemaV3 +from faststream.specification.asyncapi.v2_6_0.schema import ( + ApplicationSchema as SchemaV2_6, +) +from faststream.specification.asyncapi.v3_0_0.schema import ( + ApplicationSchema as SchemaV3, +) from faststream.specification.base.specification import Specification if TYPE_CHECKING: diff --git a/faststream/_internal/cli/utils/imports.py b/faststream/_internal/cli/utils/imports.py index 27be43cf05..860b69a42a 100644 --- a/faststream/_internal/cli/utils/imports.py +++ b/faststream/_internal/cli/utils/imports.py @@ -8,7 +8,9 @@ def import_from_string( - import_str: str, *, is_factory: bool = False + import_str: str, + *, + is_factory: bool = False, ) -> tuple[Path, object]: module_path, instance = _import_object_or_factory(import_str) diff --git a/faststream/_internal/fastapi/_compat.py b/faststream/_internal/fastapi/_compat.py index 2359c114e6..f4b423fb44 100644 --- a/faststream/_internal/fastapi/_compat.py +++ b/faststream/_internal/fastapi/_compat.py @@ -13,13 +13,28 @@ from fastapi.requests import Request major, minor, patch, *_ = FASTAPI_VERSION.split(".") -major = int(major) -minor = int(minor) -patch = int(patch) -FASTAPI_V2 = major > 0 or minor > 100 -FASTAPI_V106 = major > 0 or minor >= 106 -FASTAPI_v102_3 = major > 0 or minor > 112 or (minor == 112 and patch > 2) -FASTAPI_v102_4 = major > 0 or minor > 112 or (minor == 112 and patch > 3) + +_FASTAPI_MAJOR, _FASTAPI_MINOR = int(major), int(minor) + +FASTAPI_V2 = _FASTAPI_MAJOR > 0 or _FASTAPI_MINOR > 100 +FASTAPI_V106 = _FASTAPI_MAJOR > 0 or _FASTAPI_MINOR >= 106 + +try: + _FASTAPI_PATCH = int(patch) +except ValueError: + FASTAPI_v102_3 = True + FASTAPI_v102_4 = True +else: + FASTAPI_v102_3 = ( + _FASTAPI_MAJOR > 0 + or _FASTAPI_MINOR > 112 + or (_FASTAPI_MINOR == 112 and _FASTAPI_PATCH > 2) + ) + FASTAPI_v102_4 = ( + _FASTAPI_MAJOR > 0 + or _FASTAPI_MINOR > 112 + or (_FASTAPI_MINOR == 112 and _FASTAPI_PATCH > 3) + ) __all__ = ( "RequestValidationError", diff --git a/faststream/_internal/fastapi/get_dependant.py b/faststream/_internal/fastapi/get_dependant.py index 2db1b140d9..0c2d777455 100644 --- a/faststream/_internal/fastapi/get_dependant.py +++ b/faststream/_internal/fastapi/get_dependant.py @@ -89,7 +89,7 @@ def _patch_fastapi_dependent(dependant: "Dependant") -> "Dependant": lambda x: isinstance(x, FieldInfo), p.field_info.metadata or (), ), - Field(**field_data), # type: ignore[pydantic-field] + Field(**field_data), ) else: @@ -109,7 +109,7 @@ def _patch_fastapi_dependent(dependant: "Dependant") -> "Dependant": "le": info.field_info.le, }, ) - f = Field(**field_data) # type: ignore[pydantic-field] + f = Field(**field_data) params_unique[p.name] = ( info.annotation, diff --git a/faststream/_internal/fastapi/router.py b/faststream/_internal/fastapi/router.py index 4828893324..29452792b1 100644 --- a/faststream/_internal/fastapi/router.py +++ b/faststream/_internal/fastapi/router.py @@ -55,7 +55,7 @@ from faststream._internal.types import BrokerMiddleware from faststream.message import StreamMessage from faststream.specification.base.specification import Specification - from faststream.specification.schema.tag import Tag, TagDict + from faststream.specification.schema.extra import Tag, TagDict class _BackgroundMiddleware(BaseMiddleware): @@ -121,7 +121,7 @@ def __init__( generate_unique_id, ), # Specification information - specification_tags: Optional[Iterable[Union["Tag", "TagDict"]]] = None, + specification_tags: Iterable[Union["Tag", "TagDict"]] = (), schema_url: Optional[str] = "/asyncapi", **connection_kwars: Any, ) -> None: diff --git a/faststream/_internal/publisher/specified.py b/faststream/_internal/publisher/specified.py index 8ad62a1d00..a6e34a163b 100644 --- a/faststream/_internal/publisher/specified.py +++ b/faststream/_internal/publisher/specified.py @@ -1,11 +1,9 @@ from inspect import Parameter, unwrap -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Callable, Optional, Union from fast_depends.core import build_call_model from fast_depends.pydantic._compat import create_model, get_config_base -from faststream._internal.publisher.proto import PublisherProto -from faststream._internal.subscriber.call_wrapper.call import HandlerCallWrapper from faststream._internal.types import ( MsgType, P_HandlerParams, @@ -13,34 +11,40 @@ ) from faststream.specification.asyncapi.message import get_model_schema from faststream.specification.asyncapi.utils import to_camelcase -from faststream.specification.base.proto import SpecificationEndpoint +from faststream.specification.proto import EndpointSpecification +from faststream.specification.schema import PublisherSpec if TYPE_CHECKING: - from faststream._internal.basic_types import AnyDict + from faststream._internal.basic_types import AnyCallable, AnyDict + from faststream._internal.state import BrokerState, Pointer + from faststream._internal.subscriber.call_wrapper.call import HandlerCallWrapper -class BaseSpicificationPublisher(SpecificationEndpoint, PublisherProto[MsgType]): +class SpecificationPublisher(EndpointSpecification[PublisherSpec]): """A base class for publishers in an asynchronous API.""" + _state: "Pointer[BrokerState]" # should be set in next parent + def __init__( self, - *, + *args: Any, schema_: Optional[Any], - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, + **kwargs: Any, ) -> None: - self.calls = [] + self.calls: list[AnyCallable] = [] - self.title_ = title_ - self.description_ = description_ - self.include_in_schema = include_in_schema self.schema_ = schema_ + super().__init__(*args, **kwargs) + def __call__( self, - func: HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn], - ) -> HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn]: + func: Union[ + Callable[P_HandlerParams, T_HandlerReturn], + "HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn]", + ], + ) -> "HandlerCallWrapper[MsgType, P_HandlerParams, T_HandlerReturn]": + func = super().__call__(func) self.calls.append(func._original_call) return func diff --git a/faststream/_internal/publisher/usecase.py b/faststream/_internal/publisher/usecase.py index c729a5b13d..08094ca56f 100644 --- a/faststream/_internal/publisher/usecase.py +++ b/faststream/_internal/publisher/usecase.py @@ -27,8 +27,6 @@ ) from faststream.message.source_type import SourceType -from .specified import BaseSpicificationPublisher - if TYPE_CHECKING: from faststream._internal.publisher.proto import ProducerProto from faststream._internal.types import ( @@ -38,7 +36,7 @@ from faststream.response.response import PublishCommand -class PublisherUsecase(BaseSpicificationPublisher, PublisherProto[MsgType]): +class PublisherUsecase(PublisherProto[MsgType]): """A base class for publishers in an asynchronous API.""" def __init__( @@ -46,11 +44,6 @@ def __init__( *, broker_middlewares: Iterable["BrokerMiddleware[MsgType]"], middlewares: Iterable["PublisherMiddleware"], - # AsyncAPI args - schema_: Optional[Any], - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, ) -> None: self.middlewares = middlewares self._broker_middlewares = broker_middlewares @@ -60,13 +53,6 @@ def __init__( self._fake_handler = False self.mock: Optional[MagicMock] = None - super().__init__( - title_=title_, - description_=description_, - include_in_schema=include_in_schema, - schema_=schema_, - ) - self._state: Pointer[BrokerState] = Pointer( EmptyBrokerState("You should include publisher to any broker.") ) @@ -115,7 +101,6 @@ def __call__( ensure_call_wrapper(func) ) handler._publishers.append(self) - super().__call__(handler) return handler async def _basic_publish( @@ -129,14 +114,14 @@ async def _basic_publish( context = self._state.get().di_state.context for pub_m in chain( + self.middlewares[::-1], ( _extra_middlewares or ( m(None, context=context).publish_scope - for m in self._broker_middlewares + for m in self._broker_middlewares[::-1] ) ), - self.middlewares, ): pub = partial(pub_m, pub) @@ -151,8 +136,11 @@ async def _basic_request( context = self._state.get().di_state.context for pub_m in chain( - (m(None, context=context).publish_scope for m in self._broker_middlewares), - self.middlewares, + self.middlewares[::-1], + ( + m(None, context=context).publish_scope + for m in self._broker_middlewares[::-1] + ), ): request = partial(pub_m, request) @@ -161,7 +149,8 @@ async def _basic_request( response_msg: Any = await process_msg( msg=published_msg, middlewares=( - m(published_msg, context=context) for m in self._broker_middlewares + m(published_msg, context=context) + for m in self._broker_middlewares[::-1] ), parser=self._producer._parser, decoder=self._producer._decoder, @@ -180,14 +169,14 @@ async def _basic_publish_batch( context = self._state.get().di_state.context for pub_m in chain( + self.middlewares[::-1], ( _extra_middlewares or ( m(None, context=context).publish_scope - for m in self._broker_middlewares + for m in self._broker_middlewares[::-1] ) ), - self.middlewares, ): pub = partial(pub_m, pub) diff --git a/faststream/_internal/subscriber/call_item.py b/faststream/_internal/subscriber/call_item.py index 48814e9ea0..550da4badb 100644 --- a/faststream/_internal/subscriber/call_item.py +++ b/faststream/_internal/subscriber/call_item.py @@ -153,7 +153,7 @@ async def call( """Execute wrapped handler with consume middlewares.""" call: AsyncFuncAny = self.handler.call_wrapped - for middleware in chain(self.item_middlewares, _extra_middlewares): + for middleware in chain(self.item_middlewares[::-1], _extra_middlewares): call = partial(middleware, call) try: diff --git a/faststream/_internal/subscriber/mixins.py b/faststream/_internal/subscriber/mixins.py index 412f8f2c79..c76887b757 100644 --- a/faststream/_internal/subscriber/mixins.py +++ b/faststream/_internal/subscriber/mixins.py @@ -12,8 +12,8 @@ class TasksMixin(SubscriberUsecase[Any]): - def __init__(self, **kwargs: Any) -> None: - super().__init__(**kwargs) + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) self.tasks: list[asyncio.Task[Any]] = [] def add_task(self, coro: Coroutine[Any, Any, Any]) -> None: @@ -36,7 +36,7 @@ class ConcurrentMixin(TasksMixin): def __init__( self, - *, + *args: Any, max_workers: int, **kwargs: Any, ) -> None: @@ -47,7 +47,7 @@ def __init__( ) self.limiter = anyio.Semaphore(max_workers) - super().__init__(**kwargs) + super().__init__(*args, **kwargs) def start_consume_task(self) -> None: self.add_task(self._serve_consume_queue()) diff --git a/faststream/_internal/subscriber/proto.py b/faststream/_internal/subscriber/proto.py index a402009407..7ddc03d7ee 100644 --- a/faststream/_internal/subscriber/proto.py +++ b/faststream/_internal/subscriber/proto.py @@ -17,16 +17,16 @@ ProducerProto, ) from faststream._internal.state import BrokerState, Pointer - from faststream._internal.subscriber.call_item import HandlerItem from faststream._internal.types import ( BrokerMiddleware, CustomCallable, - Filter, SubscriberMiddleware, ) from faststream.message import StreamMessage from faststream.response import Response + from .call_item import HandlerItem + class SubscriberProto( Endpoint, @@ -68,10 +68,6 @@ def _make_response_publisher( message: "StreamMessage[MsgType]", ) -> Iterable["BasePublisherProto"]: ... - @property - @abstractmethod - def call_name(self) -> str: ... - @abstractmethod async def start(self) -> None: ... @@ -95,7 +91,6 @@ async def get_one( def add_call( self, *, - filter_: "Filter[Any]", parser_: "CustomCallable", decoder_: "CustomCallable", middlewares_: Iterable["SubscriberMiddleware[Any]"], diff --git a/faststream/_internal/subscriber/specified.py b/faststream/_internal/subscriber/specified.py index e6dec70970..3af87b590c 100644 --- a/faststream/_internal/subscriber/specified.py +++ b/faststream/_internal/subscriber/specified.py @@ -1,30 +1,38 @@ from typing import ( TYPE_CHECKING, + Any, Optional, ) -from faststream._internal.subscriber.proto import SubscriberProto -from faststream._internal.types import MsgType from faststream.exceptions import SetupError from faststream.specification.asyncapi.message import parse_handler_params from faststream.specification.asyncapi.utils import to_camelcase -from faststream.specification.base.proto import SpecificationEndpoint +from faststream.specification.proto import EndpointSpecification +from faststream.specification.schema import SubscriberSpec if TYPE_CHECKING: from faststream._internal.basic_types import AnyDict + from faststream._internal.types import ( + MsgType, + ) + from .call_item import HandlerItem + + +class SpecificationSubscriber( + EndpointSpecification[SubscriberSpec], +): + calls: list["HandlerItem[MsgType]"] -class BaseSpicificationSubscriber(SpecificationEndpoint, SubscriberProto[MsgType]): def __init__( self, - *, - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, + *args: Any, + **kwargs: Any, ) -> None: - self.title_ = title_ - self.description_ = description_ - self.include_in_schema = include_in_schema + self.calls = [] + + # Call next base class parent init + super().__init__(*args, **kwargs) @property def call_name(self) -> str: @@ -34,9 +42,9 @@ def call_name(self) -> str: return to_camelcase(self.calls[0].call_name) - def get_description(self) -> Optional[str]: + def get_default_description(self) -> Optional[str]: """Returns the description of the handler.""" - if not self.calls: # pragma: no cover + if not self.calls: return None return self.calls[0].description diff --git a/faststream/_internal/subscriber/usecase.py b/faststream/_internal/subscriber/usecase.py index c8ee25678a..069185c612 100644 --- a/faststream/_internal/subscriber/usecase.py +++ b/faststream/_internal/subscriber/usecase.py @@ -34,8 +34,6 @@ from faststream.middlewares.logging import CriticalLogMiddleware from faststream.response import ensure_response -from .specified import BaseSpicificationSubscriber - if TYPE_CHECKING: from fast_depends.dependencies import Dependant @@ -79,7 +77,7 @@ def __init__( self.dependencies = dependencies -class SubscriberUsecase(BaseSpicificationSubscriber, SubscriberProto[MsgType]): +class SubscriberUsecase(SubscriberProto[MsgType]): """A class representing an asynchronous handler.""" lock: "AbstractContextManager[Any]" @@ -100,18 +98,8 @@ def __init__( default_parser: "AsyncCallable", default_decoder: "AsyncCallable", ack_policy: AckPolicy, - # AsyncAPI information - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, ) -> None: """Initialize a new instance of the class.""" - super().__init__( - title_=title_, - description_=description_, - include_in_schema=include_in_schema, - ) - self.calls = [] self._parser = default_parser @@ -349,7 +337,9 @@ async def process_message(self, msg: MsgType) -> "Response": await h.call( message=message, # consumer middlewares - _extra_middlewares=(m.consume_scope for m in middlewares), + _extra_middlewares=( + m.consume_scope for m in middlewares[::-1] + ), ), ) @@ -362,7 +352,9 @@ async def process_message(self, msg: MsgType) -> "Response": ): await p._publish( result_msg.as_publish_command(), - _extra_middlewares=(m.publish_scope for m in middlewares), + _extra_middlewares=( + m.publish_scope for m in middlewares[::-1] + ), ) # Return data for tests diff --git a/faststream/_internal/subscriber/utils.py b/faststream/_internal/subscriber/utils.py index 213a52c414..31f2c5358d 100644 --- a/faststream/_internal/subscriber/utils.py +++ b/faststream/_internal/subscriber/utils.py @@ -13,7 +13,7 @@ ) import anyio -from typing_extensions import Literal, Self, overload +from typing_extensions import Self from faststream._internal.types import MsgType from faststream._internal.utils.functions import return_input, to_async @@ -31,26 +31,6 @@ from faststream.middlewares import BaseMiddleware -@overload -async def process_msg( - msg: Literal[None], - middlewares: Iterable["BaseMiddleware"], - parser: Callable[[MsgType], Awaitable["StreamMessage[MsgType]"]], - decoder: Callable[["StreamMessage[MsgType]"], "Any"], - source_type: SourceType = SourceType.CONSUME, -) -> None: ... - - -@overload -async def process_msg( - msg: MsgType, - middlewares: Iterable["BaseMiddleware"], - parser: Callable[[MsgType], Awaitable["StreamMessage[MsgType]"]], - decoder: Callable[["StreamMessage[MsgType]"], "Any"], - source_type: SourceType = SourceType.CONSUME, -) -> "StreamMessage[MsgType]": ... - - async def process_msg( msg: Optional[MsgType], middlewares: Iterable["BaseMiddleware"], diff --git a/faststream/_internal/utils/data.py b/faststream/_internal/utils/data.py index cc12c4cec2..98e3729fac 100644 --- a/faststream/_internal/utils/data.py +++ b/faststream/_internal/utils/data.py @@ -5,8 +5,19 @@ TypedDictCls = TypeVar("TypedDictCls") -def filter_by_dict(typed_dict: type[TypedDictCls], data: AnyDict) -> TypedDictCls: +def filter_by_dict( + typed_dict: type[TypedDictCls], + data: AnyDict, +) -> tuple[TypedDictCls, AnyDict]: annotations = typed_dict.__annotations__ - return typed_dict( # type: ignore[call-arg] - {k: v for k, v in data.items() if k in annotations}, - ) + + out_data = {} + extra_data = {} + + for k, v in data.items(): + if k in annotations: + out_data[k] = v + else: + extra_data[k] = v + + return typed_dict(out_data), extra_data diff --git a/faststream/confluent/broker/broker.py b/faststream/confluent/broker/broker.py index d8a1cd6671..7a4276781b 100644 --- a/faststream/confluent/broker/broker.py +++ b/faststream/confluent/broker/broker.py @@ -53,12 +53,12 @@ from faststream.confluent.config import ConfluentConfig from faststream.confluent.message import KafkaMessage from faststream.security import BaseSecurity - from faststream.specification.schema.tag import Tag, TagDict + from faststream.specification.schema.extra import Tag, TagDict Partition = TypeVar("Partition") -class KafkaBroker( +class KafkaBroker( # type: ignore[misc] KafkaRegistrator, BrokerUsecase[ Union[ @@ -304,9 +304,9 @@ def __init__( Doc("AsyncAPI server description."), ] = None, tags: Annotated[ - Optional[Iterable[Union["Tag", "TagDict"]]], + Iterable[Union["Tag", "TagDict"]], Doc("AsyncAPI server tags."), - ] = None, + ] = (), # logging args logger: Annotated[ Optional["LoggerProto"], @@ -452,9 +452,10 @@ async def _connect( # type: ignore[override] self._producer.connect(native_producer) + connection_kwargs, _ = filter_by_dict(ConsumerConnectionParams, kwargs) return partial( AsyncConfluentConsumer, - **filter_by_dict(ConsumerConnectionParams, kwargs), + **connection_kwargs, logger=self._state.get().logger_state, config=self.config, ) diff --git a/faststream/confluent/broker/registrator.py b/faststream/confluent/broker/registrator.py index c9851d1db4..d372f003c9 100644 --- a/faststream/confluent/broker/registrator.py +++ b/faststream/confluent/broker/registrator.py @@ -16,7 +16,6 @@ from faststream._internal.constants import EMPTY from faststream.confluent.publisher.factory import create_publisher from faststream.confluent.subscriber.factory import create_subscriber -from faststream.exceptions import SetupError from faststream.middlewares import AckPolicy if TYPE_CHECKING: @@ -50,10 +49,10 @@ class KafkaRegistrator( ): """Includable to KafkaBroker router.""" - _subscribers: list[ + _subscribers: list[ # type: ignore[assignment] Union["SpecificationBatchSubscriber", "SpecificationDefaultSubscriber"] ] - _publishers: list[ + _publishers: list[ # type: ignore[assignment] Union["SpecificationBatchPublisher", "SpecificationDefaultPublisher"] ] @@ -167,7 +166,7 @@ def subscriber( Please, use `ack_policy=AckPolicy.ACK_FIRST` instead. """, ), - ] = True, + ] = EMPTY, auto_commit_interval_ms: Annotated[ int, Doc( @@ -302,10 +301,15 @@ def subscriber( Iterable["SubscriberMiddleware[KafkaMessage]"], Doc("Subscriber middlewares to wrap incoming message processing."), ] = (), - ack_policy: Annotated[ - AckPolicy, + no_ack: Annotated[ + bool, Doc("Whether to disable **FastStream** auto acknowledgement logic or not."), + deprecated( + "This option was deprecated in 0.6.0 to prior to **ack_policy=AckPolicy.DO_NOTHING**. " + "Scheduled to remove in 0.7.0" + ), ] = EMPTY, + ack_policy: AckPolicy = EMPTY, no_reply: Annotated[ bool, Doc( @@ -440,7 +444,7 @@ def subscriber( Please, use `ack_policy=AckPolicy.ACK_FIRST` instead. """, ), - ] = True, + ] = EMPTY, auto_commit_interval_ms: Annotated[ int, Doc( @@ -575,10 +579,15 @@ def subscriber( Iterable["SubscriberMiddleware[KafkaMessage]"], Doc("Subscriber middlewares to wrap incoming message processing."), ] = (), - ack_policy: Annotated[ - AckPolicy, + no_ack: Annotated[ + bool, Doc("Whether to disable **FastStream** auto acknowledgement logic or not."), + deprecated( + "This option was deprecated in 0.6.0 to prior to **ack_policy=AckPolicy.DO_NOTHING**. " + "Scheduled to remove in 0.7.0" + ), ] = EMPTY, + ack_policy: AckPolicy = EMPTY, no_reply: Annotated[ bool, Doc( @@ -713,7 +722,7 @@ def subscriber( Please, use `ack_policy=AckPolicy.ACK_FIRST` instead. """, ), - ] = True, + ] = EMPTY, auto_commit_interval_ms: Annotated[ int, Doc( @@ -848,10 +857,15 @@ def subscriber( Iterable["SubscriberMiddleware[KafkaMessage]"], Doc("Subscriber middlewares to wrap incoming message processing."), ] = (), - ack_policy: Annotated[ - AckPolicy, + no_ack: Annotated[ + bool, Doc("Whether to disable **FastStream** auto acknowledgement logic or not."), + deprecated( + "This option was deprecated in 0.6.0 to prior to **ack_policy=AckPolicy.DO_NOTHING**. " + "Scheduled to remove in 0.7.0" + ), ] = EMPTY, + ack_policy: AckPolicy = EMPTY, no_reply: Annotated[ bool, Doc( @@ -989,7 +1003,7 @@ def subscriber( Please, use `ack_policy=AckPolicy.ACK_FIRST` instead. """, ), - ] = True, + ] = EMPTY, auto_commit_interval_ms: Annotated[ int, Doc( @@ -1124,10 +1138,15 @@ def subscriber( Iterable["SubscriberMiddleware[KafkaMessage]"], Doc("Subscriber middlewares to wrap incoming message processing."), ] = (), - ack_policy: Annotated[ - AckPolicy, + no_ack: Annotated[ + bool, Doc("Whether to disable **FastStream** auto acknowledgement logic or not."), + deprecated( + "This option was deprecated in 0.6.0 to prior to **ack_policy=AckPolicy.DO_NOTHING**. " + "Scheduled to remove in 0.7.0" + ), ] = EMPTY, + ack_policy: AckPolicy = EMPTY, no_reply: Annotated[ bool, Doc( @@ -1154,52 +1173,49 @@ def subscriber( "SpecificationDefaultSubscriber", "SpecificationBatchSubscriber", ]: - - subscriber = super().subscriber( - create_subscriber( - *topics, - polling_interval=polling_interval, - partitions=partitions, - batch=batch, - max_records=max_records, - group_id=group_id, - connection_data={ - "group_instance_id": group_instance_id, - "fetch_max_wait_ms": fetch_max_wait_ms, - "fetch_max_bytes": fetch_max_bytes, - "fetch_min_bytes": fetch_min_bytes, - "max_partition_fetch_bytes": max_partition_fetch_bytes, - "auto_offset_reset": auto_offset_reset, - "enable_auto_commit": auto_commit, - "auto_commit_interval_ms": auto_commit_interval_ms, - "check_crcs": check_crcs, - "partition_assignment_strategy": partition_assignment_strategy, - "max_poll_interval_ms": max_poll_interval_ms, - "session_timeout_ms": session_timeout_ms, - "heartbeat_interval_ms": heartbeat_interval_ms, - "isolation_level": isolation_level, - }, - is_manual=not auto_commit, - # subscriber args - ack_policy=ack_policy, - no_reply=no_reply, - broker_middlewares=self.middlewares, - broker_dependencies=self._dependencies, - # Specification - title_=title, - description_=description, - include_in_schema=self._solve_include_in_schema(include_in_schema), - ), + subscriber = create_subscriber( + *topics, + polling_interval=polling_interval, + partitions=partitions, + batch=batch, + max_records=max_records, + group_id=group_id, + connection_data={ + "group_instance_id": group_instance_id, + "fetch_max_wait_ms": fetch_max_wait_ms, + "fetch_max_bytes": fetch_max_bytes, + "fetch_min_bytes": fetch_min_bytes, + "max_partition_fetch_bytes": max_partition_fetch_bytes, + "auto_offset_reset": auto_offset_reset, + "auto_commit_interval_ms": auto_commit_interval_ms, + "check_crcs": check_crcs, + "partition_assignment_strategy": partition_assignment_strategy, + "max_poll_interval_ms": max_poll_interval_ms, + "session_timeout_ms": session_timeout_ms, + "heartbeat_interval_ms": heartbeat_interval_ms, + "isolation_level": isolation_level, + }, + auto_commit=auto_commit, + # subscriber args + ack_policy=ack_policy, + no_ack=no_ack, + no_reply=no_reply, + broker_middlewares=self.middlewares, + broker_dependencies=self._dependencies, + # Specification + title_=title, + description_=description, + include_in_schema=self._solve_include_in_schema(include_in_schema), ) if batch: - return cast("SpecificationBatchSubscriber", subscriber).add_call( - parser_=parser or self._parser, - decoder_=decoder or self._decoder, - dependencies_=dependencies, - middlewares_=middlewares, - ) - return cast("SpecificationDefaultSubscriber", subscriber).add_call( + subscriber = cast("SpecificationBatchSubscriber", subscriber) + else: + subscriber = cast("SpecificationDefaultSubscriber", subscriber) + + subscriber = super().subscriber(subscriber) # type: ignore[arg-type,assignment] + + return subscriber.add_call( parser_=parser or self._parser, decoder_=decoder or self._decoder, dependencies_=dependencies, @@ -1535,6 +1551,8 @@ def publisher( ) if batch: - return cast("SpecificationBatchPublisher", super().publisher(publisher)) + publisher = cast("SpecificationBatchPublisher", publisher) + else: + publisher = cast("SpecificationDefaultPublisher", publisher) - return cast("SpecificationDefaultPublisher", super().publisher(publisher)) + return super().publisher(publisher) diff --git a/faststream/confluent/client.py b/faststream/confluent/client.py index 7b7b0b8fe2..bd1736ddb5 100644 --- a/faststream/confluent/client.py +++ b/faststream/confluent/client.py @@ -383,7 +383,7 @@ async def getmany( ) -> tuple[Message, ...]: """Consumes a batch of messages from Kafka and groups them by topic and partition.""" raw_messages: list[Optional[Message]] = await call_or_await( - self.consumer.consume, + self.consumer.consume, # type: ignore[arg-type] num_messages=max_records or 10, timeout=timeout, ) diff --git a/faststream/confluent/fastapi/fastapi.py b/faststream/confluent/fastapi/fastapi.py index 2c7c302365..18686adc86 100644 --- a/faststream/confluent/fastapi/fastapi.py +++ b/faststream/confluent/fastapi/fastapi.py @@ -53,7 +53,7 @@ SpecificationDefaultSubscriber, ) from faststream.security import BaseSecurity - from faststream.specification.schema.tag import Tag, TagDict + from faststream.specification.schema.extra import Tag, TagDict Partition = TypeVar("Partition") @@ -296,9 +296,9 @@ def __init__( Doc("Specification server description."), ] = None, specification_tags: Annotated[ - Optional[Iterable[Union["Tag", "TagDict"]]], + Iterable[Union["Tag", "TagDict"]], Doc("Specification server tags."), - ] = None, + ] = (), # logging args logger: Annotated[ Optional["LoggerProto"], @@ -561,7 +561,7 @@ def __init__( graceful_timeout=graceful_timeout, decoder=decoder, parser=parser, - middlewares=middlewares, + middlewares=middlewares, # type: ignore[arg-type] schema_url=schema_url, setup_state=setup_state, # logger options @@ -705,7 +705,7 @@ def subscriber( Please, use `ack_policy=AckPolicy.ACK_FIRST` instead. """, ), - ] = True, + ] = EMPTY, auto_commit_interval_ms: Annotated[ int, Doc( @@ -840,10 +840,15 @@ def subscriber( Iterable["SubscriberMiddleware[KafkaMessage]"], Doc("Subscriber middlewares to wrap incoming message processing."), ] = (), - ack_policy: Annotated[ - AckPolicy, + no_ack: Annotated[ + bool, Doc("Whether to disable **FastStream** auto acknowledgement logic or not."), + deprecated( + "This option was deprecated in 0.6.0 to prior to **ack_policy=AckPolicy.DO_NOTHING**. " + "Scheduled to remove in 0.7.0" + ), ] = EMPTY, + ack_policy: AckPolicy = EMPTY, no_reply: Annotated[ bool, Doc( @@ -1101,7 +1106,7 @@ def subscriber( Please, use `ack_policy=AckPolicy.ACK_FIRST` instead. """, ), - ] = True, + ] = EMPTY, auto_commit_interval_ms: Annotated[ int, Doc( @@ -1487,7 +1492,7 @@ def subscriber( Please, use `ack_policy=AckPolicy.ACK_FIRST` instead. """, ), - ] = True, + ] = EMPTY, auto_commit_interval_ms: Annotated[ int, Doc( @@ -1622,10 +1627,15 @@ def subscriber( Iterable["SubscriberMiddleware[KafkaMessage]"], Doc("Subscriber middlewares to wrap incoming message processing."), ] = (), - ack_policy: Annotated[ - AckPolicy, + no_ack: Annotated[ + bool, Doc("Whether to disable **FastStream** auto acknowledgement logic or not."), + deprecated( + "This option was deprecated in 0.6.0 to prior to **ack_policy=AckPolicy.DO_NOTHING**. " + "Scheduled to remove in 0.7.0" + ), ] = EMPTY, + ack_policy: AckPolicy = EMPTY, no_reply: Annotated[ bool, Doc( @@ -1886,7 +1896,7 @@ def subscriber( Please, use `ack_policy=AckPolicy.ACK_FIRST` instead. """, ), - ] = True, + ] = EMPTY, auto_commit_interval_ms: Annotated[ int, Doc( @@ -2021,10 +2031,15 @@ def subscriber( Iterable["SubscriberMiddleware[KafkaMessage]"], Doc("Subscriber middlewares to wrap incoming message processing."), ] = (), - ack_policy: Annotated[ - AckPolicy, + no_ack: Annotated[ + bool, Doc("Whether to disable **FastStream** auto acknowledgement logic or not."), + deprecated( + "This option was deprecated in 0.6.0 to prior to **ack_policy=AckPolicy.DO_NOTHING**. " + "Scheduled to remove in 0.7.0" + ), ] = EMPTY, + ack_policy: AckPolicy = EMPTY, no_reply: Annotated[ bool, Doc( @@ -2174,7 +2189,6 @@ def subscriber( "SpecificationBatchSubscriber", "SpecificationDefaultSubscriber", ]: - subscriber = super().subscriber( *topics, polling_interval=polling_interval, @@ -2202,6 +2216,7 @@ def subscriber( decoder=decoder, middlewares=middlewares, ack_policy=ack_policy, + no_ack=no_ack, no_reply=no_reply, title=title, description=description, diff --git a/faststream/confluent/message.py b/faststream/confluent/message.py index 14ea16efc2..8adb249384 100644 --- a/faststream/confluent/message.py +++ b/faststream/confluent/message.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING, Any, Optional, Protocol, Union -from faststream.message import StreamMessage +from faststream.message import AckStatus, StreamMessage if TYPE_CHECKING: from confluent_kafka import Message @@ -59,9 +59,12 @@ def __init__( ) -> None: super().__init__(*args, **kwargs) - self.is_manual = is_manual self.consumer = consumer + self.is_manual = is_manual + if not is_manual: + self.committed = AckStatus.ACKED + async def ack(self) -> None: """Acknowledge the Kafka message.""" if self.is_manual and not self.committed: diff --git a/faststream/confluent/publisher/factory.py b/faststream/confluent/publisher/factory.py index 284536604d..15a57de0ff 100644 --- a/faststream/confluent/publisher/factory.py +++ b/faststream/confluent/publisher/factory.py @@ -5,6 +5,7 @@ Literal, Optional, Union, + cast, overload, ) @@ -68,8 +69,9 @@ def create_publisher( headers: Optional[dict[str, str]], reply_to: str, # Publisher args - broker_middlewares: Iterable[ - "BrokerMiddleware[Union[tuple[ConfluentMsg, ...], ConfluentMsg]]" + broker_middlewares: Union[ + Iterable["BrokerMiddleware[ConfluentMsg]"], + Iterable["BrokerMiddleware[tuple[ConfluentMsg, ...]]"], ], middlewares: Iterable["PublisherMiddleware"], # Specification args @@ -92,8 +94,9 @@ def create_publisher( headers: Optional[dict[str, str]], reply_to: str, # Publisher args - broker_middlewares: Iterable[ - "BrokerMiddleware[Union[tuple[ConfluentMsg, ...], ConfluentMsg]]" + broker_middlewares: Union[ + Iterable["BrokerMiddleware[ConfluentMsg]"], + Iterable["BrokerMiddleware[tuple[ConfluentMsg, ...]]"], ], middlewares: Iterable["PublisherMiddleware"], # Specification args @@ -115,7 +118,10 @@ def create_publisher( partition=partition, headers=headers, reply_to=reply_to, - broker_middlewares=broker_middlewares, + broker_middlewares=cast( + Iterable["BrokerMiddleware[tuple[ConfluentMsg, ...]]"], + broker_middlewares, + ), middlewares=middlewares, schema_=schema_, title_=title_, @@ -130,7 +136,10 @@ def create_publisher( partition=partition, headers=headers, reply_to=reply_to, - broker_middlewares=broker_middlewares, + broker_middlewares=cast( + Iterable["BrokerMiddleware[ConfluentMsg]"], + broker_middlewares, + ), middlewares=middlewares, schema_=schema_, title_=title_, diff --git a/faststream/confluent/publisher/specified.py b/faststream/confluent/publisher/specified.py index fec0faf183..69b4ca499b 100644 --- a/faststream/confluent/publisher/specified.py +++ b/faststream/confluent/publisher/specified.py @@ -1,58 +1,43 @@ -from typing import ( - TYPE_CHECKING, -) - -from faststream._internal.types import MsgType -from faststream.confluent.publisher.usecase import ( - BatchPublisher, - DefaultPublisher, - LogicPublisher, +from faststream._internal.publisher.specified import ( + SpecificationPublisher as SpecificationPublisherMixin, ) +from faststream.confluent.publisher.usecase import BatchPublisher, DefaultPublisher from faststream.specification.asyncapi.utils import resolve_payloads +from faststream.specification.schema import Message, Operation, PublisherSpec from faststream.specification.schema.bindings import ChannelBinding, kafka -from faststream.specification.schema.channel import Channel -from faststream.specification.schema.message import CorrelationId, Message -from faststream.specification.schema.operation import Operation -if TYPE_CHECKING: - from confluent_kafka import Message as ConfluentMsg - -class SpecificationPublisher(LogicPublisher[MsgType]): +class SpecificationPublisher(SpecificationPublisherMixin): """A class representing a publisher.""" - def get_name(self) -> str: + def get_default_name(self) -> str: return f"{self.topic}:Publisher" - def get_schema(self) -> dict[str, Channel]: + def get_schema(self) -> dict[str, PublisherSpec]: payloads = self.get_payloads() return { - self.name: Channel( + self.name: PublisherSpec( description=self.description, - publish=Operation( + operation=Operation( message=Message( title=f"{self.name}:Message", payload=resolve_payloads(payloads, "Publisher"), - correlationId=CorrelationId( - location="$message.header#/correlation_id", - ), ), + bindings=None, + ), + bindings=ChannelBinding( + kafka=kafka.ChannelBinding( + topic=self.topic, partitions=None, replicas=None + ) ), - bindings=ChannelBinding(kafka=kafka.ChannelBinding(topic=self.topic)), ), } -class SpecificationBatchPublisher( - BatchPublisher, - SpecificationPublisher[tuple["ConfluentMsg", ...]], -): +class SpecificationBatchPublisher(SpecificationPublisher, BatchPublisher): pass -class SpecificationDefaultPublisher( - DefaultPublisher, - SpecificationPublisher["ConfluentMsg"], -): +class SpecificationDefaultPublisher(SpecificationPublisher, DefaultPublisher): pass diff --git a/faststream/confluent/publisher/usecase.py b/faststream/confluent/publisher/usecase.py index d6b7132155..e7cd4a1fb4 100644 --- a/faststream/confluent/publisher/usecase.py +++ b/faststream/confluent/publisher/usecase.py @@ -1,7 +1,6 @@ from collections.abc import Iterable from typing import ( TYPE_CHECKING, - Any, Optional, Union, ) @@ -40,20 +39,10 @@ def __init__( # Publisher args broker_middlewares: Iterable["BrokerMiddleware[MsgType]"], middlewares: Iterable["PublisherMiddleware"], - # AsyncAPI args - schema_: Optional[Any], - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, ) -> None: super().__init__( broker_middlewares=broker_middlewares, middlewares=middlewares, - # AsyncAPI args - schema_=schema_, - title_=title_, - description_=description_, - include_in_schema=include_in_schema, ) self.topic = topic @@ -105,11 +94,6 @@ def __init__( # Publisher args broker_middlewares: Iterable["BrokerMiddleware[Message]"], middlewares: Iterable["PublisherMiddleware"], - # AsyncAPI args - schema_: Optional[Any], - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, ) -> None: super().__init__( topic=topic, @@ -119,11 +103,6 @@ def __init__( # publisher args broker_middlewares=broker_middlewares, middlewares=middlewares, - # AsyncAPI args - schema_=schema_, - title_=title_, - description_=description_, - include_in_schema=include_in_schema, ) self.key = key diff --git a/faststream/confluent/router.py b/faststream/confluent/router.py index 52af6078e0..94b78f849a 100644 --- a/faststream/confluent/router.py +++ b/faststream/confluent/router.py @@ -253,7 +253,7 @@ def __init__( Please, use `ack_policy=AckPolicy.ACK_FIRST` instead. """, ), - ] = True, + ] = EMPTY, auto_commit_interval_ms: Annotated[ int, Doc( @@ -388,10 +388,15 @@ def __init__( Iterable["SubscriberMiddleware[KafkaMessage]"], Doc("Subscriber middlewares to wrap incoming message processing."), ] = (), - ack_policy: Annotated[ - AckPolicy, + no_ack: Annotated[ + bool, Doc("Whether to disable **FastStream** auto acknowledgement logic or not."), + deprecated( + "This option was deprecated in 0.6.0 to prior to **ack_policy=AckPolicy.DO_NOTHING**. " + "Scheduled to remove in 0.7.0" + ), ] = EMPTY, + ack_policy: AckPolicy = EMPTY, no_reply: Annotated[ bool, Doc( @@ -415,7 +420,6 @@ def __init__( Doc("Whetever to include operation in AsyncAPI schema or not."), ] = True, ) -> None: - super().__init__( call, *topics, @@ -449,8 +453,8 @@ def __init__( title=title, description=description, include_in_schema=include_in_schema, - # FastDepends args ack_policy=ack_policy, + no_ack=no_ack, ) @@ -509,7 +513,7 @@ def __init__( # basic args prefix=prefix, dependencies=dependencies, - middlewares=middlewares, + middlewares=middlewares, # type: ignore[arg-type] parser=parser, decoder=decoder, include_in_schema=include_in_schema, diff --git a/faststream/confluent/subscriber/factory.py b/faststream/confluent/subscriber/factory.py index a5e08164b2..2248d810cf 100644 --- a/faststream/confluent/subscriber/factory.py +++ b/faststream/confluent/subscriber/factory.py @@ -5,15 +5,16 @@ Literal, Optional, Union, + cast, overload, ) from faststream._internal.constants import EMPTY -from faststream.exceptions import SetupError from faststream.confluent.subscriber.specified import ( SpecificationBatchSubscriber, SpecificationDefaultSubscriber, ) +from faststream.exceptions import SetupError from faststream.middlewares import AckPolicy if TYPE_CHECKING: @@ -35,9 +36,10 @@ def create_subscriber( # Kafka information group_id: Optional[str], connection_data: "AnyDict", - is_manual: bool, + auto_commit: bool, # Subscriber args ack_policy: "AckPolicy", + no_ack: bool, no_reply: bool, broker_dependencies: Iterable["Dependant"], broker_middlewares: Iterable["BrokerMiddleware[tuple[ConfluentMsg, ...]]"], @@ -58,9 +60,10 @@ def create_subscriber( # Kafka information group_id: Optional[str], connection_data: "AnyDict", - is_manual: bool, + auto_commit: bool, # Subscriber args ack_policy: "AckPolicy", + no_ack: bool, no_reply: bool, broker_dependencies: Iterable["Dependant"], broker_middlewares: Iterable["BrokerMiddleware[ConfluentMsg]"], @@ -81,13 +84,15 @@ def create_subscriber( # Kafka information group_id: Optional[str], connection_data: "AnyDict", - is_manual: bool, + auto_commit: bool, # Subscriber args ack_policy: "AckPolicy", + no_ack: bool, no_reply: bool, broker_dependencies: Iterable["Dependant"], - broker_middlewares: Iterable[ - "BrokerMiddleware[Union[ConfluentMsg, tuple[ConfluentMsg, ...]]]" + broker_middlewares: Union[ + Iterable["BrokerMiddleware[tuple[ConfluentMsg, ...]]"], + Iterable["BrokerMiddleware[ConfluentMsg]"], ], # Specification args title_: Optional[str], @@ -108,13 +113,15 @@ def create_subscriber( # Kafka information group_id: Optional[str], connection_data: "AnyDict", - is_manual: bool, + auto_commit: bool, # Subscriber args ack_policy: "AckPolicy", + no_ack: bool, no_reply: bool, broker_dependencies: Iterable["Dependant"], - broker_middlewares: Iterable[ - "BrokerMiddleware[Union[ConfluentMsg, tuple[ConfluentMsg, ...]]]" + broker_middlewares: Union[ + Iterable["BrokerMiddleware[tuple[ConfluentMsg, ...]]"], + Iterable["BrokerMiddleware[ConfluentMsg]"], ], # Specification args title_: Optional[str], @@ -125,14 +132,26 @@ def create_subscriber( "SpecificationBatchSubscriber", ]: _validate_input_for_misconfigure( - ack_policy=ack_policy, is_manual=is_manual, group_id=group_id, + *topics, + partitions=partitions, + ack_policy=ack_policy, + no_ack=no_ack, + auto_commit=auto_commit, + group_id=group_id, ) + if auto_commit is not EMPTY: + ack_policy = AckPolicy.ACK_FIRST if auto_commit else AckPolicy.REJECT_ON_ERROR + + if no_ack is not EMPTY: + ack_policy = AckPolicy.DO_NOTHING if no_ack else EMPTY + if ack_policy is EMPTY: - if not is_manual: - ack_policy = AckPolicy.DO_NOTHING - else: - ack_policy = AckPolicy.REJECT_ON_ERROR + ack_policy = AckPolicy.ACK_FIRST + + if ack_policy is AckPolicy.ACK_FIRST: + connection_data["enable_auto_commit"] = True + ack_policy = AckPolicy.DO_NOTHING if batch: return SpecificationBatchSubscriber( @@ -142,26 +161,31 @@ def create_subscriber( max_records=max_records, group_id=group_id, connection_data=connection_data, - is_manual=is_manual, ack_policy=ack_policy, no_reply=no_reply, broker_dependencies=broker_dependencies, - broker_middlewares=broker_middlewares, + broker_middlewares=cast( + Iterable["BrokerMiddleware[tuple[ConfluentMsg, ...]]"], + broker_middlewares, + ), title_=title_, description_=description_, include_in_schema=include_in_schema, ) + return SpecificationDefaultSubscriber( *topics, partitions=partitions, polling_interval=polling_interval, group_id=group_id, connection_data=connection_data, - is_manual=is_manual, ack_policy=ack_policy, no_reply=no_reply, broker_dependencies=broker_dependencies, - broker_middlewares=broker_middlewares, + broker_middlewares=cast( + Iterable["BrokerMiddleware[ConfluentMsg]"], + broker_middlewares, + ), title_=title_, description_=description_, include_in_schema=include_in_schema, @@ -169,25 +193,49 @@ def create_subscriber( def _validate_input_for_misconfigure( - *, + *topics: str, + partitions: Sequence["TopicPartition"], ack_policy: "AckPolicy", - is_manual: bool, + auto_commit: bool, + no_ack: bool, group_id: Optional[str], ) -> None: - if not is_manual and ack_policy is not EMPTY and ack_policy is not AckPolicy.ACK_FIRST: + if auto_commit is not EMPTY: warnings.warn( - "You can't use ack_policy other then AckPolicy.ACK_FIRST with `auto_commit=True`", - RuntimeWarning, + "`auto_commit` option was deprecated in prior to `ack_policy=AckPolicy.ACK_FIRST`. Scheduled to remove in 0.7.0", + category=DeprecationWarning, stacklevel=4, ) - elif is_manual and ack_policy is not EMPTY and ack_policy is AckPolicy.ACK_FIRST: + + if ack_policy is not EMPTY: + msg = "You can't use deprecated `auto_commit` and `ack_policy` simultaneously. Please, use `ack_policy` only." + raise SetupError(msg) + + ack_policy = AckPolicy.ACK_FIRST if auto_commit else AckPolicy.REJECT_ON_ERROR + + if no_ack is not EMPTY: warnings.warn( - "You can't use AckPolicy.ACK_FIRST with `auto_commit=False`", - RuntimeWarning, + "`no_ack` option was deprecated in prior to `ack_policy=AckPolicy.DO_NOTHING`. Scheduled to remove in 0.7.0", + category=DeprecationWarning, stacklevel=4, ) - if is_manual and not group_id: + if ack_policy is not EMPTY: + msg = "You can't use deprecated `no_ack` and `ack_policy` simultaneously. Please, use `ack_policy` only." + raise SetupError(msg) + + ack_policy = AckPolicy.DO_NOTHING if no_ack else EMPTY + + if ack_policy is EMPTY: + ack_policy = AckPolicy.ACK_FIRST + + if not group_id and ack_policy is not AckPolicy.ACK_FIRST: msg = "You must use `group_id` with manual commit mode." raise SetupError(msg) + if not topics and not partitions: + msg = "You should provide either `topics` or `partitions`." + raise SetupError(msg) + if topics and partitions: + msg = "You can't provide both `topics` and `partitions`." + raise SetupError(msg) diff --git a/faststream/confluent/subscriber/specified.py b/faststream/confluent/subscriber/specified.py index 3d93e83e2a..6dcad3c001 100644 --- a/faststream/confluent/subscriber/specified.py +++ b/faststream/confluent/subscriber/specified.py @@ -1,64 +1,55 @@ -from typing import ( - TYPE_CHECKING, -) +from collections.abc import Iterable +from itertools import chain +from typing import TYPE_CHECKING -from faststream._internal.types import MsgType -from faststream.confluent.subscriber.usecase import ( - BatchSubscriber, - DefaultSubscriber, - LogicSubscriber, +from faststream._internal.subscriber.specified import ( + SpecificationSubscriber as SpecificationSubscriberMixin, ) +from faststream.confluent.subscriber.usecase import BatchSubscriber, DefaultSubscriber from faststream.specification.asyncapi.utils import resolve_payloads +from faststream.specification.schema import Message, Operation, SubscriberSpec from faststream.specification.schema.bindings import ChannelBinding, kafka -from faststream.specification.schema.channel import Channel -from faststream.specification.schema.message import CorrelationId, Message -from faststream.specification.schema.operation import Operation if TYPE_CHECKING: - from confluent_kafka import Message as ConfluentMsg + from faststream.confluent.schemas import TopicPartition -class SpecificationSubscriber(LogicSubscriber[MsgType]): +class SpecificationSubscriber(SpecificationSubscriberMixin): """A class to handle logic and async API operations.""" - def get_name(self) -> str: - return f'{",".join(self.topics)}:{self.call_name}' + topics: Iterable[str] + partitions: Iterable["TopicPartition"] # TODO: support partitions + + def get_default_name(self) -> str: + return f"{','.join(self.topics)}:{self.call_name}" - def get_schema(self) -> dict[str, Channel]: + def get_schema(self) -> dict[str, SubscriberSpec]: channels = {} payloads = self.get_payloads() - for t in self.topics: + for t in chain(self.topics, {p.topic for p in self.partitions}): handler_name = self.title_ or f"{t}:{self.call_name}" - channels[handler_name] = Channel( + channels[handler_name] = SubscriberSpec( description=self.description, - subscribe=Operation( + operation=Operation( message=Message( title=f"{handler_name}:Message", payload=resolve_payloads(payloads), - correlationId=CorrelationId( - location="$message.header#/correlation_id", - ), ), + bindings=None, ), bindings=ChannelBinding( - kafka=kafka.ChannelBinding(topic=t), + kafka=kafka.ChannelBinding(topic=t, partitions=None, replicas=None), ), ) return channels -class SpecificationDefaultSubscriber( - DefaultSubscriber, - SpecificationSubscriber["ConfluentMsg"], -): +class SpecificationDefaultSubscriber(SpecificationSubscriber, DefaultSubscriber): pass -class SpecificationBatchSubscriber( - BatchSubscriber, - SpecificationSubscriber[tuple["ConfluentMsg", ...]], -): +class SpecificationBatchSubscriber(SpecificationSubscriber, BatchSubscriber): pass diff --git a/faststream/confluent/subscriber/usecase.py b/faststream/confluent/subscriber/usecase.py index adb321dd4a..57d39eeb95 100644 --- a/faststream/confluent/subscriber/usecase.py +++ b/faststream/confluent/subscriber/usecase.py @@ -18,6 +18,7 @@ from faststream.confluent.parser import AsyncConfluentParser from faststream.confluent.publisher.fake import KafkaFakePublisher from faststream.confluent.schemas import TopicPartition +from faststream.middlewares import AckPolicy if TYPE_CHECKING: from fast_depends.dependencies import Dependant @@ -32,7 +33,6 @@ ) from faststream.confluent.client import AsyncConfluentConsumer from faststream.message import StreamMessage - from faststream.middlewares import AckPolicy class LogicSubscriber(ABC, SubscriberUsecase[MsgType]): @@ -63,10 +63,6 @@ def __init__( no_reply: bool, broker_dependencies: Iterable["Dependant"], broker_middlewares: Iterable["BrokerMiddleware[MsgType]"], - # AsyncAPI args - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, ) -> None: super().__init__( default_parser=default_parser, @@ -76,10 +72,6 @@ def __init__( no_reply=no_reply, broker_middlewares=broker_middlewares, broker_dependencies=broker_dependencies, - # AsyncAPI args - title_=title_, - description_=description_, - include_in_schema=include_in_schema, ) self.__connection_data = connection_data @@ -157,7 +149,7 @@ async def get_one( self, *, timeout: float = 5.0, - ) -> "Optional[StreamMessage[Message]]": + ) -> "Optional[StreamMessage[MsgType]]": assert self.consumer, "You should start subscriber at first." # nosec B101 assert ( # nosec B101 not self.calls @@ -168,7 +160,7 @@ async def get_one( context = self._state.get().di_state.context return await process_msg( - msg=raw_message, + msg=raw_message, # type: ignore[arg-type] middlewares=( m(raw_message, context=context) for m in self._broker_middlewares ), @@ -252,18 +244,15 @@ def __init__( polling_interval: float, group_id: Optional[str], connection_data: "AnyDict", - is_manual: bool, # Subscriber args ack_policy: "AckPolicy", no_reply: bool, broker_dependencies: Iterable["Dependant"], broker_middlewares: Iterable["BrokerMiddleware[Message]"], - # AsyncAPI args - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, ) -> None: - self.parser = AsyncConfluentParser(is_manual=is_manual) + self.parser = AsyncConfluentParser( + is_manual=ack_policy is not AckPolicy.ACK_FIRST + ) super().__init__( *topics, @@ -279,10 +268,6 @@ def __init__( no_reply=no_reply, broker_middlewares=broker_middlewares, broker_dependencies=broker_dependencies, - # AsyncAPI args - title_=title_, - description_=description_, - include_in_schema=include_in_schema, ) async def get_msg(self) -> Optional["Message"]: @@ -315,20 +300,17 @@ def __init__( # Kafka information group_id: Optional[str], connection_data: "AnyDict", - is_manual: bool, # Subscriber args ack_policy: "AckPolicy", no_reply: bool, broker_dependencies: Iterable["Dependant"], broker_middlewares: Iterable["BrokerMiddleware[tuple[Message, ...]]"], - # AsyncAPI args - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, ) -> None: self.max_records = max_records - self.parser = AsyncConfluentParser(is_manual=is_manual) + self.parser = AsyncConfluentParser( + is_manual=ack_policy is not AckPolicy.ACK_FIRST + ) super().__init__( *topics, @@ -344,10 +326,6 @@ def __init__( no_reply=no_reply, broker_middlewares=broker_middlewares, broker_dependencies=broker_dependencies, - # AsyncAPI args - title_=title_, - description_=description_, - include_in_schema=include_in_schema, ) async def get_msg(self) -> Optional[tuple["Message", ...]]: diff --git a/faststream/exceptions.py b/faststream/exceptions.py index 6d51e76cb3..32557dd42c 100644 --- a/faststream/exceptions.py +++ b/faststream/exceptions.py @@ -162,4 +162,4 @@ def __str__(self) -> str: pip install watchfiles """ -SCHEMA_NOT_SUPPORTED = "{schema_filename} not supported. Make sure that your schema is valid and schema version supported by FastStream" +SCHEMA_NOT_SUPPORTED = "`{schema_filename}` not supported. Make sure that your schema is valid and schema version supported by FastStream" diff --git a/faststream/kafka/broker/broker.py b/faststream/kafka/broker/broker.py index 0f962f0d3b..bec2ed99c9 100644 --- a/faststream/kafka/broker/broker.py +++ b/faststream/kafka/broker/broker.py @@ -57,7 +57,7 @@ ) from faststream.kafka.message import KafkaMessage from faststream.security import BaseSecurity - from faststream.specification.schema.tag import Tag, TagDict + from faststream.specification.schema.extra import Tag, TagDict class KafkaInitKwargs(TypedDict, total=False): request_timeout_ms: Annotated[ @@ -477,9 +477,9 @@ def __init__( Doc("AsyncAPI server description."), ] = None, tags: Annotated[ - Optional[Iterable[Union["Tag", "TagDict"]]], + Iterable[Union["Tag", "TagDict"]], Doc("AsyncAPI server tags."), - ] = None, + ] = (), # logging args logger: Annotated[ Optional["LoggerProto"], @@ -640,10 +640,8 @@ async def _connect( # type: ignore[override] await self._producer.connect(producer) - return partial( - aiokafka.AIOKafkaConsumer, - **filter_by_dict(ConsumerConnectionParams, kwargs), - ) + connection_kwargs, _ = filter_by_dict(ConsumerConnectionParams, kwargs) + return partial(aiokafka.AIOKafkaConsumer, **connection_kwargs) async def start(self) -> None: """Connect broker to Kafka and startup all subscribers.""" diff --git a/faststream/kafka/broker/registrator.py b/faststream/kafka/broker/registrator.py index 5d0cfcb936..1e291ff7fc 100644 --- a/faststream/kafka/broker/registrator.py +++ b/faststream/kafka/broker/registrator.py @@ -39,6 +39,7 @@ ) from faststream.kafka.subscriber.specified import ( SpecificationBatchSubscriber, + SpecificationConcurrentDefaultSubscriber, SpecificationDefaultSubscriber, ) @@ -54,7 +55,11 @@ class KafkaRegistrator( """Includable to KafkaBroker router.""" _subscribers: list[ - Union["SpecificationBatchSubscriber", "SpecificationDefaultSubscriber"], + Union[ + "SpecificationBatchSubscriber", + "SpecificationDefaultSubscriber", + "SpecificationConcurrentDefaultSubscriber", + ] ] _publishers: list[ Union["SpecificationBatchPublisher", "SpecificationDefaultPublisher"], @@ -173,7 +178,7 @@ def subscriber( Please, use `ack_policy=AckPolicy.ACK_FIRST` instead. """, ), - ] = True, + ] = EMPTY, auto_commit_interval_ms: Annotated[ int, Doc( @@ -404,10 +409,15 @@ def subscriber( Iterable["SubscriberMiddleware[KafkaMessage]"], Doc("Subscriber middlewares to wrap incoming message processing."), ] = (), - ack_policy: Annotated[ - AckPolicy, + no_ack: Annotated[ + bool, Doc("Whether to disable **FastStream** auto acknowledgement logic or not."), + deprecated( + "This option was deprecated in 0.6.0 to prior to **ack_policy=AckPolicy.DO_NOTHING**. " + "Scheduled to remove in 0.7.0" + ), ] = EMPTY, + ack_policy: AckPolicy = EMPTY, no_reply: Annotated[ bool, Doc( @@ -545,7 +555,7 @@ def subscriber( Please, use `ack_policy=AckPolicy.ACK_FIRST` instead. """, ), - ] = True, + ] = EMPTY, auto_commit_interval_ms: Annotated[ int, Doc( @@ -776,10 +786,15 @@ def subscriber( Iterable["SubscriberMiddleware[KafkaMessage]"], Doc("Subscriber middlewares to wrap incoming message processing."), ] = (), - ack_policy: Annotated[ - AckPolicy, + no_ack: Annotated[ + bool, Doc("Whether to disable **FastStream** auto acknowledgement logic or not."), + deprecated( + "This option was deprecated in 0.6.0 to prior to **ack_policy=AckPolicy.DO_NOTHING**. " + "Scheduled to remove in 0.7.0" + ), ] = EMPTY, + ack_policy: AckPolicy = EMPTY, no_reply: Annotated[ bool, Doc( @@ -917,7 +932,7 @@ def subscriber( Please, use `ack_policy=AckPolicy.ACK_FIRST` instead. """, ), - ] = True, + ] = EMPTY, auto_commit_interval_ms: Annotated[ int, Doc( @@ -1148,10 +1163,15 @@ def subscriber( Iterable["SubscriberMiddleware[KafkaMessage]"], Doc("Subscriber middlewares to wrap incoming message processing."), ] = (), - ack_policy: Annotated[ - AckPolicy, + no_ack: Annotated[ + bool, Doc("Whether to disable **FastStream** auto acknowledgement logic or not."), + deprecated( + "This option was deprecated in 0.6.0 to prior to **ack_policy=AckPolicy.DO_NOTHING**. " + "Scheduled to remove in 0.7.0" + ), ] = EMPTY, + ack_policy: AckPolicy = EMPTY, no_reply: Annotated[ bool, Doc( @@ -1292,7 +1312,7 @@ def subscriber( Please, use `ack_policy=AckPolicy.ACK_FIRST` instead. """, ), - ] = True, + ] = EMPTY, auto_commit_interval_ms: Annotated[ int, Doc( @@ -1523,10 +1543,19 @@ def subscriber( Iterable["SubscriberMiddleware[KafkaMessage]"], Doc("Subscriber middlewares to wrap incoming message processing."), ] = (), - ack_policy: Annotated[ - AckPolicy, + max_workers: Annotated[ + int, + Doc("Number of workers to process messages concurrently."), + ] = 1, + no_ack: Annotated[ + bool, Doc("Whether to disable **FastStream** auto acknowledgement logic or not."), + deprecated( + "This option was deprecated in 0.6.0 to prior to **ack_policy=AckPolicy.DO_NOTHING**. " + "Scheduled to remove in 0.7.0" + ), ] = EMPTY, + ack_policy: AckPolicy = EMPTY, no_reply: Annotated[ bool, Doc( @@ -1552,61 +1581,62 @@ def subscriber( ) -> Union[ "SpecificationDefaultSubscriber", "SpecificationBatchSubscriber", + "SpecificationConcurrentDefaultSubscriber", ]: - - subscriber = super().subscriber( - create_subscriber( - *topics, - batch=batch, - batch_timeout_ms=batch_timeout_ms, - max_records=max_records, - group_id=group_id, - listener=listener, - pattern=pattern, - connection_args={ - "key_deserializer": key_deserializer, - "value_deserializer": value_deserializer, - "fetch_max_wait_ms": fetch_max_wait_ms, - "fetch_max_bytes": fetch_max_bytes, - "fetch_min_bytes": fetch_min_bytes, - "max_partition_fetch_bytes": max_partition_fetch_bytes, - "auto_offset_reset": auto_offset_reset, - "enable_auto_commit": auto_commit, - "auto_commit_interval_ms": auto_commit_interval_ms, - "check_crcs": check_crcs, - "partition_assignment_strategy": partition_assignment_strategy, - "max_poll_interval_ms": max_poll_interval_ms, - "rebalance_timeout_ms": rebalance_timeout_ms, - "session_timeout_ms": session_timeout_ms, - "heartbeat_interval_ms": heartbeat_interval_ms, - "consumer_timeout_ms": consumer_timeout_ms, - "max_poll_records": max_poll_records, - "exclude_internal_topics": exclude_internal_topics, - "isolation_level": isolation_level, - }, - partitions=partitions, - is_manual=not auto_commit, - # subscriber args - ack_policy=ack_policy, - no_reply=no_reply, - broker_middlewares=self.middlewares, - broker_dependencies=self._dependencies, - # Specification - title_=title, - description_=description, - include_in_schema=self._solve_include_in_schema(include_in_schema), - ), + sub = create_subscriber( + *topics, + batch=batch, + max_workers=max_workers, + batch_timeout_ms=batch_timeout_ms, + max_records=max_records, + group_id=group_id, + listener=listener, + pattern=pattern, + connection_args={ + "key_deserializer": key_deserializer, + "value_deserializer": value_deserializer, + "fetch_max_wait_ms": fetch_max_wait_ms, + "fetch_max_bytes": fetch_max_bytes, + "fetch_min_bytes": fetch_min_bytes, + "max_partition_fetch_bytes": max_partition_fetch_bytes, + "auto_offset_reset": auto_offset_reset, + "auto_commit_interval_ms": auto_commit_interval_ms, + "check_crcs": check_crcs, + "partition_assignment_strategy": partition_assignment_strategy, + "max_poll_interval_ms": max_poll_interval_ms, + "rebalance_timeout_ms": rebalance_timeout_ms, + "session_timeout_ms": session_timeout_ms, + "heartbeat_interval_ms": heartbeat_interval_ms, + "consumer_timeout_ms": consumer_timeout_ms, + "max_poll_records": max_poll_records, + "exclude_internal_topics": exclude_internal_topics, + "isolation_level": isolation_level, + }, + partitions=partitions, + # acknowledgement args + ack_policy=ack_policy, + no_ack=no_ack, + auto_commit=auto_commit, + # subscriber args + no_reply=no_reply, + broker_middlewares=self.middlewares, + broker_dependencies=self._dependencies, + # Specification + title_=title, + description_=description, + include_in_schema=self._solve_include_in_schema(include_in_schema), ) + subscriber = super().subscriber(sub) + if batch: - return cast("SpecificationBatchSubscriber", subscriber).add_call( - parser_=parser or self._parser, - decoder_=decoder or self._decoder, - dependencies_=dependencies, - middlewares_=middlewares, - ) + subscriber = cast("SpecificationBatchSubscriber", subscriber) + elif max_workers > 1: + subscriber = cast("SpecificationConcurrentDefaultSubscriber", subscriber) + else: + subscriber = cast("SpecificationDefaultSubscriber", subscriber) - return cast("SpecificationDefaultSubscriber", subscriber).add_call( + return subscriber.add_call( parser_=parser or self._parser, decoder_=decoder or self._decoder, dependencies_=dependencies, diff --git a/faststream/kafka/fastapi/fastapi.py b/faststream/kafka/fastapi/fastapi.py index 747103bfea..a58d75d042 100644 --- a/faststream/kafka/fastapi/fastapi.py +++ b/faststream/kafka/fastapi/fastapi.py @@ -55,10 +55,11 @@ ) from faststream.kafka.subscriber.specified import ( SpecificationBatchSubscriber, + SpecificationConcurrentDefaultSubscriber, SpecificationDefaultSubscriber, ) from faststream.security import BaseSecurity - from faststream.specification.schema.tag import Tag, TagDict + from faststream.specification.schema.extra import Tag, TagDict Partition = TypeVar("Partition") @@ -304,9 +305,9 @@ def __init__( Doc("Specification server description."), ] = None, specification_tags: Annotated[ - Optional[Iterable[Union["Tag", "TagDict"]]], + Iterable[Union["Tag", "TagDict"]], Doc("Specification server tags."), - ] = None, + ] = (), # logging args logger: Annotated[ Optional["LoggerProto"], @@ -716,7 +717,7 @@ def subscriber( Please, use `ack_policy=AckPolicy.ACK_FIRST` instead. """, ), - ] = True, + ] = EMPTY, auto_commit_interval_ms: Annotated[ int, Doc( @@ -951,10 +952,15 @@ def subscriber( Iterable["SubscriberMiddleware[KafkaMessage]"], Doc("Subscriber middlewares to wrap incoming message processing."), ] = (), - ack_policy: Annotated[ - AckPolicy, + no_ack: Annotated[ + bool, Doc("Whether to disable **FastStream** auto acknowledgement logic or not."), + deprecated( + "This option was deprecated in 0.6.0 to prior to **ack_policy=AckPolicy.DO_NOTHING**. " + "Scheduled to remove in 0.7.0" + ), ] = EMPTY, + ack_policy: AckPolicy = EMPTY, no_reply: Annotated[ bool, Doc( @@ -1208,7 +1214,7 @@ def subscriber( Please, use `ack_policy=AckPolicy.ACK_FIRST` instead. """, ), - ] = True, + ] = EMPTY, auto_commit_interval_ms: Annotated[ int, Doc( @@ -1443,10 +1449,15 @@ def subscriber( Iterable["SubscriberMiddleware[KafkaMessage]"], Doc("Subscriber middlewares to wrap incoming message processing."), ] = (), - ack_policy: Annotated[ - AckPolicy, + no_ack: Annotated[ + bool, Doc("Whether to disable **FastStream** auto acknowledgement logic or not."), + deprecated( + "This option was deprecated in 0.6.0 to prior to **ack_policy=AckPolicy.DO_NOTHING**. " + "Scheduled to remove in 0.7.0" + ), ] = EMPTY, + ack_policy: AckPolicy = EMPTY, no_reply: Annotated[ bool, Doc( @@ -1700,7 +1711,7 @@ def subscriber( Please, use `ack_policy=AckPolicy.ACK_FIRST` instead. """, ), - ] = True, + ] = EMPTY, auto_commit_interval_ms: Annotated[ int, Doc( @@ -1935,10 +1946,15 @@ def subscriber( Iterable["SubscriberMiddleware[KafkaMessage]"], Doc("Subscriber middlewares to wrap incoming message processing."), ] = (), - ack_policy: Annotated[ - AckPolicy, + no_ack: Annotated[ + bool, Doc("Whether to disable **FastStream** auto acknowledgement logic or not."), + deprecated( + "This option was deprecated in 0.6.0 to prior to **ack_policy=AckPolicy.DO_NOTHING**. " + "Scheduled to remove in 0.7.0" + ), ] = EMPTY, + ack_policy: AckPolicy = EMPTY, no_reply: Annotated[ bool, Doc( @@ -2195,7 +2211,7 @@ def subscriber( Please, use `ack_policy=AckPolicy.ACK_FIRST` instead. """, ), - ] = True, + ] = EMPTY, auto_commit_interval_ms: Annotated[ int, Doc( @@ -2430,10 +2446,15 @@ def subscriber( Iterable["SubscriberMiddleware[KafkaMessage]"], Doc("Subscriber middlewares to wrap incoming message processing."), ] = (), - ack_policy: Annotated[ - AckPolicy, + no_ack: Annotated[ + bool, Doc("Whether to disable **FastStream** auto acknowledgement logic or not."), + deprecated( + "This option was deprecated in 0.6.0 to prior to **ack_policy=AckPolicy.DO_NOTHING**. " + "Scheduled to remove in 0.7.0" + ), ] = EMPTY, + ack_policy: AckPolicy = EMPTY, no_reply: Annotated[ bool, Doc( @@ -2579,14 +2600,19 @@ def subscriber( """, ), ] = False, + max_workers: Annotated[ + int, + Doc("Number of workers to process messages concurrently."), + ] = 1, ) -> Union[ "SpecificationBatchSubscriber", "SpecificationDefaultSubscriber", + "SpecificationConcurrentDefaultSubscriber", ]: - subscriber = super().subscriber( *topics, group_id=group_id, + max_workers=max_workers, key_deserializer=key_deserializer, value_deserializer=value_deserializer, fetch_max_wait_ms=fetch_max_wait_ms, @@ -2618,6 +2644,7 @@ def subscriber( decoder=decoder, middlewares=middlewares, ack_policy=ack_policy, + no_ack=no_ack, no_reply=no_reply, title=title, description=description, @@ -2634,6 +2661,8 @@ def subscriber( if batch: return cast("SpecificationBatchSubscriber", subscriber) + if max_workers > 1: + return cast("SpecificationConcurrentDefaultSubscriber", subscriber) return cast("SpecificationDefaultSubscriber", subscriber) @overload # type: ignore[override] diff --git a/faststream/kafka/message.py b/faststream/kafka/message.py index e00c541795..20fe0d0edd 100644 --- a/faststream/kafka/message.py +++ b/faststream/kafka/message.py @@ -2,7 +2,7 @@ from aiokafka import TopicPartition as AIOKafkaTopicPartition -from faststream.message import StreamMessage +from faststream.message import AckStatus, StreamMessage if TYPE_CHECKING: from aiokafka import ConsumerRecord @@ -51,6 +51,21 @@ class KafkaMessage( This class extends `StreamMessage` and is specialized for handling Kafka ConsumerRecord objects. """ + def __init__(self, *args: Any, consumer: ConsumerProtocol, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + self.consumer = consumer + self.committed = AckStatus.ACKED + + +class KafkaAckableMessage( + StreamMessage[ + Union[ + "ConsumerRecord", + tuple["ConsumerRecord", ...], + ] + ] +): def __init__( self, *args: Any, @@ -61,6 +76,12 @@ def __init__( self.consumer = consumer + async def ack(self) -> None: + """Acknowledge the Kafka message.""" + if not self.committed: + await self.consumer.commit() + await super().ack() + async def nack(self) -> None: """Reject the Kafka message.""" if not self.committed: @@ -78,11 +99,3 @@ async def nack(self) -> None: offset=raw_message.offset, ) await super().nack() - - -class KafkaAckableMessage(KafkaMessage): - async def ack(self) -> None: - """Acknowledge the Kafka message.""" - if not self.committed: - await self.consumer.commit() - await super().ack() diff --git a/faststream/kafka/publisher/specified.py b/faststream/kafka/publisher/specified.py index d765cc8f8b..b23eef8d92 100644 --- a/faststream/kafka/publisher/specified.py +++ b/faststream/kafka/publisher/specified.py @@ -1,56 +1,43 @@ -from typing import TYPE_CHECKING - -from faststream._internal.types import MsgType -from faststream.kafka.publisher.usecase import ( - BatchPublisher, - DefaultPublisher, - LogicPublisher, +from faststream._internal.publisher.specified import ( + SpecificationPublisher as SpecificationPublisherMixin, ) +from faststream.kafka.publisher.usecase import BatchPublisher, DefaultPublisher from faststream.specification.asyncapi.utils import resolve_payloads +from faststream.specification.schema import Message, Operation, PublisherSpec from faststream.specification.schema.bindings import ChannelBinding, kafka -from faststream.specification.schema.channel import Channel -from faststream.specification.schema.message import CorrelationId, Message -from faststream.specification.schema.operation import Operation - -if TYPE_CHECKING: - from aiokafka import ConsumerRecord -class SpecificationPublisher(LogicPublisher[MsgType]): +class SpecificationPublisher(SpecificationPublisherMixin): """A class representing a publisher.""" - def get_name(self) -> str: + def get_default_name(self) -> str: return f"{self.topic}:Publisher" - def get_schema(self) -> dict[str, Channel]: + def get_schema(self) -> dict[str, PublisherSpec]: payloads = self.get_payloads() return { - self.name: Channel( + self.name: PublisherSpec( description=self.description, - publish=Operation( + operation=Operation( message=Message( title=f"{self.name}:Message", payload=resolve_payloads(payloads, "Publisher"), - correlationId=CorrelationId( - location="$message.header#/correlation_id", - ), ), + bindings=None, + ), + bindings=ChannelBinding( + kafka=kafka.ChannelBinding( + topic=self.topic, partitions=None, replicas=None + ) ), - bindings=ChannelBinding(kafka=kafka.ChannelBinding(topic=self.topic)), ), } -class SpecificationBatchPublisher( - BatchPublisher, - SpecificationPublisher[tuple["ConsumerRecord", ...]], -): +class SpecificationBatchPublisher(SpecificationPublisher, BatchPublisher): pass -class SpecificationDefaultPublisher( - DefaultPublisher, - SpecificationPublisher["ConsumerRecord"], -): +class SpecificationDefaultPublisher(SpecificationPublisher, DefaultPublisher): pass diff --git a/faststream/kafka/publisher/usecase.py b/faststream/kafka/publisher/usecase.py index 0f005770de..895abe044c 100644 --- a/faststream/kafka/publisher/usecase.py +++ b/faststream/kafka/publisher/usecase.py @@ -42,20 +42,10 @@ def __init__( # Publisher args broker_middlewares: Iterable["BrokerMiddleware[MsgType]"], middlewares: Iterable["PublisherMiddleware"], - # AsyncAPI args - schema_: Optional[Any], - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, ) -> None: super().__init__( broker_middlewares=broker_middlewares, middlewares=middlewares, - # AsyncAPI args - schema_=schema_, - title_=title_, - description_=description_, - include_in_schema=include_in_schema, ) self.topic = topic @@ -154,11 +144,6 @@ def __init__( # Publisher args broker_middlewares: Iterable["BrokerMiddleware[ConsumerRecord]"], middlewares: Iterable["PublisherMiddleware"], - # AsyncAPI args - schema_: Optional[Any], - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, ) -> None: super().__init__( topic=topic, @@ -168,11 +153,6 @@ def __init__( # publisher args broker_middlewares=broker_middlewares, middlewares=middlewares, - # AsyncAPI args - schema_=schema_, - title_=title_, - description_=description_, - include_in_schema=include_in_schema, ) self.key = key diff --git a/faststream/kafka/router.py b/faststream/kafka/router.py index 10d8df303e..2038c89648 100644 --- a/faststream/kafka/router.py +++ b/faststream/kafka/router.py @@ -261,7 +261,7 @@ def __init__( Please, use `ack_policy=AckPolicy.ACK_FIRST` instead. """, ), - ] = True, + ] = EMPTY, auto_commit_interval_ms: Annotated[ int, Doc( @@ -491,10 +491,15 @@ def __init__( Iterable["SubscriberMiddleware[KafkaMessage]"], Doc("Subscriber middlewares to wrap incoming message processing."), ] = (), - ack_policy: Annotated[ - AckPolicy, + no_ack: Annotated[ + bool, Doc("Whether to disable **FastStream** auto acknowledgement logic or not."), + deprecated( + "This option was deprecated in 0.6.0 to prior to **ack_policy=AckPolicy.DO_NOTHING**. " + "Scheduled to remove in 0.7.0" + ), ] = EMPTY, + ack_policy: AckPolicy = EMPTY, no_reply: Annotated[ bool, Doc( @@ -517,12 +522,16 @@ def __init__( bool, Doc("Whetever to include operation in AsyncAPI schema or not."), ] = True, + max_workers: Annotated[ + int, + Doc("Number of workers to process messages concurrently."), + ] = 1, ) -> None: - super().__init__( call, *topics, publishers=publishers, + max_workers=max_workers, group_id=group_id, key_deserializer=key_deserializer, value_deserializer=value_deserializer, @@ -555,12 +564,12 @@ def __init__( decoder=decoder, middlewares=middlewares, no_reply=no_reply, + ack_policy=ack_policy, + no_ack=no_ack, # AsyncAPI args title=title, description=description, include_in_schema=include_in_schema, - # FastDepends args - ack_policy=ack_policy, ) diff --git a/faststream/kafka/subscriber/factory.py b/faststream/kafka/subscriber/factory.py index da9d859e32..405e08f092 100644 --- a/faststream/kafka/subscriber/factory.py +++ b/faststream/kafka/subscriber/factory.py @@ -12,6 +12,7 @@ from faststream.exceptions import SetupError from faststream.kafka.subscriber.specified import ( SpecificationBatchSubscriber, + SpecificationConcurrentDefaultSubscriber, SpecificationDefaultSubscriber, ) from faststream.middlewares import AckPolicy @@ -37,9 +38,11 @@ def create_subscriber( pattern: Optional[str], connection_args: "AnyDict", partitions: Iterable["TopicPartition"], - is_manual: bool, + auto_commit: bool, # Subscriber args ack_policy: "AckPolicy", + max_workers: int, + no_ack: bool, no_reply: bool, broker_dependencies: Iterable["Dependant"], broker_middlewares: Iterable["BrokerMiddleware[tuple[ConsumerRecord, ...]]"], @@ -62,9 +65,11 @@ def create_subscriber( pattern: Optional[str], connection_args: "AnyDict", partitions: Iterable["TopicPartition"], - is_manual: bool, + auto_commit: bool, # Subscriber args ack_policy: "AckPolicy", + max_workers: int, + no_ack: bool, no_reply: bool, broker_dependencies: Iterable["Dependant"], broker_middlewares: Iterable["BrokerMiddleware[ConsumerRecord]"], @@ -87,9 +92,11 @@ def create_subscriber( pattern: Optional[str], connection_args: "AnyDict", partitions: Iterable["TopicPartition"], - is_manual: bool, + auto_commit: bool, # Subscriber args ack_policy: "AckPolicy", + max_workers: int, + no_ack: bool, no_reply: bool, broker_dependencies: Iterable["Dependant"], broker_middlewares: Iterable[ @@ -102,6 +109,7 @@ def create_subscriber( ) -> Union[ "SpecificationDefaultSubscriber", "SpecificationBatchSubscriber", + "SpecificationConcurrentDefaultSubscriber", ]: ... @@ -116,9 +124,11 @@ def create_subscriber( pattern: Optional[str], connection_args: "AnyDict", partitions: Iterable["TopicPartition"], - is_manual: bool, + auto_commit: bool, # Subscriber args ack_policy: "AckPolicy", + max_workers: int, + no_ack: bool, no_reply: bool, broker_dependencies: Iterable["Dependant"], broker_middlewares: Iterable[ @@ -131,21 +141,31 @@ def create_subscriber( ) -> Union[ "SpecificationDefaultSubscriber", "SpecificationBatchSubscriber", + "SpecificationConcurrentDefaultSubscriber", ]: _validate_input_for_misconfigure( *topics, pattern=pattern, partitions=partitions, ack_policy=ack_policy, - is_manual=is_manual, + no_ack=no_ack, + auto_commit=auto_commit, + max_workers=max_workers, group_id=group_id, ) + if auto_commit is not EMPTY: + ack_policy = AckPolicy.ACK_FIRST if auto_commit else AckPolicy.REJECT_ON_ERROR + + if no_ack is not EMPTY: + ack_policy = AckPolicy.DO_NOTHING if no_ack else EMPTY + if ack_policy is EMPTY: - if not is_manual: - ack_policy = AckPolicy.DO_NOTHING - else: - ack_policy = AckPolicy.REJECT_ON_ERROR + ack_policy = AckPolicy.ACK_FIRST + + if ack_policy is AckPolicy.ACK_FIRST: + connection_args["enable_auto_commit"] = True + ack_policy = AckPolicy.DO_NOTHING if batch: return SpecificationBatchSubscriber( @@ -157,7 +177,24 @@ def create_subscriber( pattern=pattern, connection_args=connection_args, partitions=partitions, - is_manual=is_manual, + ack_policy=ack_policy, + no_reply=no_reply, + broker_dependencies=broker_dependencies, + broker_middlewares=broker_middlewares, + title_=title_, + description_=description_, + include_in_schema=include_in_schema, + ) + + if max_workers > 1: + return SpecificationConcurrentDefaultSubscriber( + *topics, + max_workers=max_workers, + group_id=group_id, + listener=listener, + pattern=pattern, + connection_args=connection_args, + partitions=partitions, ack_policy=ack_policy, no_reply=no_reply, broker_dependencies=broker_dependencies, @@ -174,7 +211,6 @@ def create_subscriber( pattern=pattern, connection_args=connection_args, partitions=partitions, - is_manual=is_manual, ack_policy=ack_policy, no_reply=no_reply, broker_dependencies=broker_dependencies, @@ -190,31 +226,51 @@ def _validate_input_for_misconfigure( partitions: Iterable["TopicPartition"], pattern: Optional[str], ack_policy: "AckPolicy", - is_manual: bool, + auto_commit: bool, + no_ack: bool, group_id: Optional[str], + max_workers: int, ) -> None: - if not is_manual and ack_policy is not EMPTY and ack_policy is not AckPolicy.ACK_FIRST: + if auto_commit is not EMPTY: warnings.warn( - "You can't use ack_policy other then AckPolicy.ACK_FIRST with `auto_commit=True`", - RuntimeWarning, + "`auto_commit` option was deprecated in prior to `ack_policy=AckPolicy.ACK_FIRST`. Scheduled to remove in 0.7.0", + category=DeprecationWarning, stacklevel=4, ) - elif is_manual and ack_policy is not EMPTY and ack_policy is AckPolicy.ACK_FIRST: + + if ack_policy is not EMPTY: + msg = "You can't use deprecated `auto_commit` and `ack_policy` simultaneously. Please, use `ack_policy` only." + raise SetupError(msg) + + ack_policy = AckPolicy.ACK_FIRST if auto_commit else AckPolicy.REJECT_ON_ERROR + + if no_ack is not EMPTY: warnings.warn( - "You can't use AckPolicy.ACK_FIRST with `auto_commit=False`", - RuntimeWarning, + "`no_ack` option was deprecated in prior to `ack_policy=AckPolicy.DO_NOTHING`. Scheduled to remove in 0.7.0", + category=DeprecationWarning, stacklevel=4, ) - if is_manual and not group_id: + if ack_policy is not EMPTY: + msg = "You can't use deprecated `no_ack` and `ack_policy` simultaneously. Please, use `ack_policy` only." + raise SetupError(msg) + + ack_policy = AckPolicy.DO_NOTHING if no_ack else EMPTY + + if ack_policy is EMPTY: + ack_policy = AckPolicy.ACK_FIRST + + if max_workers > 1 and ack_policy is not AckPolicy.ACK_FIRST: + msg = "You can't use `max_workers` option with manual commit mode." + raise SetupError(msg) + + if not group_id and ack_policy is not AckPolicy.ACK_FIRST: msg = "You must use `group_id` with manual commit mode." raise SetupError(msg) if not topics and not partitions and not pattern: msg = "You should provide either `topics` or `partitions` or `pattern`." - raise SetupError( - msg, - ) + raise SetupError(msg) if topics and partitions: msg = "You can't provide both `topics` and `partitions`." raise SetupError(msg) diff --git a/faststream/kafka/subscriber/specified.py b/faststream/kafka/subscriber/specified.py index b49f10e77e..f2d5f70a3a 100644 --- a/faststream/kafka/subscriber/specified.py +++ b/faststream/kafka/subscriber/specified.py @@ -1,50 +1,52 @@ -from typing import ( - TYPE_CHECKING, -) +from collections.abc import Iterable +from itertools import chain +from typing import TYPE_CHECKING, Optional -from faststream._internal.types import MsgType +from faststream._internal.subscriber.specified import ( + SpecificationSubscriber as SpecificationSubscriberMixin, +) from faststream.kafka.subscriber.usecase import ( BatchSubscriber, + ConcurrentDefaultSubscriber, DefaultSubscriber, - LogicSubscriber, ) from faststream.specification.asyncapi.utils import resolve_payloads +from faststream.specification.schema import Message, Operation, SubscriberSpec from faststream.specification.schema.bindings import ChannelBinding, kafka -from faststream.specification.schema.channel import Channel -from faststream.specification.schema.message import CorrelationId, Message -from faststream.specification.schema.operation import Operation if TYPE_CHECKING: - from aiokafka import ConsumerRecord + from aiokafka import TopicPartition -class SpecificationSubscriber(LogicSubscriber[MsgType]): +class SpecificationSubscriber(SpecificationSubscriberMixin): """A class to handle logic and async API operations.""" - def get_name(self) -> str: - return f'{",".join(self.topics)}:{self.call_name}' + topics: Iterable[str] + partitions: Iterable["TopicPartition"] # TODO: support partitions + _pattern: Optional[str] # TODO: support pattern schema - def get_schema(self) -> dict[str, Channel]: + def get_default_name(self) -> str: + return f"{','.join(self.topics)}:{self.call_name}" + + def get_schema(self) -> dict[str, SubscriberSpec]: channels = {} payloads = self.get_payloads() - for t in self.topics: + for t in chain(self.topics, {p.topic for p in self.partitions}): handler_name = self.title_ or f"{t}:{self.call_name}" - channels[handler_name] = Channel( + channels[handler_name] = SubscriberSpec( description=self.description, - subscribe=Operation( + operation=Operation( message=Message( title=f"{handler_name}:Message", payload=resolve_payloads(payloads), - correlationId=CorrelationId( - location="$message.header#/correlation_id", - ), ), + bindings=None, ), bindings=ChannelBinding( - kafka=kafka.ChannelBinding(topic=t), + kafka=kafka.ChannelBinding(topic=t, partitions=None, replicas=None), ), ) @@ -52,14 +54,21 @@ def get_schema(self) -> dict[str, Channel]: class SpecificationDefaultSubscriber( + SpecificationSubscriber, DefaultSubscriber, - SpecificationSubscriber["ConsumerRecord"], ): pass class SpecificationBatchSubscriber( + SpecificationSubscriber, BatchSubscriber, - SpecificationSubscriber[tuple["ConsumerRecord", ...]], +): + pass + + +class SpecificationConcurrentDefaultSubscriber( + SpecificationSubscriber, + ConcurrentDefaultSubscriber, ): pass diff --git a/faststream/kafka/subscriber/usecase.py b/faststream/kafka/subscriber/usecase.py index fab52a66f2..37e2378a6c 100644 --- a/faststream/kafka/subscriber/usecase.py +++ b/faststream/kafka/subscriber/usecase.py @@ -1,19 +1,14 @@ -import asyncio from abc import abstractmethod from collections.abc import Iterable, Sequence from itertools import chain -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Optional, -) +from typing import TYPE_CHECKING, Any, Callable, Optional import anyio from aiokafka import TopicPartition from aiokafka.errors import ConsumerStoppedError, KafkaError from typing_extensions import override +from faststream._internal.subscriber.mixins import ConcurrentMixin, TasksMixin from faststream._internal.subscriber.usecase import SubscriberUsecase from faststream._internal.subscriber.utils import process_msg from faststream._internal.types import ( @@ -39,7 +34,7 @@ from faststream.middlewares import AckPolicy -class LogicSubscriber(SubscriberUsecase[MsgType]): +class LogicSubscriber(TasksMixin, SubscriberUsecase[MsgType]): """A class to handle logic for consuming messages from Kafka.""" topics: Sequence[str] @@ -48,7 +43,6 @@ class LogicSubscriber(SubscriberUsecase[MsgType]): builder: Optional[Callable[..., "AIOKafkaConsumer"]] consumer: Optional["AIOKafkaConsumer"] - task: Optional["asyncio.Task[None]"] client_id: Optional[str] batch: bool parser: AioKafkaParser @@ -69,10 +63,6 @@ def __init__( no_reply: bool, broker_dependencies: Iterable["Dependant"], broker_middlewares: Iterable["BrokerMiddleware[MsgType]"], - # AsyncAPI args - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, ) -> None: super().__init__( default_parser=default_parser, @@ -82,10 +72,6 @@ def __init__( no_reply=no_reply, broker_middlewares=broker_middlewares, broker_dependencies=broker_dependencies, - # AsyncAPI args - title_=title_, - description_=description_, - include_in_schema=include_in_schema, ) self.topics = topics @@ -101,7 +87,6 @@ def __init__( self.builder = None self.consumer = None - self.task = None @override def _setup( # type: ignore[override] @@ -153,7 +138,7 @@ async def start(self) -> None: await super().start() if self.calls: - self.task = asyncio.create_task(self._consume()) + self.add_task(self._consume()) async def close(self) -> None: await super().close() @@ -162,11 +147,6 @@ async def close(self) -> None: await self.consumer.stop() self.consumer = None - if self.task is not None and not self.task.done(): - self.task.cancel() - - self.task = None - @override async def get_one( self, @@ -190,7 +170,7 @@ async def get_one( context = self._state.get().di_state.context - msg: StreamMessage[MsgType] = await process_msg( + return await process_msg( msg=raw_message, middlewares=( m(raw_message, context=context) for m in self._broker_middlewares @@ -198,7 +178,6 @@ async def get_one( parser=self._parser, decoder=self._decoder, ) - return msg def _make_response_publisher( self, @@ -237,7 +216,10 @@ async def _consume(self) -> None: connected = True if msg: - await self.consume(msg) + await self.consume_one(msg) + + async def consume_one(self, msg: MsgType) -> None: + await self.consume(msg) @property def topic_names(self) -> list[str]: @@ -281,16 +263,11 @@ def __init__( pattern: Optional[str], connection_args: "AnyDict", partitions: Iterable["TopicPartition"], - is_manual: bool, # Subscriber args ack_policy: "AckPolicy", no_reply: bool, broker_dependencies: Iterable["Dependant"], broker_middlewares: Iterable["BrokerMiddleware[ConsumerRecord]"], - # AsyncAPI args - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, ) -> None: if pattern: reg, pattern = compile_path( @@ -303,7 +280,9 @@ def __init__( reg = None self.parser = AioKafkaParser( - msg_class=KafkaAckableMessage if is_manual else KafkaMessage, + msg_class=KafkaMessage + if ack_policy is ack_policy.ACK_FIRST + else KafkaAckableMessage, regex=reg, ) @@ -322,10 +301,6 @@ def __init__( no_reply=no_reply, broker_middlewares=broker_middlewares, broker_dependencies=broker_dependencies, - # AsyncAPI args - title_=title_, - description_=description_, - include_in_schema=include_in_schema, ) async def get_msg(self) -> "ConsumerRecord": @@ -348,6 +323,46 @@ def get_log_context( ) +class ConcurrentDefaultSubscriber(ConcurrentMixin, DefaultSubscriber): + def __init__( + self, + *topics: str, + # Kafka information + group_id: Optional[str], + listener: Optional["ConsumerRebalanceListener"], + pattern: Optional[str], + connection_args: "AnyDict", + partitions: Iterable["TopicPartition"], + # Subscriber args + max_workers: int, + ack_policy: "AckPolicy", + no_reply: bool, + broker_dependencies: Iterable["Dependant"], + broker_middlewares: Iterable["BrokerMiddleware[ConsumerRecord]"], + ) -> None: + super().__init__( + *topics, + group_id=group_id, + listener=listener, + pattern=pattern, + connection_args=connection_args, + partitions=partitions, + max_workers=max_workers, + # Propagated args + ack_policy=ack_policy, + no_reply=no_reply, + broker_middlewares=broker_middlewares, + broker_dependencies=broker_dependencies, + ) + + async def start(self) -> None: + await super().start() + self.start_consume_task() + + async def consume_one(self, msg: "ConsumerRecord") -> None: + await self._put_msg(msg) + + class BatchSubscriber(LogicSubscriber[tuple["ConsumerRecord", ...]]): def __init__( self, @@ -360,7 +375,6 @@ def __init__( pattern: Optional[str], connection_args: "AnyDict", partitions: Iterable["TopicPartition"], - is_manual: bool, # Subscriber args ack_policy: "AckPolicy", no_reply: bool, @@ -368,10 +382,6 @@ def __init__( broker_middlewares: Iterable[ "BrokerMiddleware[Sequence[tuple[ConsumerRecord, ...]]]" ], - # AsyncAPI args - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, ) -> None: self.batch_timeout_ms = batch_timeout_ms self.max_records = max_records @@ -387,7 +397,9 @@ def __init__( reg = None self.parser = AioKafkaBatchParser( - msg_class=KafkaAckableMessage if is_manual else KafkaMessage, + msg_class=KafkaMessage + if ack_policy is ack_policy.ACK_FIRST + else KafkaAckableMessage, regex=reg, ) @@ -406,10 +418,6 @@ def __init__( no_reply=no_reply, broker_middlewares=broker_middlewares, broker_dependencies=broker_dependencies, - # AsyncAPI args - title_=title_, - description_=description_, - include_in_schema=include_in_schema, ) async def get_msg(self) -> tuple["ConsumerRecord", ...]: diff --git a/faststream/middlewares/acknowledgement/conf.py b/faststream/middlewares/acknowledgement/conf.py index b8ad83f802..c5cd759e10 100644 --- a/faststream/middlewares/acknowledgement/conf.py +++ b/faststream/middlewares/acknowledgement/conf.py @@ -3,16 +3,16 @@ class AckPolicy(str, Enum): ACK_FIRST = "ack_first" - """Ack message on consume""" + """Ack message on consume.""" ACK = "ack" - """Ack message after all process""" + """Ack message after all process.""" REJECT_ON_ERROR = "reject_on_error" - """Reject message on unhandled exceptions""" + """Reject message on unhandled exceptions.""" NACK_ON_ERROR = "nack_on_error" - """Nack message on unhandled exceptions""" + """Nack message on unhandled exceptions.""" DO_NOTHING = "do_nothing" - """Not create AcknowledgementMiddleware""" + """Disable default FastStream Acknowledgement logic. User should confirm all actions manually.""" diff --git a/faststream/middlewares/acknowledgement/middleware.py b/faststream/middlewares/acknowledgement/middleware.py index e2d30f5c0c..c8419d8e10 100644 --- a/faststream/middlewares/acknowledgement/middleware.py +++ b/faststream/middlewares/acknowledgement/middleware.py @@ -110,7 +110,7 @@ async def __ack(self, **exc_extra_options: Any) -> None: await self.message.ack(**exc_extra_options, **self.extra_options) except Exception as er: if self.logger is not None: - self.logger.log(er, logging.CRITICAL, exc_info=er) + self.logger.log(repr(er), logging.CRITICAL, exc_info=er) async def __nack(self, **exc_extra_options: Any) -> None: if self.message: @@ -118,7 +118,7 @@ async def __nack(self, **exc_extra_options: Any) -> None: await self.message.nack(**exc_extra_options, **self.extra_options) except Exception as er: if self.logger is not None: - self.logger.log(er, logging.CRITICAL, exc_info=er) + self.logger.log(repr(er), logging.CRITICAL, exc_info=er) async def __reject(self, **exc_extra_options: Any) -> None: if self.message: @@ -126,4 +126,4 @@ async def __reject(self, **exc_extra_options: Any) -> None: await self.message.reject(**exc_extra_options, **self.extra_options) except Exception as er: if self.logger is not None: - self.logger.log(er, logging.CRITICAL, exc_info=er) + self.logger.log(repr(er), logging.CRITICAL, exc_info=er) diff --git a/faststream/nats/__init__.py b/faststream/nats/__init__.py index 42018a3a2a..55ed9abd15 100644 --- a/faststream/nats/__init__.py +++ b/faststream/nats/__init__.py @@ -18,7 +18,7 @@ from faststream.nats.broker.broker import NatsBroker from faststream.nats.response import NatsResponse from faststream.nats.router import NatsPublisher, NatsRoute, NatsRouter -from faststream.nats.schemas import JStream, KvWatch, ObjWatch, PullSub +from faststream.nats.schemas import JStream, KvWatch, ObjWatch, PubAck, PullSub from faststream.nats.testing import TestNatsBroker __all__ = ( @@ -38,6 +38,7 @@ "NatsRouter", "ObjWatch", "Placement", + "PubAck", "PullSub", "RePublish", "ReplayPolicy", diff --git a/faststream/nats/broker/broker.py b/faststream/nats/broker/broker.py index b5996e29e8..ee1005d7d5 100644 --- a/faststream/nats/broker/broker.py +++ b/faststream/nats/broker/broker.py @@ -36,7 +36,7 @@ from faststream.nats.publisher.producer import NatsFastProducer, NatsJSFastProducer from faststream.nats.response import NatsPublishCommand from faststream.nats.security import parse_security -from faststream.nats.subscriber.specified import SpecificationSubscriber +from faststream.nats.subscriber.usecases.basic import LogicSubscriber from faststream.response.publish_type import PublishType from .logging import make_nats_logger_state @@ -55,7 +55,7 @@ JWTCallback, SignatureCallback, ) - from nats.js.api import Placement, PubAck, RePublish, StorageType + from nats.js.api import Placement, RePublish, StorageType from nats.js.kv import KeyValue from nats.js.object_store import ObjectStore from typing_extensions import TypedDict, Unpack @@ -71,9 +71,10 @@ CustomCallable, ) from faststream.nats.message import NatsMessage - from faststream.nats.publisher.specified import SpecificationPublisher + from faststream.nats.publisher.usecase import LogicPublisher + from faststream.nats.schemas import PubAck from faststream.security import BaseSecurity - from faststream.specification.schema.tag import Tag, TagDict + from faststream.specification.schema.extra import Tag, TagDict class NatsInitKwargs(TypedDict, total=False): """NatsBroker.connect() method type hints.""" @@ -399,9 +400,9 @@ def __init__( Doc("AsyncAPI server description."), ] = None, tags: Annotated[ - Optional[Iterable[Union["Tag", "TagDict"]]], + Iterable[Union["Tag", "TagDict"]], Doc("AsyncAPI server tags."), - ] = None, + ] = (), # logging args logger: Annotated[ Optional["LoggerProto"], @@ -560,7 +561,6 @@ async def _connect(self, **kwargs: Any) -> "Client": self._os_declarer.connect(stream) self._connection_state = ConnectedState(connection, stream) - return connection async def close( @@ -600,7 +600,7 @@ async def start(self) -> None: ) except BadRequestError as e: # noqa: PERF203 - log_context = SpecificationSubscriber.build_log_context( + log_context = LogicSubscriber.build_log_context( message=None, subject="", queue="", @@ -700,7 +700,7 @@ async def publish( Returns: `None` if you publishes a regular message. - `nats.js.api.PubAck` if you publishes a message to stream. + `faststream.nats.PubAck` if you publishes a message to stream. """ cmd = NatsPublishCommand( message=message, @@ -747,8 +747,7 @@ async def request( # type: ignore[override] Manual message **correlation_id** setter. **correlation_id** is a useful option to trace messages. stream: - This option validates that the target subject is in presented stream. - Can be omitted without any effect if you doesn't want PubAck frame. + JetStream name. This option is required if your target subscriber listens for events using JetStream. timeout: Timeout to send message to NATS. @@ -773,7 +772,7 @@ async def request( # type: ignore[override] @override def setup_subscriber( # type: ignore[override] self, - subscriber: "SpecificationSubscriber", + subscriber: "LogicSubscriber", ) -> None: return super().setup_subscriber( subscriber, @@ -785,7 +784,7 @@ def setup_subscriber( # type: ignore[override] @override def setup_publisher( # type: ignore[override] self, - publisher: "SpecificationPublisher", + publisher: "LogicPublisher", ) -> None: producer = self._js_producer if publisher.stream is not None else self._producer @@ -851,7 +850,7 @@ def _log_connection_broken( self, error_cb: Optional["ErrorCallback"] = None, ) -> "ErrorCallback": - c = SpecificationSubscriber.build_log_context(None, "") + c = LogicSubscriber.build_log_context(None, "") async def wrapper(err: Exception) -> None: if error_cb is not None: @@ -872,7 +871,7 @@ def _log_reconnected( self, cb: Optional["Callback"] = None, ) -> "Callback": - c = SpecificationSubscriber.build_log_context(None, "") + c = LogicSubscriber.build_log_context(None, "") async def wrapper() -> None: if cb is not None: diff --git a/faststream/nats/broker/registrator.py b/faststream/nats/broker/registrator.py index 365c114dcb..31cc5e0882 100644 --- a/faststream/nats/broker/registrator.py +++ b/faststream/nats/broker/registrator.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Annotated, Any, Optional, Union, cast from nats.js import api -from typing_extensions import Doc, override +from typing_extensions import Doc, deprecated, override from faststream._internal.broker.abc_broker import ABCBroker from faststream._internal.constants import EMPTY @@ -136,7 +136,13 @@ def subscriber( # type: ignore[override] ack_first: Annotated[ bool, Doc("Whether to `ack` message at start of consuming or not."), - ] = False, + deprecated( + """ + This option is deprecated and will be removed in 0.7.0 release. + Please, use `ack_policy=AckPolicy.ACK_FIRST` instead. + """, + ), + ] = EMPTY, stream: Annotated[ Union[str, "JStream", None], Doc("Subscribe to NATS Stream with `subject` filter."), @@ -162,10 +168,15 @@ def subscriber( # type: ignore[override] int, Doc("Number of workers to process messages concurrently."), ] = 1, - ack_policy: Annotated[ - AckPolicy, + no_ack: Annotated[ + bool, Doc("Whether to disable **FastStream** auto acknowledgement logic or not."), + deprecated( + "This option was deprecated in 0.6.0 to prior to **ack_policy=AckPolicy.DO_NOTHING**. " + "Scheduled to remove in 0.7.0" + ), ] = EMPTY, + ack_policy: AckPolicy = EMPTY, no_reply: Annotated[ bool, Doc( @@ -221,6 +232,7 @@ def subscriber( # type: ignore[override] ack_first=ack_first, # subscriber args ack_policy=ack_policy, + no_ack=no_ack, no_reply=no_reply, broker_middlewares=self.middlewares, broker_dependencies=self._dependencies, diff --git a/faststream/nats/fastapi/fastapi.py b/faststream/nats/fastapi/fastapi.py index a62318c82b..f11475af4a 100644 --- a/faststream/nats/fastapi/fastapi.py +++ b/faststream/nats/fastapi/fastapi.py @@ -63,7 +63,7 @@ from faststream.nats.publisher.specified import SpecificationPublisher from faststream.nats.schemas import JStream, KvWatch, ObjWatch, PullSub from faststream.security import BaseSecurity - from faststream.specification.schema.tag import Tag, TagDict + from faststream.specification.schema.extra import Tag, TagDict class NatsRouter(StreamRouter["Msg"]): @@ -188,6 +188,10 @@ def __init__( Optional[str], Doc("Nkeys seed to be used."), ] = None, + nkeys_seed_str: Annotated[ + Optional[str], + Doc("Raw nkeys seed to be used."), + ] = None, inbox_prefix: Annotated[ Union[str, bytes], Doc( @@ -245,9 +249,9 @@ def __init__( Doc("AsyncAPI server description."), ] = None, specification_tags: Annotated[ - Optional[Iterable[Union["Tag", "TagDict"]]], + Iterable[Union["Tag", "TagDict"]], Doc("AsyncAPI server tags."), - ] = None, + ] = (), # logging args logger: Annotated[ Optional["LoggerProto"], @@ -515,6 +519,7 @@ def __init__( user_jwt_cb=user_jwt_cb, user_credentials=user_credentials, nkeys_seed=nkeys_seed, + nkeys_seed_str=nkeys_seed_str, inbox_prefix=inbox_prefix, pending_size=pending_size, flush_timeout=flush_timeout, @@ -651,7 +656,13 @@ def subscriber( # type: ignore[override] ack_first: Annotated[ bool, Doc("Whether to `ack` message at start of consuming or not."), - ] = False, + deprecated( + """ + This option is deprecated and will be removed in 0.7.0 release. + Please, use `ack_policy=AckPolicy.ACK_FIRST` instead. + """, + ), + ] = EMPTY, stream: Annotated[ Union[str, "JStream", None], Doc("Subscribe to NATS Stream with `subject` filter."), @@ -677,10 +688,15 @@ def subscriber( # type: ignore[override] int, Doc("Number of workers to process messages concurrently."), ] = 1, - ack_policy: Annotated[ - AckPolicy, + no_ack: Annotated[ + bool, Doc("Whether to disable **FastStream** auto acknowledgement logic or not."), + deprecated( + "This option was deprecated in 0.6.0 to prior to **ack_policy=AckPolicy.DO_NOTHING**. " + "Scheduled to remove in 0.7.0" + ), ] = EMPTY, + ack_policy: AckPolicy = EMPTY, no_reply: Annotated[ bool, Doc( @@ -853,6 +869,7 @@ def subscriber( # type: ignore[override] middlewares=middlewares, max_workers=max_workers, ack_policy=ack_policy, + no_ack=no_ack, no_reply=no_reply, title=title, description=description, diff --git a/faststream/nats/message.py b/faststream/nats/message.py index cbefcce62d..ce7486aceb 100644 --- a/faststream/nats/message.py +++ b/faststream/nats/message.py @@ -17,6 +17,11 @@ async def ack(self) -> None: await self.raw_message.ack() await super().ack() + async def ack_sync(self) -> None: + if not self.raw_message._ackd: + await self.raw_message.ack_sync() + await super().ack() + async def nack( self, delay: Optional[float] = None, diff --git a/faststream/nats/parser.py b/faststream/nats/parser.py index d5ddcfe316..7553ef7882 100644 --- a/faststream/nats/parser.py +++ b/faststream/nats/parser.py @@ -54,9 +54,11 @@ async def decode_message( class NatsParser(NatsBaseParser): """A class to parse NATS core messages.""" - def __init__(self, *, pattern: str) -> None: + def __init__(self, *, pattern: str, is_ack_disabled: bool) -> None: super().__init__(pattern=pattern) + self.is_ack_disabled = is_ack_disabled + async def parse_message( self, message: "Msg", @@ -68,7 +70,8 @@ async def parse_message( headers = message.header or {} - message._ackd = True # prevent Core message from acknowledgement + if self.is_ack_disabled: + message._ackd = True return NatsMessage( raw_message=message, diff --git a/faststream/nats/publisher/producer.py b/faststream/nats/publisher/producer.py index aba0f78349..7feefe3a59 100644 --- a/faststream/nats/publisher/producer.py +++ b/faststream/nats/publisher/producer.py @@ -19,13 +19,14 @@ if TYPE_CHECKING: from nats.aio.client import Client from nats.aio.msg import Msg - from nats.js import JetStreamContext, api + from nats.js import JetStreamContext from faststream._internal.types import ( AsyncCallable, CustomCallable, ) from faststream.nats.response import NatsPublishCommand + from faststream.nats.schemas import PubAck class NatsFastProducer(ProducerProto): @@ -39,7 +40,7 @@ def __init__( parser: Optional["CustomCallable"], decoder: Optional["CustomCallable"], ) -> None: - default = NatsParser(pattern="") + default = NatsParser(pattern="", is_ack_disabled=True) self._parser = resolve_custom_func(parser, default.parse_message) self._decoder = resolve_custom_func(decoder, default.decode_message) @@ -110,7 +111,9 @@ def __init__( parser: Optional["CustomCallable"], decoder: Optional["CustomCallable"], ) -> None: - default = NatsParser(pattern="") # core parser to serializer responses + default = NatsParser( + pattern="", is_ack_disabled=True + ) # core parser to serializer responses self._parser = resolve_custom_func(parser, default.parse_message) self._decoder = resolve_custom_func(decoder, default.decode_message) @@ -126,7 +129,7 @@ def disconnect(self) -> None: async def publish( # type: ignore[override] self, cmd: "NatsPublishCommand", - ) -> "api.PubAck": + ) -> "PubAck": payload, content_type = encode_message(cmd.body) headers_to_send = { diff --git a/faststream/nats/publisher/specified.py b/faststream/nats/publisher/specified.py index 41cfdc27b9..029c62b344 100644 --- a/faststream/nats/publisher/specified.py +++ b/faststream/nats/publisher/specified.py @@ -1,35 +1,38 @@ +from faststream._internal.publisher.specified import ( + SpecificationPublisher as SpecificationPublisherMixin, +) from faststream.nats.publisher.usecase import LogicPublisher from faststream.specification.asyncapi.utils import resolve_payloads +from faststream.specification.schema import Message, Operation, PublisherSpec from faststream.specification.schema.bindings import ChannelBinding, nats -from faststream.specification.schema.channel import Channel -from faststream.specification.schema.message import CorrelationId, Message -from faststream.specification.schema.operation import Operation -class SpecificationPublisher(LogicPublisher): +class SpecificationPublisher( + SpecificationPublisherMixin, + LogicPublisher, +): """A class to represent a NATS publisher.""" - def get_name(self) -> str: + def get_default_name(self) -> str: return f"{self.subject}:Publisher" - def get_schema(self) -> dict[str, Channel]: + def get_schema(self) -> dict[str, PublisherSpec]: payloads = self.get_payloads() return { - self.name: Channel( + self.name: PublisherSpec( description=self.description, - publish=Operation( + operation=Operation( message=Message( title=f"{self.name}:Message", payload=resolve_payloads(payloads, "Publisher"), - correlationId=CorrelationId( - location="$message.header#/correlation_id", - ), ), + bindings=None, ), bindings=ChannelBinding( nats=nats.ChannelBinding( subject=self.subject, + queue=None, ), ), ), diff --git a/faststream/nats/publisher/usecase.py b/faststream/nats/publisher/usecase.py index 9d3ccd92dc..05b98bb20a 100644 --- a/faststream/nats/publisher/usecase.py +++ b/faststream/nats/publisher/usecase.py @@ -1,7 +1,6 @@ from collections.abc import Iterable from typing import ( TYPE_CHECKING, - Any, Optional, Union, ) @@ -15,13 +14,11 @@ from faststream.response.publish_type import PublishType if TYPE_CHECKING: - from nats.js import api - from faststream._internal.basic_types import SendableMessage from faststream._internal.types import BrokerMiddleware, PublisherMiddleware from faststream.nats.message import NatsMessage from faststream.nats.publisher.producer import NatsFastProducer, NatsJSFastProducer - from faststream.nats.schemas import JStream + from faststream.nats.schemas import JStream, PubAck from faststream.response.response import PublishCommand @@ -41,21 +38,11 @@ def __init__( # Publisher args broker_middlewares: Iterable["BrokerMiddleware[Msg]"], middlewares: Iterable["PublisherMiddleware"], - # AsyncAPI args - schema_: Optional[Any], - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, ) -> None: """Initialize NATS publisher object.""" super().__init__( broker_middlewares=broker_middlewares, middlewares=middlewares, - # AsyncAPI args - schema_=schema_, - title_=title_, - description_=description_, - include_in_schema=include_in_schema, ) self.subject = subject @@ -86,7 +73,7 @@ async def publish( correlation_id: Optional[str] = None, stream: Optional[str] = None, timeout: Optional[float] = None, - ) -> "api.PubAck": ... + ) -> "PubAck": ... @override async def publish( @@ -98,7 +85,7 @@ async def publish( correlation_id: Optional[str] = None, stream: Optional[str] = None, timeout: Optional[float] = None, - ) -> Optional["api.PubAck"]: + ) -> Optional["PubAck"]: """Publish message directly. Args: @@ -123,7 +110,7 @@ async def publish( Returns: `None` if you publishes a regular message. - `nats.js.api.PubAck` if you publishes a message to stream. + `faststream.nats.PubAck` if you publishes a message to stream. """ cmd = NatsPublishCommand( message, diff --git a/faststream/nats/router.py b/faststream/nats/router.py index be895eb8af..5b9defc5d3 100644 --- a/faststream/nats/router.py +++ b/faststream/nats/router.py @@ -9,7 +9,7 @@ ) from nats.js import api -from typing_extensions import Doc +from typing_extensions import Doc, deprecated from faststream._internal.broker.router import ( ArgsContainer, @@ -226,7 +226,13 @@ def __init__( ack_first: Annotated[ bool, Doc("Whether to `ack` message at start of consuming or not."), - ] = False, + deprecated( + """ + This option is deprecated and will be removed in 0.7.0 release. + Please, use `ack_policy=AckPolicy.ACK_FIRST` instead. + """, + ), + ] = EMPTY, stream: Annotated[ Union[str, "JStream", None], Doc("Subscribe to NATS Stream with `subject` filter."), @@ -252,10 +258,15 @@ def __init__( int, Doc("Number of workers to process messages concurrently."), ] = 1, - ack_policy: Annotated[ - AckPolicy, + no_ack: Annotated[ + bool, Doc("Whether to disable **FastStream** auto acknowledgement logic or not."), + deprecated( + "This option was deprecated in 0.6.0 to prior to **ack_policy=AckPolicy.DO_NOTHING**. " + "Scheduled to remove in 0.7.0" + ), ] = EMPTY, + ack_policy: AckPolicy = EMPTY, no_reply: Annotated[ bool, Doc( @@ -306,6 +317,7 @@ def __init__( decoder=decoder, middlewares=middlewares, ack_policy=ack_policy, + no_ack=no_ack, no_reply=no_reply, title=title, description=description, diff --git a/faststream/nats/schemas/__init__.py b/faststream/nats/schemas/__init__.py index 1edd51bcbe..accadfc731 100644 --- a/faststream/nats/schemas/__init__.py +++ b/faststream/nats/schemas/__init__.py @@ -1,3 +1,5 @@ +from nats.js.api import PubAck + from faststream.nats.schemas.js_stream import JStream from faststream.nats.schemas.kv_watch import KvWatch from faststream.nats.schemas.obj_watch import ObjWatch @@ -7,5 +9,6 @@ "JStream", "KvWatch", "ObjWatch", + "PubAck", "PullSub", ) diff --git a/faststream/nats/schemas/js_stream.py b/faststream/nats/schemas/js_stream.py index 3ad4fc2e4f..62e97124f6 100644 --- a/faststream/nats/schemas/js_stream.py +++ b/faststream/nats/schemas/js_stream.py @@ -6,7 +6,6 @@ from faststream._internal.proto import NameRequired from faststream._internal.utils.path import compile_path -from faststream.middlewares import AckPolicy if TYPE_CHECKING: from re import Pattern @@ -121,10 +120,13 @@ def __init__( "cluster may be available but for reads only.", ), ] = None, - ack_policy: Annotated[ - AckPolicy, - Doc("Whether to disable **FastStream** auto acknowledgement logic or not."), - ] = AckPolicy.REJECT_ON_ERROR, + no_ack: Annotated[ + bool, + Doc( + "Should stream acknowledge writes or not. Without acks publisher can't determine, does message " + "received by stream or not." + ), + ] = False, template_owner: Optional[str] = None, duplicate_window: Annotated[ float, @@ -189,7 +191,6 @@ def __init__( super().__init__(name) subjects = subjects or [] - no_ack = ack_policy is AckPolicy.DO_NOTHING self.subjects = subjects self.declare = declare diff --git a/faststream/nats/subscriber/factory.py b/faststream/nats/subscriber/factory.py index ac149a1f68..534f4868f4 100644 --- a/faststream/nats/subscriber/factory.py +++ b/faststream/nats/subscriber/factory.py @@ -63,6 +63,7 @@ def create_subscriber( stream: Optional["JStream"], # Subscriber args ack_policy: "AckPolicy", + no_ack: bool, no_reply: bool, broker_dependencies: Iterable["Dependant"], broker_middlewares: Iterable["BrokerMiddleware[Any]"], @@ -96,6 +97,7 @@ def create_subscriber( headers_only=headers_only, pull_sub=pull_sub, ack_policy=ack_policy, + no_ack=no_ack, kv_watch=kv_watch, obj_watch=obj_watch, ack_first=ack_first, @@ -103,6 +105,12 @@ def create_subscriber( stream=stream, ) + if ack_first is not EMPTY: + ack_policy = AckPolicy.ACK_FIRST if ack_first else AckPolicy.REJECT_ON_ERROR + + if no_ack is not EMPTY: + no_ack = AckPolicy.DO_NOTHING if no_ack else EMPTY + if ack_policy is EMPTY: ack_policy = AckPolicy.REJECT_ON_ERROR @@ -134,8 +142,10 @@ def create_subscriber( else: # JS Push Subscriber if ack_policy is AckPolicy.ACK_FIRST: - ack_first = True + manual_ack = False ack_policy = AckPolicy.DO_NOTHING + else: + manual_ack = True extra_options.update( { @@ -144,7 +154,7 @@ def create_subscriber( "flow_control": flow_control, "deliver_policy": deliver_policy, "headers_only": headers_only, - "manual_ack": not ack_first, + "manual_ack": manual_ack, }, ) @@ -192,6 +202,7 @@ def create_subscriber( extra_options=extra_options, # Subscriber args no_reply=no_reply, + ack_policy=ack_policy, broker_dependencies=broker_dependencies, broker_middlewares=broker_middlewares, # Specification @@ -208,6 +219,7 @@ def create_subscriber( extra_options=extra_options, # Subscriber args no_reply=no_reply, + ack_policy=ack_policy, broker_dependencies=broker_dependencies, broker_middlewares=broker_middlewares, # Specification @@ -330,32 +342,11 @@ def _validate_input_for_misconfigure( # noqa: PLR0915 kv_watch: Optional["KvWatch"], obj_watch: Optional["ObjWatch"], ack_policy: "AckPolicy", # default EMPTY - ack_first: bool, # default False + no_ack: bool, # default EMPTY + ack_first: bool, # default EMPTY max_workers: int, # default 1 stream: Optional["JStream"], ) -> None: - if not subject and not config: - msg = "You must provide either the `subject` or `config` option." - raise SetupError(msg) - - if stream and kv_watch: - msg = "You can't use both the `stream` and `kv_watch` options simultaneously." - raise SetupError(msg) - - if stream and obj_watch: - msg = "You can't use both the `stream` and `obj_watch` options simultaneously." - raise SetupError(msg) - - if kv_watch and obj_watch: - msg = ( - "You can't use both the `kv_watch` and `obj_watch` options simultaneously." - ) - raise SetupError(msg) - - if pull_sub and not stream: - msg = "JetStream Pull Subscriber can only be used with the `stream` option." - raise SetupError(msg) - if ack_policy is not EMPTY: if obj_watch is not None: warnings.warn( @@ -371,20 +362,74 @@ def _validate_input_for_misconfigure( # noqa: PLR0915 stacklevel=4, ) - elif stream is None: + elif stream is None and ack_policy is not AckPolicy.DO_NOTHING: + warnings.warn( + ( + "Core subscriber supports only `ack_policy=AckPolicy.DO_NOTHING` option for very specific cases. " + "If you are using different option, probably, you should use JetStream Subscriber instead." + ), + RuntimeWarning, + stacklevel=4, + ) + + if max_msgs > 0 and any((stream, kv_watch, obj_watch)): warnings.warn( - "You can't use acknowledgement policy with core subscriber. Use JetStream instead.", + "The `max_msgs` option can be used only with a NATS Core Subscriber.", RuntimeWarning, stacklevel=4, ) - if max_msgs > 0 and any((stream, kv_watch, obj_watch)): + if ack_first is not EMPTY: + warnings.warn( + "`ack_first` option was deprecated in prior to `ack_policy=AckPolicy.ACK_FIRST`. Scheduled to remove in 0.7.0", + category=DeprecationWarning, + stacklevel=4, + ) + + if ack_policy is not EMPTY: + msg = "You can't use deprecated `ack_first` and `ack_policy` simultaneously. Please, use `ack_policy` only." + raise SetupError(msg) + + ack_policy = AckPolicy.ACK_FIRST if ack_first else AckPolicy.REJECT_ON_ERROR + + if no_ack is not EMPTY: warnings.warn( - "The `max_msgs` option can be used only with a NATS Core Subscriber.", - RuntimeWarning, + "`no_ack` option was deprecated in prior to `ack_policy=AckPolicy.DO_NOTHING`. Scheduled to remove in 0.7.0", + category=DeprecationWarning, stacklevel=4, ) + if ack_policy is not EMPTY: + msg = "You can't use deprecated `no_ack` and `ack_policy` simultaneously. Please, use `ack_policy` only." + raise SetupError(msg) + + no_ack = AckPolicy.DO_NOTHING if no_ack else EMPTY + + if ack_policy is EMPTY: + ack_policy = AckPolicy.REJECT_ON_ERROR + + if not subject and not config: + msg = "You must provide either the `subject` or `config` option." + raise SetupError(msg) + + if stream and kv_watch: + msg = "You can't use both the `stream` and `kv_watch` options simultaneously." + raise SetupError(msg) + + if stream and obj_watch: + msg = "You can't use both the `stream` and `obj_watch` options simultaneously." + raise SetupError(msg) + + if kv_watch and obj_watch: + msg = ( + "You can't use both the `kv_watch` and `obj_watch` options simultaneously." + ) + raise SetupError(msg) + + if pull_sub and not stream: + msg = "JetStream Pull Subscriber can only be used with the `stream` option." + raise SetupError(msg) + if not stream: if obj_watch or kv_watch: # Obj/Kv Subscriber @@ -466,9 +511,9 @@ def _validate_input_for_misconfigure( # noqa: PLR0915 stacklevel=4, ) - if ack_first: + if ack_policy is AckPolicy.ACK_FIRST: warnings.warn( - message="The `ack_first` option can be used only with JetStream Push Subscription.", + message="The `ack_policy=AckPolicy.ACK_FIRST:` option can be used only with JetStream Push Subscription.", category=RuntimeWarning, stacklevel=4, ) @@ -489,9 +534,9 @@ def _validate_input_for_misconfigure( # noqa: PLR0915 stacklevel=4, ) - if ack_first: + if ack_policy is AckPolicy.ACK_FIRST: warnings.warn( - message="The `ack_first` option has no effect with JetStream Pull Subscription. It can only be used with JetStream Push Subscription.", + message="The `ack_policy=AckPolicy.ACK_FIRST` option has no effect with JetStream Pull Subscription. It can only be used with JetStream Push Subscription.", category=RuntimeWarning, stacklevel=4, ) diff --git a/faststream/nats/subscriber/specified.py b/faststream/nats/subscriber/specified.py index 2c3387ded3..dc8c6720ef 100644 --- a/faststream/nats/subscriber/specified.py +++ b/faststream/nats/subscriber/specified.py @@ -1,7 +1,8 @@ -from typing import Any - from typing_extensions import override +from faststream._internal.subscriber.specified import ( + SpecificationSubscriber as SpecificationSubscriberMixin, +) from faststream.nats.subscriber.usecases import ( BatchPullStreamSubscriber, ConcurrentCoreSubscriber, @@ -9,38 +10,35 @@ ConcurrentPushStreamSubscriber, CoreSubscriber, KeyValueWatchSubscriber, - LogicSubscriber, ObjStoreWatchSubscriber, PullStreamSubscriber, PushStreamSubscription, ) from faststream.specification.asyncapi.utils import resolve_payloads +from faststream.specification.schema import Message, Operation, SubscriberSpec from faststream.specification.schema.bindings import ChannelBinding, nats -from faststream.specification.schema.channel import Channel -from faststream.specification.schema.message import CorrelationId, Message -from faststream.specification.schema.operation import Operation -class SpecificationSubscriber(LogicSubscriber[Any]): +class SpecificationSubscriber(SpecificationSubscriberMixin): """A class to represent a NATS handler.""" - def get_name(self) -> str: + subject: str + + def get_default_name(self) -> str: return f"{self.subject}:{self.call_name}" - def get_schema(self) -> dict[str, Channel]: + def get_schema(self) -> dict[str, SubscriberSpec]: payloads = self.get_payloads() return { - self.name: Channel( + self.name: SubscriberSpec( description=self.description, - subscribe=Operation( + operation=Operation( message=Message( title=f"{self.name}:Message", payload=resolve_payloads(payloads), - correlationId=CorrelationId( - location="$message.header#/correlation_id", - ), ), + bindings=None, ), bindings=ChannelBinding( nats=nats.ChannelBinding( @@ -108,11 +106,11 @@ class SpecificationKeyValueWatchSubscriber( """KeyValueWatch consumer with Specification methods.""" @override - def get_name(self) -> str: + def get_default_name(self) -> str: return "" @override - def get_schema(self) -> dict[str, Channel]: + def get_schema(self) -> dict[str, SubscriberSpec]: return {} @@ -123,9 +121,9 @@ class SpecificationObjStoreWatchSubscriber( """ObjStoreWatch consumer with Specification methods.""" @override - def get_name(self) -> str: + def get_default_name(self) -> str: return "" @override - def get_schema(self) -> dict[str, Channel]: + def get_schema(self) -> dict[str, SubscriberSpec]: return {} diff --git a/faststream/nats/subscriber/usecase.py b/faststream/nats/subscriber/usecase.py deleted file mode 100644 index e2c6b207e1..0000000000 --- a/faststream/nats/subscriber/usecase.py +++ /dev/null @@ -1,1198 +0,0 @@ -from abc import abstractmethod -from collections.abc import Awaitable, Iterable -from contextlib import suppress -from typing import ( - TYPE_CHECKING, - Annotated, - Any, - Callable, - Generic, - Optional, - cast, -) - -import anyio -from nats.errors import ConnectionClosedError, TimeoutError -from nats.js.api import ConsumerConfig, ObjectInfo -from typing_extensions import Doc, override - -from faststream._internal.subscriber.mixins import ConcurrentMixin, TasksMixin -from faststream._internal.subscriber.usecase import SubscriberUsecase -from faststream._internal.subscriber.utils import process_msg -from faststream._internal.types import MsgType -from faststream.middlewares import AckPolicy -from faststream.nats.helpers import KVBucketDeclarer, OSBucketDeclarer -from faststream.nats.message import NatsMessage -from faststream.nats.parser import ( - BatchParser, - JsParser, - KvParser, - NatsParser, - ObjParser, -) -from faststream.nats.publisher.fake import NatsFakePublisher -from faststream.nats.schemas.js_stream import compile_nats_wildcard -from faststream.nats.subscriber.adapters import ( - UnsubscribeAdapter, - Unsubscriptable, -) - -from .state import ConnectedSubscriberState, EmptySubscriberState, SubscriberState - -if TYPE_CHECKING: - from fast_depends.dependencies import Dependant - from nats.aio.msg import Msg - from nats.aio.subscription import Subscription - from nats.js import JetStreamContext - from nats.js.kv import KeyValue - from nats.js.object_store import ObjectStore - - from faststream._internal.basic_types import ( - AnyDict, - SendableMessage, - ) - from faststream._internal.publisher.proto import BasePublisherProto, ProducerProto - from faststream._internal.state import ( - BrokerState as BasicState, - Pointer, - ) - from faststream._internal.types import ( - AsyncCallable, - BrokerMiddleware, - CustomCallable, - ) - from faststream.message import StreamMessage - from faststream.nats.broker.state import BrokerState - from faststream.nats.helpers import KVBucketDeclarer, OSBucketDeclarer - from faststream.nats.message import NatsKvMessage, NatsObjMessage - from faststream.nats.schemas import JStream, KvWatch, ObjWatch, PullSub - - -class LogicSubscriber(SubscriberUsecase[MsgType], Generic[MsgType]): - """A class to represent a NATS handler.""" - - subscription: Optional[Unsubscriptable] - _fetch_sub: Optional[Unsubscriptable] - producer: Optional["ProducerProto"] - - def __init__( - self, - *, - subject: str, - config: "ConsumerConfig", - extra_options: Optional["AnyDict"], - # Subscriber args - default_parser: "AsyncCallable", - default_decoder: "AsyncCallable", - ack_policy: "AckPolicy", - no_reply: bool, - broker_dependencies: Iterable["Dependant"], - broker_middlewares: Iterable["BrokerMiddleware[MsgType]"], - # AsyncAPI args - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, - ) -> None: - self.subject = subject - self.config = config - - self.extra_options = extra_options or {} - - super().__init__( - default_parser=default_parser, - default_decoder=default_decoder, - # Propagated args - ack_policy=ack_policy, - no_reply=no_reply, - broker_middlewares=broker_middlewares, - broker_dependencies=broker_dependencies, - # AsyncAPI args - title_=title_, - description_=description_, - include_in_schema=include_in_schema, - ) - - self._fetch_sub = None - self.subscription = None - self.producer = None - - self._connection_state: SubscriberState = EmptySubscriberState() - - @override - def _setup( # type: ignore[override] - self, - *, - connection_state: "BrokerState", - os_declarer: "OSBucketDeclarer", - kv_declarer: "KVBucketDeclarer", - # basic args - extra_context: "AnyDict", - # broker options - broker_parser: Optional["CustomCallable"], - broker_decoder: Optional["CustomCallable"], - # dependant args - state: "Pointer[BasicState]", - ) -> None: - self._connection_state = ConnectedSubscriberState( - parent_state=connection_state, - os_declarer=os_declarer, - kv_declarer=kv_declarer, - ) - - super()._setup( - extra_context=extra_context, - broker_parser=broker_parser, - broker_decoder=broker_decoder, - state=state, - ) - - @property - def clear_subject(self) -> str: - """Compile `test.{name}` to `test.*` subject.""" - _, path = compile_nats_wildcard(self.subject) - return path - - async def start(self) -> None: - """Create NATS subscription and start consume tasks.""" - await super().start() - - if self.calls: - await self._create_subscription() - - async def close(self) -> None: - """Clean up handler subscription, cancel consume task in graceful mode.""" - await super().close() - - if self.subscription is not None: - await self.subscription.unsubscribe() - self.subscription = None - - if self._fetch_sub is not None: - await self._fetch_sub.unsubscribe() - self.subscription = None - - @abstractmethod - async def _create_subscription(self) -> None: - """Create NATS subscription object to consume messages.""" - raise NotImplementedError - - @staticmethod - def build_log_context( - message: Annotated[ - Optional["StreamMessage[MsgType]"], - Doc("Message which we are building context for"), - ], - subject: Annotated[ - str, - Doc("NATS subject we are listening"), - ], - *, - queue: Annotated[ - str, - Doc("Using queue group name"), - ] = "", - stream: Annotated[ - str, - Doc("Stream object we are listening"), - ] = "", - ) -> dict[str, str]: - """Static method to build log context out of `self.consume` scope.""" - return { - "subject": subject, - "queue": queue, - "stream": stream, - "message_id": getattr(message, "message_id", ""), - } - - def add_prefix(self, prefix: str) -> None: - """Include Subscriber in router.""" - if self.subject: - self.subject = f"{prefix}{self.subject}" - else: - self.config.filter_subjects = [ - f"{prefix}{subject}" for subject in (self.config.filter_subjects or ()) - ] - - @property - def _resolved_subject_string(self) -> str: - return self.subject or ", ".join(self.config.filter_subjects or ()) - - -class _DefaultSubscriber(LogicSubscriber[MsgType]): - def __init__( - self, - *, - subject: str, - config: "ConsumerConfig", - # default args - extra_options: Optional["AnyDict"], - # Subscriber args - default_parser: "AsyncCallable", - default_decoder: "AsyncCallable", - ack_policy: "AckPolicy", - no_reply: bool, - broker_dependencies: Iterable["Dependant"], - broker_middlewares: Iterable["BrokerMiddleware[MsgType]"], - # AsyncAPI args - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, - ) -> None: - super().__init__( - subject=subject, - config=config, - extra_options=extra_options, - # subscriber args - default_parser=default_parser, - default_decoder=default_decoder, - # Propagated args - ack_policy=ack_policy, - no_reply=no_reply, - broker_middlewares=broker_middlewares, - broker_dependencies=broker_dependencies, - # AsyncAPI args - description_=description_, - title_=title_, - include_in_schema=include_in_schema, - ) - - def _make_response_publisher( - self, - message: "StreamMessage[Any]", - ) -> Iterable["BasePublisherProto"]: - """Create Publisher objects to use it as one of `publishers` in `self.consume` scope.""" - return ( - NatsFakePublisher( - producer=self._state.get().producer, - subject=message.reply_to, - ), - ) - - def get_log_context( - self, - message: Annotated[ - Optional["StreamMessage[MsgType]"], - Doc("Message which we are building context for"), - ], - ) -> dict[str, str]: - """Log context factory using in `self.consume` scope.""" - return self.build_log_context( - message=message, - subject=self.subject, - ) - - -class CoreSubscriber(_DefaultSubscriber["Msg"]): - subscription: Optional["Subscription"] - _fetch_sub: Optional["Subscription"] - - def __init__( - self, - *, - # default args - subject: str, - config: "ConsumerConfig", - queue: str, - extra_options: Optional["AnyDict"], - # Subscriber args - no_reply: bool, - broker_dependencies: Iterable["Dependant"], - broker_middlewares: Iterable["BrokerMiddleware[Msg]"], - # AsyncAPI args - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, - ) -> None: - parser_ = NatsParser(pattern=subject) - - self.queue = queue - - super().__init__( - subject=subject, - config=config, - extra_options=extra_options, - # subscriber args - default_parser=parser_.parse_message, - default_decoder=parser_.decode_message, - # Propagated args - ack_policy=AckPolicy.DO_NOTHING, - no_reply=no_reply, - broker_middlewares=broker_middlewares, - broker_dependencies=broker_dependencies, - # AsyncAPI args - description_=description_, - title_=title_, - include_in_schema=include_in_schema, - ) - - @override - async def get_one( - self, - *, - timeout: float = 5.0, - ) -> "Optional[NatsMessage]": - assert ( # nosec B101 - not self.calls - ), "You can't use `get_one` method if subscriber has registered handlers." - - if self._fetch_sub is None: - fetch_sub = self._fetch_sub = await self._connection_state.client.subscribe( - subject=self.clear_subject, - queue=self.queue, - **self.extra_options, - ) - else: - fetch_sub = self._fetch_sub - - try: - raw_message = await fetch_sub.next_msg(timeout=timeout) - except TimeoutError: - return None - - context = self._state.get().di_state.context - - msg: NatsMessage = await process_msg( # type: ignore[assignment] - msg=raw_message, - middlewares=( - m(raw_message, context=context) for m in self._broker_middlewares - ), - parser=self._parser, - decoder=self._decoder, - ) - return msg - - @override - async def _create_subscription(self) -> None: - """Create NATS subscription and start consume task.""" - if self.subscription: - return - - self.subscription = await self._connection_state.client.subscribe( - subject=self.clear_subject, - queue=self.queue, - cb=self.consume, - **self.extra_options, - ) - - def get_log_context( - self, - message: Annotated[ - Optional["StreamMessage[Msg]"], - Doc("Message which we are building context for"), - ], - ) -> dict[str, str]: - """Log context factory using in `self.consume` scope.""" - return self.build_log_context( - message=message, - subject=self.subject, - queue=self.queue, - ) - - -class ConcurrentCoreSubscriber( - ConcurrentMixin, - CoreSubscriber, -): - def __init__( - self, - *, - max_workers: int, - # default args - subject: str, - config: "ConsumerConfig", - queue: str, - extra_options: Optional["AnyDict"], - # Subscriber args - no_reply: bool, - broker_dependencies: Iterable["Dependant"], - broker_middlewares: Iterable["BrokerMiddleware[Msg]"], - # AsyncAPI args - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, - ) -> None: - super().__init__( - max_workers=max_workers, - # basic args - subject=subject, - config=config, - queue=queue, - extra_options=extra_options, - # Propagated args - no_reply=no_reply, - broker_middlewares=broker_middlewares, - broker_dependencies=broker_dependencies, - # AsyncAPI args - description_=description_, - title_=title_, - include_in_schema=include_in_schema, - ) - - @override - async def _create_subscription(self) -> None: - """Create NATS subscription and start consume task.""" - if self.subscription: - return - - self.start_consume_task() - - self.subscription = await self._connection_state.client.subscribe( - subject=self.clear_subject, - queue=self.queue, - cb=self._put_msg, - **self.extra_options, - ) - - -class _StreamSubscriber(_DefaultSubscriber["Msg"]): - _fetch_sub: Optional["JetStreamContext.PullSubscription"] - - def __init__( - self, - *, - stream: "JStream", - # default args - subject: str, - config: "ConsumerConfig", - queue: str, - extra_options: Optional["AnyDict"], - # Subscriber args - ack_policy: "AckPolicy", - no_reply: bool, - broker_dependencies: Iterable["Dependant"], - broker_middlewares: Iterable["BrokerMiddleware[Msg]"], - # AsyncAPI args - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, - ) -> None: - parser_ = JsParser(pattern=subject) - - self.queue = queue - self.stream = stream - - super().__init__( - subject=subject, - config=config, - extra_options=extra_options, - # subscriber args - default_parser=parser_.parse_message, - default_decoder=parser_.decode_message, - # Propagated args - ack_policy=ack_policy, - no_reply=no_reply, - broker_middlewares=broker_middlewares, - broker_dependencies=broker_dependencies, - # AsyncAPI args - description_=description_, - title_=title_, - include_in_schema=include_in_schema, - ) - - def get_log_context( - self, - message: Annotated[ - Optional["StreamMessage[Msg]"], - Doc("Message which we are building context for"), - ], - ) -> dict[str, str]: - """Log context factory using in `self.consume` scope.""" - return self.build_log_context( - message=message, - subject=self._resolved_subject_string, - queue=self.queue, - stream=self.stream.name, - ) - - @override - async def get_one( - self, - *, - timeout: float = 5, - ) -> Optional["NatsMessage"]: - assert ( # nosec B101 - not self.calls - ), "You can't use `get_one` method if subscriber has registered handlers." - - if not self._fetch_sub: - extra_options = { - "pending_bytes_limit": self.extra_options["pending_bytes_limit"], - "pending_msgs_limit": self.extra_options["pending_msgs_limit"], - "durable": self.extra_options["durable"], - "stream": self.extra_options["stream"], - } - if inbox_prefix := self.extra_options.get("inbox_prefix"): - extra_options["inbox_prefix"] = inbox_prefix - - self._fetch_sub = await self._connection_state.js.pull_subscribe( - subject=self.clear_subject, - config=self.config, - **extra_options, - ) - - try: - raw_message = ( - await self._fetch_sub.fetch( - batch=1, - timeout=timeout, - ) - )[0] - except (TimeoutError, ConnectionClosedError): - return None - - context = self._state.get().di_state.context - - msg: NatsMessage = await process_msg( # type: ignore[assignment] - msg=raw_message, - middlewares=( - m(raw_message, context=context) for m in self._broker_middlewares - ), - parser=self._parser, - decoder=self._decoder, - ) - return msg - - -class PushStreamSubscription(_StreamSubscriber): - subscription: Optional["JetStreamContext.PushSubscription"] - - @override - async def _create_subscription(self) -> None: - """Create NATS subscription and start consume task.""" - if self.subscription: - return - - self.subscription = await self._connection_state.js.subscribe( - subject=self.clear_subject, - queue=self.queue, - cb=self.consume, - config=self.config, - **self.extra_options, - ) - - -class ConcurrentPushStreamSubscriber( - ConcurrentMixin, - _StreamSubscriber, -): - subscription: Optional["JetStreamContext.PushSubscription"] - - def __init__( - self, - *, - max_workers: int, - stream: "JStream", - # default args - subject: str, - config: "ConsumerConfig", - queue: str, - extra_options: Optional["AnyDict"], - # Subscriber args - ack_policy: "AckPolicy", - no_reply: bool, - broker_dependencies: Iterable["Dependant"], - broker_middlewares: Iterable["BrokerMiddleware[Msg]"], - # AsyncAPI args - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, - ) -> None: - super().__init__( - max_workers=max_workers, - # basic args - stream=stream, - subject=subject, - config=config, - queue=queue, - extra_options=extra_options, - # Propagated args - ack_policy=ack_policy, - no_reply=no_reply, - broker_middlewares=broker_middlewares, - broker_dependencies=broker_dependencies, - # AsyncAPI args - description_=description_, - title_=title_, - include_in_schema=include_in_schema, - ) - - @override - async def _create_subscription(self) -> None: - """Create NATS subscription and start consume task.""" - if self.subscription: - return - - self.start_consume_task() - - self.subscription = await self._connection_state.js.subscribe( - subject=self.clear_subject, - queue=self.queue, - cb=self._put_msg, - config=self.config, - **self.extra_options, - ) - - -class PullStreamSubscriber( - TasksMixin, - _StreamSubscriber, -): - subscription: Optional["JetStreamContext.PullSubscription"] - - def __init__( - self, - *, - pull_sub: "PullSub", - stream: "JStream", - # default args - subject: str, - config: "ConsumerConfig", - extra_options: Optional["AnyDict"], - # Subscriber args - ack_policy: "AckPolicy", - no_reply: bool, - broker_dependencies: Iterable["Dependant"], - broker_middlewares: Iterable["BrokerMiddleware[Msg]"], - # AsyncAPI args - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, - ) -> None: - self.pull_sub = pull_sub - - super().__init__( - # basic args - stream=stream, - subject=subject, - config=config, - extra_options=extra_options, - queue="", - # Propagated args - ack_policy=ack_policy, - no_reply=no_reply, - broker_middlewares=broker_middlewares, - broker_dependencies=broker_dependencies, - # AsyncAPI args - description_=description_, - title_=title_, - include_in_schema=include_in_schema, - ) - - @override - async def _create_subscription(self) -> None: - """Create NATS subscription and start consume task.""" - if self.subscription: - return - - self.subscription = await self._connection_state.js.pull_subscribe( - subject=self.clear_subject, - config=self.config, - **self.extra_options, - ) - self.add_task(self._consume_pull(cb=self.consume)) - - async def _consume_pull( - self, - cb: Callable[["Msg"], Awaitable["SendableMessage"]], - ) -> None: - """Endless task consuming messages using NATS Pull subscriber.""" - assert self.subscription # nosec B101 - - while self.running: # pragma: no branch - messages = [] - with suppress(TimeoutError, ConnectionClosedError): - messages = await self.subscription.fetch( - batch=self.pull_sub.batch_size, - timeout=self.pull_sub.timeout, - ) - - if messages: - async with anyio.create_task_group() as tg: - for msg in messages: - tg.start_soon(cb, msg) - - -class ConcurrentPullStreamSubscriber( - ConcurrentMixin, - PullStreamSubscriber, -): - def __init__( - self, - *, - max_workers: int, - # default args - pull_sub: "PullSub", - stream: "JStream", - subject: str, - config: "ConsumerConfig", - extra_options: Optional["AnyDict"], - # Subscriber args - ack_policy: "AckPolicy", - no_reply: bool, - broker_dependencies: Iterable["Dependant"], - broker_middlewares: Iterable["BrokerMiddleware[Msg]"], - # AsyncAPI args - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, - ) -> None: - super().__init__( - max_workers=max_workers, - # basic args - pull_sub=pull_sub, - stream=stream, - subject=subject, - config=config, - extra_options=extra_options, - # Propagated args - ack_policy=ack_policy, - no_reply=no_reply, - broker_middlewares=broker_middlewares, - broker_dependencies=broker_dependencies, - # AsyncAPI args - description_=description_, - title_=title_, - include_in_schema=include_in_schema, - ) - - @override - async def _create_subscription(self) -> None: - """Create NATS subscription and start consume task.""" - if self.subscription: - return - - self.start_consume_task() - - self.subscription = await self._connection_state.js.pull_subscribe( - subject=self.clear_subject, - config=self.config, - **self.extra_options, - ) - self.add_task(self._consume_pull(cb=self._put_msg)) - - -class BatchPullStreamSubscriber( - TasksMixin, - _DefaultSubscriber[list["Msg"]], -): - """Batch-message consumer class.""" - - subscription: Optional["JetStreamContext.PullSubscription"] - _fetch_sub: Optional["JetStreamContext.PullSubscription"] - - def __init__( - self, - *, - # default args - subject: str, - config: "ConsumerConfig", - stream: "JStream", - pull_sub: "PullSub", - extra_options: Optional["AnyDict"], - # Subscriber args - ack_policy: "AckPolicy", - no_reply: bool, - broker_dependencies: Iterable["Dependant"], - broker_middlewares: Iterable["BrokerMiddleware[list[Msg]]"], - # AsyncAPI args - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, - ) -> None: - parser = BatchParser(pattern=subject) - - self.stream = stream - self.pull_sub = pull_sub - - super().__init__( - subject=subject, - config=config, - extra_options=extra_options, - # subscriber args - default_parser=parser.parse_batch, - default_decoder=parser.decode_batch, - # Propagated args - ack_policy=ack_policy, - no_reply=no_reply, - broker_middlewares=broker_middlewares, - broker_dependencies=broker_dependencies, - # AsyncAPI args - description_=description_, - title_=title_, - include_in_schema=include_in_schema, - ) - - @override - async def get_one( - self, - *, - timeout: float = 5, - ) -> Optional["NatsMessage"]: - assert ( # nosec B101 - not self.calls - ), "You can't use `get_one` method if subscriber has registered handlers." - - if not self._fetch_sub: - fetch_sub = ( - self._fetch_sub - ) = await self._connection_state.js.pull_subscribe( - subject=self.clear_subject, - config=self.config, - **self.extra_options, - ) - else: - fetch_sub = self._fetch_sub - - try: - raw_message = await fetch_sub.fetch( - batch=1, - timeout=timeout, - ) - except TimeoutError: - return None - - context = self._state.get().di_state.context - - return cast( - NatsMessage, - await process_msg( - msg=raw_message, - middlewares=( - m(raw_message, context=context) for m in self._broker_middlewares - ), - parser=self._parser, - decoder=self._decoder, - ), - ) - - @override - async def _create_subscription(self) -> None: - """Create NATS subscription and start consume task.""" - if self.subscription: - return - - self.subscription = await self._connection_state.js.pull_subscribe( - subject=self.clear_subject, - config=self.config, - **self.extra_options, - ) - self.add_task(self._consume_pull()) - - async def _consume_pull(self) -> None: - """Endless task consuming messages using NATS Pull subscriber.""" - assert self.subscription, "You should call `create_subscription` at first." # nosec B101 - - while self.running: # pragma: no branch - with suppress(TimeoutError, ConnectionClosedError): - messages = await self.subscription.fetch( - batch=self.pull_sub.batch_size, - timeout=self.pull_sub.timeout, - ) - - if messages: - await self.consume(messages) - - -class KeyValueWatchSubscriber( - TasksMixin, - LogicSubscriber["KeyValue.Entry"], -): - subscription: Optional["UnsubscribeAdapter[KeyValue.KeyWatcher]"] - _fetch_sub: Optional[UnsubscribeAdapter["KeyValue.KeyWatcher"]] - - def __init__( - self, - *, - subject: str, - config: "ConsumerConfig", - kv_watch: "KvWatch", - broker_dependencies: Iterable["Dependant"], - broker_middlewares: Iterable["BrokerMiddleware[KeyValue.Entry]"], - # AsyncAPI args - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, - ) -> None: - parser = KvParser(pattern=subject) - self.kv_watch = kv_watch - - super().__init__( - subject=subject, - config=config, - extra_options=None, - ack_policy=AckPolicy.DO_NOTHING, - no_reply=True, - default_parser=parser.parse_message, - default_decoder=parser.decode_message, - broker_middlewares=broker_middlewares, - broker_dependencies=broker_dependencies, - # AsyncAPI args - description_=description_, - title_=title_, - include_in_schema=include_in_schema, - ) - - @override - async def get_one( - self, - *, - timeout: float = 5, - ) -> Optional["NatsKvMessage"]: - assert ( # nosec B101 - not self.calls - ), "You can't use `get_one` method if subscriber has registered handlers." - - if not self._fetch_sub: - bucket = await self._connection_state.kv_declarer.create_key_value( - bucket=self.kv_watch.name, - declare=self.kv_watch.declare, - ) - - fetch_sub = self._fetch_sub = UnsubscribeAdapter["KeyValue.KeyWatcher"]( - await bucket.watch( - keys=self.clear_subject, - headers_only=self.kv_watch.headers_only, - include_history=self.kv_watch.include_history, - ignore_deletes=self.kv_watch.ignore_deletes, - meta_only=self.kv_watch.meta_only, - ), - ) - else: - fetch_sub = self._fetch_sub - - raw_message = None - sleep_interval = timeout / 10 - with anyio.move_on_after(timeout): - while ( # noqa: ASYNC110 - # type: ignore[no-untyped-call] - raw_message := await fetch_sub.obj.updates(timeout) - ) is None: - await anyio.sleep(sleep_interval) - - context = self._state.get().di_state.context - - msg: NatsKvMessage = await process_msg( - msg=raw_message, - middlewares=( - m(raw_message, context=context) for m in self._broker_middlewares - ), - parser=self._parser, - decoder=self._decoder, - ) - return msg - - @override - async def _create_subscription(self) -> None: - if self.subscription: - return - - bucket = await self._connection_state.kv_declarer.create_key_value( - bucket=self.kv_watch.name, - declare=self.kv_watch.declare, - ) - - self.subscription = UnsubscribeAdapter["KeyValue.KeyWatcher"]( - await bucket.watch( - keys=self.clear_subject, - headers_only=self.kv_watch.headers_only, - include_history=self.kv_watch.include_history, - ignore_deletes=self.kv_watch.ignore_deletes, - meta_only=self.kv_watch.meta_only, - ), - ) - - self.add_task(self.__consume_watch()) - - async def __consume_watch(self) -> None: - assert self.subscription, "You should call `create_subscription` at first." # nosec B101 - - key_watcher = self.subscription.obj - - while self.running: - with suppress(ConnectionClosedError, TimeoutError): - message = cast( - Optional["KeyValue.Entry"], - # type: ignore[no-untyped-call] - await key_watcher.updates(self.kv_watch.timeout), - ) - - if message: - await self.consume(message) - - def _make_response_publisher( - self, - message: Annotated[ - "StreamMessage[KeyValue.Entry]", - Doc("Message requiring reply"), - ], - ) -> Iterable["BasePublisherProto"]: - """Create Publisher objects to use it as one of `publishers` in `self.consume` scope.""" - return () - - def get_log_context( - self, - message: Annotated[ - Optional["StreamMessage[KeyValue.Entry]"], - Doc("Message which we are building context for"), - ], - ) -> dict[str, str]: - """Log context factory using in `self.consume` scope.""" - return self.build_log_context( - message=message, - subject=self.subject, - stream=self.kv_watch.name, - ) - - -OBJECT_STORAGE_CONTEXT_KEY = "__object_storage" - - -class ObjStoreWatchSubscriber( - TasksMixin, - LogicSubscriber[ObjectInfo], -): - subscription: Optional["UnsubscribeAdapter[ObjectStore.ObjectWatcher]"] - _fetch_sub: Optional[UnsubscribeAdapter["ObjectStore.ObjectWatcher"]] - - def __init__( - self, - *, - subject: str, - config: "ConsumerConfig", - obj_watch: "ObjWatch", - broker_dependencies: Iterable["Dependant"], - broker_middlewares: Iterable["BrokerMiddleware[list[Msg]]"], - # AsyncAPI args - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, - ) -> None: - parser = ObjParser(pattern="") - - self.obj_watch = obj_watch - self.obj_watch_conn = None - - super().__init__( - subject=subject, - config=config, - extra_options=None, - ack_policy=AckPolicy.DO_NOTHING, - no_reply=True, - default_parser=parser.parse_message, - default_decoder=parser.decode_message, - broker_middlewares=broker_middlewares, - broker_dependencies=broker_dependencies, - # AsyncAPI args - description_=description_, - title_=title_, - include_in_schema=include_in_schema, - ) - - @override - async def get_one( - self, - *, - timeout: float = 5, - ) -> Optional["NatsObjMessage"]: - assert ( # nosec B101 - not self.calls - ), "You can't use `get_one` method if subscriber has registered handlers." - - if not self._fetch_sub: - self.bucket = await self._connection_state.os_declarer.create_object_store( - bucket=self.subject, - declare=self.obj_watch.declare, - ) - - obj_watch = await self.bucket.watch( - ignore_deletes=self.obj_watch.ignore_deletes, - include_history=self.obj_watch.include_history, - meta_only=self.obj_watch.meta_only, - ) - fetch_sub = self._fetch_sub = UnsubscribeAdapter[ - "ObjectStore.ObjectWatcher" - ](obj_watch) - else: - fetch_sub = self._fetch_sub - - raw_message = None - sleep_interval = timeout / 10 - with anyio.move_on_after(timeout): - while ( # noqa: ASYNC110 - # type: ignore[no-untyped-call] - raw_message := await fetch_sub.obj.updates(timeout) - ) is None: - await anyio.sleep(sleep_interval) - - context = self._state.get().di_state.context - - msg: NatsObjMessage = await process_msg( - msg=raw_message, - middlewares=( - m(raw_message, context=context) for m in self._broker_middlewares - ), - parser=self._parser, - decoder=self._decoder, - ) - return msg - - @override - async def _create_subscription(self) -> None: - if self.subscription: - return - - self.bucket = await self._connection_state.os_declarer.create_object_store( - bucket=self.subject, - declare=self.obj_watch.declare, - ) - - self.add_task(self.__consume_watch()) - - async def __consume_watch(self) -> None: - assert self.bucket, "You should call `create_subscription` at first." # nosec B101 - - # Should be created inside task to avoid nats-py lock - obj_watch = await self.bucket.watch( - ignore_deletes=self.obj_watch.ignore_deletes, - include_history=self.obj_watch.include_history, - meta_only=self.obj_watch.meta_only, - ) - - self.subscription = UnsubscribeAdapter["ObjectStore.ObjectWatcher"](obj_watch) - - context = self._state.get().di_state.context - - while self.running: - with suppress(TimeoutError): - message = cast( - Optional["ObjectInfo"], - # type: ignore[no-untyped-call] - await obj_watch.updates(self.obj_watch.timeout), - ) - - if message: - with context.scope(OBJECT_STORAGE_CONTEXT_KEY, self.bucket): - await self.consume(message) - - def _make_response_publisher( - self, - message: Annotated[ - "StreamMessage[ObjectInfo]", - Doc("Message requiring reply"), - ], - ) -> Iterable["BasePublisherProto"]: - """Create Publisher objects to use it as one of `publishers` in `self.consume` scope.""" - return () - - def get_log_context( - self, - message: Annotated[ - Optional["StreamMessage[ObjectInfo]"], - Doc("Message which we are building context for"), - ], - ) -> dict[str, str]: - """Log context factory using in `self.consume` scope.""" - return self.build_log_context( - message=message, - subject=self.subject, - ) diff --git a/faststream/nats/subscriber/usecases/basic.py b/faststream/nats/subscriber/usecases/basic.py index 3b5f30e1fc..bee03746b3 100644 --- a/faststream/nats/subscriber/usecases/basic.py +++ b/faststream/nats/subscriber/usecases/basic.py @@ -3,7 +3,6 @@ from typing import ( TYPE_CHECKING, Any, - Generic, Optional, ) @@ -46,7 +45,7 @@ from faststream.nats.helpers import KVBucketDeclarer, OSBucketDeclarer -class LogicSubscriber(SubscriberUsecase[MsgType], Generic[MsgType]): +class LogicSubscriber(SubscriberUsecase[MsgType]): """Basic class for all NATS Subscriber types (KeyValue, ObjectStorage, Core & JetStream).""" subscription: Optional[Unsubscriptable] @@ -66,10 +65,6 @@ def __init__( no_reply: bool, broker_dependencies: Iterable["Dependant"], broker_middlewares: Iterable["BrokerMiddleware[MsgType]"], - # AsyncAPI args - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, ) -> None: self.subject = subject self.config = config @@ -84,10 +79,6 @@ def __init__( no_reply=no_reply, broker_middlewares=broker_middlewares, broker_dependencies=broker_dependencies, - # AsyncAPI args - title_=title_, - description_=description_, - include_in_schema=include_in_schema, ) self._fetch_sub = None @@ -201,10 +192,6 @@ def __init__( no_reply: bool, broker_dependencies: Iterable["Dependant"], broker_middlewares: Iterable["BrokerMiddleware[MsgType]"], - # AsyncAPI args - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, ) -> None: super().__init__( subject=subject, @@ -218,10 +205,6 @@ def __init__( no_reply=no_reply, broker_middlewares=broker_middlewares, broker_dependencies=broker_dependencies, - # AsyncAPI args - description_=description_, - title_=title_, - include_in_schema=include_in_schema, ) def _make_response_publisher( diff --git a/faststream/nats/subscriber/usecases/core_subscriber.py b/faststream/nats/subscriber/usecases/core_subscriber.py index 3cff6547d2..329eb66a50 100644 --- a/faststream/nats/subscriber/usecases/core_subscriber.py +++ b/faststream/nats/subscriber/usecases/core_subscriber.py @@ -39,16 +39,16 @@ def __init__( config: "ConsumerConfig", queue: str, extra_options: Optional["AnyDict"], + ack_policy: AckPolicy, # Subscriber args no_reply: bool, broker_dependencies: Iterable["Dependant"], broker_middlewares: Iterable["BrokerMiddleware[Msg]"], - # AsyncAPI args - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, ) -> None: - parser_ = NatsParser(pattern=subject) + parser_ = NatsParser( + pattern=subject, + is_ack_disabled=ack_policy is not AckPolicy.DO_NOTHING, + ) self.queue = queue @@ -64,10 +64,6 @@ def __init__( no_reply=no_reply, broker_middlewares=broker_middlewares, broker_dependencies=broker_dependencies, - # AsyncAPI args - description_=description_, - title_=title_, - include_in_schema=include_in_schema, ) @override @@ -145,13 +141,10 @@ def __init__( queue: str, extra_options: Optional["AnyDict"], # Subscriber args + ack_policy: AckPolicy, no_reply: bool, broker_dependencies: Iterable["Dependant"], broker_middlewares: Iterable["BrokerMiddleware[Msg]"], - # AsyncAPI args - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, ) -> None: super().__init__( max_workers=max_workers, @@ -161,13 +154,10 @@ def __init__( queue=queue, extra_options=extra_options, # Propagated args + ack_policy=ack_policy, no_reply=no_reply, broker_middlewares=broker_middlewares, broker_dependencies=broker_dependencies, - # AsyncAPI args - description_=description_, - title_=title_, - include_in_schema=include_in_schema, ) @override diff --git a/faststream/nats/subscriber/usecases/key_value_subscriber.py b/faststream/nats/subscriber/usecases/key_value_subscriber.py index cf4a2a3f4e..9b3f27c494 100644 --- a/faststream/nats/subscriber/usecases/key_value_subscriber.py +++ b/faststream/nats/subscriber/usecases/key_value_subscriber.py @@ -52,10 +52,6 @@ def __init__( kv_watch: "KvWatch", broker_dependencies: Iterable["Dependant"], broker_middlewares: Iterable["BrokerMiddleware[KeyValue.Entry]"], - # AsyncAPI args - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, ) -> None: parser = KvParser(pattern=subject) self.kv_watch = kv_watch @@ -70,10 +66,6 @@ def __init__( default_decoder=parser.decode_message, broker_middlewares=broker_middlewares, broker_dependencies=broker_dependencies, - # AsyncAPI args - description_=description_, - title_=title_, - include_in_schema=include_in_schema, ) @override diff --git a/faststream/nats/subscriber/usecases/object_storage_subscriber.py b/faststream/nats/subscriber/usecases/object_storage_subscriber.py index a1d5bace48..0e6332ce3e 100644 --- a/faststream/nats/subscriber/usecases/object_storage_subscriber.py +++ b/faststream/nats/subscriber/usecases/object_storage_subscriber.py @@ -56,10 +56,6 @@ def __init__( obj_watch: "ObjWatch", broker_dependencies: Iterable["Dependant"], broker_middlewares: Iterable["BrokerMiddleware[list[Msg]]"], - # AsyncAPI args - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, ) -> None: parser = ObjParser(pattern="") @@ -76,10 +72,6 @@ def __init__( default_decoder=parser.decode_message, broker_middlewares=broker_middlewares, broker_dependencies=broker_dependencies, - # AsyncAPI args - description_=description_, - title_=title_, - include_in_schema=include_in_schema, ) @override diff --git a/faststream/nats/subscriber/usecases/stream_basic.py b/faststream/nats/subscriber/usecases/stream_basic.py index c053f2ce5e..80de14d278 100644 --- a/faststream/nats/subscriber/usecases/stream_basic.py +++ b/faststream/nats/subscriber/usecases/stream_basic.py @@ -50,10 +50,6 @@ def __init__( no_reply: bool, broker_dependencies: Iterable["Dependant"], broker_middlewares: Iterable["BrokerMiddleware[Msg]"], - # AsyncAPI args - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, ) -> None: parser_ = JsParser(pattern=subject) @@ -72,10 +68,6 @@ def __init__( no_reply=no_reply, broker_middlewares=broker_middlewares, broker_dependencies=broker_dependencies, - # AsyncAPI args - description_=description_, - title_=title_, - include_in_schema=include_in_schema, ) def get_log_context( diff --git a/faststream/nats/subscriber/usecases/stream_pull_subscriber.py b/faststream/nats/subscriber/usecases/stream_pull_subscriber.py index 44d82e89dd..7fa638eb11 100644 --- a/faststream/nats/subscriber/usecases/stream_pull_subscriber.py +++ b/faststream/nats/subscriber/usecases/stream_pull_subscriber.py @@ -58,10 +58,6 @@ def __init__( no_reply: bool, broker_dependencies: Iterable["Dependant"], broker_middlewares: Iterable["BrokerMiddleware[Msg]"], - # AsyncAPI args - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, ) -> None: self.pull_sub = pull_sub @@ -77,10 +73,6 @@ def __init__( no_reply=no_reply, broker_middlewares=broker_middlewares, broker_dependencies=broker_dependencies, - # AsyncAPI args - description_=description_, - title_=title_, - include_in_schema=include_in_schema, ) @override @@ -136,10 +128,6 @@ def __init__( no_reply: bool, broker_dependencies: Iterable["Dependant"], broker_middlewares: Iterable["BrokerMiddleware[Msg]"], - # AsyncAPI args - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, ) -> None: super().__init__( max_workers=max_workers, @@ -154,10 +142,6 @@ def __init__( no_reply=no_reply, broker_middlewares=broker_middlewares, broker_dependencies=broker_dependencies, - # AsyncAPI args - description_=description_, - title_=title_, - include_in_schema=include_in_schema, ) @override @@ -199,10 +183,6 @@ def __init__( no_reply: bool, broker_dependencies: Iterable["Dependant"], broker_middlewares: Iterable["BrokerMiddleware[list[Msg]]"], - # AsyncAPI args - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, ) -> None: parser = BatchParser(pattern=subject) @@ -221,10 +201,6 @@ def __init__( no_reply=no_reply, broker_middlewares=broker_middlewares, broker_dependencies=broker_dependencies, - # AsyncAPI args - description_=description_, - title_=title_, - include_in_schema=include_in_schema, ) @override diff --git a/faststream/nats/subscriber/usecases/stream_push_subscriber.py b/faststream/nats/subscriber/usecases/stream_push_subscriber.py index ac14ae3509..66ea31c68d 100644 --- a/faststream/nats/subscriber/usecases/stream_push_subscriber.py +++ b/faststream/nats/subscriber/usecases/stream_push_subscriber.py @@ -65,10 +65,6 @@ def __init__( no_reply: bool, broker_dependencies: Iterable["Dependant"], broker_middlewares: Iterable["BrokerMiddleware[Msg]"], - # AsyncAPI args - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, ) -> None: super().__init__( max_workers=max_workers, @@ -83,10 +79,6 @@ def __init__( no_reply=no_reply, broker_middlewares=broker_middlewares, broker_dependencies=broker_dependencies, - # AsyncAPI args - description_=description_, - title_=title_, - include_in_schema=include_in_schema, ) @override diff --git a/faststream/nats/testing.py b/faststream/nats/testing.py index 63c9fb9a22..0b0ce88302 100644 --- a/faststream/nats/testing.py +++ b/faststream/nats/testing.py @@ -86,7 +86,7 @@ class FakeProducer(NatsFastProducer): def __init__(self, broker: NatsBroker) -> None: self.broker = broker - default = NatsParser(pattern="") + default = NatsParser(pattern="", is_ack_disabled=True) self._parser = resolve_custom_func(broker._parser, default.parse_message) self._decoder = resolve_custom_func(broker._decoder, default.decode_message) diff --git a/faststream/prometheus/container.py b/faststream/prometheus/container.py index dd93701b05..ed7ee8bc5f 100644 --- a/faststream/prometheus/container.py +++ b/faststream/prometheus/container.py @@ -1,7 +1,11 @@ from collections.abc import Sequence -from typing import Optional +from typing import TYPE_CHECKING, Optional, cast -from prometheus_client import CollectorRegistry, Counter, Gauge, Histogram +from prometheus_client import Counter, Gauge, Histogram + +if TYPE_CHECKING: + from prometheus_client import CollectorRegistry + from prometheus_client.registry import Collector class MetricsContainer: @@ -44,58 +48,118 @@ def __init__( self._registry = registry self._metrics_prefix = metrics_prefix - self.received_messages_total = Counter( - name=f"{metrics_prefix}_received_messages_total", + received_messages_total_name = f"{metrics_prefix}_received_messages_total" + self.received_messages_total = cast( + Counter, self._get_registered_metric(received_messages_total_name) + ) or Counter( + name=received_messages_total_name, documentation="Count of received messages by broker and handler", labelnames=["app_name", "broker", "handler"], registry=registry, ) - self.received_messages_size_bytes = Histogram( - name=f"{metrics_prefix}_received_messages_size_bytes", + + received_messages_size_bytes_name = ( + f"{metrics_prefix}_received_messages_size_bytes" + ) + self.received_messages_size_bytes = cast( + Histogram, self._get_registered_metric(received_messages_size_bytes_name) + ) or Histogram( + name=received_messages_size_bytes_name, documentation="Histogram of received messages size in bytes by broker and handler", labelnames=["app_name", "broker", "handler"], registry=registry, buckets=received_messages_size_buckets or self.DEFAULT_SIZE_BUCKETS, ) - self.received_messages_in_process = Gauge( - name=f"{metrics_prefix}_received_messages_in_process", + + received_messages_in_process_name = ( + f"{metrics_prefix}_received_messages_in_process" + ) + self.received_messages_in_process = cast( + Gauge, self._get_registered_metric(received_messages_in_process_name) + ) or Gauge( + name=received_messages_in_process_name, documentation="Gauge of received messages in process by broker and handler", labelnames=["app_name", "broker", "handler"], registry=registry, ) - self.received_processed_messages_total = Counter( - name=f"{metrics_prefix}_received_processed_messages_total", + + received_processed_messages_total_name = ( + f"{metrics_prefix}_received_processed_messages_total" + ) + self.received_processed_messages_total = cast( + Counter, self._get_registered_metric(received_processed_messages_total_name) + ) or Counter( + name=received_processed_messages_total_name, documentation="Count of received processed messages by broker, handler and status", labelnames=["app_name", "broker", "handler", "status"], registry=registry, ) - self.received_processed_messages_duration_seconds = Histogram( - name=f"{metrics_prefix}_received_processed_messages_duration_seconds", + + received_processed_messages_duration_seconds_name = ( + f"{metrics_prefix}_received_processed_messages_duration_seconds" + ) + self.received_processed_messages_duration_seconds = cast( + Histogram, + self._get_registered_metric( + received_processed_messages_duration_seconds_name + ), + ) or Histogram( + name=received_processed_messages_duration_seconds_name, documentation="Histogram of received processed messages duration in seconds by broker and handler", labelnames=["app_name", "broker", "handler"], registry=registry, ) - self.received_processed_messages_exceptions_total = Counter( - name=f"{metrics_prefix}_received_processed_messages_exceptions_total", + + received_processed_messages_exceptions_total_name = ( + f"{metrics_prefix}_received_processed_messages_exceptions_total" + ) + self.received_processed_messages_exceptions_total = cast( + Counter, + self._get_registered_metric( + received_processed_messages_exceptions_total_name + ), + ) or Counter( + name=received_processed_messages_exceptions_total_name, documentation="Count of received processed messages exceptions by broker, handler and exception_type", labelnames=["app_name", "broker", "handler", "exception_type"], registry=registry, ) - self.published_messages_total = Counter( - name=f"{metrics_prefix}_published_messages_total", + + published_messages_total_name = f"{metrics_prefix}_published_messages_total" + self.published_messages_total = cast( + Counter, self._get_registered_metric(published_messages_total_name) + ) or Counter( + name=published_messages_total_name, documentation="Count of published messages by destination and status", labelnames=["app_name", "broker", "destination", "status"], registry=registry, ) - self.published_messages_duration_seconds = Histogram( - name=f"{metrics_prefix}_published_messages_duration_seconds", + + published_messages_duration_seconds_name = ( + f"{metrics_prefix}_published_messages_duration_seconds" + ) + self.published_messages_duration_seconds = cast( + Histogram, + self._get_registered_metric(published_messages_duration_seconds_name), + ) or Histogram( + name=published_messages_duration_seconds_name, documentation="Histogram of published messages duration in seconds by broker and destination", labelnames=["app_name", "broker", "destination"], registry=registry, ) - self.published_messages_exceptions_total = Counter( - name=f"{metrics_prefix}_published_messages_exceptions_total", + + published_messages_exceptions_total_name = ( + f"{metrics_prefix}_published_messages_exceptions_total" + ) + self.published_messages_exceptions_total = cast( + Counter, + self._get_registered_metric(published_messages_exceptions_total_name), + ) or Counter( + name=published_messages_exceptions_total_name, documentation="Count of published messages exceptions by broker, destination and exception_type", labelnames=["app_name", "broker", "destination", "exception_type"], registry=registry, ) + + def _get_registered_metric(self, metric_name: str) -> Optional["Collector"]: + return self._registry._names_to_collectors.get(metric_name) diff --git a/faststream/prometheus/middleware.py b/faststream/prometheus/middleware.py index d61dc42b0c..18462f2de8 100644 --- a/faststream/prometheus/middleware.py +++ b/faststream/prometheus/middleware.py @@ -4,6 +4,7 @@ from faststream import BaseMiddleware from faststream._internal.constants import EMPTY +from faststream.exceptions import IgnoredException from faststream.message import SourceType from faststream.prometheus.consts import ( PROCESSING_STATUS_BY_ACK_STATUS, @@ -58,8 +59,8 @@ def __call__( /, *, context: "ContextRepo", - ) -> "_PrometheusMiddleware": - return _PrometheusMiddleware( + ) -> "BasePrometheusMiddleware": + return BasePrometheusMiddleware( msg, metrics_manager=self._metrics_manager, settings_provider_factory=self._settings_provider_factory, @@ -67,7 +68,7 @@ def __call__( ) -class _PrometheusMiddleware(BaseMiddleware): +class BasePrometheusMiddleware(BaseMiddleware): def __init__( self, msg: Optional[Any], @@ -121,11 +122,13 @@ async def consume_scope( except Exception as e: err = e - self._metrics_manager.add_received_processed_message_exception( - exception_type=type(err).__name__, - broker=messaging_system, - handler=destination_name, - ) + + if not isinstance(err, IgnoredException): + self._metrics_manager.add_received_processed_message_exception( + exception_type=type(err).__name__, + broker=messaging_system, + handler=destination_name, + ) raise finally: diff --git a/faststream/rabbit/broker/broker.py b/faststream/rabbit/broker/broker.py index b0e60e62cf..722b12ff1a 100644 --- a/faststream/rabbit/broker/broker.py +++ b/faststream/rabbit/broker/broker.py @@ -59,7 +59,7 @@ from faststream.rabbit.message import RabbitMessage from faststream.rabbit.types import AioPikaSendableMessage from faststream.security import BaseSecurity - from faststream.specification.schema.tag import Tag, TagDict + from faststream.specification.schema.extra import Tag, TagDict class RabbitBroker( @@ -196,9 +196,9 @@ def __init__( Doc("AsyncAPI server description."), ] = None, tags: Annotated[ - Optional[Iterable[Union["Tag", "TagDict"]]], + Iterable[Union["Tag", "TagDict"]], Doc("AsyncAPI server tags."), - ] = None, + ] = (), # logging args logger: Annotated[ Optional["LoggerProto"], @@ -807,7 +807,7 @@ async def ping(self, timeout: Optional[float]) -> bool: if cancel_scope.cancel_called: return False - if not self._connection.is_closed: + if self._connection.connected.is_set(): return True await anyio.sleep(sleep_time) diff --git a/faststream/rabbit/broker/registrator.py b/faststream/rabbit/broker/registrator.py index 37ca0066f7..0ce98da45e 100644 --- a/faststream/rabbit/broker/registrator.py +++ b/faststream/rabbit/broker/registrator.py @@ -1,7 +1,7 @@ from collections.abc import Iterable from typing import TYPE_CHECKING, Annotated, Any, Optional, Union, cast -from typing_extensions import Doc, override +from typing_extensions import Doc, deprecated, override from faststream._internal.broker.abc_broker import ABCBroker from faststream._internal.constants import EMPTY @@ -59,10 +59,15 @@ def subscriber( # type: ignore[override] Optional["AnyDict"], Doc("Extra consumer arguments to use in `queue.consume(...)` method."), ] = None, - ack_policy: Annotated[ - AckPolicy, + no_ack: Annotated[ + bool, Doc("Whether to disable **FastStream** auto acknowledgement logic or not."), + deprecated( + "This option was deprecated in 0.6.0 to prior to **ack_policy=AckPolicy.DO_NOTHING**. " + "Scheduled to remove in 0.7.0" + ), ] = EMPTY, + ack_policy: AckPolicy = EMPTY, # broker arguments dependencies: Annotated[ Iterable["Dependant"], @@ -112,6 +117,7 @@ def subscriber( # type: ignore[override] consume_args=consume_args, # subscriber args ack_policy=ack_policy, + no_ack=no_ack, no_reply=no_reply, broker_middlewares=self.middlewares, broker_dependencies=self._dependencies, diff --git a/faststream/rabbit/fastapi/fastapi.py b/faststream/rabbit/fastapi/fastapi.py index 6e32718447..a53a93e5fa 100644 --- a/faststream/rabbit/fastapi/fastapi.py +++ b/faststream/rabbit/fastapi/fastapi.py @@ -50,7 +50,7 @@ from faststream.rabbit.message import RabbitMessage from faststream.rabbit.publisher.specified import SpecificationPublisher from faststream.security import BaseSecurity - from faststream.specification.schema.tag import Tag, TagDict + from faststream.specification.schema.extra import Tag, TagDict class RabbitRouter(StreamRouter["IncomingMessage"]): @@ -176,9 +176,9 @@ def __init__( Doc("AsyncAPI server description."), ] = None, specification_tags: Annotated[ - Optional[Iterable[Union["Tag", "TagDict"]]], + Iterable[Union["Tag", "TagDict"]], Doc("AsyncAPI server tags."), - ] = None, + ] = (), # logging args logger: Annotated[ Optional["LoggerProto"], @@ -512,10 +512,15 @@ def subscriber( # type: ignore[override] Iterable["SubscriberMiddleware[RabbitMessage]"], Doc("Subscriber middlewares to wrap incoming message processing."), ] = (), - ack_policy: Annotated[ - AckPolicy, + no_ack: Annotated[ + bool, Doc("Whether to disable **FastStream** auto acknowledgement logic or not."), + deprecated( + "This option was deprecated in 0.6.0 to prior to **ack_policy=AckPolicy.DO_NOTHING**. " + "Scheduled to remove in 0.7.0" + ), ] = EMPTY, + ack_policy: AckPolicy = EMPTY, no_reply: Annotated[ bool, Doc( @@ -673,6 +678,7 @@ def subscriber( # type: ignore[override] decoder=decoder, middlewares=middlewares, ack_policy=ack_policy, + no_ack=no_ack, no_reply=no_reply, title=title, description=description, diff --git a/faststream/rabbit/publisher/specified.py b/faststream/rabbit/publisher/specified.py index 6d16769926..e8da19b3fd 100644 --- a/faststream/rabbit/publisher/specified.py +++ b/faststream/rabbit/publisher/specified.py @@ -1,34 +1,77 @@ +from collections.abc import Iterable +from typing import ( + TYPE_CHECKING, + Any, + Optional, +) + +from faststream._internal.publisher.specified import ( + SpecificationPublisher as SpecificationPublisherMixin, +) +from faststream.rabbit.schemas.proto import BaseRMQInformation as RMQSpecificationMixin from faststream.rabbit.utils import is_routing_exchange from faststream.specification.asyncapi.utils import resolve_payloads +from faststream.specification.schema import Message, Operation, PublisherSpec from faststream.specification.schema.bindings import ( ChannelBinding, OperationBinding, amqp, ) -from faststream.specification.schema.channel import Channel -from faststream.specification.schema.message import CorrelationId, Message -from faststream.specification.schema.operation import Operation -from .usecase import LogicPublisher +from .usecase import LogicPublisher, PublishKwargs +if TYPE_CHECKING: + from aio_pika import IncomingMessage -class SpecificationPublisher(LogicPublisher): - """AsyncAPI-compatible Rabbit Publisher class. + from faststream._internal.types import BrokerMiddleware, PublisherMiddleware + from faststream.rabbit.schemas import RabbitExchange, RabbitQueue - Creating by - ```python - publisher: SpecificationPublisher = ( - broker.publisher(...) - ) - # or - publisher: SpecificationPublisher = ( - router.publisher(...) - ) - ``` - """ +class SpecificationPublisher( + SpecificationPublisherMixin, + RMQSpecificationMixin, + LogicPublisher, +): + """AsyncAPI-compatible Rabbit Publisher class.""" - def get_name(self) -> str: + def __init__( + self, + *, + routing_key: str, + queue: "RabbitQueue", + exchange: "RabbitExchange", + # PublishCommand options + message_kwargs: "PublishKwargs", + # Publisher args + broker_middlewares: Iterable["BrokerMiddleware[IncomingMessage]"], + middlewares: Iterable["PublisherMiddleware"], + # AsyncAPI args + schema_: Optional[Any], + title_: Optional[str], + description_: Optional[str], + include_in_schema: bool, + ) -> None: + super().__init__( + title_=title_, + description_=description_, + include_in_schema=include_in_schema, + schema_=schema_, + # propagate to RMQSpecificationMixin + queue=queue, + exchange=exchange, + ) + + LogicPublisher.__init__( + self, + queue=queue, + exchange=exchange, + routing_key=routing_key, + message_kwargs=message_kwargs, + middlewares=middlewares, + broker_middlewares=broker_middlewares, + ) + + def get_default_name(self) -> str: routing = ( self.routing_key or (self.queue.routing if is_routing_exchange(self.exchange) else None) @@ -37,26 +80,28 @@ def get_name(self) -> str: return f"{routing}:{getattr(self.exchange, 'name', None) or '_'}:Publisher" - def get_schema(self) -> dict[str, Channel]: + def get_schema(self) -> dict[str, PublisherSpec]: payloads = self.get_payloads() + exchange_binding = amqp.Exchange.from_exchange(self.exchange) + queue_binding = amqp.Queue.from_queue(self.queue) + return { - self.name: Channel( + self.name: PublisherSpec( description=self.description, - publish=Operation( + operation=Operation( bindings=OperationBinding( amqp=amqp.OperationBinding( - cc=self.routing or None, - deliveryMode=2 - if self.message_options.get("persist") - else 1, - replyTo=self.message_options.get("reply_to"), # type: ignore[arg-type] - mandatory=self.publish_options.get("mandatory"), # type: ignore[arg-type] - priority=self.message_options.get("priority"), # type: ignore[arg-type] + routing_key=self.routing or None, + queue=queue_binding, + exchange=exchange_binding, + ack=True, + persist=self.message_options.get("persist"), + priority=self.message_options.get("priority"), + reply_to=self.message_options.get("reply_to"), + mandatory=self.publish_options.get("mandatory"), ), - ) - if is_routing_exchange(self.exchange) - else None, + ), message=Message( title=f"{self.name}:Message", payload=resolve_payloads( @@ -64,34 +109,13 @@ def get_schema(self) -> dict[str, Channel]: "Publisher", served_words=2 if self.title_ is None else 1, ), - correlationId=CorrelationId( - location="$message.header#/correlation_id", - ), ), ), bindings=ChannelBinding( amqp=amqp.ChannelBinding( - is_="routingKey", - queue=amqp.Queue( - name=self.queue.name, - durable=self.queue.durable, - exclusive=self.queue.exclusive, - autoDelete=self.queue.auto_delete, - vhost=self.virtual_host, - ) - if is_routing_exchange(self.exchange) and self.queue.name - else None, - exchange=( - amqp.Exchange(type="default", vhost=self.virtual_host) - if not self.exchange.name - else amqp.Exchange( - type=self.exchange.type.value, - name=self.exchange.name, - durable=self.exchange.durable, - autoDelete=self.exchange.auto_delete, - vhost=self.virtual_host, - ) - ), + virtual_host=self.virtual_host, + queue=queue_binding, + exchange=exchange_binding, ), ), ), diff --git a/faststream/rabbit/publisher/usecase.py b/faststream/rabbit/publisher/usecase.py index f34cf06b3c..0ae13cf319 100644 --- a/faststream/rabbit/publisher/usecase.py +++ b/faststream/rabbit/publisher/usecase.py @@ -3,7 +3,6 @@ from typing import ( TYPE_CHECKING, Annotated, - Any, Optional, Union, ) @@ -15,7 +14,7 @@ from faststream._internal.utils.data import filter_by_dict from faststream.message import gen_cor_id from faststream.rabbit.response import RabbitPublishCommand -from faststream.rabbit.schemas import BaseRMQInformation, RabbitExchange, RabbitQueue +from faststream.rabbit.schemas import RabbitExchange, RabbitQueue from faststream.response.publish_type import PublishType from .options import MessageOptions, PublishOptions @@ -47,10 +46,7 @@ class PublishKwargs(MessageOptions, PublishOptions, total=False): ] -class LogicPublisher( - PublisherUsecase[IncomingMessage], - BaseRMQInformation, -): +class LogicPublisher(PublisherUsecase[IncomingMessage]): """A class to represent a RabbitMQ publisher.""" app_id: Optional[str] @@ -68,53 +64,38 @@ def __init__( # Publisher args broker_middlewares: Iterable["BrokerMiddleware[IncomingMessage]"], middlewares: Iterable["PublisherMiddleware"], - # AsyncAPI args - schema_: Optional[Any], - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, ) -> None: + self.queue = queue + self.routing_key = routing_key + + self.exchange = exchange + super().__init__( broker_middlewares=broker_middlewares, middlewares=middlewares, - # AsyncAPI args - schema_=schema_, - title_=title_, - description_=description_, - include_in_schema=include_in_schema, ) - self.routing_key = routing_key - request_options = dict(message_kwargs) self.headers = request_options.pop("headers") or {} self.reply_to = request_options.pop("reply_to", "") self.timeout = request_options.pop("timeout", None) - self.message_options = filter_by_dict(MessageOptions, request_options) - self.publish_options = filter_by_dict(PublishOptions, request_options) - # BaseRMQInformation - self.queue = queue - self.exchange = exchange + message_options, _ = filter_by_dict(MessageOptions, request_options) + self.message_options = message_options + + publish_options, _ = filter_by_dict(PublishOptions, request_options) + self.publish_options = publish_options - # Setup it later self.app_id = None - self.virtual_host = "" @override def _setup( # type: ignore[override] self, *, - app_id: Optional[str], - virtual_host: str, state: "BrokerState", ) -> None: - if app_id: - self.message_options["app_id"] = app_id - self.app_id = app_id - - self.virtual_host = virtual_host - + # AppId was set in `faststream.rabbit.schemas.proto.BaseRMQInformation` + self.message_options["app_id"] = self.app_id super()._setup(state=state) @property diff --git a/faststream/rabbit/router.py b/faststream/rabbit/router.py index 8eaec725ef..862ea5bcb0 100644 --- a/faststream/rabbit/router.py +++ b/faststream/rabbit/router.py @@ -1,7 +1,7 @@ from collections.abc import Awaitable, Iterable from typing import TYPE_CHECKING, Annotated, Any, Callable, Optional, Union -from typing_extensions import Doc +from typing_extensions import Doc, deprecated from faststream._internal.broker.router import ( ArgsContainer, @@ -231,10 +231,15 @@ def __init__( Iterable["SubscriberMiddleware[RabbitMessage]"], Doc("Subscriber middlewares to wrap incoming message processing."), ] = (), - ack_policy: Annotated[ - AckPolicy, + no_ack: Annotated[ + bool, Doc("Whether to disable **FastStream** auto acknowledgement logic or not."), + deprecated( + "This option was deprecated in 0.6.0 to prior to **ack_policy=AckPolicy.DO_NOTHING**. " + "Scheduled to remove in 0.7.0" + ), ] = EMPTY, + ack_policy: AckPolicy = EMPTY, no_reply: Annotated[ bool, Doc( @@ -269,6 +274,7 @@ def __init__( decoder=decoder, middlewares=middlewares, ack_policy=ack_policy, + no_ack=no_ack, no_reply=no_reply, title=title, description=description, diff --git a/faststream/rabbit/schemas/proto.py b/faststream/rabbit/schemas/proto.py index 41045b94fa..2109772124 100644 --- a/faststream/rabbit/schemas/proto.py +++ b/faststream/rabbit/schemas/proto.py @@ -1,13 +1,40 @@ -from typing import Optional, Protocol +from typing import TYPE_CHECKING, Any, Optional -from faststream.rabbit.schemas.exchange import RabbitExchange -from faststream.rabbit.schemas.queue import RabbitQueue +if TYPE_CHECKING: + from faststream.rabbit.schemas.exchange import RabbitExchange + from faststream.rabbit.schemas.queue import RabbitQueue -class BaseRMQInformation(Protocol): +class BaseRMQInformation: """Base class to store Specification RMQ bindings.""" virtual_host: str - queue: RabbitQueue - exchange: RabbitExchange + queue: "RabbitQueue" + exchange: "RabbitExchange" app_id: Optional[str] + + def __init__( + self, + *, + queue: "RabbitQueue", + exchange: "RabbitExchange", + ) -> None: + self.queue = queue + self.exchange = exchange + + # Setup it later + self.app_id = None + self.virtual_host = "" + + def _setup( + self, + *, + app_id: Optional[str], + virtual_host: str, + **kwargs: Any, + ) -> None: + self.app_id = app_id + self.virtual_host = virtual_host + + # Setup next parent class + super()._setup(**kwargs) diff --git a/faststream/rabbit/subscriber/factory.py b/faststream/rabbit/subscriber/factory.py index 8a4475ec58..df49821c6c 100644 --- a/faststream/rabbit/subscriber/factory.py +++ b/faststream/rabbit/subscriber/factory.py @@ -1,7 +1,9 @@ +import warnings from collections.abc import Iterable from typing import TYPE_CHECKING, Optional from faststream._internal.constants import EMPTY +from faststream.exceptions import SetupError from faststream.middlewares import AckPolicy from faststream.rabbit.subscriber.specified import SpecificationSubscriber @@ -24,15 +26,16 @@ def create_subscriber( broker_dependencies: Iterable["Dependant"], broker_middlewares: Iterable["BrokerMiddleware[IncomingMessage]"], ack_policy: "AckPolicy", + no_ack: bool, # AsyncAPI args title_: Optional[str], description_: Optional[str], include_in_schema: bool, ) -> SpecificationSubscriber: - _validate_input_for_misconfigure() + _validate_input_for_misconfigure(ack_policy=ack_policy, no_ack=no_ack) if ack_policy is EMPTY: - ack_policy = AckPolicy.REJECT_ON_ERROR + ack_policy = AckPolicy.DO_NOTHING if no_ack else AckPolicy.REJECT_ON_ERROR return SpecificationSubscriber( queue=queue, @@ -48,5 +51,18 @@ def create_subscriber( ) -def _validate_input_for_misconfigure() -> None: - """Nothing to check yet.""" +def _validate_input_for_misconfigure( + *, + ack_policy: "AckPolicy", + no_ack: bool, +) -> None: + if no_ack is not EMPTY: + warnings.warn( + "`no_ack` option was deprecated in prior to `ack_policy=AckPolicy.DO_NOTHING`. Scheduled to remove in 0.7.0", + category=DeprecationWarning, + stacklevel=4, + ) + + if ack_policy is not EMPTY: + msg = "You can't use deprecated `no_ack` and `ack_policy` simultaneously. Please, use `ack_policy` only." + raise SetupError(msg) diff --git a/faststream/rabbit/subscriber/specified.py b/faststream/rabbit/subscriber/specified.py index 275f509d2b..b071ea2828 100644 --- a/faststream/rabbit/subscriber/specified.py +++ b/faststream/rabbit/subscriber/specified.py @@ -1,67 +1,107 @@ +from collections.abc import Iterable +from typing import TYPE_CHECKING, Optional + +from faststream._internal.subscriber.specified import ( + SpecificationSubscriber as SpecificationSubscriberMixin, +) +from faststream.rabbit.schemas.proto import BaseRMQInformation as RMQSpecificationMixin from faststream.rabbit.subscriber.usecase import LogicSubscriber -from faststream.rabbit.utils import is_routing_exchange from faststream.specification.asyncapi.utils import resolve_payloads +from faststream.specification.schema import Message, Operation, SubscriberSpec from faststream.specification.schema.bindings import ( ChannelBinding, OperationBinding, amqp, ) -from faststream.specification.schema.channel import Channel -from faststream.specification.schema.message import CorrelationId, Message -from faststream.specification.schema.operation import Operation +if TYPE_CHECKING: + from aio_pika import IncomingMessage + from fast_depends.dependencies import Dependant + + from faststream._internal.basic_types import AnyDict + from faststream._internal.types import BrokerMiddleware + from faststream.middlewares import AckPolicy + from faststream.rabbit.schemas.exchange import RabbitExchange + from faststream.rabbit.schemas.queue import RabbitQueue -class SpecificationSubscriber(LogicSubscriber): + +class SpecificationSubscriber( + SpecificationSubscriberMixin, + RMQSpecificationMixin, + LogicSubscriber, +): """AsyncAPI-compatible Rabbit Subscriber class.""" - def get_name(self) -> str: + def __init__( + self, + *, + queue: "RabbitQueue", + exchange: "RabbitExchange", + consume_args: Optional["AnyDict"], + # Subscriber args + ack_policy: "AckPolicy", + no_reply: bool, + broker_dependencies: Iterable["Dependant"], + broker_middlewares: Iterable["BrokerMiddleware[IncomingMessage]"], + # AsyncAPI args + title_: Optional[str], + description_: Optional[str], + include_in_schema: bool, + ) -> None: + super().__init__( + title_=title_, + description_=description_, + include_in_schema=include_in_schema, + # propagate to RMQSpecificationMixin + queue=queue, + exchange=exchange, + ) + + LogicSubscriber.__init__( + self, + queue=queue, + consume_args=consume_args, + ack_policy=ack_policy, + no_reply=no_reply, + broker_dependencies=broker_dependencies, + broker_middlewares=broker_middlewares, + ) + + def get_default_name(self) -> str: return f"{self.queue.name}:{getattr(self.exchange, 'name', None) or '_'}:{self.call_name}" - def get_schema(self) -> dict[str, Channel]: + def get_schema(self) -> dict[str, SubscriberSpec]: payloads = self.get_payloads() + exchange_binding = amqp.Exchange.from_exchange(self.exchange) + queue_binding = amqp.Queue.from_queue(self.queue) + return { - self.name: Channel( + self.name: SubscriberSpec( description=self.description, - subscribe=Operation( + operation=Operation( bindings=OperationBinding( amqp=amqp.OperationBinding( - cc=self.queue.routing, + routing_key=self.queue.routing, + queue=queue_binding, + exchange=exchange_binding, + ack=True, + reply_to=None, + persist=None, + mandatory=None, + priority=None, ), - ) - if is_routing_exchange(self.exchange) - else None, + ), message=Message( title=f"{self.name}:Message", payload=resolve_payloads(payloads), - correlationId=CorrelationId( - location="$message.header#/correlation_id", - ), ), ), bindings=ChannelBinding( amqp=amqp.ChannelBinding( - is_="routingKey", - queue=amqp.Queue( - name=self.queue.name, - durable=self.queue.durable, - exclusive=self.queue.exclusive, - autoDelete=self.queue.auto_delete, - vhost=self.virtual_host, - ) - if is_routing_exchange(self.exchange) and self.queue.name - else None, - exchange=( - amqp.Exchange(type="default", vhost=self.virtual_host) - if not self.exchange.name - else amqp.Exchange( - type=self.exchange.type.value, - name=self.exchange.name, - durable=self.exchange.durable, - autoDelete=self.exchange.auto_delete, - vhost=self.virtual_host, - ) - ), + virtual_host=self.virtual_host, + queue=queue_binding, + exchange=exchange_binding, ), ), ), diff --git a/faststream/rabbit/subscriber/usecase.py b/faststream/rabbit/subscriber/usecase.py index e91333ed25..df229a5cc4 100644 --- a/faststream/rabbit/subscriber/usecase.py +++ b/faststream/rabbit/subscriber/usecase.py @@ -15,7 +15,6 @@ from faststream.exceptions import SetupError from faststream.rabbit.parser import AioPikaParser from faststream.rabbit.publisher.fake import RabbitFakePublisher -from faststream.rabbit.schemas import BaseRMQInformation if TYPE_CHECKING: from aio_pika import IncomingMessage, RobustQueue @@ -36,10 +35,7 @@ ) -class LogicSubscriber( - SubscriberUsecase["IncomingMessage"], - BaseRMQInformation, -): +class LogicSubscriber(SubscriberUsecase["IncomingMessage"]): """A class to handle logic for RabbitMQ message consumption.""" app_id: Optional[str] @@ -53,18 +49,15 @@ def __init__( self, *, queue: "RabbitQueue", - exchange: "RabbitExchange", consume_args: Optional["AnyDict"], # Subscriber args ack_policy: "AckPolicy", no_reply: bool, broker_dependencies: Iterable["Dependant"], broker_middlewares: Iterable["BrokerMiddleware[IncomingMessage]"], - # AsyncAPI args - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, ) -> None: + self.queue = queue + parser = AioPikaParser(pattern=queue.path_regex) super().__init__( @@ -75,10 +68,6 @@ def __init__( no_reply=no_reply, broker_middlewares=broker_middlewares, broker_dependencies=broker_dependencies, - # AsyncAPI - title_=title_, - description_=description_, - include_in_schema=include_in_schema, ) self.consume_args = consume_args or {} @@ -86,20 +75,13 @@ def __init__( self._consumer_tag = None self._queue_obj = None - # BaseRMQInformation - self.queue = queue - self.exchange = exchange # Setup it later - self.app_id = None - self.virtual_host = "" self.declarer = None @override def _setup( # type: ignore[override] self, *, - app_id: Optional[str], - virtual_host: str, declarer: "RabbitDeclarer", # basic args extra_context: "AnyDict", @@ -109,8 +91,6 @@ def _setup( # type: ignore[override] # dependant args state: "BrokerState", ) -> None: - self.app_id = app_id - self.virtual_host = virtual_host self.declarer = declarer super()._setup( diff --git a/faststream/redis/broker/broker.py b/faststream/redis/broker/broker.py index 8b5a0a0017..bae4bd9de3 100644 --- a/faststream/redis/broker/broker.py +++ b/faststream/redis/broker/broker.py @@ -56,7 +56,7 @@ ) from faststream.redis.message import BaseMessage, RedisMessage from faststream.security import BaseSecurity - from faststream.specification.schema.tag import Tag, TagDict + from faststream.specification.schema.extra import Tag, TagDict class RedisInitKwargs(TypedDict, total=False): host: Optional[str] @@ -162,9 +162,9 @@ def __init__( Doc("AsyncAPI server description."), ] = None, tags: Annotated[ - Optional[Iterable[Union["Tag", "TagDict"]]], + Iterable[Union["Tag", "TagDict"]], Doc("AsyncAPI server tags."), - ] = None, + ] = (), # logging args logger: Annotated[ Optional["LoggerProto"], diff --git a/faststream/redis/broker/registrator.py b/faststream/redis/broker/registrator.py index 10cf4afe98..50885af7fa 100644 --- a/faststream/redis/broker/registrator.py +++ b/faststream/redis/broker/registrator.py @@ -1,7 +1,7 @@ from collections.abc import Iterable from typing import TYPE_CHECKING, Annotated, Any, Optional, Union, cast -from typing_extensions import Doc, override +from typing_extensions import Doc, deprecated, override from faststream._internal.broker.abc_broker import ABCBroker from faststream._internal.constants import EMPTY @@ -67,10 +67,15 @@ def subscriber( # type: ignore[override] Iterable["SubscriberMiddleware[UnifyRedisMessage]"], Doc("Subscriber middlewares to wrap incoming message processing."), ] = (), - ack_policy: Annotated[ - AckPolicy, + no_ack: Annotated[ + bool, Doc("Whether to disable **FastStream** auto acknowledgement logic or not."), + deprecated( + "This option was deprecated in 0.6.0 to prior to **ack_policy=AckPolicy.DO_NOTHING**. " + "Scheduled to remove in 0.7.0" + ), ] = EMPTY, + ack_policy: AckPolicy = EMPTY, no_reply: Annotated[ bool, Doc( @@ -103,6 +108,7 @@ def subscriber( # type: ignore[override] stream=stream, # subscriber args ack_policy=ack_policy, + no_ack=no_ack, no_reply=no_reply, broker_middlewares=self.middlewares, broker_dependencies=self._dependencies, diff --git a/faststream/redis/fastapi/fastapi.py b/faststream/redis/fastapi/fastapi.py index c64968aac8..f71f329276 100644 --- a/faststream/redis/fastapi/fastapi.py +++ b/faststream/redis/fastapi/fastapi.py @@ -50,7 +50,7 @@ from faststream.redis.message import UnifyRedisMessage from faststream.redis.publisher.specified import SpecificationPublisher from faststream.security import BaseSecurity - from faststream.specification.schema.tag import Tag, TagDict + from faststream.specification.schema.extra import Tag, TagDict class RedisRouter(StreamRouter[UnifyRedisDict]): @@ -125,12 +125,12 @@ def __init__( Doc("AsyncAPI server description."), ] = None, specification_tags: Annotated[ - Optional[Iterable[Union["Tag", "TagDict"]]], + Iterable[Union["Tag", "TagDict"]], Doc("AsyncAPI server tags."), - ] = None, + ] = (), # logging args logger: Annotated[ - Union["LoggerProto", None, object], + Optional["LoggerProto"], Doc("User specified logger to pass into Context and log service messages."), ] = EMPTY, log_level: Annotated[ @@ -462,10 +462,15 @@ def subscriber( # type: ignore[override] Iterable["SubscriberMiddleware[UnifyRedisMessage]"], Doc("Subscriber middlewares to wrap incoming message processing."), ] = (), - ack_policy: Annotated[ - AckPolicy, + no_ack: Annotated[ + bool, Doc("Whether to disable **FastStream** auto acknowledgement logic or not."), + deprecated( + "This option was deprecated in 0.6.0 to prior to **ack_policy=AckPolicy.DO_NOTHING**. " + "Scheduled to remove in 0.7.0" + ), ] = EMPTY, + ack_policy: AckPolicy = EMPTY, no_reply: Annotated[ bool, Doc( @@ -623,6 +628,7 @@ def subscriber( # type: ignore[override] decoder=decoder, middlewares=middlewares, ack_policy=ack_policy, + no_ack=no_ack, no_reply=no_reply, title=title, description=description, diff --git a/faststream/redis/publisher/specified.py b/faststream/redis/publisher/specified.py index f0598834c6..3ccd57c931 100644 --- a/faststream/redis/publisher/specified.py +++ b/faststream/redis/publisher/specified.py @@ -1,40 +1,41 @@ from typing import TYPE_CHECKING +from faststream._internal.publisher.specified import ( + SpecificationPublisher as SpecificationPublisherMixin, +) from faststream.redis.publisher.usecase import ( ChannelPublisher, ListBatchPublisher, ListPublisher, - LogicPublisher, StreamPublisher, ) from faststream.redis.schemas.proto import RedisSpecificationProtocol from faststream.specification.asyncapi.utils import resolve_payloads +from faststream.specification.schema import Message, Operation, PublisherSpec from faststream.specification.schema.bindings import ChannelBinding, redis -from faststream.specification.schema.channel import Channel -from faststream.specification.schema.message import CorrelationId, Message -from faststream.specification.schema.operation import Operation if TYPE_CHECKING: from faststream.redis.schemas import ListSub -class SpecificationPublisher(LogicPublisher, RedisSpecificationProtocol): +class SpecificationPublisher( + SpecificationPublisherMixin, + RedisSpecificationProtocol[PublisherSpec], +): """A class to represent a Redis publisher.""" - def get_schema(self) -> dict[str, Channel]: + def get_schema(self) -> dict[str, PublisherSpec]: payloads = self.get_payloads() return { - self.name: Channel( + self.name: PublisherSpec( description=self.description, - publish=Operation( + operation=Operation( message=Message( title=f"{self.name}:Message", payload=resolve_payloads(payloads, "Publisher"), - correlationId=CorrelationId( - location="$message.header#/correlation_id", - ), ), + bindings=None, ), bindings=ChannelBinding( redis=self.channel_binding, @@ -43,8 +44,8 @@ def get_schema(self) -> dict[str, Channel]: } -class SpecificationChannelPublisher(ChannelPublisher, SpecificationPublisher): - def get_name(self) -> str: +class SpecificationChannelPublisher(SpecificationPublisher, ChannelPublisher): + def get_default_name(self) -> str: return f"{self.channel.name}:Publisher" @property @@ -58,7 +59,7 @@ def channel_binding(self) -> "redis.ChannelBinding": class _ListPublisherMixin(SpecificationPublisher): list: "ListSub" - def get_name(self) -> str: + def get_default_name(self) -> str: return f"{self.list.name}:Publisher" @property @@ -69,16 +70,16 @@ def channel_binding(self) -> "redis.ChannelBinding": ) -class SpecificationListPublisher(ListPublisher, _ListPublisherMixin): +class SpecificationListPublisher(_ListPublisherMixin, ListPublisher): pass -class SpecificationListBatchPublisher(ListBatchPublisher, _ListPublisherMixin): +class SpecificationListBatchPublisher(_ListPublisherMixin, ListBatchPublisher): pass -class SpecificationStreamPublisher(StreamPublisher, SpecificationPublisher): - def get_name(self) -> str: +class SpecificationStreamPublisher(SpecificationPublisher, StreamPublisher): + def get_default_name(self) -> str: return f"{self.stream.name}:Publisher" @property diff --git a/faststream/redis/publisher/usecase.py b/faststream/redis/publisher/usecase.py index 479b9ccf66..601d18130c 100644 --- a/faststream/redis/publisher/usecase.py +++ b/faststream/redis/publisher/usecase.py @@ -33,20 +33,10 @@ def __init__( # Publisher args broker_middlewares: Iterable["BrokerMiddleware[UnifyRedisDict]"], middlewares: Iterable["PublisherMiddleware"], - # AsyncAPI args - schema_: Optional[Any], - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, ) -> None: super().__init__( broker_middlewares=broker_middlewares, middlewares=middlewares, - # AsyncAPI args - schema_=schema_, - title_=title_, - description_=description_, - include_in_schema=include_in_schema, ) self.reply_to = reply_to @@ -67,21 +57,12 @@ def __init__( # Regular publisher options broker_middlewares: Iterable["BrokerMiddleware[UnifyRedisDict]"], middlewares: Iterable["PublisherMiddleware"], - # AsyncAPI options - schema_: Optional[Any], - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, ) -> None: super().__init__( reply_to=reply_to, headers=headers, broker_middlewares=broker_middlewares, middlewares=middlewares, - schema_=schema_, - title_=title_, - description_=description_, - include_in_schema=include_in_schema, ) self.channel = channel @@ -204,21 +185,12 @@ def __init__( # Regular publisher options broker_middlewares: Iterable["BrokerMiddleware[UnifyRedisDict]"], middlewares: Iterable["PublisherMiddleware"], - # AsyncAPI options - schema_: Optional[Any], - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, ) -> None: super().__init__( reply_to=reply_to, headers=headers, broker_middlewares=broker_middlewares, middlewares=middlewares, - schema_=schema_, - title_=title_, - description_=description_, - include_in_schema=include_in_schema, ) self.list = list @@ -399,21 +371,12 @@ def __init__( # Regular publisher options broker_middlewares: Iterable["BrokerMiddleware[UnifyRedisDict]"], middlewares: Iterable["PublisherMiddleware"], - # AsyncAPI options - schema_: Optional[Any], - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, ) -> None: super().__init__( reply_to=reply_to, headers=headers, broker_middlewares=broker_middlewares, middlewares=middlewares, - schema_=schema_, - title_=title_, - description_=description_, - include_in_schema=include_in_schema, ) self.stream = stream diff --git a/faststream/redis/router.py b/faststream/redis/router.py index fb155829a0..a2a3b8ec2a 100644 --- a/faststream/redis/router.py +++ b/faststream/redis/router.py @@ -1,7 +1,7 @@ from collections.abc import Awaitable, Iterable from typing import TYPE_CHECKING, Annotated, Any, Callable, Optional, Union -from typing_extensions import Doc +from typing_extensions import Doc, deprecated from faststream._internal.broker.router import ( ArgsContainer, @@ -149,10 +149,15 @@ def __init__( Iterable["SubscriberMiddleware[UnifyRedisMessage]"], Doc("Subscriber middlewares to wrap incoming message processing."), ] = (), - ack_policy: Annotated[ - AckPolicy, + no_ack: Annotated[ + bool, Doc("Whether to disable **FastStream** auto acknowledgement logic or not."), + deprecated( + "This option was deprecated in 0.6.0 to prior to **ack_policy=AckPolicy.DO_NOTHING**. " + "Scheduled to remove in 0.7.0" + ), ] = EMPTY, + ack_policy: AckPolicy = EMPTY, no_reply: Annotated[ bool, Doc( @@ -187,6 +192,7 @@ def __init__( decoder=decoder, middlewares=middlewares, ack_policy=ack_policy, + no_ack=no_ack, no_reply=no_reply, title=title, description=description, diff --git a/faststream/redis/schemas/proto.py b/faststream/redis/schemas/proto.py index 685d4aa679..1f4191df73 100644 --- a/faststream/redis/schemas/proto.py +++ b/faststream/redis/schemas/proto.py @@ -2,14 +2,14 @@ from typing import TYPE_CHECKING, Any, Union from faststream.exceptions import SetupError -from faststream.specification.base.proto import SpecificationEndpoint +from faststream.specification.proto.endpoint import EndpointSpecification, T if TYPE_CHECKING: from faststream.redis.schemas import ListSub, PubSub, StreamSub from faststream.specification.schema.bindings import redis -class RedisSpecificationProtocol(SpecificationEndpoint): +class RedisSpecificationProtocol(EndpointSpecification[T]): @property @abstractmethod def channel_binding(self) -> "redis.ChannelBinding": ... diff --git a/faststream/redis/schemas/stream_sub.py b/faststream/redis/schemas/stream_sub.py index 07488d5f86..50a0b6d606 100644 --- a/faststream/redis/schemas/stream_sub.py +++ b/faststream/redis/schemas/stream_sub.py @@ -3,7 +3,6 @@ from faststream._internal.proto import NameRequired from faststream.exceptions import SetupError -from faststream.middlewares import AckPolicy class StreamSub(NameRequired): @@ -28,13 +27,11 @@ def __init__( group: Optional[str] = None, consumer: Optional[str] = None, batch: bool = False, - ack_policy: AckPolicy = AckPolicy.REJECT_ON_ERROR, + no_ack: bool = False, last_id: Optional[str] = None, maxlen: Optional[int] = None, max_records: Optional[int] = None, ) -> None: - no_ack = ack_policy is AckPolicy.DO_NOTHING - if (group and not consumer) or (not group and consumer): msg = "You should specify `group` and `consumer` both" raise SetupError(msg) diff --git a/faststream/redis/subscriber/factory.py b/faststream/redis/subscriber/factory.py index 378b561348..db42423078 100644 --- a/faststream/redis/subscriber/factory.py +++ b/faststream/redis/subscriber/factory.py @@ -39,6 +39,7 @@ def create_subscriber( stream: Union["StreamSub", str, None], # Subscriber args ack_policy: "AckPolicy", + no_ack: bool, no_reply: bool = False, broker_dependencies: Iterable["Dependant"] = (), broker_middlewares: Iterable["BrokerMiddleware[UnifyRedisDict]"] = (), @@ -52,10 +53,11 @@ def create_subscriber( list=list, stream=stream, ack_policy=ack_policy, + no_ack=no_ack, ) if ack_policy is EMPTY: - ack_policy = AckPolicy.REJECT_ON_ERROR + ack_policy = AckPolicy.DO_NOTHING if no_ack else AckPolicy.REJECT_ON_ERROR if (channel_sub := PubSub.validate(channel)) is not None: return SpecificationChannelSubscriber( @@ -133,9 +135,21 @@ def _validate_input_for_misconfigure( list: Union["ListSub", str, None], stream: Union["StreamSub", str, None], ack_policy: AckPolicy, + no_ack: bool, ) -> None: validate_options(channel=channel, list=list, stream=stream) + if no_ack is not EMPTY: + warnings.warn( + "`no_ack` option was deprecated in prior to `ack_policy=AckPolicy.DO_NOTHING`. Scheduled to remove in 0.7.0", + category=DeprecationWarning, + stacklevel=4, + ) + + if ack_policy is not EMPTY: + msg = "You can't use deprecated `no_ack` and `ack_policy` simultaneously. Please, use `ack_policy` only." + raise SetupError(msg) + if ack_policy is not EMPTY: if channel: warnings.warn( diff --git a/faststream/redis/subscriber/specified.py b/faststream/redis/subscriber/specified.py index 800e5b1f02..e943a80aeb 100644 --- a/faststream/redis/subscriber/specified.py +++ b/faststream/redis/subscriber/specified.py @@ -1,37 +1,37 @@ +from faststream._internal.subscriber.specified import ( + SpecificationSubscriber as SpecificationSubscriberMixin, +) from faststream.redis.schemas import ListSub, StreamSub from faststream.redis.schemas.proto import RedisSpecificationProtocol from faststream.redis.subscriber.usecase import ( BatchListSubscriber, ChannelSubscriber, ListSubscriber, - LogicSubscriber, StreamBatchSubscriber, StreamSubscriber, ) from faststream.specification.asyncapi.utils import resolve_payloads +from faststream.specification.schema import Message, Operation, SubscriberSpec from faststream.specification.schema.bindings import ChannelBinding, redis -from faststream.specification.schema.channel import Channel -from faststream.specification.schema.message import CorrelationId, Message -from faststream.specification.schema.operation import Operation -class SpecificationSubscriber(LogicSubscriber, RedisSpecificationProtocol): +class SpecificationSubscriber( + SpecificationSubscriberMixin, RedisSpecificationProtocol[SubscriberSpec] +): """A class to represent a Redis handler.""" - def get_schema(self) -> dict[str, Channel]: + def get_schema(self) -> dict[str, SubscriberSpec]: payloads = self.get_payloads() return { - self.name: Channel( + self.name: SubscriberSpec( description=self.description, - subscribe=Operation( + operation=Operation( message=Message( title=f"{self.name}:Message", payload=resolve_payloads(payloads), - correlationId=CorrelationId( - location="$message.header#/correlation_id", - ), ), + bindings=None, ), bindings=ChannelBinding( redis=self.channel_binding, @@ -40,8 +40,8 @@ def get_schema(self) -> dict[str, Channel]: } -class SpecificationChannelSubscriber(ChannelSubscriber, SpecificationSubscriber): - def get_name(self) -> str: +class SpecificationChannelSubscriber(SpecificationSubscriber, ChannelSubscriber): + def get_default_name(self) -> str: return f"{self.channel.name}:{self.call_name}" @property @@ -55,7 +55,7 @@ def channel_binding(self) -> "redis.ChannelBinding": class _StreamSubscriberMixin(SpecificationSubscriber): stream_sub: StreamSub - def get_name(self) -> str: + def get_default_name(self) -> str: return f"{self.stream_sub.name}:{self.call_name}" @property @@ -68,18 +68,18 @@ def channel_binding(self) -> "redis.ChannelBinding": ) -class SpecificationStreamSubscriber(StreamSubscriber, _StreamSubscriberMixin): +class SpecificationStreamSubscriber(_StreamSubscriberMixin, StreamSubscriber): pass -class SpecificationStreamBatchSubscriber(StreamBatchSubscriber, _StreamSubscriberMixin): +class SpecificationStreamBatchSubscriber(_StreamSubscriberMixin, StreamBatchSubscriber): pass class _ListSubscriberMixin(SpecificationSubscriber): list_sub: ListSub - def get_name(self) -> str: + def get_default_name(self) -> str: return f"{self.list_sub.name}:{self.call_name}" @property @@ -90,9 +90,9 @@ def channel_binding(self) -> "redis.ChannelBinding": ) -class SpecificationListSubscriber(ListSubscriber, _ListSubscriberMixin): +class SpecificationListSubscriber(_ListSubscriberMixin, ListSubscriber): pass -class SpecificationListBatchSubscriber(BatchListSubscriber, _ListSubscriberMixin): +class SpecificationListBatchSubscriber(_ListSubscriberMixin, BatchListSubscriber): pass diff --git a/faststream/redis/subscriber/usecase.py b/faststream/redis/subscriber/usecase.py index 1f89ca54b4..4d689a9193 100644 --- a/faststream/redis/subscriber/usecase.py +++ b/faststream/redis/subscriber/usecase.py @@ -76,10 +76,6 @@ def __init__( no_reply: bool, broker_dependencies: Iterable["Dependant"], broker_middlewares: Iterable["BrokerMiddleware[UnifyRedisDict]"], - # AsyncAPI args - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, ) -> None: super().__init__( default_parser=default_parser, @@ -89,10 +85,6 @@ def __init__( no_reply=no_reply, broker_middlewares=broker_middlewares, broker_dependencies=broker_dependencies, - # AsyncAPI - title_=title_, - description_=description_, - include_in_schema=include_in_schema, ) self._client = None @@ -208,10 +200,6 @@ def __init__( no_reply: bool, broker_dependencies: Iterable["Dependant"], broker_middlewares: Iterable["BrokerMiddleware[UnifyRedisDict]"], - # AsyncAPI args - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, ) -> None: parser = RedisPubSubParser(pattern=channel.path_regex) super().__init__( @@ -222,10 +210,6 @@ def __init__( no_reply=no_reply, broker_middlewares=broker_middlewares, broker_dependencies=broker_dependencies, - # AsyncAPI - title_=title_, - description_=description_, - include_in_schema=include_in_schema, ) self.channel = channel @@ -333,10 +317,6 @@ def __init__( no_reply: bool, broker_dependencies: Iterable["Dependant"], broker_middlewares: Iterable["BrokerMiddleware[UnifyRedisDict]"], - # AsyncAPI args - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, ) -> None: super().__init__( default_parser=default_parser, @@ -346,10 +326,6 @@ def __init__( no_reply=no_reply, broker_middlewares=broker_middlewares, broker_dependencies=broker_dependencies, - # AsyncAPI - title_=title_, - description_=description_, - include_in_schema=include_in_schema, ) self.list_sub = list @@ -438,10 +414,6 @@ def __init__( no_reply: bool, broker_dependencies: Iterable["Dependant"], broker_middlewares: Iterable["BrokerMiddleware[UnifyRedisDict]"], - # AsyncAPI args - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, ) -> None: parser = RedisListParser() super().__init__( @@ -453,10 +425,6 @@ def __init__( no_reply=no_reply, broker_middlewares=broker_middlewares, broker_dependencies=broker_dependencies, - # AsyncAPI - title_=title_, - description_=description_, - include_in_schema=include_in_schema, ) async def _get_msgs(self, client: "Redis[bytes]") -> None: @@ -486,10 +454,6 @@ def __init__( no_reply: bool, broker_dependencies: Iterable["Dependant"], broker_middlewares: Iterable["BrokerMiddleware[UnifyRedisDict]"], - # AsyncAPI args - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, ) -> None: parser = RedisBatchListParser() super().__init__( @@ -501,10 +465,6 @@ def __init__( no_reply=no_reply, broker_middlewares=broker_middlewares, broker_dependencies=broker_dependencies, - # AsyncAPI - title_=title_, - description_=description_, - include_in_schema=include_in_schema, ) async def _get_msgs(self, client: "Redis[bytes]") -> None: @@ -538,10 +498,6 @@ def __init__( no_reply: bool, broker_dependencies: Iterable["Dependant"], broker_middlewares: Iterable["BrokerMiddleware[UnifyRedisDict]"], - # AsyncAPI args - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, ) -> None: super().__init__( default_parser=default_parser, @@ -551,10 +507,6 @@ def __init__( no_reply=no_reply, broker_middlewares=broker_middlewares, broker_dependencies=broker_dependencies, - # AsyncAPI - title_=title_, - description_=description_, - include_in_schema=include_in_schema, ) self.stream_sub = stream @@ -728,10 +680,6 @@ def __init__( no_reply: bool, broker_dependencies: Iterable["Dependant"], broker_middlewares: Iterable["BrokerMiddleware[UnifyRedisDict]"], - # AsyncAPI args - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, ) -> None: parser = RedisStreamParser() super().__init__( @@ -743,10 +691,6 @@ def __init__( no_reply=no_reply, broker_middlewares=broker_middlewares, broker_dependencies=broker_dependencies, - # AsyncAPI - title_=title_, - description_=description_, - include_in_schema=include_in_schema, ) async def _get_msgs( @@ -795,10 +739,6 @@ def __init__( no_reply: bool, broker_dependencies: Iterable["Dependant"], broker_middlewares: Iterable["BrokerMiddleware[UnifyRedisDict]"], - # AsyncAPI args - title_: Optional[str], - description_: Optional[str], - include_in_schema: bool, ) -> None: parser = RedisBatchStreamParser() super().__init__( @@ -810,10 +750,6 @@ def __init__( no_reply=no_reply, broker_middlewares=broker_middlewares, broker_dependencies=broker_dependencies, - # AsyncAPI - title_=title_, - description_=description_, - include_in_schema=include_in_schema, ) async def _get_msgs( diff --git a/faststream/security.py b/faststream/security.py index 583f7fd30d..e208d3023e 100644 --- a/faststream/security.py +++ b/faststream/security.py @@ -167,7 +167,7 @@ def get_requirement(self) -> list["AnyDict"]: def get_schema(self) -> dict[str, dict[str, str]]: """Get the security schema for SASL/OAUTHBEARER authentication.""" - return {"oauthbearer": {"type": "oauthBearer"}} + return {"oauthbearer": {"type": "oauth2", "$ref": ""}} class SASLGSSAPI(BaseSecurity): diff --git a/faststream/specification/__init__.py b/faststream/specification/__init__.py index 502bf43de6..7738408d36 100644 --- a/faststream/specification/__init__.py +++ b/faststream/specification/__init__.py @@ -1,5 +1,5 @@ from .asyncapi.factory import AsyncAPI -from .schema import Contact, ExternalDocs, License, Tag +from .schema.extra import Contact, ExternalDocs, License, Tag __all__ = ( "AsyncAPI", diff --git a/faststream/specification/asyncapi/factory.py b/faststream/specification/asyncapi/factory.py index 7b537be3c8..dc263a6b52 100644 --- a/faststream/specification/asyncapi/factory.py +++ b/faststream/specification/asyncapi/factory.py @@ -1,4 +1,4 @@ -from collections.abc import Sequence +from collections.abc import Iterable from typing import TYPE_CHECKING, Any, Literal, Optional, Union from faststream.specification.base.specification import Specification @@ -6,70 +6,67 @@ if TYPE_CHECKING: from faststream._internal.basic_types import AnyDict, AnyHttpUrl from faststream._internal.broker.broker import BrokerUsecase - from faststream.specification.schema.contact import Contact, ContactDict - from faststream.specification.schema.docs import ExternalDocs, ExternalDocsDict - from faststream.specification.schema.license import License, LicenseDict - from faststream.specification.schema.tag import Tag, TagDict + from faststream.specification.schema import ( + Contact, + ContactDict, + ExternalDocs, + ExternalDocsDict, + License, + LicenseDict, + Tag, + TagDict, + ) -class AsyncAPI(Specification): - def __new__( # type: ignore[misc] - cls, - broker: "BrokerUsecase[Any, Any]", - /, - title: str = "FastStream", - app_version: str = "0.1.0", - schema_version: Union[Literal["3.0.0", "2.6.0"], str] = "3.0.0", - description: str = "", - terms_of_service: Optional["AnyHttpUrl"] = None, - license: Optional[Union["License", "LicenseDict", "AnyDict"]] = None, - contact: Optional[Union["Contact", "ContactDict", "AnyDict"]] = None, - tags: Optional[Sequence[Union["Tag", "TagDict", "AnyDict"]]] = None, - external_docs: Optional[ - Union["ExternalDocs", "ExternalDocsDict", "AnyDict"] - ] = None, - identifier: Optional[str] = None, - ) -> Specification: - if schema_version.startswith("3.0."): - from .v3_0_0.facade import AsyncAPI3 +def AsyncAPI( # noqa: N802 + broker: "BrokerUsecase[Any, Any]", + /, + title: str = "FastStream", + app_version: str = "0.1.0", + schema_version: Union[Literal["3.0.0", "2.6.0"], str] = "3.0.0", + description: str = "", + terms_of_service: Optional["AnyHttpUrl"] = None, + license: Optional[Union["License", "LicenseDict", "AnyDict"]] = None, + contact: Optional[Union["Contact", "ContactDict", "AnyDict"]] = None, + tags: Iterable[Union["Tag", "TagDict", "AnyDict"]] = (), + external_docs: Optional[ + Union["ExternalDocs", "ExternalDocsDict", "AnyDict"] + ] = None, + identifier: Optional[str] = None, +) -> Specification: + if schema_version.startswith("3.0."): + from .v3_0_0.facade import AsyncAPI3 - return AsyncAPI3( - broker, - title=title, - app_version=app_version, - schema_version=schema_version, - description=description, - terms_of_service=terms_of_service, - contact=contact, - license=license, - identifier=identifier, - tags=tags, - external_docs=external_docs, - ) - if schema_version.startswith("2.6."): - from .v2_6_0.facade import AsyncAPI2 + return AsyncAPI3( + broker, + title=title, + app_version=app_version, + schema_version=schema_version, + description=description, + terms_of_service=terms_of_service, + contact=contact, + license=license, + identifier=identifier, + tags=tags, + external_docs=external_docs, + ) - return AsyncAPI2( - broker, - title=title, - app_version=app_version, - schema_version=schema_version, - description=description, - terms_of_service=terms_of_service, - contact=contact, - license=license, - identifier=identifier, - tags=tags, - external_docs=external_docs, - ) - msg = f"Unsupported schema version: {schema_version}" - raise NotImplementedError(msg) + if schema_version.startswith("2.6."): + from .v2_6_0.facade import AsyncAPI2 - def to_json(self) -> str: - raise NotImplementedError + return AsyncAPI2( + broker, + title=title, + app_version=app_version, + schema_version=schema_version, + description=description, + terms_of_service=terms_of_service, + contact=contact, + license=license, + identifier=identifier, + tags=tags, + external_docs=external_docs, + ) - def to_jsonable(self) -> Any: - raise NotImplementedError - - def to_yaml(self) -> str: - raise NotImplementedError + msg = f"Unsupported schema version: {schema_version}" + raise NotImplementedError(msg) diff --git a/faststream/specification/asyncapi/utils.py b/faststream/specification/asyncapi/utils.py index 2e6ffadfe2..7f16a215dc 100644 --- a/faststream/specification/asyncapi/utils.py +++ b/faststream/specification/asyncapi/utils.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from faststream._internal.basic_types import AnyDict @@ -49,3 +49,36 @@ def resolve_payloads( def clear_key(key: str) -> str: return key.replace("/", ".") + + +def move_pydantic_refs( + original: Any, + key: str, +) -> Any: + """Remove pydantic references and replacem them by real schemas.""" + if not isinstance(original, dict): + return original + + data = original.copy() + + for k in data: + item = data[k] + + if isinstance(item, str): + if key in item: + data[k] = data[k].replace(key, "components/schemas") + + elif isinstance(item, dict): + data[k] = move_pydantic_refs(data[k], key) + + elif isinstance(item, list): + for i in range(len(data[k])): + data[k][i] = move_pydantic_refs(item[i], key) + + if ( + isinstance(desciminator := data.get("discriminator"), dict) + and "propertyName" in desciminator + ): + data["discriminator"] = desciminator["propertyName"] + + return data diff --git a/faststream/specification/asyncapi/v2_6_0/facade.py b/faststream/specification/asyncapi/v2_6_0/facade.py index 80c7cedd38..d8c4b5618b 100644 --- a/faststream/specification/asyncapi/v2_6_0/facade.py +++ b/faststream/specification/asyncapi/v2_6_0/facade.py @@ -1,18 +1,24 @@ -from collections.abc import Sequence +from collections.abc import Iterable from typing import TYPE_CHECKING, Any, Optional, Union from faststream.specification.base.specification import Specification from .generate import get_app_schema -from .schema import Schema +from .schema import ApplicationSchema if TYPE_CHECKING: from faststream._internal.basic_types import AnyDict, AnyHttpUrl from faststream._internal.broker.broker import BrokerUsecase - from faststream.specification.schema.contact import Contact, ContactDict - from faststream.specification.schema.docs import ExternalDocs, ExternalDocsDict - from faststream.specification.schema.license import License, LicenseDict - from faststream.specification.schema.tag import Tag, TagDict + from faststream.specification.schema import ( + Contact, + ContactDict, + ExternalDocs, + ExternalDocsDict, + License, + LicenseDict, + Tag, + TagDict, + ) class AsyncAPI2(Specification): @@ -28,7 +34,7 @@ def __init__( contact: Optional[Union["Contact", "ContactDict", "AnyDict"]] = None, license: Optional[Union["License", "LicenseDict", "AnyDict"]] = None, identifier: Optional[str] = None, - tags: Optional[Sequence[Union["Tag", "TagDict", "AnyDict"]]] = None, + tags: Iterable[Union["Tag", "TagDict", "AnyDict"]] = (), external_docs: Optional[ Union["ExternalDocs", "ExternalDocsDict", "AnyDict"] ] = None, @@ -55,7 +61,7 @@ def to_yaml(self) -> str: return self.schema.to_yaml() @property - def schema(self) -> Schema: # type: ignore[override] + def schema(self) -> ApplicationSchema: # type: ignore[override] return get_app_schema( self.broker, title=self.title, diff --git a/faststream/specification/asyncapi/v2_6_0/generate.py b/faststream/specification/asyncapi/v2_6_0/generate.py index 2c6a3b3900..4c81514da7 100644 --- a/faststream/specification/asyncapi/v2_6_0/generate.py +++ b/faststream/specification/asyncapi/v2_6_0/generate.py @@ -4,32 +4,33 @@ from faststream._internal._compat import DEF_KEY from faststream._internal.basic_types import AnyDict, AnyHttpUrl from faststream._internal.constants import ContentTypes -from faststream.specification.asyncapi.utils import clear_key +from faststream.specification.asyncapi.utils import clear_key, move_pydantic_refs from faststream.specification.asyncapi.v2_6_0.schema import ( + ApplicationInfo, + ApplicationSchema, Channel, Components, - Info, + Contact, + ExternalDocs, + License, + Message, Reference, - Schema, Server, Tag, - channel_from_spec, - contact_from_spec, - docs_from_spec, - license_from_spec, - tag_from_spec, ) -from faststream.specification.asyncapi.v2_6_0.schema.message import Message if TYPE_CHECKING: from faststream._internal.broker.broker import BrokerUsecase from faststream._internal.types import ConnectionType, MsgType - from faststream.specification.schema.contact import Contact, ContactDict - from faststream.specification.schema.docs import ExternalDocs, ExternalDocsDict - from faststream.specification.schema.license import License, LicenseDict - from faststream.specification.schema.tag import ( - Tag as SpecsTag, - TagDict as SpecsTagDict, + from faststream.specification.schema.extra import ( + Contact as SpecContact, + ContactDict, + ExternalDocs as SpecDocs, + ExternalDocsDict, + License as SpecLicense, + LicenseDict, + Tag as SpecTag, + TagDict, ) @@ -41,12 +42,12 @@ def get_app_schema( schema_version: str, description: str, terms_of_service: Optional["AnyHttpUrl"], - contact: Optional[Union["Contact", "ContactDict", "AnyDict"]], - license: Optional[Union["License", "LicenseDict", "AnyDict"]], + contact: Optional[Union["SpecContact", "ContactDict", "AnyDict"]], + license: Optional[Union["SpecLicense", "LicenseDict", "AnyDict"]], identifier: Optional[str], - tags: Optional[Sequence[Union["SpecsTag", "SpecsTagDict", "AnyDict"]]], - external_docs: Optional[Union["ExternalDocs", "ExternalDocsDict", "AnyDict"]], -) -> Schema: + tags: Sequence[Union["SpecTag", "TagDict", "AnyDict"]], + external_docs: Optional[Union["SpecDocs", "ExternalDocsDict", "AnyDict"]], +) -> ApplicationSchema: """Get the application schema.""" broker._setup() @@ -62,20 +63,20 @@ def get_app_schema( for channel_name, ch in channels.items(): resolve_channel_messages(ch, channel_name, payloads, messages) - return Schema( - info=Info( + return ApplicationSchema( + info=ApplicationInfo( title=title, version=app_version, description=description, termsOfService=terms_of_service, - contact=contact_from_spec(contact) if contact else None, - license=license_from_spec(license) if license else None, + contact=Contact.from_spec(contact), + license=License.from_spec(license), ), + tags=[Tag.from_spec(tag) for tag in tags] or None, + externalDocs=ExternalDocs.from_spec(external_docs), asyncapi=schema_version, defaultContentType=ContentTypes.JSON.value, id=identifier, - tags=[tag_from_spec(tag) for tag in tags] if tags else None, - externalDocs=docs_from_spec(external_docs) if external_docs else None, servers=servers, channels=channels, components=Components( @@ -121,31 +122,22 @@ def get_broker_server( """Get the broker server for an application.""" servers = {} - tags: Optional[list[Union[Tag, AnyDict]]] = None - if broker.tags: - tags = [tag_from_spec(tag) for tag in broker.tags] - broker_meta: AnyDict = { "protocol": broker.protocol, "protocolVersion": broker.protocol_version, "description": broker.description, - "tags": tags or None, + "tags": [Tag.from_spec(tag) for tag in broker.tags] or None, + "security": broker.security.get_requirement() if broker.security else None, # TODO # "variables": "", # "bindings": "", } - if broker.security is not None: - broker_meta["security"] = broker.security.get_requirement() - urls = broker.url if isinstance(broker.url, list) else [broker.url] for i, url in enumerate(urls, 1): server_name = "development" if len(urls) == 1 else f"Server{i}" - servers[server_name] = Server( - url=url, - **broker_meta, - ) + servers[server_name] = Server(url=url, **broker_meta) return servers @@ -157,16 +149,16 @@ def get_broker_channels( channels = {} for h in broker._subscribers: - schema = h.schema() - channels.update( - {key: channel_from_spec(channel) for key, channel in schema.items()}, - ) + # TODO: add duplication key warning + channels.update({ + key: Channel.from_sub(channel) for key, channel in h.schema().items() + }) for p in broker._publishers: - schema = p.schema() - channels.update( - {key: channel_from_spec(channel) for key, channel in schema.items()}, - ) + # TODO: add duplication key warning + channels.update({ + key: Channel.from_pub(channel) for key, channel in p.schema().items() + }) return channels @@ -223,36 +215,3 @@ def _resolve_msg_payloads( message_title = clear_key(m.title) messages[message_title] = m return Reference(**{"$ref": f"#/components/messages/{message_title}"}) - - -def move_pydantic_refs( - original: Any, - key: str, -) -> Any: - """Remove pydantic references and replacem them by real schemas.""" - if not isinstance(original, dict): - return original - - data = original.copy() - - for k in data: - item = data[k] - - if isinstance(item, str): - if key in item: - data[k] = data[k].replace(key, "components/schemas") - - elif isinstance(item, dict): - data[k] = move_pydantic_refs(data[k], key) - - elif isinstance(item, list): - for i in range(len(data[k])): - data[k][i] = move_pydantic_refs(item[i], key) - - if ( - isinstance(desciminator := data.get("discriminator"), dict) - and "propertyName" in desciminator - ): - data["discriminator"] = desciminator["propertyName"] - - return data diff --git a/faststream/specification/asyncapi/v2_6_0/schema/__init__.py b/faststream/specification/asyncapi/v2_6_0/schema/__init__.py index 1e29c5b8cd..e0cbcbd7b2 100644 --- a/faststream/specification/asyncapi/v2_6_0/schema/__init__.py +++ b/faststream/specification/asyncapi/v2_6_0/schema/__init__.py @@ -1,61 +1,31 @@ -from .channels import ( - Channel, - from_spec as channel_from_spec, -) +from .channels import Channel from .components import Components -from .contact import ( - Contact, - from_spec as contact_from_spec, -) -from .docs import ( - ExternalDocs, - from_spec as docs_from_spec, -) -from .info import Info -from .license import ( - License, - from_spec as license_from_spec, -) -from .message import ( - CorrelationId, - Message, - from_spec as message_from_spec, -) -from .operations import ( - Operation, - from_spec as operation_from_spec, -) -from .schema import Schema +from .contact import Contact +from .docs import ExternalDocs +from .info import ApplicationInfo +from .license import License +from .message import CorrelationId, Message +from .operations import Operation +from .schema import ApplicationSchema from .servers import Server, ServerVariable -from .tag import ( - Tag, - from_spec as tag_from_spec, -) +from .tag import Tag from .utils import Parameter, Reference __all__ = ( + "ApplicationInfo", + "ApplicationSchema", "Channel", "Channel", "Components", "Contact", "CorrelationId", "ExternalDocs", - "Info", "License", "Message", "Operation", "Parameter", "Reference", - "Schema", "Server", "ServerVariable", "Tag", - "channel_from_spec", - "channel_from_spec", - "contact_from_spec", - "docs_from_spec", - "license_from_spec", - "message_from_spec", - "operation_from_spec", - "tag_from_spec", ) diff --git a/faststream/specification/asyncapi/v2_6_0/schema/bindings/__init__.py b/faststream/specification/asyncapi/v2_6_0/schema/bindings/__init__.py index a0e9cb8389..84b0fa22e8 100644 --- a/faststream/specification/asyncapi/v2_6_0/schema/bindings/__init__.py +++ b/faststream/specification/asyncapi/v2_6_0/schema/bindings/__init__.py @@ -1,13 +1,6 @@ -from .main import ( - ChannelBinding, - OperationBinding, - channel_binding_from_spec, - operation_binding_from_spec, -) +from .main import ChannelBinding, OperationBinding __all__ = ( "ChannelBinding", "OperationBinding", - "channel_binding_from_spec", - "operation_binding_from_spec", ) diff --git a/faststream/specification/asyncapi/v2_6_0/schema/bindings/amqp/__init__.py b/faststream/specification/asyncapi/v2_6_0/schema/bindings/amqp/__init__.py index 7ead3ce532..8555fd981a 100644 --- a/faststream/specification/asyncapi/v2_6_0/schema/bindings/amqp/__init__.py +++ b/faststream/specification/asyncapi/v2_6_0/schema/bindings/amqp/__init__.py @@ -1,15 +1,7 @@ -from .channel import ( - ChannelBinding, - from_spec as channel_binding_from_spec, -) -from .operation import ( - OperationBinding, - from_spec as operation_binding_from_spec, -) +from .channel import ChannelBinding +from .operation import OperationBinding __all__ = ( "ChannelBinding", "OperationBinding", - "channel_binding_from_spec", - "operation_binding_from_spec", ) diff --git a/faststream/specification/asyncapi/v2_6_0/schema/bindings/amqp/channel.py b/faststream/specification/asyncapi/v2_6_0/schema/bindings/amqp/channel.py index 1317967a1d..aa729dce29 100644 --- a/faststream/specification/asyncapi/v2_6_0/schema/bindings/amqp/channel.py +++ b/faststream/specification/asyncapi/v2_6_0/schema/bindings/amqp/channel.py @@ -3,12 +3,12 @@ References: https://github.com/asyncapi/bindings/tree/master/amqp """ -from typing import Literal, Optional +from typing import Literal, Optional, overload from pydantic import BaseModel, Field from typing_extensions import Self -from faststream.specification import schema as spec +from faststream.specification.schema.bindings import amqp class Queue(BaseModel): @@ -28,14 +28,25 @@ class Queue(BaseModel): autoDelete: bool vhost: str = "/" + @overload @classmethod - def from_spec(cls, binding: spec.bindings.amqp.Queue) -> Self: + def from_spec(cls, binding: None, vhost: str) -> None: ... + + @overload + @classmethod + def from_spec(cls, binding: amqp.Queue, vhost: str) -> Self: ... + + @classmethod + def from_spec(cls, binding: Optional[amqp.Queue], vhost: str) -> Optional[Self]: + if binding is None: + return None + return cls( name=binding.name, durable=binding.durable, exclusive=binding.exclusive, - autoDelete=binding.autoDelete, - vhost=binding.vhost, + autoDelete=binding.auto_delete, + vhost=vhost, ) @@ -65,14 +76,25 @@ class Exchange(BaseModel): autoDelete: Optional[bool] = None vhost: str = "/" + @overload + @classmethod + def from_spec(cls, binding: None, vhost: str) -> None: ... + + @overload + @classmethod + def from_spec(cls, binding: amqp.Exchange, vhost: str) -> Self: ... + @classmethod - def from_spec(cls, binding: spec.bindings.amqp.Exchange) -> Self: + def from_spec(cls, binding: Optional[amqp.Exchange], vhost: str) -> Optional[Self]: + if binding is None: + return None + return cls( name=binding.name, type=binding.type, durable=binding.durable, - autoDelete=binding.autoDelete, - vhost=binding.vhost, + autoDelete=binding.auto_delete, + vhost=vhost, ) @@ -92,19 +114,31 @@ class ChannelBinding(BaseModel): exchange: Optional[Exchange] = None @classmethod - def from_spec(cls, binding: spec.bindings.amqp.ChannelBinding) -> Self: + def from_sub(cls, binding: Optional[amqp.ChannelBinding]) -> Optional[Self]: + if binding is None: + return None + return cls( **{ - "is": binding.is_, - "queue": Queue.from_spec(binding.queue) - if binding.queue is not None - else None, - "exchange": Exchange.from_spec(binding.exchange) - if binding.exchange is not None + "is": "routingKey", + "queue": Queue.from_spec(binding.queue, binding.virtual_host) + if binding.exchange.is_respect_routing_key else None, + "exchange": Exchange.from_spec(binding.exchange, binding.virtual_host), }, ) + @classmethod + def from_pub(cls, binding: Optional[amqp.ChannelBinding]) -> Optional[Self]: + if binding is None: + return None -def from_spec(binding: spec.bindings.amqp.ChannelBinding) -> ChannelBinding: - return ChannelBinding.from_spec(binding) + return cls( + **{ + "is": "routingKey", + "queue": Queue.from_spec(binding.queue, binding.virtual_host) + if binding.exchange.is_respect_routing_key and binding.queue.name + else None, + "exchange": Exchange.from_spec(binding.exchange, binding.virtual_host), + }, + ) diff --git a/faststream/specification/asyncapi/v2_6_0/schema/bindings/amqp/operation.py b/faststream/specification/asyncapi/v2_6_0/schema/bindings/amqp/operation.py index cd90dde96d..47ed19af93 100644 --- a/faststream/specification/asyncapi/v2_6_0/schema/bindings/amqp/operation.py +++ b/faststream/specification/asyncapi/v2_6_0/schema/bindings/amqp/operation.py @@ -8,7 +8,7 @@ from pydantic import BaseModel, PositiveInt from typing_extensions import Self -from faststream.specification import schema as spec +from faststream.specification.schema.bindings import amqp class OperationBinding(BaseModel): @@ -22,24 +22,38 @@ class OperationBinding(BaseModel): """ cc: Optional[str] = None - ack: bool = True + ack: bool replyTo: Optional[str] = None deliveryMode: Optional[int] = None mandatory: Optional[bool] = None priority: Optional[PositiveInt] = None + bindingVersion: str = "0.2.0" @classmethod - def from_spec(cls, binding: spec.bindings.amqp.OperationBinding) -> Self: + def from_sub(cls, binding: Optional[amqp.OperationBinding]) -> Optional[Self]: + if not binding: + return None + return cls( - cc=binding.cc, + cc=binding.routing_key if binding.exchange.is_respect_routing_key else None, ack=binding.ack, - replyTo=binding.replyTo, - deliveryMode=binding.deliveryMode, + replyTo=binding.reply_to, + deliveryMode=None if binding.persist is None else int(binding.persist) + 1, mandatory=binding.mandatory, priority=binding.priority, ) + @classmethod + def from_pub(cls, binding: Optional[amqp.OperationBinding]) -> Optional[Self]: + if not binding: + return None -def from_spec(binding: spec.bindings.amqp.OperationBinding) -> OperationBinding: - return OperationBinding.from_spec(binding) + return cls( + cc=binding.routing_key if binding.exchange.is_respect_routing_key else None, + ack=binding.ack, + replyTo=binding.reply_to, + deliveryMode=None if binding.persist is None else int(binding.persist) + 1, + mandatory=binding.mandatory, + priority=binding.priority, + ) diff --git a/faststream/specification/asyncapi/v2_6_0/schema/bindings/kafka/__init__.py b/faststream/specification/asyncapi/v2_6_0/schema/bindings/kafka/__init__.py index 7ead3ce532..8555fd981a 100644 --- a/faststream/specification/asyncapi/v2_6_0/schema/bindings/kafka/__init__.py +++ b/faststream/specification/asyncapi/v2_6_0/schema/bindings/kafka/__init__.py @@ -1,15 +1,7 @@ -from .channel import ( - ChannelBinding, - from_spec as channel_binding_from_spec, -) -from .operation import ( - OperationBinding, - from_spec as operation_binding_from_spec, -) +from .channel import ChannelBinding +from .operation import OperationBinding __all__ = ( "ChannelBinding", "OperationBinding", - "channel_binding_from_spec", - "operation_binding_from_spec", ) diff --git a/faststream/specification/asyncapi/v2_6_0/schema/bindings/kafka/channel.py b/faststream/specification/asyncapi/v2_6_0/schema/bindings/kafka/channel.py index d6eda36274..1f304410ba 100644 --- a/faststream/specification/asyncapi/v2_6_0/schema/bindings/kafka/channel.py +++ b/faststream/specification/asyncapi/v2_6_0/schema/bindings/kafka/channel.py @@ -8,7 +8,7 @@ from pydantic import BaseModel, PositiveInt from typing_extensions import Self -from faststream.specification import schema as spec +from faststream.specification.schema.bindings import kafka class ChannelBinding(BaseModel): @@ -30,14 +30,23 @@ class ChannelBinding(BaseModel): # topicConfiguration @classmethod - def from_spec(cls, binding: spec.bindings.kafka.ChannelBinding) -> Self: + def from_sub(cls, binding: Optional[kafka.ChannelBinding]) -> Optional[Self]: + if binding is None: + return None + return cls( topic=binding.topic, partitions=binding.partitions, replicas=binding.replicas, - bindingVersion=binding.bindingVersion, ) + @classmethod + def from_pub(cls, binding: Optional[kafka.ChannelBinding]) -> Optional[Self]: + if binding is None: + return None -def from_spec(binding: spec.bindings.kafka.ChannelBinding) -> ChannelBinding: - return ChannelBinding.from_spec(binding) + return cls( + topic=binding.topic, + partitions=binding.partitions, + replicas=binding.replicas, + ) diff --git a/faststream/specification/asyncapi/v2_6_0/schema/bindings/kafka/operation.py b/faststream/specification/asyncapi/v2_6_0/schema/bindings/kafka/operation.py index 005cc92cf7..4155ce220e 100644 --- a/faststream/specification/asyncapi/v2_6_0/schema/bindings/kafka/operation.py +++ b/faststream/specification/asyncapi/v2_6_0/schema/bindings/kafka/operation.py @@ -9,7 +9,7 @@ from typing_extensions import Self from faststream._internal.basic_types import AnyDict -from faststream.specification import schema as spec +from faststream.specification.schema.bindings import kafka class OperationBinding(BaseModel): @@ -28,14 +28,23 @@ class OperationBinding(BaseModel): bindingVersion: str = "0.4.0" @classmethod - def from_spec(cls, binding: spec.bindings.kafka.OperationBinding) -> Self: + def from_sub(cls, binding: Optional[kafka.OperationBinding]) -> Optional[Self]: + if not binding: + return None + return cls( - groupId=binding.groupId, - clientId=binding.clientId, - replyTo=binding.replyTo, - bindingVersion=binding.bindingVersion, + groupId=binding.group_id, + clientId=binding.client_id, + replyTo=binding.reply_to, ) + @classmethod + def from_pub(cls, binding: Optional[kafka.OperationBinding]) -> Optional[Self]: + if not binding: + return None -def from_spec(binding: spec.bindings.kafka.OperationBinding) -> OperationBinding: - return OperationBinding.from_spec(binding) + return cls( + groupId=binding.group_id, + clientId=binding.client_id, + replyTo=binding.reply_to, + ) diff --git a/faststream/specification/asyncapi/v2_6_0/schema/bindings/main/__init__.py b/faststream/specification/asyncapi/v2_6_0/schema/bindings/main/__init__.py index 7ead3ce532..8555fd981a 100644 --- a/faststream/specification/asyncapi/v2_6_0/schema/bindings/main/__init__.py +++ b/faststream/specification/asyncapi/v2_6_0/schema/bindings/main/__init__.py @@ -1,15 +1,7 @@ -from .channel import ( - ChannelBinding, - from_spec as channel_binding_from_spec, -) -from .operation import ( - OperationBinding, - from_spec as operation_binding_from_spec, -) +from .channel import ChannelBinding +from .operation import OperationBinding __all__ = ( "ChannelBinding", "OperationBinding", - "channel_binding_from_spec", - "operation_binding_from_spec", ) diff --git a/faststream/specification/asyncapi/v2_6_0/schema/bindings/main/channel.py b/faststream/specification/asyncapi/v2_6_0/schema/bindings/main/channel.py index 258e08ea3a..bf4b7dbd98 100644 --- a/faststream/specification/asyncapi/v2_6_0/schema/bindings/main/channel.py +++ b/faststream/specification/asyncapi/v2_6_0/schema/bindings/main/channel.py @@ -1,10 +1,9 @@ -from typing import Optional +from typing import Optional, overload from pydantic import BaseModel from typing_extensions import Self from faststream._internal._compat import PYDANTIC_V2 -from faststream.specification import schema as spec from faststream.specification.asyncapi.v2_6_0.schema.bindings import ( amqp as amqp_bindings, kafka as kafka_bindings, @@ -12,6 +11,7 @@ redis as redis_bindings, sqs as sqs_bindings, ) +from faststream.specification.schema.bindings import ChannelBinding as SpecBinding class ChannelBinding(BaseModel): @@ -23,7 +23,6 @@ class ChannelBinding(BaseModel): sqs : SQS channel binding (optional) nats : NATS channel binding (optional) redis : Redis channel binding (optional) - """ amqp: Optional[amqp_bindings.ChannelBinding] = None @@ -40,26 +39,78 @@ class ChannelBinding(BaseModel): class Config: extra = "allow" + @overload + @classmethod + def from_sub(cls, binding: None) -> None: ... + + @overload @classmethod - def from_spec(cls, binding: spec.bindings.ChannelBinding) -> Self: - return cls( - amqp=amqp_bindings.channel_binding_from_spec(binding.amqp) - if binding.amqp is not None - else None, - kafka=kafka_bindings.channel_binding_from_spec(binding.kafka) - if binding.kafka is not None - else None, - sqs=sqs_bindings.channel_binding_from_spec(binding.sqs) - if binding.sqs is not None - else None, - nats=nats_bindings.channel_binding_from_spec(binding.nats) - if binding.nats is not None - else None, - redis=redis_bindings.channel_binding_from_spec(binding.redis) - if binding.redis is not None - else None, - ) - - -def from_spec(binding: spec.bindings.ChannelBinding) -> ChannelBinding: - return ChannelBinding.from_spec(binding) + def from_sub(cls, binding: SpecBinding) -> Self: ... + + @classmethod + def from_sub(cls, binding: Optional[SpecBinding]) -> Optional[Self]: + if binding is None: + return None + + if binding.amqp and ( + amqp := amqp_bindings.ChannelBinding.from_sub(binding.amqp) + ): + return cls(amqp=amqp) + + if binding.kafka and ( + kafka := kafka_bindings.ChannelBinding.from_sub(binding.kafka) + ): + return cls(kafka=kafka) + + if binding.nats and ( + nats := nats_bindings.ChannelBinding.from_sub(binding.nats) + ): + return cls(nats=nats) + + if binding.redis and ( + redis := redis_bindings.ChannelBinding.from_sub(binding.redis) + ): + return cls(redis=redis) + + if binding.sqs and (sqs := sqs_bindings.ChannelBinding.from_sub(binding.sqs)): + return cls(sqs=sqs) + + return None + + @overload + @classmethod + def from_pub(cls, binding: None) -> None: ... + + @overload + @classmethod + def from_pub(cls, binding: SpecBinding) -> Self: ... + + @classmethod + def from_pub(cls, binding: Optional[SpecBinding]) -> Optional[Self]: + if binding is None: + return None + + if binding.amqp and ( + amqp := amqp_bindings.ChannelBinding.from_pub(binding.amqp) + ): + return cls(amqp=amqp) + + if binding.kafka and ( + kafka := kafka_bindings.ChannelBinding.from_pub(binding.kafka) + ): + return cls(kafka=kafka) + + if binding.nats and ( + nats := nats_bindings.ChannelBinding.from_pub(binding.nats) + ): + return cls(nats=nats) + + if binding.redis and ( + redis := redis_bindings.ChannelBinding.from_pub(binding.redis) + ): + return cls(redis=redis) + + if binding.sqs and (sqs := sqs_bindings.ChannelBinding.from_pub(binding.sqs)): + return cls(sqs=sqs) + + return None diff --git a/faststream/specification/asyncapi/v2_6_0/schema/bindings/main/operation.py b/faststream/specification/asyncapi/v2_6_0/schema/bindings/main/operation.py index 61d614dd4d..7367b7921f 100644 --- a/faststream/specification/asyncapi/v2_6_0/schema/bindings/main/operation.py +++ b/faststream/specification/asyncapi/v2_6_0/schema/bindings/main/operation.py @@ -1,10 +1,9 @@ -from typing import Optional +from typing import Optional, overload from pydantic import BaseModel from typing_extensions import Self from faststream._internal._compat import PYDANTIC_V2 -from faststream.specification import schema as spec from faststream.specification.asyncapi.v2_6_0.schema.bindings import ( amqp as amqp_bindings, kafka as kafka_bindings, @@ -12,6 +11,7 @@ redis as redis_bindings, sqs as sqs_bindings, ) +from faststream.specification.schema.bindings import OperationBinding as SpecBinding class OperationBinding(BaseModel): @@ -23,7 +23,6 @@ class OperationBinding(BaseModel): sqs : SQS operation binding (optional) nats : NATS operation binding (optional) redis : Redis operation binding (optional) - """ amqp: Optional[amqp_bindings.OperationBinding] = None @@ -40,26 +39,78 @@ class OperationBinding(BaseModel): class Config: extra = "allow" + @overload + @classmethod + def from_sub(cls, binding: None) -> None: ... + + @overload @classmethod - def from_spec(cls, binding: spec.bindings.OperationBinding) -> Self: - return cls( - amqp=amqp_bindings.operation_binding_from_spec(binding.amqp) - if binding.amqp is not None - else None, - kafka=kafka_bindings.operation_binding_from_spec(binding.kafka) - if binding.kafka is not None - else None, - sqs=sqs_bindings.operation_binding_from_spec(binding.sqs) - if binding.sqs is not None - else None, - nats=nats_bindings.operation_binding_from_spec(binding.nats) - if binding.nats is not None - else None, - redis=redis_bindings.operation_binding_from_spec(binding.redis) - if binding.redis is not None - else None, - ) - - -def from_spec(binding: spec.bindings.OperationBinding) -> OperationBinding: - return OperationBinding.from_spec(binding) + def from_sub(cls, binding: SpecBinding) -> Self: ... + + @classmethod + def from_sub(cls, binding: Optional[SpecBinding]) -> Optional[Self]: + if binding is None: + return None + + if binding.amqp and ( + amqp := amqp_bindings.OperationBinding.from_sub(binding.amqp) + ): + return cls(amqp=amqp) + + if binding.kafka and ( + kafka := kafka_bindings.OperationBinding.from_sub(binding.kafka) + ): + return cls(kafka=kafka) + + if binding.nats and ( + nats := nats_bindings.OperationBinding.from_sub(binding.nats) + ): + return cls(nats=nats) + + if binding.redis and ( + redis := redis_bindings.OperationBinding.from_sub(binding.redis) + ): + return cls(redis=redis) + + if binding.sqs and (sqs := sqs_bindings.OperationBinding.from_sub(binding.sqs)): + return cls(sqs=sqs) + + return None + + @overload + @classmethod + def from_pub(cls, binding: None) -> None: ... + + @overload + @classmethod + def from_pub(cls, binding: SpecBinding) -> Self: ... + + @classmethod + def from_pub(cls, binding: Optional[SpecBinding]) -> Optional[Self]: + if binding is None: + return None + + if binding.amqp and ( + amqp := amqp_bindings.OperationBinding.from_pub(binding.amqp) + ): + return cls(amqp=amqp) + + if binding.kafka and ( + kafka := kafka_bindings.OperationBinding.from_pub(binding.kafka) + ): + return cls(kafka=kafka) + + if binding.nats and ( + nats := nats_bindings.OperationBinding.from_pub(binding.nats) + ): + return cls(nats=nats) + + if binding.redis and ( + redis := redis_bindings.OperationBinding.from_pub(binding.redis) + ): + return cls(redis=redis) + + if binding.sqs and (sqs := sqs_bindings.OperationBinding.from_pub(binding.sqs)): + return cls(sqs=sqs) + + return None diff --git a/faststream/specification/asyncapi/v2_6_0/schema/bindings/nats/__init__.py b/faststream/specification/asyncapi/v2_6_0/schema/bindings/nats/__init__.py index 7ead3ce532..8555fd981a 100644 --- a/faststream/specification/asyncapi/v2_6_0/schema/bindings/nats/__init__.py +++ b/faststream/specification/asyncapi/v2_6_0/schema/bindings/nats/__init__.py @@ -1,15 +1,7 @@ -from .channel import ( - ChannelBinding, - from_spec as channel_binding_from_spec, -) -from .operation import ( - OperationBinding, - from_spec as operation_binding_from_spec, -) +from .channel import ChannelBinding +from .operation import OperationBinding __all__ = ( "ChannelBinding", "OperationBinding", - "channel_binding_from_spec", - "operation_binding_from_spec", ) diff --git a/faststream/specification/asyncapi/v2_6_0/schema/bindings/nats/channel.py b/faststream/specification/asyncapi/v2_6_0/schema/bindings/nats/channel.py index ba39c7569b..4cc83faddb 100644 --- a/faststream/specification/asyncapi/v2_6_0/schema/bindings/nats/channel.py +++ b/faststream/specification/asyncapi/v2_6_0/schema/bindings/nats/channel.py @@ -8,7 +8,7 @@ from pydantic import BaseModel from typing_extensions import Self -from faststream.specification import schema as spec +from faststream.specification.schema.bindings import nats class ChannelBinding(BaseModel): @@ -25,13 +25,23 @@ class ChannelBinding(BaseModel): bindingVersion: str = "custom" @classmethod - def from_spec(cls, binding: spec.bindings.nats.ChannelBinding) -> Self: + def from_sub(cls, binding: Optional[nats.ChannelBinding]) -> Optional[Self]: + if binding is None: + return None + return cls( subject=binding.subject, queue=binding.queue, - bindingVersion=binding.bindingVersion, + bindingVersion="custom", ) + @classmethod + def from_pub(cls, binding: Optional[nats.ChannelBinding]) -> Optional[Self]: + if binding is None: + return None -def from_spec(binding: spec.bindings.nats.ChannelBinding) -> ChannelBinding: - return ChannelBinding.from_spec(binding) + return cls( + subject=binding.subject, + queue=binding.queue, + bindingVersion="custom", + ) diff --git a/faststream/specification/asyncapi/v2_6_0/schema/bindings/nats/operation.py b/faststream/specification/asyncapi/v2_6_0/schema/bindings/nats/operation.py index b38a2f89dd..5e1514fcba 100644 --- a/faststream/specification/asyncapi/v2_6_0/schema/bindings/nats/operation.py +++ b/faststream/specification/asyncapi/v2_6_0/schema/bindings/nats/operation.py @@ -9,7 +9,7 @@ from typing_extensions import Self from faststream._internal.basic_types import AnyDict -from faststream.specification import schema as spec +from faststream.specification.schema.bindings import nats class OperationBinding(BaseModel): @@ -24,12 +24,19 @@ class OperationBinding(BaseModel): bindingVersion: str = "custom" @classmethod - def from_spec(cls, binding: spec.bindings.nats.OperationBinding) -> Self: + def from_sub(cls, binding: Optional[nats.OperationBinding]) -> Optional[Self]: + if not binding: + return None + return cls( - replyTo=binding.replyTo, - bindingVersion=binding.bindingVersion, + replyTo=binding.reply_to, ) + @classmethod + def from_pub(cls, binding: Optional[nats.OperationBinding]) -> Optional[Self]: + if not binding: + return None -def from_spec(binding: spec.bindings.nats.OperationBinding) -> OperationBinding: - return OperationBinding.from_spec(binding) + return cls( + replyTo=binding.reply_to, + ) diff --git a/faststream/specification/asyncapi/v2_6_0/schema/bindings/redis/__init__.py b/faststream/specification/asyncapi/v2_6_0/schema/bindings/redis/__init__.py index 7ead3ce532..8555fd981a 100644 --- a/faststream/specification/asyncapi/v2_6_0/schema/bindings/redis/__init__.py +++ b/faststream/specification/asyncapi/v2_6_0/schema/bindings/redis/__init__.py @@ -1,15 +1,7 @@ -from .channel import ( - ChannelBinding, - from_spec as channel_binding_from_spec, -) -from .operation import ( - OperationBinding, - from_spec as operation_binding_from_spec, -) +from .channel import ChannelBinding +from .operation import OperationBinding __all__ = ( "ChannelBinding", "OperationBinding", - "channel_binding_from_spec", - "operation_binding_from_spec", ) diff --git a/faststream/specification/asyncapi/v2_6_0/schema/bindings/redis/channel.py b/faststream/specification/asyncapi/v2_6_0/schema/bindings/redis/channel.py index 579f9170ea..abc5bf96d6 100644 --- a/faststream/specification/asyncapi/v2_6_0/schema/bindings/redis/channel.py +++ b/faststream/specification/asyncapi/v2_6_0/schema/bindings/redis/channel.py @@ -8,7 +8,7 @@ from pydantic import BaseModel from typing_extensions import Self -from faststream.specification import schema as spec +from faststream.specification.schema.bindings import redis class ChannelBinding(BaseModel): @@ -22,20 +22,30 @@ class ChannelBinding(BaseModel): channel: str method: Optional[str] = None - group_name: Optional[str] = None - consumer_name: Optional[str] = None + groupName: Optional[str] = None + consumerName: Optional[str] = None bindingVersion: str = "custom" @classmethod - def from_spec(cls, binding: spec.bindings.redis.ChannelBinding) -> Self: + def from_sub(cls, binding: Optional[redis.ChannelBinding]) -> Optional[Self]: + if binding is None: + return None + return cls( channel=binding.channel, method=binding.method, - group_name=binding.group_name, - consumer_name=binding.consumer_name, - bindingVersion=binding.bindingVersion, + groupName=binding.group_name, + consumerName=binding.consumer_name, ) + @classmethod + def from_pub(cls, binding: Optional[redis.ChannelBinding]) -> Optional[Self]: + if binding is None: + return None -def from_spec(binding: spec.bindings.redis.ChannelBinding) -> ChannelBinding: - return ChannelBinding.from_spec(binding) + return cls( + channel=binding.channel, + method=binding.method, + groupName=binding.group_name, + consumerName=binding.consumer_name, + ) diff --git a/faststream/specification/asyncapi/v2_6_0/schema/bindings/redis/operation.py b/faststream/specification/asyncapi/v2_6_0/schema/bindings/redis/operation.py index 39a4c94b99..cce0316160 100644 --- a/faststream/specification/asyncapi/v2_6_0/schema/bindings/redis/operation.py +++ b/faststream/specification/asyncapi/v2_6_0/schema/bindings/redis/operation.py @@ -9,7 +9,7 @@ from typing_extensions import Self from faststream._internal.basic_types import AnyDict -from faststream.specification import schema as spec +from faststream.specification.schema.bindings import redis class OperationBinding(BaseModel): @@ -24,12 +24,19 @@ class OperationBinding(BaseModel): bindingVersion: str = "custom" @classmethod - def from_spec(cls, binding: spec.bindings.redis.OperationBinding) -> Self: + def from_sub(cls, binding: Optional[redis.OperationBinding]) -> Optional[Self]: + if not binding: + return None + return cls( - replyTo=binding.replyTo, - bindingVersion=binding.bindingVersion, + replyTo=binding.reply_to, ) + @classmethod + def from_pub(cls, binding: Optional[redis.OperationBinding]) -> Optional[Self]: + if not binding: + return None -def from_spec(binding: spec.bindings.redis.OperationBinding) -> OperationBinding: - return OperationBinding.from_spec(binding) + return cls( + replyTo=binding.reply_to, + ) diff --git a/faststream/specification/asyncapi/v2_6_0/schema/bindings/sqs/__init__.py b/faststream/specification/asyncapi/v2_6_0/schema/bindings/sqs/__init__.py index 7ead3ce532..33cdca3a8b 100644 --- a/faststream/specification/asyncapi/v2_6_0/schema/bindings/sqs/__init__.py +++ b/faststream/specification/asyncapi/v2_6_0/schema/bindings/sqs/__init__.py @@ -1,15 +1,4 @@ -from .channel import ( - ChannelBinding, - from_spec as channel_binding_from_spec, -) -from .operation import ( - OperationBinding, - from_spec as operation_binding_from_spec, -) +from .channel import ChannelBinding +from .operation import OperationBinding -__all__ = ( - "ChannelBinding", - "OperationBinding", - "channel_binding_from_spec", - "operation_binding_from_spec", -) +__all__ = ("ChannelBinding", "OperationBinding") diff --git a/faststream/specification/asyncapi/v2_6_0/schema/bindings/sqs/channel.py b/faststream/specification/asyncapi/v2_6_0/schema/bindings/sqs/channel.py index 631e9c7bb4..93a1c5ac80 100644 --- a/faststream/specification/asyncapi/v2_6_0/schema/bindings/sqs/channel.py +++ b/faststream/specification/asyncapi/v2_6_0/schema/bindings/sqs/channel.py @@ -7,7 +7,7 @@ from typing_extensions import Self from faststream._internal.basic_types import AnyDict -from faststream.specification import schema as spec +from faststream.specification.schema.bindings import sqs class ChannelBinding(BaseModel): @@ -22,12 +22,12 @@ class ChannelBinding(BaseModel): bindingVersion: str = "custom" @classmethod - def from_spec(cls, binding: spec.bindings.sqs.ChannelBinding) -> Self: + def from_spec(cls, binding: sqs.ChannelBinding) -> Self: return cls( queue=binding.queue, bindingVersion=binding.bindingVersion, ) -def from_spec(binding: spec.bindings.sqs.ChannelBinding) -> ChannelBinding: +def from_spec(binding: sqs.ChannelBinding) -> ChannelBinding: return ChannelBinding.from_spec(binding) diff --git a/faststream/specification/asyncapi/v2_6_0/schema/bindings/sqs/operation.py b/faststream/specification/asyncapi/v2_6_0/schema/bindings/sqs/operation.py index 4ea0ece20a..35aa598d20 100644 --- a/faststream/specification/asyncapi/v2_6_0/schema/bindings/sqs/operation.py +++ b/faststream/specification/asyncapi/v2_6_0/schema/bindings/sqs/operation.py @@ -9,7 +9,7 @@ from typing_extensions import Self from faststream._internal.basic_types import AnyDict -from faststream.specification import schema as spec +from faststream.specification.schema.bindings import sqs class OperationBinding(BaseModel): @@ -24,12 +24,12 @@ class OperationBinding(BaseModel): bindingVersion: str = "custom" @classmethod - def from_spec(cls, binding: spec.bindings.sqs.OperationBinding) -> Self: + def from_spec(cls, binding: sqs.OperationBinding) -> Self: return cls( replyTo=binding.replyTo, bindingVersion=binding.bindingVersion, ) -def from_spec(binding: spec.bindings.sqs.OperationBinding) -> OperationBinding: +def from_spec(binding: sqs.OperationBinding) -> OperationBinding: return OperationBinding.from_spec(binding) diff --git a/faststream/specification/asyncapi/v2_6_0/schema/channels.py b/faststream/specification/asyncapi/v2_6_0/schema/channels.py index 99ac5585bc..5310578554 100644 --- a/faststream/specification/asyncapi/v2_6_0/schema/channels.py +++ b/faststream/specification/asyncapi/v2_6_0/schema/channels.py @@ -4,15 +4,10 @@ from typing_extensions import Self from faststream._internal._compat import PYDANTIC_V2 -from faststream.specification import schema as spec -from faststream.specification.asyncapi.v2_6_0.schema.bindings import ( - ChannelBinding, - channel_binding_from_spec, -) -from faststream.specification.asyncapi.v2_6_0.schema.operations import ( - Operation, - from_spec as operation_from_spec, -) +from faststream.specification.schema import PublisherSpec, SubscriberSpec + +from .bindings import ChannelBinding +from .operations import Operation class Channel(BaseModel): @@ -28,7 +23,6 @@ class Channel(BaseModel): Configurations: model_config : configuration for the model (only applicable for Pydantic version 2) Config : configuration for the class (only applicable for Pydantic version 1) - """ description: Optional[str] = None @@ -49,21 +43,21 @@ class Config: extra = "allow" @classmethod - def from_spec(cls, channel: spec.channel.Channel) -> Self: + def from_sub(cls, subscriber: SubscriberSpec) -> Self: return cls( - description=channel.description, - servers=channel.servers, - bindings=channel_binding_from_spec(channel.bindings) - if channel.bindings is not None - else None, - subscribe=operation_from_spec(channel.subscribe) - if channel.subscribe is not None - else None, - publish=operation_from_spec(channel.publish) - if channel.publish is not None - else None, + description=subscriber.description, + servers=None, + bindings=ChannelBinding.from_sub(subscriber.bindings), + subscribe=None, + publish=Operation.from_sub(subscriber.operation), ) - -def from_spec(channel: spec.channel.Channel) -> Channel: - return Channel.from_spec(channel) + @classmethod + def from_pub(cls, publisher: PublisherSpec) -> Self: + return cls( + description=publisher.description, + servers=None, + bindings=ChannelBinding.from_pub(publisher.bindings), + subscribe=Operation.from_pub(publisher.operation), + publish=None, + ) diff --git a/faststream/specification/asyncapi/v2_6_0/schema/components.py b/faststream/specification/asyncapi/v2_6_0/schema/components.py index 764d639a24..a80c3420d0 100644 --- a/faststream/specification/asyncapi/v2_6_0/schema/components.py +++ b/faststream/specification/asyncapi/v2_6_0/schema/components.py @@ -36,7 +36,6 @@ class Components(BaseModel): - channelBindings - operationBindings - messageBindings - """ messages: Optional[dict[str, Message]] = None diff --git a/faststream/specification/asyncapi/v2_6_0/schema/contact.py b/faststream/specification/asyncapi/v2_6_0/schema/contact.py index 809b5190c4..9456d7dd61 100644 --- a/faststream/specification/asyncapi/v2_6_0/schema/contact.py +++ b/faststream/specification/asyncapi/v2_6_0/schema/contact.py @@ -1,15 +1,15 @@ -from typing import ( - Optional, - Union, - overload, -) +from typing import Optional, Union, overload from pydantic import AnyHttpUrl, BaseModel from typing_extensions import Self from faststream._internal._compat import PYDANTIC_V2, EmailStr from faststream._internal.basic_types import AnyDict -from faststream.specification import schema as spec +from faststream._internal.utils.data import filter_by_dict +from faststream.specification.schema.extra import ( + Contact as SpecContact, + ContactDict, +) class Contact(BaseModel): @@ -19,10 +19,10 @@ class Contact(BaseModel): name : name of the contact (str) url : URL of the contact (Optional[AnyHttpUrl]) email : email of the contact (Optional[EmailStr]) - """ name: str + # Use default values to be able build from dict url: Optional[AnyHttpUrl] = None email: Optional[EmailStr] = None @@ -34,31 +34,39 @@ class Contact(BaseModel): class Config: extra = "allow" + @overload @classmethod - def from_spec(cls, contact: spec.contact.Contact) -> Self: - return cls( - name=contact.name, - url=contact.url, - email=contact.email, - ) - + def from_spec(cls, contact: None) -> None: ... -@overload -def from_spec(contact: spec.contact.Contact) -> Contact: ... + @overload + @classmethod + def from_spec(cls, contact: SpecContact) -> Self: ... + @overload + @classmethod + def from_spec(cls, contact: ContactDict) -> Self: ... -@overload -def from_spec(contact: spec.contact.ContactDict) -> AnyDict: ... + @overload + @classmethod + def from_spec(cls, contact: AnyDict) -> AnyDict: ... + @classmethod + def from_spec( + cls, contact: Union[SpecContact, ContactDict, AnyDict, None] + ) -> Union[Self, AnyDict, None]: + if contact is None: + return None -@overload -def from_spec(contact: AnyDict) -> AnyDict: ... + if isinstance(contact, SpecContact): + return cls( + name=contact.name, + url=contact.url, + email=contact.email, + ) + contact_data, custom_data = filter_by_dict(ContactDict, contact) -def from_spec( - contact: Union[spec.contact.Contact, spec.contact.ContactDict, AnyDict], -) -> Union[Contact, AnyDict]: - if isinstance(contact, spec.contact.Contact): - return Contact.from_spec(contact) + if custom_data: + return contact - return dict(contact) + return cls(**contact_data) diff --git a/faststream/specification/asyncapi/v2_6_0/schema/docs.py b/faststream/specification/asyncapi/v2_6_0/schema/docs.py index 34b2e3ed8d..6ad8d6a252 100644 --- a/faststream/specification/asyncapi/v2_6_0/schema/docs.py +++ b/faststream/specification/asyncapi/v2_6_0/schema/docs.py @@ -5,7 +5,11 @@ from faststream._internal._compat import PYDANTIC_V2 from faststream._internal.basic_types import AnyDict -from faststream.specification import schema as spec +from faststream._internal.utils.data import filter_by_dict +from faststream.specification.schema.extra import ( + ExternalDocs as SpecDocs, + ExternalDocsDict, +) class ExternalDocs(BaseModel): @@ -14,10 +18,10 @@ class ExternalDocs(BaseModel): Attributes: url : URL of the external documentation description : optional description of the external documentation - """ url: AnyHttpUrl + # Use default values to be able build from dict description: Optional[str] = None if PYDANTIC_V2: @@ -28,27 +32,35 @@ class ExternalDocs(BaseModel): class Config: extra = "allow" + @overload @classmethod - def from_spec(cls, docs: spec.docs.ExternalDocs) -> Self: - return cls(url=docs.url, description=docs.description) - + def from_spec(cls, docs: None) -> None: ... -@overload -def from_spec(docs: spec.docs.ExternalDocs) -> ExternalDocs: ... + @overload + @classmethod + def from_spec(cls, docs: SpecDocs) -> Self: ... + @overload + @classmethod + def from_spec(cls, docs: ExternalDocsDict) -> Self: ... -@overload -def from_spec(docs: spec.docs.ExternalDocsDict) -> AnyDict: ... + @overload + @classmethod + def from_spec(cls, docs: AnyDict) -> AnyDict: ... + @classmethod + def from_spec( + cls, docs: Union[SpecDocs, ExternalDocsDict, AnyDict, None] + ) -> Union[Self, AnyDict, None]: + if docs is None: + return None -@overload -def from_spec(docs: AnyDict) -> AnyDict: ... + if isinstance(docs, SpecDocs): + return cls(url=docs.url, description=docs.description) + docs_data, custom_data = filter_by_dict(ExternalDocsDict, docs) -def from_spec( - docs: Union[spec.docs.ExternalDocs, spec.docs.ExternalDocsDict, AnyDict], -) -> Union[ExternalDocs, AnyDict]: - if isinstance(docs, spec.docs.ExternalDocs): - return ExternalDocs.from_spec(docs) + if custom_data: + return docs - return dict(docs) + return cls(**docs_data) diff --git a/faststream/specification/asyncapi/v2_6_0/schema/info.py b/faststream/specification/asyncapi/v2_6_0/schema/info.py index b4cf3bceec..50f79fa026 100644 --- a/faststream/specification/asyncapi/v2_6_0/schema/info.py +++ b/faststream/specification/asyncapi/v2_6_0/schema/info.py @@ -8,19 +8,19 @@ from faststream._internal.basic_types import AnyDict from faststream.specification.asyncapi.v2_6_0.schema.contact import Contact from faststream.specification.asyncapi.v2_6_0.schema.license import License -from faststream.specification.base.info import BaseInfo +from faststream.specification.base.info import BaseApplicationInfo -class Info(BaseInfo): - """A class to represent information. +class ApplicationInfo(BaseApplicationInfo): + """A class to represent application information. Attributes: title : title of the information - version : version of the information (default: "1.0.0") - description : description of the information (default: "") - termsOfService : terms of service for the information (default: None) - contact : contact information for the information (default: None) - license : license information for the information (default: None) + version : version of the information + description : description of the information + termsOfService : terms of service for the information + contact : contact information for the information + license : license information for the information """ termsOfService: Optional[AnyHttpUrl] = None diff --git a/faststream/specification/asyncapi/v2_6_0/schema/license.py b/faststream/specification/asyncapi/v2_6_0/schema/license.py index 761511789a..1d3b778d8e 100644 --- a/faststream/specification/asyncapi/v2_6_0/schema/license.py +++ b/faststream/specification/asyncapi/v2_6_0/schema/license.py @@ -1,17 +1,15 @@ -from typing import ( - Optional, - Union, - overload, -) +from typing import Optional, Union, overload from pydantic import AnyHttpUrl, BaseModel from typing_extensions import Self -from faststream._internal._compat import ( - PYDANTIC_V2, -) +from faststream._internal._compat import PYDANTIC_V2 from faststream._internal.basic_types import AnyDict -from faststream.specification import schema as spec +from faststream._internal.utils.data import filter_by_dict +from faststream.specification.schema.extra import ( + License as SpecLicense, + LicenseDict, +) class License(BaseModel): @@ -23,10 +21,10 @@ class License(BaseModel): Config: extra : allow additional attributes in the model (PYDANTIC_V2) - """ name: str + # Use default values to be able build from dict url: Optional[AnyHttpUrl] = None if PYDANTIC_V2: @@ -37,30 +35,38 @@ class License(BaseModel): class Config: extra = "allow" + @overload @classmethod - def from_spec(cls, license: spec.license.License) -> Self: - return cls( - name=license.name, - url=license.url, - ) - + def from_spec(cls, license: None) -> None: ... -@overload -def from_spec(license: spec.license.License) -> License: ... + @overload + @classmethod + def from_spec(cls, license: SpecLicense) -> Self: ... + @overload + @classmethod + def from_spec(cls, license: LicenseDict) -> Self: ... -@overload -def from_spec(license: spec.license.LicenseDict) -> AnyDict: ... + @overload + @classmethod + def from_spec(cls, license: AnyDict) -> AnyDict: ... + @classmethod + def from_spec( + cls, license: Union[SpecLicense, LicenseDict, AnyDict, None] + ) -> Union[Self, AnyDict, None]: + if license is None: + return None -@overload -def from_spec(license: AnyDict) -> AnyDict: ... + if isinstance(license, SpecLicense): + return cls( + name=license.name, + url=license.url, + ) + license_data, custom_data = filter_by_dict(LicenseDict, license) -def from_spec( - license: Union[spec.license.License, spec.license.LicenseDict, AnyDict], -) -> Union[License, AnyDict]: - if isinstance(license, spec.license.License): - return License.from_spec(license) + if custom_data: + return license - return dict(license) + return cls(**license_data) diff --git a/faststream/specification/asyncapi/v2_6_0/schema/message.py b/faststream/specification/asyncapi/v2_6_0/schema/message.py index d9cde0e403..5f56df156c 100644 --- a/faststream/specification/asyncapi/v2_6_0/schema/message.py +++ b/faststream/specification/asyncapi/v2_6_0/schema/message.py @@ -1,15 +1,12 @@ from typing import Optional, Union -import typing_extensions from pydantic import BaseModel +from typing_extensions import Self from faststream._internal._compat import PYDANTIC_V2 from faststream._internal.basic_types import AnyDict -from faststream.specification import schema as spec -from faststream.specification.asyncapi.v2_6_0.schema.tag import ( - Tag, - from_spec as tag_from_spec, -) +from faststream.specification.asyncapi.v2_6_0.schema.tag import Tag +from faststream.specification.schema.message import Message as SpecMessage class CorrelationId(BaseModel): @@ -21,11 +18,10 @@ class CorrelationId(BaseModel): Configurations: extra : allows extra fields in the correlation ID model - """ - description: Optional[str] = None location: str + description: Optional[str] = None if PYDANTIC_V2: model_config = {"extra": "allow"} @@ -35,13 +31,6 @@ class CorrelationId(BaseModel): class Config: extra = "allow" - @classmethod - def from_spec(cls, cor_id: spec.message.CorrelationId) -> typing_extensions.Self: - return cls( - description=cor_id.description, - location=cor_id.location, - ) - class Message(BaseModel): """A class to represent a message. @@ -56,7 +45,6 @@ class Message(BaseModel): contentType : content type of the message payload : dictionary representing the payload of the message tags : list of tags associated with the message - """ title: Optional[str] = None @@ -86,23 +74,18 @@ class Config: extra = "allow" @classmethod - def from_spec(cls, message: spec.message.Message) -> typing_extensions.Self: + def from_spec(cls, message: SpecMessage) -> Self: return cls( title=message.title, - name=message.name, - summary=message.summary, - description=message.description, - messageId=message.messageId, - correlationId=CorrelationId.from_spec(message.correlationId) - if message.correlationId is not None - else None, - contentType=message.contentType, payload=message.payload, - tags=[tag_from_spec(tag) for tag in message.tags] - if message.tags is not None - else None, + correlationId=CorrelationId( + description=None, + location="$message.header#/correlation_id", + ), + name=None, + summary=None, + description=None, + messageId=None, + contentType=None, + tags=None, ) - - -def from_spec(message: spec.message.Message) -> Message: - return Message.from_spec(message) diff --git a/faststream/specification/asyncapi/v2_6_0/schema/operations.py b/faststream/specification/asyncapi/v2_6_0/schema/operations.py index 9d6c6d61fa..c837c844d7 100644 --- a/faststream/specification/asyncapi/v2_6_0/schema/operations.py +++ b/faststream/specification/asyncapi/v2_6_0/schema/operations.py @@ -5,22 +5,12 @@ from faststream._internal._compat import PYDANTIC_V2 from faststream._internal.basic_types import AnyDict -from faststream.specification import schema as spec -from faststream.specification.asyncapi.v2_6_0.schema.bindings import OperationBinding -from faststream.specification.asyncapi.v2_6_0.schema.bindings.main import ( - operation_binding_from_spec, -) -from faststream.specification.asyncapi.v2_6_0.schema.message import ( - Message, - from_spec as message_from_spec, -) -from faststream.specification.asyncapi.v2_6_0.schema.tag import ( - Tag, - from_spec as tag_from_spec, -) -from faststream.specification.asyncapi.v2_6_0.schema.utils import ( - Reference, -) +from faststream.specification.schema.operation import Operation as OperationSpec + +from .bindings import OperationBinding +from .message import Message +from .tag import Tag +from .utils import Reference class Operation(BaseModel): @@ -34,7 +24,6 @@ class Operation(BaseModel): message : message of the operation security : security details of the operation tags : tags associated with the operation - """ operationId: Optional[str] = None @@ -61,22 +50,25 @@ class Config: extra = "allow" @classmethod - def from_spec(cls, operation: spec.operation.Operation) -> Self: + def from_sub(cls, operation: OperationSpec) -> Self: return cls( - operationId=operation.operationId, - summary=operation.summary, - description=operation.description, - bindings=operation_binding_from_spec(operation.bindings) - if operation.bindings is not None - else None, - message=message_from_spec(operation.message) - if operation.message is not None - else None, - tags=[tag_from_spec(tag) for tag in operation.tags] - if operation.tags is not None - else None, + message=Message.from_spec(operation.message), + bindings=OperationBinding.from_sub(operation.bindings), + operationId=None, + summary=None, + description=None, + tags=None, + security=None, ) - -def from_spec(operation: spec.operation.Operation) -> Operation: - return Operation.from_spec(operation) + @classmethod + def from_pub(cls, operation: OperationSpec) -> Self: + return cls( + message=Message.from_spec(operation.message), + bindings=OperationBinding.from_pub(operation.bindings), + operationId=None, + summary=None, + description=None, + tags=None, + security=None, + ) diff --git a/faststream/specification/asyncapi/v2_6_0/schema/schema.py b/faststream/specification/asyncapi/v2_6_0/schema/schema.py index 9c19130cc2..8f4a70a701 100644 --- a/faststream/specification/asyncapi/v2_6_0/schema/schema.py +++ b/faststream/specification/asyncapi/v2_6_0/schema/schema.py @@ -4,14 +4,14 @@ from faststream.specification.asyncapi.v2_6_0.schema.channels import Channel from faststream.specification.asyncapi.v2_6_0.schema.components import Components from faststream.specification.asyncapi.v2_6_0.schema.docs import ExternalDocs -from faststream.specification.asyncapi.v2_6_0.schema.info import Info +from faststream.specification.asyncapi.v2_6_0.schema.info import ApplicationInfo from faststream.specification.asyncapi.v2_6_0.schema.servers import Server from faststream.specification.asyncapi.v2_6_0.schema.tag import Tag -from faststream.specification.base.schema import BaseSchema +from faststream.specification.base.schema import BaseApplicationSchema -class Schema(BaseSchema): - """A class to represent a schema. +class ApplicationSchema(BaseApplicationSchema): + """A class to represent an application schema. Attributes: asyncapi : version of the async API @@ -25,9 +25,9 @@ class Schema(BaseSchema): externalDocs : optional external documentation """ - info: Info + info: ApplicationInfo - asyncapi: Union[Literal["2.6.0"], str] = "2.6.0" + asyncapi: Union[Literal["2.6.0"], str] id: Optional[str] = None defaultContentType: Optional[str] = None servers: Optional[dict[str, Server]] = None diff --git a/faststream/specification/asyncapi/v2_6_0/schema/servers.py b/faststream/specification/asyncapi/v2_6_0/schema/servers.py index a50be2669e..cae721cfd1 100644 --- a/faststream/specification/asyncapi/v2_6_0/schema/servers.py +++ b/faststream/specification/asyncapi/v2_6_0/schema/servers.py @@ -18,7 +18,6 @@ class ServerVariable(BaseModel): default : default value for the server variable (optional) description : description of the server variable (optional) examples : list of example values for the server variable (optional) - """ enum: Optional[list[str]] = None @@ -50,19 +49,15 @@ class Server(BaseModel): Note: The attributes `description`, `protocolVersion`, `tags`, `security`, `variables`, and `bindings` are all optional. - - Configurations: - If `PYDANTIC_V2` is True, the model configuration is set to allow extra attributes. - Otherwise, the `Config` class is defined with the `extra` attribute set to "allow". - """ url: str protocol: str - description: Optional[str] = None - protocolVersion: Optional[str] = None - tags: Optional[list[Union[Tag, AnyDict]]] = None - security: Optional[SecurityRequirement] = None + protocolVersion: Optional[str] + description: Optional[str] + tags: Optional[list[Union[Tag, AnyDict]]] + security: Optional[SecurityRequirement] + variables: Optional[dict[str, Union[ServerVariable, Reference]]] = None if PYDANTIC_V2: diff --git a/faststream/specification/asyncapi/v2_6_0/schema/tag.py b/faststream/specification/asyncapi/v2_6_0/schema/tag.py index 182c4effd9..a4fdae8d8d 100644 --- a/faststream/specification/asyncapi/v2_6_0/schema/tag.py +++ b/faststream/specification/asyncapi/v2_6_0/schema/tag.py @@ -1,14 +1,15 @@ from typing import Optional, Union, overload -import typing_extensions from pydantic import BaseModel +from typing_extensions import Self from faststream._internal._compat import PYDANTIC_V2 from faststream._internal.basic_types import AnyDict -from faststream.specification import schema as spec -from faststream.specification.asyncapi.v2_6_0.schema.docs import ( - ExternalDocs, - from_spec as docs_from_spec, +from faststream._internal.utils.data import filter_by_dict +from faststream.specification.asyncapi.v2_6_0.schema.docs import ExternalDocs +from faststream.specification.schema.extra import ( + Tag as SpecTag, + TagDict, ) @@ -19,10 +20,10 @@ class Tag(BaseModel): name : name of the tag description : description of the tag (optional) externalDocs : external documentation for the tag (optional) - """ name: str + # Use default values to be able build from dict description: Optional[str] = None externalDocs: Optional[ExternalDocs] = None @@ -34,31 +35,34 @@ class Tag(BaseModel): class Config: extra = "allow" + @overload @classmethod - def from_spec(cls, tag: spec.tag.Tag) -> typing_extensions.Self: - return cls( - name=tag.name, - description=tag.description, - externalDocs=docs_from_spec(tag.externalDocs) if tag.externalDocs else None, - ) - - -@overload -def from_spec(tag: spec.tag.Tag) -> Tag: ... - + def from_spec(cls, tag: SpecTag) -> Self: ... -@overload -def from_spec(tag: spec.tag.TagDict) -> AnyDict: ... + @overload + @classmethod + def from_spec(cls, tag: TagDict) -> Self: ... + @overload + @classmethod + def from_spec(cls, tag: AnyDict) -> AnyDict: ... -@overload -def from_spec(tag: AnyDict) -> AnyDict: ... + @classmethod + def from_spec(cls, tag: Union[SpecTag, TagDict, AnyDict]) -> Union[Self, AnyDict]: + if isinstance(tag, SpecTag): + return cls( + name=tag.name, + description=tag.description, + externalDocs=ExternalDocs.from_spec(tag.external_docs), + ) + tag_data, custom_data = filter_by_dict(TagDict, tag) -def from_spec( - tag: Union[spec.tag.Tag, spec.tag.TagDict, AnyDict], -) -> Union[Tag, AnyDict]: - if isinstance(tag, spec.tag.Tag): - return Tag.from_spec(tag) + if custom_data: + return tag - return dict(tag) + return cls( + name=tag_data.get("name"), + description=tag_data.get("description"), + externalDocs=tag_data.get("external_docs"), + ) diff --git a/faststream/specification/asyncapi/v2_6_0/schema/utils.py b/faststream/specification/asyncapi/v2_6_0/schema/utils.py index d145abe37d..6d492ffeb5 100644 --- a/faststream/specification/asyncapi/v2_6_0/schema/utils.py +++ b/faststream/specification/asyncapi/v2_6_0/schema/utils.py @@ -6,7 +6,6 @@ class Reference(BaseModel): Attributes: ref : the reference string - """ ref: str = Field(..., alias="$ref") diff --git a/faststream/specification/asyncapi/v3_0_0/facade.py b/faststream/specification/asyncapi/v3_0_0/facade.py index 26e668bd8f..4ce47b6f90 100644 --- a/faststream/specification/asyncapi/v3_0_0/facade.py +++ b/faststream/specification/asyncapi/v3_0_0/facade.py @@ -1,18 +1,24 @@ -from collections.abc import Sequence +from collections.abc import Iterable from typing import TYPE_CHECKING, Any, Optional, Union from faststream.specification.base.specification import Specification from .generate import get_app_schema -from .schema import Schema +from .schema import ApplicationSchema if TYPE_CHECKING: from faststream._internal.basic_types import AnyDict, AnyHttpUrl from faststream._internal.broker.broker import BrokerUsecase - from faststream.specification.schema.contact import Contact, ContactDict - from faststream.specification.schema.docs import ExternalDocs, ExternalDocsDict - from faststream.specification.schema.license import License, LicenseDict - from faststream.specification.schema.tag import Tag, TagDict + from faststream.specification.schema.extra import ( + Contact, + ContactDict, + ExternalDocs, + ExternalDocsDict, + License, + LicenseDict, + Tag, + TagDict, + ) class AsyncAPI3(Specification): @@ -28,7 +34,7 @@ def __init__( contact: Optional[Union["Contact", "ContactDict", "AnyDict"]] = None, license: Optional[Union["License", "LicenseDict", "AnyDict"]] = None, identifier: Optional[str] = None, - tags: Optional[Sequence[Union["Tag", "TagDict", "AnyDict"]]] = None, + tags: Iterable[Union["Tag", "TagDict", "AnyDict"]] = (), external_docs: Optional[ Union["ExternalDocs", "ExternalDocsDict", "AnyDict"] ] = None, @@ -55,7 +61,7 @@ def to_yaml(self) -> str: return self.schema.to_yaml() @property - def schema(self) -> Schema: # type: ignore[override] + def schema(self) -> ApplicationSchema: # type: ignore[override] return get_app_schema( self.broker, title=self.title, diff --git a/faststream/specification/asyncapi/v3_0_0/generate.py b/faststream/specification/asyncapi/v3_0_0/generate.py index b14dca2772..1efc8c4fdc 100644 --- a/faststream/specification/asyncapi/v3_0_0/generate.py +++ b/faststream/specification/asyncapi/v3_0_0/generate.py @@ -5,40 +5,34 @@ from faststream._internal._compat import DEF_KEY from faststream._internal.basic_types import AnyDict, AnyHttpUrl from faststream._internal.constants import ContentTypes -from faststream.specification.asyncapi.utils import clear_key -from faststream.specification.asyncapi.v2_6_0.generate import move_pydantic_refs -from faststream.specification.asyncapi.v2_6_0.schema import ( - Reference, - Tag, - contact_from_spec, - docs_from_spec, - license_from_spec, - tag_from_spec, -) -from faststream.specification.asyncapi.v2_6_0.schema.message import Message +from faststream.specification.asyncapi.utils import clear_key, move_pydantic_refs from faststream.specification.asyncapi.v3_0_0.schema import ( + ApplicationInfo, + ApplicationSchema, Channel, Components, - Info, + Contact, + ExternalDocs, + License, + Message, Operation, - Schema, + Reference, Server, - channel_from_spec, - operation_from_spec, -) -from faststream.specification.asyncapi.v3_0_0.schema.operations import ( - Action, + Tag, ) if TYPE_CHECKING: from faststream._internal.broker.broker import BrokerUsecase from faststream._internal.types import ConnectionType, MsgType - from faststream.specification.schema.contact import Contact, ContactDict - from faststream.specification.schema.docs import ExternalDocs, ExternalDocsDict - from faststream.specification.schema.license import License, LicenseDict - from faststream.specification.schema.tag import ( - Tag as SpecsTag, - TagDict as SpecsTagDict, + from faststream.specification.schema.extra import ( + Contact as SpecContact, + ContactDict, + ExternalDocs as SpecDocs, + ExternalDocsDict, + License as SpecLicense, + LicenseDict, + Tag as SpecTag, + TagDict, ) @@ -50,18 +44,17 @@ def get_app_schema( schema_version: str, description: str, terms_of_service: Optional["AnyHttpUrl"], - contact: Optional[Union["Contact", "ContactDict", "AnyDict"]], - license: Optional[Union["License", "LicenseDict", "AnyDict"]], + contact: Optional[Union["SpecContact", "ContactDict", "AnyDict"]], + license: Optional[Union["SpecLicense", "LicenseDict", "AnyDict"]], identifier: Optional[str], - tags: Optional[Sequence[Union["SpecsTag", "SpecsTagDict", "AnyDict"]]], - external_docs: Optional[Union["ExternalDocs", "ExternalDocsDict", "AnyDict"]], -) -> Schema: + tags: Optional[Sequence[Union["SpecTag", "TagDict", "AnyDict"]]], + external_docs: Optional[Union["SpecDocs", "ExternalDocsDict", "AnyDict"]], +) -> ApplicationSchema: """Get the application schema.""" broker._setup() servers = get_broker_server(broker) - channels = get_broker_channels(broker) - operations = get_broker_operations(broker) + channels, operations = get_broker_channels(broker) messages: dict[str, Message] = {} payloads: dict[str, AnyDict] = {} @@ -86,16 +79,16 @@ def get_app_schema( channel.messages = msgs - return Schema( - info=Info( + return ApplicationSchema( + info=ApplicationInfo( title=title, version=app_version, description=description, termsOfService=terms_of_service, - contact=contact_from_spec(contact) if contact else None, - license=license_from_spec(license) if license else None, - tags=[tag_from_spec(tag) for tag in tags] if tags else None, - externalDocs=docs_from_spec(external_docs) if external_docs else None, + contact=Contact.from_spec(contact), + license=License.from_spec(license), + tags=[Tag.from_spec(tag) for tag in tags] or None, + externalDocs=ExternalDocs.from_spec(external_docs), ), asyncapi=schema_version, defaultContentType=ContentTypes.JSON.value, @@ -121,7 +114,7 @@ def get_broker_server( tags: Optional[list[Union[Tag, AnyDict]]] = None if broker.tags: - tags = [tag_from_spec(tag) for tag in broker.tags] + tags = [Tag.from_spec(tag) for tag in broker.tags] broker_meta: AnyDict = { "protocol": broker.protocol, @@ -152,77 +145,52 @@ def get_broker_server( return servers -def get_broker_operations( - broker: "BrokerUsecase[MsgType, ConnectionType]", -) -> dict[str, Operation]: - """Get the broker operations for an application.""" - operations = {} - - for h in broker._subscribers: - for channel, specs_channel in h.schema().items(): - channel_name = clear_key(channel) - - if specs_channel.subscribe is not None: - operations[f"{channel_name}Subscribe"] = operation_from_spec( - specs_channel.subscribe, - Action.RECEIVE, - channel_name, - ) - - for p in broker._publishers: - for channel, specs_channel in p.schema().items(): - channel_name = clear_key(channel) - - if specs_channel.publish is not None: - operations[f"{channel_name}"] = operation_from_spec( - specs_channel.publish, - Action.SEND, - channel_name, - ) - - return operations - - def get_broker_channels( broker: "BrokerUsecase[MsgType, ConnectionType]", -) -> dict[str, Channel]: +) -> tuple[dict[str, Channel], dict[str, Operation]]: """Get the broker channels for an application.""" channels = {} + operations = {} for sub in broker._subscribers: - channels_schema_v3_0 = {} - for channel_name, specs_channel in sub.schema().items(): - if specs_channel.subscribe: - message = specs_channel.subscribe.message - assert message.title - - *left, right = message.title.split(":") - message.title = ":".join(left) + f":Subscribe{right}" - - # TODO: why we are format just a key? - channels_schema_v3_0[clear_key(channel_name)] = channel_from_spec( - specs_channel, - message, - channel_name, - "SubscribeMessage", - ) - - channels.update(channels_schema_v3_0) + for key, channel in sub.schema().items(): + channel_obj = Channel.from_sub(key, channel) + + channel_key = clear_key(key) + # TODO: add duplication key warning + channels[channel_key] = channel_obj + + operations[f"{channel_key}Subscribe"] = Operation.from_sub( + messages=[ + Reference(**{ + "$ref": f"#/channels/{channel_key}/messages/{msg_name}" + }) + for msg_name in channel_obj.messages + ], + channel=Reference(**{"$ref": f"#/channels/{channel_key}"}), + operation=channel.operation, + ) for pub in broker._publishers: - channels_schema_v3_0 = {} - for channel_name, specs_channel in pub.schema().items(): - if specs_channel.publish: - channels_schema_v3_0[clear_key(channel_name)] = channel_from_spec( - specs_channel, - specs_channel.publish.message, - channel_name, - "Message", - ) - - channels.update(channels_schema_v3_0) - - return channels + for key, channel in pub.schema().items(): + channel_obj = Channel.from_pub(key, channel) + + channel_key = clear_key(key) + # TODO: add duplication key warning + channels[channel_key] = channel_obj + + operations[channel_key] = Operation.from_pub( + messages=[ + Reference(**{ + "$ref": f"#/channels/{channel_key}/messages/{msg_name}" + }) + for msg_name in channel_obj.messages + ], + channel=Reference(**{"$ref": f"#/channels/{channel_key}"}), + operation=channel.operation, + ) + + return channels, operations def _resolve_msg_payloads( diff --git a/faststream/specification/asyncapi/v3_0_0/schema/__init__.py b/faststream/specification/asyncapi/v3_0_0/schema/__init__.py index ef2cfe21b7..e0cbcbd7b2 100644 --- a/faststream/specification/asyncapi/v3_0_0/schema/__init__.py +++ b/faststream/specification/asyncapi/v3_0_0/schema/__init__.py @@ -1,23 +1,31 @@ -from .channels import ( - Channel, - from_spec as channel_from_spec, -) +from .channels import Channel from .components import Components -from .info import Info -from .operations import ( - Operation, - from_spec as operation_from_spec, -) -from .schema import Schema -from .servers import Server +from .contact import Contact +from .docs import ExternalDocs +from .info import ApplicationInfo +from .license import License +from .message import CorrelationId, Message +from .operations import Operation +from .schema import ApplicationSchema +from .servers import Server, ServerVariable +from .tag import Tag +from .utils import Parameter, Reference __all__ = ( + "ApplicationInfo", + "ApplicationSchema", + "Channel", "Channel", "Components", - "Info", + "Contact", + "CorrelationId", + "ExternalDocs", + "License", + "Message", "Operation", - "Schema", + "Parameter", + "Reference", "Server", - "channel_from_spec", - "operation_from_spec", + "ServerVariable", + "Tag", ) diff --git a/faststream/specification/asyncapi/v3_0_0/schema/bindings/__init__.py b/faststream/specification/asyncapi/v3_0_0/schema/bindings/__init__.py index d477f704cd..c304608c5b 100644 --- a/faststream/specification/asyncapi/v3_0_0/schema/bindings/__init__.py +++ b/faststream/specification/asyncapi/v3_0_0/schema/bindings/__init__.py @@ -1,11 +1,9 @@ from .main import ( + ChannelBinding, OperationBinding, - channel_binding_from_spec, - operation_binding_from_spec, ) __all__ = ( + "ChannelBinding", "OperationBinding", - "channel_binding_from_spec", - "operation_binding_from_spec", ) diff --git a/faststream/specification/asyncapi/v3_0_0/schema/bindings/amqp/__init__.py b/faststream/specification/asyncapi/v3_0_0/schema/bindings/amqp/__init__.py index 96c7406698..8555fd981a 100644 --- a/faststream/specification/asyncapi/v3_0_0/schema/bindings/amqp/__init__.py +++ b/faststream/specification/asyncapi/v3_0_0/schema/bindings/amqp/__init__.py @@ -1,11 +1,7 @@ -from .channel import from_spec as channel_binding_from_spec -from .operation import ( - OperationBinding, - from_spec as operation_binding_from_spec, -) +from .channel import ChannelBinding +from .operation import OperationBinding __all__ = ( + "ChannelBinding", "OperationBinding", - "channel_binding_from_spec", - "operation_binding_from_spec", ) diff --git a/faststream/specification/asyncapi/v3_0_0/schema/bindings/amqp/channel.py b/faststream/specification/asyncapi/v3_0_0/schema/bindings/amqp/channel.py index a31498ee5f..ecab8e4a17 100644 --- a/faststream/specification/asyncapi/v3_0_0/schema/bindings/amqp/channel.py +++ b/faststream/specification/asyncapi/v3_0_0/schema/bindings/amqp/channel.py @@ -1,21 +1,7 @@ -from faststream.specification import schema as spec -from faststream.specification.asyncapi.v2_6_0.schema.bindings.amqp import ChannelBinding -from faststream.specification.asyncapi.v2_6_0.schema.bindings.amqp.channel import ( - Exchange, - Queue, +from faststream.specification.asyncapi.v2_6_0.schema.bindings.amqp import ( + ChannelBinding as V2Binding, ) -def from_spec(binding: spec.bindings.amqp.ChannelBinding) -> ChannelBinding: - return ChannelBinding( - **{ - "is": binding.is_, - "bindingVersion": "0.3.0", - "queue": Queue.from_spec(binding.queue) - if binding.queue is not None - else None, - "exchange": Exchange.from_spec(binding.exchange) - if binding.exchange is not None - else None, - }, - ) +class ChannelBinding(V2Binding): + bindingVersion: str = "0.3.0" diff --git a/faststream/specification/asyncapi/v3_0_0/schema/bindings/amqp/operation.py b/faststream/specification/asyncapi/v3_0_0/schema/bindings/amqp/operation.py index 1357dd325f..77ba8356a0 100644 --- a/faststream/specification/asyncapi/v3_0_0/schema/bindings/amqp/operation.py +++ b/faststream/specification/asyncapi/v3_0_0/schema/bindings/amqp/operation.py @@ -5,41 +5,46 @@ from typing import Optional -from pydantic import BaseModel, PositiveInt from typing_extensions import Self -from faststream.specification import schema as spec +from faststream.specification.asyncapi.v2_6_0.schema.bindings.amqp import ( + OperationBinding as V2Binding, +) +from faststream.specification.schema.bindings import amqp -class OperationBinding(BaseModel): - """A class to represent an operation binding. - - Attributes: - cc : optional string representing the cc - ack : boolean indicating if the operation is acknowledged - replyTo : optional dictionary representing the replyTo - bindingVersion : string representing the binding version - """ - +class OperationBinding(V2Binding): cc: Optional[list[str]] = None - ack: bool = True - replyTo: Optional[str] = None - deliveryMode: Optional[int] = None - mandatory: Optional[bool] = None - priority: Optional[PositiveInt] = None bindingVersion: str = "0.3.0" @classmethod - def from_spec(cls, binding: spec.bindings.amqp.OperationBinding) -> Self: + def from_sub(cls, binding: Optional[amqp.OperationBinding]) -> Optional[Self]: + if not binding: + return None + return cls( - cc=[binding.cc] if binding.cc is not None else None, + cc=[binding.routing_key] + if (binding.routing_key and binding.exchange.is_respect_routing_key) + else None, ack=binding.ack, - replyTo=binding.replyTo, - deliveryMode=binding.deliveryMode, + replyTo=binding.reply_to, + deliveryMode=None if binding.persist is None else int(binding.persist) + 1, mandatory=binding.mandatory, priority=binding.priority, ) + @classmethod + def from_pub(cls, binding: Optional[amqp.OperationBinding]) -> Optional[Self]: + if not binding: + return None -def from_spec(binding: spec.bindings.amqp.OperationBinding) -> OperationBinding: - return OperationBinding.from_spec(binding) + return cls( + cc=None + if (not binding.routing_key or not binding.exchange.is_respect_routing_key) + else [binding.routing_key], + ack=binding.ack, + replyTo=binding.reply_to, + deliveryMode=None if binding.persist is None else int(binding.persist) + 1, + mandatory=binding.mandatory, + priority=binding.priority, + ) diff --git a/faststream/specification/asyncapi/v3_0_0/schema/bindings/kafka.py b/faststream/specification/asyncapi/v3_0_0/schema/bindings/kafka.py new file mode 100644 index 0000000000..5605abeefa --- /dev/null +++ b/faststream/specification/asyncapi/v3_0_0/schema/bindings/kafka.py @@ -0,0 +1,9 @@ +from faststream.specification.asyncapi.v2_6_0.schema.bindings.kafka import ( + ChannelBinding, + OperationBinding, +) + +__all__ = ( + "ChannelBinding", + "OperationBinding", +) diff --git a/faststream/specification/asyncapi/v3_0_0/schema/bindings/main/__init__.py b/faststream/specification/asyncapi/v3_0_0/schema/bindings/main/__init__.py index 96c7406698..8555fd981a 100644 --- a/faststream/specification/asyncapi/v3_0_0/schema/bindings/main/__init__.py +++ b/faststream/specification/asyncapi/v3_0_0/schema/bindings/main/__init__.py @@ -1,11 +1,7 @@ -from .channel import from_spec as channel_binding_from_spec -from .operation import ( - OperationBinding, - from_spec as operation_binding_from_spec, -) +from .channel import ChannelBinding +from .operation import OperationBinding __all__ = ( + "ChannelBinding", "OperationBinding", - "channel_binding_from_spec", - "operation_binding_from_spec", ) diff --git a/faststream/specification/asyncapi/v3_0_0/schema/bindings/main/channel.py b/faststream/specification/asyncapi/v3_0_0/schema/bindings/main/channel.py index 41aef76aaa..c7552a11d1 100644 --- a/faststream/specification/asyncapi/v3_0_0/schema/bindings/main/channel.py +++ b/faststream/specification/asyncapi/v3_0_0/schema/bindings/main/channel.py @@ -1,17 +1,100 @@ -from faststream.specification import schema as spec -from faststream.specification.asyncapi.v2_6_0.schema.bindings import ChannelBinding -from faststream.specification.asyncapi.v2_6_0.schema.bindings.main import ( - channel_binding_from_spec, -) -from faststream.specification.asyncapi.v3_0_0.schema.bindings.amqp import ( - channel_binding_from_spec as amqp_channel_binding_from_spec, +from typing import Optional + +from pydantic import BaseModel +from typing_extensions import Self + +from faststream._internal._compat import PYDANTIC_V2 +from faststream.specification.asyncapi.v3_0_0.schema.bindings import ( + amqp as amqp_bindings, + kafka as kafka_bindings, + nats as nats_bindings, + redis as redis_bindings, + sqs as sqs_bindings, ) +from faststream.specification.schema.bindings import ChannelBinding as SpecBinding + + +class ChannelBinding(BaseModel): + """A class to represent channel bindings. + + Attributes: + amqp : AMQP channel binding (optional) + kafka : Kafka channel binding (optional) + sqs : SQS channel binding (optional) + nats : NATS channel binding (optional) + redis : Redis channel binding (optional) + """ + + amqp: Optional[amqp_bindings.ChannelBinding] = None + kafka: Optional[kafka_bindings.ChannelBinding] = None + sqs: Optional[sqs_bindings.ChannelBinding] = None + nats: Optional[nats_bindings.ChannelBinding] = None + redis: Optional[redis_bindings.ChannelBinding] = None + + if PYDANTIC_V2: + model_config = {"extra": "allow"} + + else: + + class Config: + extra = "allow" + + @classmethod + def from_sub(cls, binding: Optional[SpecBinding]) -> Optional[Self]: + if binding is None: + return None + + if binding.amqp and ( + amqp := amqp_bindings.ChannelBinding.from_sub(binding.amqp) + ): + return cls(amqp=amqp) + + if binding.kafka and ( + kafka := kafka_bindings.ChannelBinding.from_sub(binding.kafka) + ): + return cls(kafka=kafka) + + if binding.nats and ( + nats := nats_bindings.ChannelBinding.from_sub(binding.nats) + ): + return cls(nats=nats) + + if binding.redis and ( + redis := redis_bindings.ChannelBinding.from_sub(binding.redis) + ): + return cls(redis=redis) + + if binding.sqs and (sqs := sqs_bindings.ChannelBinding.from_sub(binding.sqs)): + return cls(sqs=sqs) + + return None + + @classmethod + def from_pub(cls, binding: Optional[SpecBinding]) -> Optional[Self]: + if binding is None: + return None + + if binding.amqp and ( + amqp := amqp_bindings.ChannelBinding.from_pub(binding.amqp) + ): + return cls(amqp=amqp) + + if binding.kafka and ( + kafka := kafka_bindings.ChannelBinding.from_pub(binding.kafka) + ): + return cls(kafka=kafka) + if binding.nats and ( + nats := nats_bindings.ChannelBinding.from_pub(binding.nats) + ): + return cls(nats=nats) -def from_spec(binding: spec.bindings.ChannelBinding) -> ChannelBinding: - channel_binding = channel_binding_from_spec(binding) + if binding.redis and ( + redis := redis_bindings.ChannelBinding.from_pub(binding.redis) + ): + return cls(redis=redis) - if binding.amqp: - channel_binding.amqp = amqp_channel_binding_from_spec(binding.amqp) + if binding.sqs and (sqs := sqs_bindings.ChannelBinding.from_pub(binding.sqs)): + return cls(sqs=sqs) - return channel_binding + return None diff --git a/faststream/specification/asyncapi/v3_0_0/schema/bindings/main/operation.py b/faststream/specification/asyncapi/v3_0_0/schema/bindings/main/operation.py index 6d0a70069e..fc37c3dc75 100644 --- a/faststream/specification/asyncapi/v3_0_0/schema/bindings/main/operation.py +++ b/faststream/specification/asyncapi/v3_0_0/schema/bindings/main/operation.py @@ -4,16 +4,14 @@ from typing_extensions import Self from faststream._internal._compat import PYDANTIC_V2 -from faststream.specification import schema as spec -from faststream.specification.asyncapi.v2_6_0.schema.bindings import ( +from faststream.specification.asyncapi.v3_0_0.schema.bindings import ( + amqp as amqp_bindings, kafka as kafka_bindings, nats as nats_bindings, redis as redis_bindings, sqs as sqs_bindings, ) -from faststream.specification.asyncapi.v3_0_0.schema.bindings import ( - amqp as amqp_bindings, -) +from faststream.specification.schema.bindings import OperationBinding as SpecBinding class OperationBinding(BaseModel): @@ -25,7 +23,6 @@ class OperationBinding(BaseModel): sqs : SQS operation binding (optional) nats : NATS operation binding (optional) redis : Redis operation binding (optional) - """ amqp: Optional[amqp_bindings.OperationBinding] = None @@ -43,25 +40,61 @@ class Config: extra = "allow" @classmethod - def from_spec(cls, binding: spec.bindings.OperationBinding) -> Self: - return cls( - amqp=amqp_bindings.operation_binding_from_spec(binding.amqp) - if binding.amqp is not None - else None, - kafka=kafka_bindings.operation_binding_from_spec(binding.kafka) - if binding.kafka is not None - else None, - sqs=sqs_bindings.operation_binding_from_spec(binding.sqs) - if binding.sqs is not None - else None, - nats=nats_bindings.operation_binding_from_spec(binding.nats) - if binding.nats is not None - else None, - redis=redis_bindings.operation_binding_from_spec(binding.redis) - if binding.redis is not None - else None, - ) - - -def from_spec(binding: spec.bindings.OperationBinding) -> OperationBinding: - return OperationBinding.from_spec(binding) + def from_sub(cls, binding: Optional[SpecBinding]) -> Optional[Self]: + if binding is None: + return None + + if binding.amqp and ( + amqp := amqp_bindings.OperationBinding.from_sub(binding.amqp) + ): + return cls(amqp=amqp) + + if binding.kafka and ( + kafka := kafka_bindings.OperationBinding.from_sub(binding.kafka) + ): + return cls(kafka=kafka) + + if binding.nats and ( + nats := nats_bindings.OperationBinding.from_sub(binding.nats) + ): + return cls(nats=nats) + + if binding.redis and ( + redis := redis_bindings.OperationBinding.from_sub(binding.redis) + ): + return cls(redis=redis) + + if binding.sqs and (sqs := sqs_bindings.OperationBinding.from_sub(binding.sqs)): + return cls(sqs=sqs) + + return None + + @classmethod + def from_pub(cls, binding: Optional[SpecBinding]) -> Optional[Self]: + if binding is None: + return None + + if binding.amqp and ( + amqp := amqp_bindings.OperationBinding.from_pub(binding.amqp) + ): + return cls(amqp=amqp) + + if binding.kafka and ( + kafka := kafka_bindings.OperationBinding.from_pub(binding.kafka) + ): + return cls(kafka=kafka) + + if binding.nats and ( + nats := nats_bindings.OperationBinding.from_pub(binding.nats) + ): + return cls(nats=nats) + + if binding.redis and ( + redis := redis_bindings.OperationBinding.from_pub(binding.redis) + ): + return cls(redis=redis) + + if binding.sqs and (sqs := sqs_bindings.OperationBinding.from_pub(binding.sqs)): + return cls(sqs=sqs) + + return None diff --git a/faststream/specification/asyncapi/v3_0_0/schema/bindings/nats.py b/faststream/specification/asyncapi/v3_0_0/schema/bindings/nats.py new file mode 100644 index 0000000000..21d5c46926 --- /dev/null +++ b/faststream/specification/asyncapi/v3_0_0/schema/bindings/nats.py @@ -0,0 +1,9 @@ +from faststream.specification.asyncapi.v2_6_0.schema.bindings.nats import ( + ChannelBinding, + OperationBinding, +) + +__all__ = ( + "ChannelBinding", + "OperationBinding", +) diff --git a/faststream/specification/asyncapi/v3_0_0/schema/bindings/redis.py b/faststream/specification/asyncapi/v3_0_0/schema/bindings/redis.py new file mode 100644 index 0000000000..26d44644f7 --- /dev/null +++ b/faststream/specification/asyncapi/v3_0_0/schema/bindings/redis.py @@ -0,0 +1,9 @@ +from faststream.specification.asyncapi.v2_6_0.schema.bindings.redis import ( + ChannelBinding, + OperationBinding, +) + +__all__ = ( + "ChannelBinding", + "OperationBinding", +) diff --git a/faststream/specification/asyncapi/v3_0_0/schema/bindings/sqs.py b/faststream/specification/asyncapi/v3_0_0/schema/bindings/sqs.py new file mode 100644 index 0000000000..e437a1cc58 --- /dev/null +++ b/faststream/specification/asyncapi/v3_0_0/schema/bindings/sqs.py @@ -0,0 +1,9 @@ +from faststream.specification.asyncapi.v2_6_0.schema.bindings.sqs import ( + ChannelBinding, + OperationBinding, +) + +__all__ = ( + "ChannelBinding", + "OperationBinding", +) diff --git a/faststream/specification/asyncapi/v3_0_0/schema/channels.py b/faststream/specification/asyncapi/v3_0_0/schema/channels.py index 3a5ccd40bb..c0a2dbe553 100644 --- a/faststream/specification/asyncapi/v3_0_0/schema/channels.py +++ b/faststream/specification/asyncapi/v3_0_0/schema/channels.py @@ -4,16 +4,11 @@ from typing_extensions import Self from faststream._internal._compat import PYDANTIC_V2 -from faststream.specification import schema as spec -from faststream.specification.asyncapi.v2_6_0.schema.bindings import ChannelBinding -from faststream.specification.asyncapi.v2_6_0.schema.message import ( - Message, - from_spec as message_from_spec, -) -from faststream.specification.asyncapi.v2_6_0.schema.utils import Reference -from faststream.specification.asyncapi.v3_0_0.schema.bindings.main import ( - channel_binding_from_spec, -) +from faststream.specification.asyncapi.v3_0_0.schema.bindings import ChannelBinding +from faststream.specification.asyncapi.v3_0_0.schema.message import Message +from faststream.specification.schema import PublisherSpec, SubscriberSpec + +from .utils import Reference class Channel(BaseModel): @@ -29,7 +24,6 @@ class Channel(BaseModel): Configurations: model_config : configuration for the model (only applicable for Pydantic version 2) Config : configuration for the class (only applicable for Pydantic version 1) - """ address: str @@ -50,30 +44,31 @@ class Config: extra = "allow" @classmethod - def from_spec( - cls, - channel: spec.channel.Channel, - message: spec.message.Message, - channel_name: str, - message_name: str, - ) -> Self: + def from_sub(cls, address: str, subscriber: SubscriberSpec) -> Self: + message = subscriber.operation.message + assert message.title + + *left, right = message.title.split(":") + message.title = ":".join((*left, f"Subscribe{right}")) + return cls( - address=channel_name, + description=subscriber.description, + address=address, messages={ - message_name: message_from_spec(message), + "SubscribeMessage": Message.from_spec(message), }, - description=channel.description, - servers=channel.servers, - bindings=channel_binding_from_spec(channel.bindings) - if channel.bindings - else None, + bindings=ChannelBinding.from_sub(subscriber.bindings), + servers=None, ) - -def from_spec( - channel: spec.channel.Channel, - message: spec.message.Message, - channel_name: str, - message_name: str, -) -> Channel: - return Channel.from_spec(channel, message, channel_name, message_name) + @classmethod + def from_pub(cls, address: str, publisher: PublisherSpec) -> Self: + return cls( + description=publisher.description, + address=address, + messages={ + "Message": Message.from_spec(publisher.operation.message), + }, + bindings=ChannelBinding.from_pub(publisher.bindings), + servers=None, + ) diff --git a/faststream/specification/asyncapi/v3_0_0/schema/contact.py b/faststream/specification/asyncapi/v3_0_0/schema/contact.py new file mode 100644 index 0000000000..c42e750b28 --- /dev/null +++ b/faststream/specification/asyncapi/v3_0_0/schema/contact.py @@ -0,0 +1,3 @@ +from faststream.specification.asyncapi.v2_6_0.schema import Contact + +__all__ = ("Contact",) diff --git a/faststream/specification/asyncapi/v3_0_0/schema/docs.py b/faststream/specification/asyncapi/v3_0_0/schema/docs.py new file mode 100644 index 0000000000..0a71688697 --- /dev/null +++ b/faststream/specification/asyncapi/v3_0_0/schema/docs.py @@ -0,0 +1,3 @@ +from faststream.specification.asyncapi.v2_6_0.schema import ExternalDocs + +__all__ = ("ExternalDocs",) diff --git a/faststream/specification/asyncapi/v3_0_0/schema/info.py b/faststream/specification/asyncapi/v3_0_0/schema/info.py index 6d15c9e4dc..c9303e690c 100644 --- a/faststream/specification/asyncapi/v3_0_0/schema/info.py +++ b/faststream/specification/asyncapi/v3_0_0/schema/info.py @@ -14,16 +14,16 @@ License, Tag, ) -from faststream.specification.base.info import BaseInfo +from faststream.specification.base.info import BaseApplicationInfo -class Info(BaseInfo): - """A class to represent information. +class ApplicationInfo(BaseApplicationInfo): + """A class to represent application information. Attributes: - termsOfService : terms of service for the information (default: None) - contact : contact information for the information (default: None) - license : license information for the information (default: None) + termsOfService : terms of service for the information + contact : contact information for the information + license : license information for the information tags : optional list of tags externalDocs : optional external documentation """ diff --git a/faststream/specification/asyncapi/v3_0_0/schema/license.py b/faststream/specification/asyncapi/v3_0_0/schema/license.py new file mode 100644 index 0000000000..44ee4b2813 --- /dev/null +++ b/faststream/specification/asyncapi/v3_0_0/schema/license.py @@ -0,0 +1,3 @@ +from faststream.specification.asyncapi.v2_6_0.schema import License + +__all__ = ("License",) diff --git a/faststream/specification/asyncapi/v3_0_0/schema/message.py b/faststream/specification/asyncapi/v3_0_0/schema/message.py new file mode 100644 index 0000000000..fa665082e9 --- /dev/null +++ b/faststream/specification/asyncapi/v3_0_0/schema/message.py @@ -0,0 +1,6 @@ +from faststream.specification.asyncapi.v2_6_0.schema.message import ( + CorrelationId, + Message, +) + +__all__ = ("CorrelationId", "Message") diff --git a/faststream/specification/asyncapi/v3_0_0/schema/operations.py b/faststream/specification/asyncapi/v3_0_0/schema/operations.py index ffc647674a..8afff3c5c6 100644 --- a/faststream/specification/asyncapi/v3_0_0/schema/operations.py +++ b/faststream/specification/asyncapi/v3_0_0/schema/operations.py @@ -1,21 +1,17 @@ from enum import Enum from typing import Optional, Union -from pydantic import BaseModel +from pydantic import BaseModel, Field from typing_extensions import Self from faststream._internal._compat import PYDANTIC_V2 from faststream._internal.basic_types import AnyDict -from faststream.specification import schema as spec -from faststream.specification.asyncapi.v2_6_0.schema.tag import Tag -from faststream.specification.asyncapi.v2_6_0.schema.utils import ( - Reference, -) -from faststream.specification.asyncapi.v3_0_0.schema.bindings import OperationBinding -from faststream.specification.asyncapi.v3_0_0.schema.bindings.main import ( - operation_binding_from_spec, -) -from faststream.specification.asyncapi.v3_0_0.schema.channels import Channel +from faststream.specification.schema.operation import Operation as OperationSpec + +from .bindings import OperationBinding +from .channels import Channel +from .tag import Tag +from .utils import Reference class Action(str, Enum): @@ -27,24 +23,24 @@ class Operation(BaseModel): """A class to represent an operation. Attributes: - operationId : ID of the operation + operation_id : ID of the operation summary : summary of the operation description : description of the operation bindings : bindings of the operation message : message of the operation security : security details of the operation tags : tags associated with the operation - """ action: Action + channel: Union[Channel, Reference] + summary: Optional[str] = None description: Optional[str] = None bindings: Optional[OperationBinding] = None - messages: list[Reference] - channel: Union[Channel, Reference] + messages: list[Reference] = Field(default_factory=list) security: Optional[dict[str, list[str]]] = None @@ -62,38 +58,37 @@ class Config: extra = "allow" @classmethod - def from_spec( + def from_sub( cls, - operation: spec.operation.Operation, - action: Action, - channel_name: str, + messages: list[Reference], + channel: Reference, + operation: OperationSpec, ) -> Self: return cls( - action=action, - summary=operation.summary, - description=operation.description, - bindings=operation_binding_from_spec(operation.bindings) - if operation.bindings - else None, - messages=[ - Reference( - **{ - "$ref": f"#/channels/{channel_name}/messages/SubscribeMessage" - if action is Action.RECEIVE - else f"#/channels/{channel_name}/messages/Message", - }, - ), - ], - channel=Reference( - **{"$ref": f"#/channels/{channel_name}"}, - ), - security=operation.security, + action=Action.RECEIVE, + messages=messages, + channel=channel, + bindings=OperationBinding.from_sub(operation.bindings), + summary=None, + description=None, + security=None, + tags=None, ) - -def from_spec( - operation: spec.operation.Operation, - action: Action, - channel_name: str, -) -> Operation: - return Operation.from_spec(operation, action, channel_name) + @classmethod + def from_pub( + cls, + messages: list[Reference], + channel: Reference, + operation: OperationSpec, + ) -> Self: + return cls( + action=Action.SEND, + messages=messages, + channel=channel, + bindings=OperationBinding.from_pub(operation.bindings), + summary=None, + description=None, + security=None, + tags=None, + ) diff --git a/faststream/specification/asyncapi/v3_0_0/schema/schema.py b/faststream/specification/asyncapi/v3_0_0/schema/schema.py index ad60b8bfae..dc894ecb4e 100644 --- a/faststream/specification/asyncapi/v3_0_0/schema/schema.py +++ b/faststream/specification/asyncapi/v3_0_0/schema/schema.py @@ -1,15 +1,17 @@ from typing import Literal, Optional, Union +from pydantic import Field + from faststream.specification.asyncapi.v3_0_0.schema.channels import Channel from faststream.specification.asyncapi.v3_0_0.schema.components import Components -from faststream.specification.asyncapi.v3_0_0.schema.info import Info +from faststream.specification.asyncapi.v3_0_0.schema.info import ApplicationInfo from faststream.specification.asyncapi.v3_0_0.schema.operations import Operation from faststream.specification.asyncapi.v3_0_0.schema.servers import Server -from faststream.specification.base.schema import BaseSchema +from faststream.specification.base.schema import BaseApplicationSchema -class Schema(BaseSchema): - """A class to represent a schema. +class ApplicationSchema(BaseApplicationSchema): + """A class to represent an application schema. Attributes: asyncapi : version of the async API @@ -21,12 +23,12 @@ class Schema(BaseSchema): components : optional components of the schema """ - info: Info + info: ApplicationInfo asyncapi: Union[Literal["3.0.0"], str] = "3.0.0" id: Optional[str] = None defaultContentType: Optional[str] = None servers: Optional[dict[str, Server]] = None - channels: dict[str, Channel] - operations: dict[str, Operation] + channels: dict[str, Channel] = Field(default_factory=dict) + operations: dict[str, Operation] = Field(default_factory=dict) components: Optional[Components] = None diff --git a/faststream/specification/asyncapi/v3_0_0/schema/tag.py b/faststream/specification/asyncapi/v3_0_0/schema/tag.py new file mode 100644 index 0000000000..e16c4f61cd --- /dev/null +++ b/faststream/specification/asyncapi/v3_0_0/schema/tag.py @@ -0,0 +1,3 @@ +from faststream.specification.asyncapi.v2_6_0.schema import Tag + +__all__ = ("Tag",) diff --git a/faststream/specification/asyncapi/v3_0_0/schema/utils.py b/faststream/specification/asyncapi/v3_0_0/schema/utils.py new file mode 100644 index 0000000000..c53f3ce1a0 --- /dev/null +++ b/faststream/specification/asyncapi/v3_0_0/schema/utils.py @@ -0,0 +1,6 @@ +from faststream.specification.asyncapi.v2_6_0.schema import Parameter, Reference + +__all__ = ( + "Parameter", + "Reference", +) diff --git a/faststream/specification/base/info.py b/faststream/specification/base/info.py index 79e5164de7..6e282dc19e 100644 --- a/faststream/specification/base/info.py +++ b/faststream/specification/base/info.py @@ -3,18 +3,18 @@ from faststream._internal._compat import PYDANTIC_V2 -class BaseInfo(BaseModel): +class BaseApplicationInfo(BaseModel): """A class to represent basic application information. Attributes: title : application title - version : application version (default: "1.0.0") - description : application description (default: "") + version : application version + description : application description """ title: str - version: str = "1.0.0" - description: str = "" + version: str + description: str if PYDANTIC_V2: model_config = {"extra": "allow"} diff --git a/faststream/specification/base/proto.py b/faststream/specification/base/proto.py deleted file mode 100644 index 42d118c46c..0000000000 --- a/faststream/specification/base/proto.py +++ /dev/null @@ -1,47 +0,0 @@ -from abc import abstractmethod -from typing import Any, Optional, Protocol - -from faststream.specification.schema.channel import Channel - - -class SpecificationEndpoint(Protocol): - """A class representing an asynchronous API operation.""" - - title_: Optional[str] - description_: Optional[str] - include_in_schema: bool - - @property - def name(self) -> str: - """Returns the name of the API operation.""" - return self.title_ or self.get_name() - - @abstractmethod - def get_name(self) -> str: - """Name property fallback.""" - raise NotImplementedError - - @property - def description(self) -> Optional[str]: - """Returns the description of the API operation.""" - return self.description_ or self.get_description() - - def get_description(self) -> Optional[str]: - """Description property fallback.""" - return None - - def schema(self) -> dict[str, Channel]: - """Returns the schema of the API operation as a dictionary of channel names and channel objects.""" - if self.include_in_schema: - return self.get_schema() - return {} - - @abstractmethod - def get_schema(self) -> dict[str, Channel]: - """Generate AsyncAPI schema.""" - raise NotImplementedError - - @abstractmethod - def get_payloads(self) -> Any: - """Generate AsyncAPI payloads.""" - raise NotImplementedError diff --git a/faststream/specification/base/schema.py b/faststream/specification/base/schema.py index 914b389bc2..828e1699b7 100644 --- a/faststream/specification/base/schema.py +++ b/faststream/specification/base/schema.py @@ -4,11 +4,11 @@ from faststream._internal._compat import model_to_json, model_to_jsonable -from .info import BaseInfo +from .info import BaseApplicationInfo -class BaseSchema(BaseModel): - """A class to represent a Pydantic-serializable schema. +class BaseApplicationSchema(BaseModel): + """A class to represent a Pydantic-serializable application schema. Attributes: info : information about the schema @@ -19,7 +19,7 @@ class BaseSchema(BaseModel): to_yaml() -> str: Convert the schema to a YAML string. """ - info: BaseInfo + info: BaseApplicationInfo def to_jsonable(self) -> Any: """Convert the schema to a JSON-serializable object.""" diff --git a/faststream/specification/base/specification.py b/faststream/specification/base/specification.py index d9dc0fcf14..0c3946e76f 100644 --- a/faststream/specification/base/specification.py +++ b/faststream/specification/base/specification.py @@ -1,11 +1,11 @@ from typing import Any, Protocol, runtime_checkable -from .schema import BaseSchema +from .schema import BaseApplicationSchema @runtime_checkable class Specification(Protocol): - schema: BaseSchema + schema: BaseApplicationSchema def to_json(self) -> str: return self.schema.to_json() diff --git a/faststream/specification/proto/__init__.py b/faststream/specification/proto/__init__.py new file mode 100644 index 0000000000..3189e7cc8f --- /dev/null +++ b/faststream/specification/proto/__init__.py @@ -0,0 +1,4 @@ +from .broker import ServerSpecification +from .endpoint import EndpointSpecification + +__all__ = ("EndpointSpecification", "ServerSpecification") diff --git a/faststream/specification/proto/broker.py b/faststream/specification/proto/broker.py new file mode 100644 index 0000000000..225393b24e --- /dev/null +++ b/faststream/specification/proto/broker.py @@ -0,0 +1,14 @@ +from collections.abc import Iterable +from typing import Optional, Protocol, Union + +from faststream.security import BaseSecurity +from faststream.specification.schema.extra import Tag, TagDict + + +class ServerSpecification(Protocol): + url: Union[str, list[str]] + protocol: Optional[str] + protocol_version: Optional[str] + description: Optional[str] + tags: Iterable[Union[Tag, TagDict]] + security: Optional[BaseSecurity] diff --git a/faststream/specification/proto/endpoint.py b/faststream/specification/proto/endpoint.py new file mode 100644 index 0000000000..380acb1071 --- /dev/null +++ b/faststream/specification/proto/endpoint.py @@ -0,0 +1,62 @@ +from abc import ABC, abstractmethod +from typing import Any, Generic, Optional, TypeVar + +T = TypeVar("T") + + +class EndpointSpecification(ABC, Generic[T]): + """A class representing an asynchronous API operation: Pub or Sub.""" + + title_: Optional[str] + description_: Optional[str] + include_in_schema: bool + + def __init__( + self, + *args: Any, + title_: Optional[str], + description_: Optional[str], + include_in_schema: bool, + **kwargs: Any, + ) -> None: + self.title_ = title_ + self.description_ = description_ + self.include_in_schema = include_in_schema + + # Call next base class parent init + super().__init__(*args, **kwargs) + + @property + def name(self) -> str: + """Returns the name of the API operation.""" + return self.title_ or self.get_default_name() + + @abstractmethod + def get_default_name(self) -> str: + """Name property fallback.""" + raise NotImplementedError + + @property + def description(self) -> Optional[str]: + """Returns the description of the API operation.""" + return self.description_ or self.get_default_description() + + def get_default_description(self) -> Optional[str]: + """Description property fallback.""" + return None + + def schema(self) -> dict[str, T]: + """Returns the schema of the API operation as a dictionary of channel names and channel objects.""" + if self.include_in_schema: + return self.get_schema() + return {} + + @abstractmethod + def get_schema(self) -> dict[str, T]: + """Generate AsyncAPI schema.""" + raise NotImplementedError + + @abstractmethod + def get_payloads(self) -> Any: + """Generate AsyncAPI payloads.""" + raise NotImplementedError diff --git a/faststream/specification/schema/__init__.py b/faststream/specification/schema/__init__.py index a2ec26fa7a..009a6a63d7 100644 --- a/faststream/specification/schema/__init__.py +++ b/faststream/specification/schema/__init__.py @@ -1,34 +1,29 @@ -from . import ( - bindings, - channel, - contact, - docs, - info, - license, - message, - operation, - security, - tag, +from .extra import ( + Contact, + ContactDict, + ExternalDocs, + ExternalDocsDict, + License, + LicenseDict, + Tag, + TagDict, ) -from .contact import Contact -from .docs import ExternalDocs -from .license import License -from .tag import Tag +from .message import Message +from .operation import Operation +from .publisher import PublisherSpec +from .subscriber import SubscriberSpec __all__ = ( "Contact", + "ContactDict", "ExternalDocs", + "ExternalDocsDict", "License", + "LicenseDict", + "Message", + "Operation", + "PublisherSpec", + "SubscriberSpec", "Tag", - # module aliases - "bindings", - "channel", - "contact", - "docs", - "info", - "license", - "message", - "operation", - "security", - "tag", + "TagDict", ) diff --git a/faststream/specification/schema/bindings/amqp.py b/faststream/specification/schema/bindings/amqp.py index 42f29dd1c8..f15201bb8e 100644 --- a/faststream/specification/schema/bindings/amqp.py +++ b/faststream/specification/schema/bindings/amqp.py @@ -1,43 +1,29 @@ -"""AsyncAPI AMQP bindings. - -References: https://github.com/asyncapi/bindings/tree/master/amqp -""" - from dataclasses import dataclass -from typing import Literal, Optional +from typing import TYPE_CHECKING, Literal, Optional + +if TYPE_CHECKING: + from faststream.rabbit.schemas import RabbitExchange, RabbitQueue @dataclass class Queue: - """A class to represent a queue. - - Attributes: - name : name of the queue - durable : indicates if the queue is durable - exclusive : indicates if the queue is exclusive - autoDelete : indicates if the queue should be automatically deleted - vhost : virtual host of the queue (default is "/") - """ - name: str durable: bool exclusive: bool - autoDelete: bool - vhost: str = "/" + auto_delete: bool + + @classmethod + def from_queue(cls, queue: "RabbitQueue") -> "Queue": + return cls( + name=queue.name, + durable=queue.durable, + exclusive=queue.exclusive, + auto_delete=queue.auto_delete, + ) @dataclass class Exchange: - """A class to represent an exchange. - - Attributes: - name : name of the exchange (optional) - type : type of the exchange, can be one of "default", "direct", "topic", "fanout", "headers" - durable : whether the exchange is durable (optional) - autoDelete : whether the exchange is automatically deleted (optional) - vhost : virtual host of the exchange, default is "/" - """ - type: Literal[ "default", "direct", @@ -51,40 +37,43 @@ class Exchange: name: Optional[str] = None durable: Optional[bool] = None - autoDelete: Optional[bool] = None - vhost: str = "/" + auto_delete: Optional[bool] = None + + @classmethod + def from_exchange(cls, exchange: "RabbitExchange") -> "Exchange": + if not exchange.name: + return cls(type="default") + return cls( + type=exchange.type.value, + name=exchange.name, + durable=exchange.durable, + auto_delete=exchange.auto_delete, + ) + + @property + def is_respect_routing_key(self) -> bool: + """Is exchange respects routing key or not.""" + return self.type in { + "default", + "direct", + "topic", + } @dataclass class ChannelBinding: - """A class to represent channel binding. - - Attributes: - is_ : Type of binding, can be "queue" or "routingKey" - bindingVersion : Version of the binding - queue : Optional queue object - exchange : Optional exchange object - """ - - is_: Literal["queue", "routingKey"] - queue: Optional[Queue] = None - exchange: Optional[Exchange] = None + queue: Queue + exchange: Exchange + virtual_host: str @dataclass class OperationBinding: - """A class to represent an operation binding. - - Attributes: - cc : optional string representing the cc - ack : boolean indicating if the operation is acknowledged - replyTo : optional dictionary representing the replyTo - bindingVersion : string representing the binding version - """ - - cc: Optional[str] = None - ack: bool = True - replyTo: Optional[str] = None - deliveryMode: Optional[int] = None - mandatory: Optional[bool] = None - priority: Optional[int] = None + routing_key: Optional[str] + queue: Queue + exchange: Exchange + ack: bool + reply_to: Optional[str] + persist: Optional[bool] + mandatory: Optional[bool] + priority: Optional[int] diff --git a/faststream/specification/schema/bindings/kafka.py b/faststream/specification/schema/bindings/kafka.py index 142a2a5285..fc9d0867c8 100644 --- a/faststream/specification/schema/bindings/kafka.py +++ b/faststream/specification/schema/bindings/kafka.py @@ -15,13 +15,11 @@ class ChannelBinding: topic : optional string representing the topic partitions : optional positive integer representing the number of partitions replicas : optional positive integer representing the number of replicas - bindingVersion : string representing the binding version """ - topic: Optional[str] = None - partitions: Optional[int] = None - replicas: Optional[int] = None - bindingVersion: str = "0.4.0" + topic: Optional[str] + partitions: Optional[int] + replicas: Optional[int] # TODO: # topicConfiguration @@ -32,13 +30,11 @@ class OperationBinding: """A class to represent an operation binding. Attributes: - groupId : optional dictionary representing the group ID - clientId : optional dictionary representing the client ID - replyTo : optional dictionary representing the reply-to - bindingVersion : version of the binding (default: "0.4.0") + group_id : optional dictionary representing the group ID + client_id : optional dictionary representing the client ID + reply_to : optional dictionary representing the reply-to """ - groupId: Optional[dict[str, Any]] = None - clientId: Optional[dict[str, Any]] = None - replyTo: Optional[dict[str, Any]] = None - bindingVersion: str = "0.4.0" + group_id: Optional[dict[str, Any]] + client_id: Optional[dict[str, Any]] + reply_to: Optional[dict[str, Any]] diff --git a/faststream/specification/schema/bindings/nats.py b/faststream/specification/schema/bindings/nats.py index 034efada4e..412f29d557 100644 --- a/faststream/specification/schema/bindings/nats.py +++ b/faststream/specification/schema/bindings/nats.py @@ -14,12 +14,10 @@ class ChannelBinding: Attributes: subject : subject of the channel binding queue : optional queue for the channel binding - bindingVersion : version of the channel binding, default is "custom" """ subject: str - queue: Optional[str] = None - bindingVersion: str = "custom" + queue: Optional[str] @dataclass @@ -27,9 +25,7 @@ class OperationBinding: """A class to represent an operation binding. Attributes: - replyTo : optional dictionary containing reply information - bindingVersion : version of the binding (default is "custom") + reply_to : optional dictionary containing reply information """ - replyTo: Optional[dict[str, Any]] = None - bindingVersion: str = "custom" + reply_to: Optional[dict[str, Any]] diff --git a/faststream/specification/schema/bindings/redis.py b/faststream/specification/schema/bindings/redis.py index c1b3e138a0..17287aa5e4 100644 --- a/faststream/specification/schema/bindings/redis.py +++ b/faststream/specification/schema/bindings/redis.py @@ -14,14 +14,12 @@ class ChannelBinding: Attributes: channel : the channel name method : the method used for binding (ssubscribe, psubscribe, subscribe) - bindingVersion : the version of the binding """ channel: str method: Optional[str] = None group_name: Optional[str] = None consumer_name: Optional[str] = None - bindingVersion: str = "custom" @dataclass @@ -29,9 +27,7 @@ class OperationBinding: """A class to represent an operation binding. Attributes: - replyTo : optional dictionary containing reply information - bindingVersion : version of the binding (default is "custom") + reply_to : optional dictionary containing reply information """ - replyTo: Optional[dict[str, Any]] = None - bindingVersion: str = "custom" + reply_to: Optional[dict[str, Any]] = None diff --git a/faststream/specification/schema/channel.py b/faststream/specification/schema/channel.py deleted file mode 100644 index 89db7d7a6f..0000000000 --- a/faststream/specification/schema/channel.py +++ /dev/null @@ -1,24 +0,0 @@ -from dataclasses import dataclass -from typing import Optional - -from faststream.specification.schema.bindings import ChannelBinding -from faststream.specification.schema.operation import Operation - - -@dataclass -class Channel: - """Channel specification. - - Attributes: - description : optional description of the channel - servers : optional list of servers associated with the channel - bindings : optional channel binding - subscribe : optional operation for subscribing to the channel - publish : optional operation for publishing to the channel - """ - - description: Optional[str] = None - servers: Optional[list[str]] = None - bindings: Optional[ChannelBinding] = None - subscribe: Optional[Operation] = None - publish: Optional[Operation] = None diff --git a/faststream/specification/schema/components.py b/faststream/specification/schema/components.py deleted file mode 100644 index 39e6011591..0000000000 --- a/faststream/specification/schema/components.py +++ /dev/null @@ -1,39 +0,0 @@ -from typing import ( - Any, - Optional, -) - -from pydantic import BaseModel - -from faststream._internal._compat import ( - PYDANTIC_V2, -) -from faststream.specification.schema.message import Message - - -class Components(BaseModel): - """A class to represent components in a system. - - Attributes: - messages : Optional dictionary of messages - schemas : Optional dictionary of schemas - - Note: - The following attributes are not implemented yet: - - servers - - serverVariables - - channels - - securitySchemes - """ - - messages: Optional[dict[str, Message]] = None - schemas: Optional[dict[str, dict[str, Any]]] = None - securitySchemes: Optional[dict[str, dict[str, Any]]] = None - - if PYDANTIC_V2: - model_config = {"extra": "allow"} - - else: - - class Config: - extra = "allow" diff --git a/faststream/specification/schema/contact.py b/faststream/specification/schema/contact.py deleted file mode 100644 index 2de5d06292..0000000000 --- a/faststream/specification/schema/contact.py +++ /dev/null @@ -1,44 +0,0 @@ -from typing import ( - Optional, -) - -from pydantic import AnyHttpUrl, BaseModel -from typing_extensions import Required, TypedDict - -from faststream._internal._compat import PYDANTIC_V2, EmailStr - - -class ContactDict(TypedDict, total=False): - """A class to represent a dictionary of contact information. - - Attributes: - name : required name of the contact (type: str) - url : URL of the contact (type: AnyHttpUrl) - email : email address of the contact (type: EmailStr) - """ - - name: Required[str] - url: AnyHttpUrl - email: EmailStr - - -class Contact(BaseModel): - """A class to represent a contact. - - Attributes: - name : name of the contact (str) - url : URL of the contact (Optional[AnyHttpUrl]) - email : email of the contact (Optional[EmailStr]) - """ - - name: str - url: Optional[AnyHttpUrl] = None - email: Optional[EmailStr] = None - - if PYDANTIC_V2: - model_config = {"extra": "allow"} - - else: - - class Config: - extra = "allow" diff --git a/faststream/specification/schema/docs.py b/faststream/specification/schema/docs.py deleted file mode 100644 index d5b69fe7b4..0000000000 --- a/faststream/specification/schema/docs.py +++ /dev/null @@ -1,29 +0,0 @@ -from dataclasses import dataclass -from typing import Optional - -from typing_extensions import Required, TypedDict - - -class ExternalDocsDict(TypedDict, total=False): - """A dictionary type for representing external documentation. - - Attributes: - url : Required URL for the external documentation - description : Description of the external documentation - """ - - url: Required[str] - description: str - - -@dataclass -class ExternalDocs: - """A class to represent external documentation. - - Attributes: - url : URL of the external documentation - description : optional description of the external documentation - """ - - url: str - description: Optional[str] = None diff --git a/faststream/specification/schema/extra/__init__.py b/faststream/specification/schema/extra/__init__.py new file mode 100644 index 0000000000..f2417a905f --- /dev/null +++ b/faststream/specification/schema/extra/__init__.py @@ -0,0 +1,15 @@ +from .contact import Contact, ContactDict +from .external_docs import ExternalDocs, ExternalDocsDict +from .license import License, LicenseDict +from .tag import Tag, TagDict + +__all__ = ( + "Contact", + "ContactDict", + "ExternalDocs", + "ExternalDocsDict", + "License", + "LicenseDict", + "Tag", + "TagDict", +) diff --git a/faststream/specification/schema/extra/contact.py b/faststream/specification/schema/extra/contact.py new file mode 100644 index 0000000000..dfabbbacb3 --- /dev/null +++ b/faststream/specification/schema/extra/contact.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass +from typing import Optional + +from pydantic import AnyHttpUrl +from typing_extensions import Required, TypedDict + +from faststream._internal._compat import EmailStr + + +class ContactDict(TypedDict, total=False): + name: Required[str] + url: AnyHttpUrl + email: EmailStr + + +@dataclass +class Contact: + name: str + url: Optional[AnyHttpUrl] = None + email: Optional[EmailStr] = None diff --git a/faststream/specification/schema/extra/external_docs.py b/faststream/specification/schema/extra/external_docs.py new file mode 100644 index 0000000000..600a6d3a95 --- /dev/null +++ b/faststream/specification/schema/extra/external_docs.py @@ -0,0 +1,15 @@ +from dataclasses import dataclass +from typing import Optional + +from typing_extensions import Required, TypedDict + + +class ExternalDocsDict(TypedDict, total=False): + url: Required[str] + description: str + + +@dataclass +class ExternalDocs: + url: str + description: Optional[str] = None diff --git a/faststream/specification/schema/extra/license.py b/faststream/specification/schema/extra/license.py new file mode 100644 index 0000000000..7bd4039621 --- /dev/null +++ b/faststream/specification/schema/extra/license.py @@ -0,0 +1,16 @@ +from dataclasses import dataclass +from typing import Optional + +from pydantic import AnyHttpUrl +from typing_extensions import Required, TypedDict + + +class LicenseDict(TypedDict, total=False): + name: Required[str] + url: AnyHttpUrl + + +@dataclass +class License: + name: str + url: Optional[AnyHttpUrl] = None diff --git a/faststream/specification/schema/extra/tag.py b/faststream/specification/schema/extra/tag.py new file mode 100644 index 0000000000..1d62ed7491 --- /dev/null +++ b/faststream/specification/schema/extra/tag.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass +from typing import Optional, Union + +from typing_extensions import Required, TypedDict + +from .external_docs import ExternalDocs, ExternalDocsDict + + +class TagDict(TypedDict, total=False): + name: Required[str] + description: str + external_docs: Union[ExternalDocs, ExternalDocsDict] + + +@dataclass +class Tag: + name: str + description: Optional[str] = None + external_docs: Optional[Union[ExternalDocs, ExternalDocsDict]] = None diff --git a/faststream/specification/schema/info.py b/faststream/specification/schema/info.py deleted file mode 100644 index 67f2341e4f..0000000000 --- a/faststream/specification/schema/info.py +++ /dev/null @@ -1,27 +0,0 @@ -from typing import ( - Any, - Optional, - Union, -) - -from pydantic import AnyHttpUrl, BaseModel - -from faststream.specification.schema.contact import Contact, ContactDict -from faststream.specification.schema.license import License, LicenseDict - - -class Info(BaseModel): - """A class to represent information. - - Attributes: - title : title of the information - version : version of the information (default: "1.0.0") - description : description of the information (default: "") - termsOfService : terms of service for the information (default: None) - contact : contact information for the information (default: None) - license : license information for the information (default: None) - """ - - termsOfService: Optional[AnyHttpUrl] = None - contact: Optional[Union[Contact, ContactDict, dict[str, Any]]] = None - license: Optional[Union[License, LicenseDict, dict[str, Any]]] = None diff --git a/faststream/specification/schema/license.py b/faststream/specification/schema/license.py deleted file mode 100644 index f95faf3e10..0000000000 --- a/faststream/specification/schema/license.py +++ /dev/null @@ -1,45 +0,0 @@ -from typing import ( - Optional, -) - -from pydantic import AnyHttpUrl, BaseModel -from typing_extensions import Required, TypedDict - -from faststream._internal._compat import ( - PYDANTIC_V2, -) - - -class LicenseDict(TypedDict, total=False): - """A dictionary-like class to represent a license. - - Attributes: - name : required name of the license (type: str) - url : URL of the license (type: AnyHttpUrl) - """ - - name: Required[str] - url: AnyHttpUrl - - -class License(BaseModel): - """A class to represent a license. - - Attributes: - name : name of the license - url : URL of the license (optional) - - Config: - extra : allow additional attributes in the model (PYDANTIC_V2) - """ - - name: str - url: Optional[AnyHttpUrl] = None - - if PYDANTIC_V2: - model_config = {"extra": "allow"} - - else: - - class Config: - extra = "allow" diff --git a/faststream/specification/schema/message.py b/faststream/specification/schema/message.py deleted file mode 100644 index 865ec95553..0000000000 --- a/faststream/specification/schema/message.py +++ /dev/null @@ -1,51 +0,0 @@ -from dataclasses import dataclass -from typing import Any, Optional, Union - -from faststream.specification.schema.docs import ExternalDocs -from faststream.specification.schema.tag import Tag - - -@dataclass -class CorrelationId: - """Correlation ID specification. - - Attributes: - description : optional description of the correlation ID - location : location of the correlation ID - - Configurations: - extra : allows extra fields in the correlation ID model - """ - - location: str - description: Optional[str] = None - - -@dataclass -class Message: - """Message specification. - - Attributes: - title : title of the message - name : name of the message - summary : summary of the message - description : description of the message - messageId : ID of the message - correlationId : correlation ID of the message - contentType : content type of the message - payload : dictionary representing the payload of the message - tags : list of tags associated with the message - externalDocs : external documentation associated with the message - """ - - payload: dict[str, Any] - title: Optional[str] = None - name: Optional[str] = None - summary: Optional[str] = None - description: Optional[str] = None - messageId: Optional[str] = None - correlationId: Optional[CorrelationId] = None - contentType: Optional[str] = None - - tags: Optional[list[Union[Tag, dict[str, Any]]]] = None - externalDocs: Optional[Union[ExternalDocs, dict[str, Any]]] = None diff --git a/faststream/specification/schema/message/__init__.py b/faststream/specification/schema/message/__init__.py new file mode 100644 index 0000000000..6221895ab5 --- /dev/null +++ b/faststream/specification/schema/message/__init__.py @@ -0,0 +1,3 @@ +from .model import Message + +__all__ = ("Message",) diff --git a/faststream/specification/schema/message/model.py b/faststream/specification/schema/message/model.py new file mode 100644 index 0000000000..8b8c37f24a --- /dev/null +++ b/faststream/specification/schema/message/model.py @@ -0,0 +1,11 @@ +from dataclasses import dataclass +from typing import Optional + +from faststream._internal.basic_types import AnyDict + + +@dataclass +class Message: + payload: AnyDict # JSON Schema + + title: Optional[str] diff --git a/faststream/specification/schema/operation.py b/faststream/specification/schema/operation.py deleted file mode 100644 index e88d3c39e4..0000000000 --- a/faststream/specification/schema/operation.py +++ /dev/null @@ -1,33 +0,0 @@ -from dataclasses import dataclass -from typing import Any, Optional, Union - -from faststream.specification.schema.bindings import OperationBinding -from faststream.specification.schema.message import Message -from faststream.specification.schema.tag import Tag - - -@dataclass -class Operation: - """A class to represent an operation. - - Attributes: - operationId : ID of the operation - summary : summary of the operation - description : description of the operation - bindings : bindings of the operation - message : message of the operation - security : security details of the operation - tags : tags associated with the operation - """ - - message: Message - - operationId: Optional[str] = None - summary: Optional[str] = None - description: Optional[str] = None - - bindings: Optional[OperationBinding] = None - - security: Optional[dict[str, list[str]]] = None - - tags: Optional[list[Union[Tag, dict[str, Any]]]] = None diff --git a/faststream/specification/schema/operation/__init__.py b/faststream/specification/schema/operation/__init__.py new file mode 100644 index 0000000000..85cbafe10a --- /dev/null +++ b/faststream/specification/schema/operation/__init__.py @@ -0,0 +1,3 @@ +from .model import Operation + +__all__ = ("Operation",) diff --git a/faststream/specification/schema/operation/model.py b/faststream/specification/schema/operation/model.py new file mode 100644 index 0000000000..2e72e523e9 --- /dev/null +++ b/faststream/specification/schema/operation/model.py @@ -0,0 +1,11 @@ +from dataclasses import dataclass +from typing import Optional + +from faststream.specification.schema.bindings import OperationBinding +from faststream.specification.schema.message import Message + + +@dataclass +class Operation: + message: Message + bindings: Optional[OperationBinding] diff --git a/faststream/specification/schema/publisher.py b/faststream/specification/schema/publisher.py new file mode 100644 index 0000000000..f534619d0f --- /dev/null +++ b/faststream/specification/schema/publisher.py @@ -0,0 +1,12 @@ +from dataclasses import dataclass +from typing import Optional + +from .bindings import ChannelBinding +from .operation import Operation + + +@dataclass +class PublisherSpec: + description: str + operation: Operation + bindings: Optional[ChannelBinding] diff --git a/faststream/specification/schema/security.py b/faststream/specification/schema/security.py deleted file mode 100644 index d940cbdc4f..0000000000 --- a/faststream/specification/schema/security.py +++ /dev/null @@ -1,105 +0,0 @@ -from typing import Literal, Optional - -from pydantic import AnyHttpUrl, BaseModel, Field - -from faststream._internal._compat import PYDANTIC_V2 - - -class OauthFlowObj(BaseModel): - """A class to represent an OAuth flow object. - - Attributes: - authorizationUrl : Optional[AnyHttpUrl] : The URL for authorization - tokenUrl : Optional[AnyHttpUrl] : The URL for token - refreshUrl : Optional[AnyHttpUrl] : The URL for refresh - scopes : dict[str, str] : The scopes for the OAuth flow - """ - - authorizationUrl: Optional[AnyHttpUrl] = None - tokenUrl: Optional[AnyHttpUrl] = None - refreshUrl: Optional[AnyHttpUrl] = None - scopes: dict[str, str] - - if PYDANTIC_V2: - model_config = {"extra": "allow"} - - else: - - class Config: - extra = "allow" - - -class OauthFlows(BaseModel): - """A class to represent OAuth flows. - - Attributes: - implicit : Optional[OauthFlowObj] : Implicit OAuth flow object - password : Optional[OauthFlowObj] : Password OAuth flow object - clientCredentials : Optional[OauthFlowObj] : Client credentials OAuth flow object - authorizationCode : Optional[OauthFlowObj] : Authorization code OAuth flow object - """ - - implicit: Optional[OauthFlowObj] = None - password: Optional[OauthFlowObj] = None - clientCredentials: Optional[OauthFlowObj] = None - authorizationCode: Optional[OauthFlowObj] = None - - if PYDANTIC_V2: - model_config = {"extra": "allow"} - - else: - - class Config: - extra = "allow" - - -class SecuritySchemaComponent(BaseModel): - """A class to represent a security schema component. - - Attributes: - type : Literal, the type of the security schema component - name : optional name of the security schema component - description : optional description of the security schema component - in_ : optional location of the security schema component - schema_ : optional schema of the security schema component - bearerFormat : optional bearer format of the security schema component - openIdConnectUrl : optional OpenID Connect URL of the security schema component - flows : optional OAuth flows of the security schema component - """ - - type: Literal[ - "userPassword", - "apikey", - "X509", - "symmetricEncryption", - "asymmetricEncryption", - "httpApiKey", - "http", - "oauth2", - "openIdConnect", - "plain", - "scramSha256", - "scramSha512", - "gssapi", - ] - name: Optional[str] = None - description: Optional[str] = None - in_: Optional[str] = Field( - default=None, - alias="in", - ) - schema_: Optional[str] = Field( - default=None, - alias="schema", - ) - bearerFormat: Optional[str] = None - openIdConnectUrl: Optional[str] = None - flows: Optional[OauthFlows] = None - - if PYDANTIC_V2: - model_config = {"extra": "allow"} - - else: - - class Config: - extra = "allow" diff --git a/faststream/specification/schema/servers.py b/faststream/specification/schema/servers.py deleted file mode 100644 index d296c359a2..0000000000 --- a/faststream/specification/schema/servers.py +++ /dev/null @@ -1,65 +0,0 @@ -from typing import Any, Optional, Union - -from pydantic import BaseModel - -from faststream._internal._compat import PYDANTIC_V2 -from faststream.specification.schema.tag import Tag - -SecurityRequirement = list[dict[str, list[str]]] - - -class ServerVariable(BaseModel): - """A class to represent a server variable. - - Attributes: - enum : list of possible values for the server variable (optional) - default : default value for the server variable (optional) - description : description of the server variable (optional) - examples : list of example values for the server variable (optional) - """ - - enum: Optional[list[str]] = None - default: Optional[str] = None - description: Optional[str] = None - examples: Optional[list[str]] = None - - if PYDANTIC_V2: - model_config = {"extra": "allow"} - - else: - - class Config: - extra = "allow" - - -class Server(BaseModel): - """A class to represent a server. - - Attributes: - url : URL of the server - protocol : protocol used by the server - description : optional description of the server - protocolVersion : optional version of the protocol used by the server - tags : optional list of tags associated with the server - security : optional security requirement for the server - variables : optional dictionary of server variables - - Note: - The attributes `description`, `protocolVersion`, `tags`, `security`, `variables`, and `bindings` are all optional. - """ - - url: str - protocol: str - description: Optional[str] = None - protocolVersion: Optional[str] = None - tags: Optional[list[Union[Tag, dict[str, Any]]]] = None - security: Optional[SecurityRequirement] = None - variables: Optional[dict[str, ServerVariable]] = None - - if PYDANTIC_V2: - model_config = {"extra": "allow"} - - else: - - class Config: - extra = "allow" diff --git a/faststream/specification/schema/subscriber.py b/faststream/specification/schema/subscriber.py new file mode 100644 index 0000000000..9d41177b4f --- /dev/null +++ b/faststream/specification/schema/subscriber.py @@ -0,0 +1,12 @@ +from dataclasses import dataclass +from typing import Optional + +from .bindings import ChannelBinding +from .operation import Operation + + +@dataclass +class SubscriberSpec: + description: Optional[str] + operation: Operation + bindings: Optional[ChannelBinding] diff --git a/faststream/specification/schema/tag.py b/faststream/specification/schema/tag.py deleted file mode 100644 index ff9509d2c8..0000000000 --- a/faststream/specification/schema/tag.py +++ /dev/null @@ -1,37 +0,0 @@ -from dataclasses import dataclass -from typing import Optional, Union - -from typing_extensions import Required, TypedDict - -from faststream.specification.schema.docs import ExternalDocs, ExternalDocsDict - - -class TagDict(TypedDict, total=False): - """A dictionary-like class for storing tags. - - Attributes: - name : required name of the tag - description : description of the tag - externalDocs : external documentation for the tag - - """ - - name: Required[str] - description: str - externalDocs: Union[ExternalDocs, ExternalDocsDict] - - -@dataclass -class Tag: - """A class to represent a tag. - - Attributes: - name : name of the tag - description : description of the tag (optional) - externalDocs : external documentation for the tag (optional) - - """ - - name: str - description: Optional[str] = None - externalDocs: Optional[Union[ExternalDocs, ExternalDocsDict]] = None diff --git a/tests/a_docs/getting_started/asyncapi/asyncapi_customization/test_basic.py b/tests/a_docs/getting_started/asyncapi/asyncapi_customization/test_basic.py index 463d541d3d..05577fbbb2 100644 --- a/tests/a_docs/getting_started/asyncapi/asyncapi_customization/test_basic.py +++ b/tests/a_docs/getting_started/asyncapi/asyncapi_customization/test_basic.py @@ -13,7 +13,7 @@ def test_basic_customization() -> None: "kafka": {"bindingVersion": "0.4.0", "topic": "input_data"}, }, "servers": ["development"], - "subscribe": { + "publish": { "message": { "$ref": "#/components/messages/input_data:OnInputData:Message", }, @@ -23,7 +23,7 @@ def test_basic_customization() -> None: "bindings": { "kafka": {"bindingVersion": "0.4.0", "topic": "output_data"}, }, - "publish": { + "subscribe": { "message": { "$ref": "#/components/messages/output_data:Publisher:Message", }, diff --git a/tests/a_docs/getting_started/asyncapi/asyncapi_customization/test_handler.py b/tests/a_docs/getting_started/asyncapi/asyncapi_customization/test_handler.py index 3b852203be..c7499bb15e 100644 --- a/tests/a_docs/getting_started/asyncapi/asyncapi_customization/test_handler.py +++ b/tests/a_docs/getting_started/asyncapi/asyncapi_customization/test_handler.py @@ -16,7 +16,7 @@ def test_handler_customization() -> None: assert subscriber_value == IsPartialDict({ "servers": ["development"], "bindings": {"kafka": {"topic": "input_data", "bindingVersion": "0.4.0"}}, - "subscribe": { + "publish": { "message": {"$ref": "#/components/messages/input_data:Consume:Message"}, }, }), subscriber_value @@ -32,7 +32,7 @@ def test_handler_customization() -> None: "description": "My publisher description", "servers": ["development"], "bindings": {"kafka": {"topic": "output_data", "bindingVersion": "0.4.0"}}, - "publish": { + "subscribe": { "message": {"$ref": "#/components/messages/output_data:Produce:Message"} }, } diff --git a/tests/asyncapi/base/v2_6_0/arguments.py b/tests/asyncapi/base/v2_6_0/arguments.py index b2479fc7ab..aaac817e85 100644 --- a/tests/asyncapi/base/v2_6_0/arguments.py +++ b/tests/asyncapi/base/v2_6_0/arguments.py @@ -5,7 +5,6 @@ import pydantic from dirty_equals import IsDict, IsPartialDict, IsStr from fast_depends import Depends -from fastapi import Depends as APIDepends from typing_extensions import Literal from faststream import Context @@ -17,7 +16,7 @@ class FastAPICompatible: broker_class: type[BrokerUsecase] - dependency_builder = staticmethod(APIDepends) + dependency_builder = staticmethod(Depends) def build_app(self, broker: BrokerUsecase[Any, Any]) -> BrokerUsecase[Any, Any]: """Patch it to test FastAPI scheme generation too.""" @@ -67,7 +66,7 @@ async def handle(msg) -> None: assert key == "custom_name" assert schema["channels"][key]["description"] == "Test description.", schema[ "channels" - ][key]["description"] + ][key] def test_empty(self) -> None: broker = self.broker_class() @@ -424,7 +423,7 @@ class TestModel(pydantic.BaseModel): @broker.subscriber("test") async def handle(model: TestModel) -> None: ... - schema = AsyncAPI(self.build_app(broker)).to_jsonable() + schema = AsyncAPI(self.build_app(broker), schema_version="2.6.0").to_jsonable() payload = schema["components"]["schemas"] diff --git a/tests/asyncapi/base/v2_6_0/fastapi.py b/tests/asyncapi/base/v2_6_0/fastapi.py index d814b9df4e..c6b1bd6a2d 100644 --- a/tests/asyncapi/base/v2_6_0/fastapi.py +++ b/tests/asyncapi/base/v2_6_0/fastapi.py @@ -2,7 +2,7 @@ import pytest from dirty_equals import IsStr -from fastapi import FastAPI +from fastapi import Depends, FastAPI from fastapi.testclient import TestClient from faststream._internal.broker.broker import BrokerUsecase @@ -15,6 +15,8 @@ class FastAPITestCase: router_class: type[StreamRouter[MsgType]] broker_wrapper: Callable[[BrokerUsecase[MsgType, Any]], BrokerUsecase[MsgType, Any]] + dependency_builder = staticmethod(Depends) + @pytest.mark.skip() @pytest.mark.asyncio() async def test_fastapi_full_information(self) -> None: diff --git a/tests/asyncapi/base/v2_6_0/from_spec/__init__.py b/tests/asyncapi/base/v2_6_0/from_spec/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/asyncapi/base/v2_6_0/from_spec/test_contact.py b/tests/asyncapi/base/v2_6_0/from_spec/test_contact.py new file mode 100644 index 0000000000..ad97d872ce --- /dev/null +++ b/tests/asyncapi/base/v2_6_0/from_spec/test_contact.py @@ -0,0 +1,59 @@ +from typing import Any + +import pytest + +from faststream.specification import Contact +from faststream.specification.asyncapi.v2_6_0.schema import Contact as AsyncAPIContact + + +@pytest.mark.parametrize( + ("arg", "result"), + ( + pytest.param( + None, + None, + id="None", + ), + pytest.param( + Contact( + name="test", + url="http://contact.com", + email="support@gmail.com", + ), + AsyncAPIContact( + name="test", + url="http://contact.com", + email="support@gmail.com", + ), + id="Contact object", + ), + pytest.param( + { + "name": "test", + "url": "http://contact.com", + }, + AsyncAPIContact( + name="test", + url="http://contact.com", + ), + id="Contact dict", + ), + pytest.param( + { + "name": "test", + "url": "http://contact.com", + "email": "support@gmail.com", + "extra": "test", + }, + { + "name": "test", + "url": "http://contact.com", + "email": "support@gmail.com", + "extra": "test", + }, + id="Unknown dict", + ), + ), +) +def test_contact_factory_method(arg: Any, result: Any) -> None: + assert AsyncAPIContact.from_spec(arg) == result diff --git a/tests/asyncapi/base/v2_6_0/from_spec/test_external_docs.py b/tests/asyncapi/base/v2_6_0/from_spec/test_external_docs.py new file mode 100644 index 0000000000..7b2ede38c8 --- /dev/null +++ b/tests/asyncapi/base/v2_6_0/from_spec/test_external_docs.py @@ -0,0 +1,35 @@ +from typing import Any + +import pytest + +from faststream.specification import ExternalDocs +from faststream.specification.asyncapi.v2_6_0.schema import ExternalDocs as AsyncAPIDocs + + +@pytest.mark.parametrize( + ("arg", "result"), + ( + pytest.param( + None, + None, + id="None", + ), + pytest.param( + ExternalDocs(description="test", url="http://docs.com"), + AsyncAPIDocs(description="test", url="http://docs.com"), + id="ExternalDocs object", + ), + pytest.param( + {"description": "test", "url": "http://docs.com"}, + AsyncAPIDocs(description="test", url="http://docs.com"), + id="ExternalDocs dict", + ), + pytest.param( + {"description": "test", "url": "http://docs.com", "extra": "test"}, + {"description": "test", "url": "http://docs.com", "extra": "test"}, + id="Unknown dict", + ), + ), +) +def test_external_docs_factory_method(arg: Any, result: Any) -> None: + assert AsyncAPIDocs.from_spec(arg) == result diff --git a/tests/asyncapi/base/v2_6_0/from_spec/test_license.py b/tests/asyncapi/base/v2_6_0/from_spec/test_license.py new file mode 100644 index 0000000000..c6e2e9421b --- /dev/null +++ b/tests/asyncapi/base/v2_6_0/from_spec/test_license.py @@ -0,0 +1,35 @@ +from typing import Any + +import pytest + +from faststream.specification import License +from faststream.specification.asyncapi.v2_6_0.schema import License as AsyncAPICLicense + + +@pytest.mark.parametrize( + ("arg", "result"), + ( + pytest.param( + None, + None, + id="None", + ), + pytest.param( + License(name="test", url="http://license.com"), + AsyncAPICLicense(name="test", url="http://license.com"), + id="License object", + ), + pytest.param( + {"name": "test", "url": "http://license.com"}, + AsyncAPICLicense(name="test", url="http://license.com"), + id="License dict", + ), + pytest.param( + {"name": "test", "url": "http://license.com", "extra": "test"}, + {"name": "test", "url": "http://license.com", "extra": "test"}, + id="Unknown dict", + ), + ), +) +def test_license_factory_method(arg: Any, result: Any) -> None: + assert AsyncAPICLicense.from_spec(arg) == result diff --git a/tests/asyncapi/base/v2_6_0/from_spec/test_tag.py b/tests/asyncapi/base/v2_6_0/from_spec/test_tag.py new file mode 100644 index 0000000000..66eedcd811 --- /dev/null +++ b/tests/asyncapi/base/v2_6_0/from_spec/test_tag.py @@ -0,0 +1,49 @@ +from typing import Any + +import pytest + +from faststream.specification import ExternalDocs, Tag +from faststream.specification.asyncapi.v2_6_0.schema import ( + ExternalDocs as AsyncAPIDocs, + Tag as AsyncAPITag, +) + + +@pytest.mark.parametrize( + ("arg", "result"), + ( + pytest.param( + Tag( + name="test", + description="test", + external_docs=ExternalDocs(url="http://docs.com"), + ), + AsyncAPITag( + name="test", + description="test", + externalDocs=AsyncAPIDocs(url="http://docs.com"), + ), + id="Tag object", + ), + pytest.param( + { + "name": "test", + "description": "test", + "external_docs": {"url": "http://docs.com"}, + }, + AsyncAPITag( + name="test", + description="test", + externalDocs=AsyncAPIDocs(url="http://docs.com"), + ), + id="Tag dict", + ), + pytest.param( + {"name": "test", "description": "test", "extra": "test"}, + {"name": "test", "description": "test", "extra": "test"}, + id="Unknown dict", + ), + ), +) +def test_tag_factory_method(arg: Any, result: Any) -> None: + assert AsyncAPITag.from_spec(arg) == result diff --git a/tests/asyncapi/base/v2_6_0/publisher.py b/tests/asyncapi/base/v2_6_0/publisher.py index 6705975a72..d61baa2d19 100644 --- a/tests/asyncapi/base/v2_6_0/publisher.py +++ b/tests/asyncapi/base/v2_6_0/publisher.py @@ -32,7 +32,7 @@ async def handle(msg) -> None: ... key = tuple(schema["channels"].keys())[0] # noqa: RUF015 assert schema["channels"][key].get("description") is None - assert schema["channels"][key].get("publish") is not None + assert schema["channels"][key].get("subscribe") is not None payload = schema["components"]["schemas"] for v in payload.values(): @@ -120,7 +120,7 @@ def test_not_include(self) -> None: async def handler(msg: str) -> None: pass - schema = AsyncAPI(self.build_app(broker)) + schema = AsyncAPI(self.build_app(broker), schema_version="2.6.0") assert schema.to_jsonable()["channels"] == {}, schema.to_jsonable()["channels"] @@ -133,7 +133,7 @@ class TestModel(pydantic.BaseModel): @broker.publisher("test") async def handle(msg) -> TestModel: ... - schema = AsyncAPI(self.build_app(broker)).to_jsonable() + schema = AsyncAPI(self.build_app(broker), schema_version="2.6.0").to_jsonable() payload = schema["components"]["schemas"] diff --git a/tests/asyncapi/base/v3_0_0/fastapi.py b/tests/asyncapi/base/v3_0_0/fastapi.py index cdc986b568..edf8cfe993 100644 --- a/tests/asyncapi/base/v3_0_0/fastapi.py +++ b/tests/asyncapi/base/v3_0_0/fastapi.py @@ -26,7 +26,6 @@ async def test_fastapi_full_information(self) -> None: ) app = FastAPI( - lifespan=broker.lifespan_context, title="CustomApp", version="1.1.1", description="Test description", @@ -77,7 +76,7 @@ async def test_fastapi_asyncapi_routes(self) -> None: @router.subscriber("test") async def handler() -> None: ... - app = FastAPI(lifespan=router.lifespan_context) + app = FastAPI() app.include_router(router) async with self.broker_wrapper(router.broker): @@ -107,7 +106,7 @@ async def handler() -> None: ... async def test_fastapi_asyncapi_not_fount(self) -> None: broker = self.router_factory(include_in_schema=False) - app = FastAPI(lifespan=broker.lifespan_context) + app = FastAPI() app.include_router(broker) async with self.broker_wrapper(broker.broker): @@ -125,7 +124,7 @@ async def test_fastapi_asyncapi_not_fount(self) -> None: async def test_fastapi_asyncapi_not_fount_by_url(self) -> None: broker = self.router_factory(schema_url=None) - app = FastAPI(lifespan=broker.lifespan_context) + app = FastAPI() app.include_router(broker) async with self.broker_wrapper(broker.broker): diff --git a/tests/asyncapi/confluent/v2_6_0/test_arguments.py b/tests/asyncapi/confluent/v2_6_0/test_arguments.py index 7f40bcfb3e..4bc16826e8 100644 --- a/tests/asyncapi/confluent/v2_6_0/test_arguments.py +++ b/tests/asyncapi/confluent/v2_6_0/test_arguments.py @@ -1,4 +1,4 @@ -from faststream.confluent import KafkaBroker +from faststream.confluent import KafkaBroker, TopicPartition from faststream.specification.asyncapi import AsyncAPI from tests.asyncapi.base.v2_6_0.arguments import ArgumentsTestcase @@ -18,3 +18,37 @@ async def handle(msg) -> None: ... assert schema["channels"][key]["bindings"] == { "kafka": {"bindingVersion": "0.4.0", "topic": "test"}, } + + def test_subscriber_with_one_topic_partitions(self) -> None: + broker = self.broker_class() + + part1 = TopicPartition("topic_name", 1) + part2 = TopicPartition("topic_name", 2) + + @broker.subscriber(partitions=[part1, part2]) + async def handle(msg): ... + + schema = AsyncAPI(self.build_app(broker), schema_version="2.6.0").to_jsonable() + key = tuple(schema["channels"].keys())[0] # noqa: RUF015 + + assert schema["channels"][key]["bindings"] == { + "kafka": {"bindingVersion": "0.4.0", "topic": "topic_name"} + } + + def test_subscriber_with_multi_topics_partitions(self) -> None: + broker = self.broker_class() + + part1 = TopicPartition("topic_name1", 1) + part2 = TopicPartition("topic_name2", 2) + + @broker.subscriber(partitions=[part1, part2]) + async def handle(msg): ... + + schema = AsyncAPI(self.build_app(broker), schema_version="2.6.0").to_jsonable() + key1 = tuple(schema["channels"].keys())[0] # noqa: RUF015 + key2 = tuple(schema["channels"].keys())[1] + + assert sorted(( + schema["channels"][key1]["bindings"]["kafka"]["topic"], + schema["channels"][key2]["bindings"]["kafka"]["topic"], + )) == sorted(("topic_name1", "topic_name2")) diff --git a/tests/asyncapi/confluent/v2_6_0/test_connection.py b/tests/asyncapi/confluent/v2_6_0/test_connection.py index 56ad2af682..368bbc00dd 100644 --- a/tests/asyncapi/confluent/v2_6_0/test_connection.py +++ b/tests/asyncapi/confluent/v2_6_0/test_connection.py @@ -1,6 +1,6 @@ from faststream.confluent import KafkaBroker +from faststream.specification import Tag from faststream.specification.asyncapi import AsyncAPI -from faststream.specification.schema.tag import Tag def test_base() -> None: diff --git a/tests/asyncapi/confluent/v2_6_0/test_naming.py b/tests/asyncapi/confluent/v2_6_0/test_naming.py index 85bd4cdc06..2fdbd64687 100644 --- a/tests/asyncapi/confluent/v2_6_0/test_naming.py +++ b/tests/asyncapi/confluent/v2_6_0/test_naming.py @@ -29,7 +29,7 @@ async def handle() -> None: ... "test:Handle": { "servers": ["development"], "bindings": {"kafka": {"topic": "test", "bindingVersion": "0.4.0"}}, - "subscribe": { + "publish": { "message": { "$ref": "#/components/messages/test:Handle:Message" }, diff --git a/tests/asyncapi/confluent/v2_6_0/test_router.py b/tests/asyncapi/confluent/v2_6_0/test_router.py index b3cd359b4c..c73885cddb 100644 --- a/tests/asyncapi/confluent/v2_6_0/test_router.py +++ b/tests/asyncapi/confluent/v2_6_0/test_router.py @@ -40,7 +40,7 @@ async def handle(msg) -> None: ... "bindings": { "kafka": {"topic": "test_test", "bindingVersion": "0.4.0"}, }, - "subscribe": { + "publish": { "message": { "$ref": "#/components/messages/test_test:Handle:Message", }, diff --git a/tests/asyncapi/confluent/v2_6_0/test_security.py b/tests/asyncapi/confluent/v2_6_0/test_security.py index f8b90168b7..f3a1bcf913 100644 --- a/tests/asyncapi/confluent/v2_6_0/test_security.py +++ b/tests/asyncapi/confluent/v2_6_0/test_security.py @@ -18,13 +18,13 @@ "test_1:TestTopic": { "bindings": {"kafka": {"bindingVersion": "0.4.0", "topic": "test_1"}}, "servers": ["development"], - "subscribe": { + "publish": { "message": {"$ref": "#/components/messages/test_1:TestTopic:Message"}, }, }, "test_2:Publisher": { "bindings": {"kafka": {"bindingVersion": "0.4.0", "topic": "test_2"}}, - "publish": { + "subscribe": { "message": {"$ref": "#/components/messages/test_2:Publisher:Message"}, }, "servers": ["development"], @@ -186,7 +186,7 @@ async def test_topic(msg: str) -> str: {"oauthbearer": []}, ] sasl_oauthbearer_security_schema["components"]["securitySchemes"] = { - "oauthbearer": {"type": "oauthBearer"}, + "oauthbearer": {"type": "oauth2", "$ref": ""} } assert schema == sasl_oauthbearer_security_schema diff --git a/tests/asyncapi/confluent/v3_0_0/test_connection.py b/tests/asyncapi/confluent/v3_0_0/test_connection.py index d49503ef9a..63b9c51da3 100644 --- a/tests/asyncapi/confluent/v3_0_0/test_connection.py +++ b/tests/asyncapi/confluent/v3_0_0/test_connection.py @@ -1,6 +1,6 @@ from faststream.confluent import KafkaBroker +from faststream.specification import Tag from faststream.specification.asyncapi import AsyncAPI -from faststream.specification.schema.tag import Tag def test_base() -> None: diff --git a/tests/asyncapi/confluent/v3_0_0/test_security.py b/tests/asyncapi/confluent/v3_0_0/test_security.py index 6eb6a2d423..2aa802e304 100644 --- a/tests/asyncapi/confluent/v3_0_0/test_security.py +++ b/tests/asyncapi/confluent/v3_0_0/test_security.py @@ -205,7 +205,7 @@ async def test_topic(msg: str) -> str: {"oauthbearer": []}, ] sasl_oauthbearer_security_schema["components"]["securitySchemes"] = { - "oauthbearer": {"type": "oauthBearer"}, + "oauthbearer": {"type": "oauth2", "$ref": ""} } assert schema == sasl_oauthbearer_security_schema diff --git a/tests/asyncapi/kafka/v2_6_0/test_app.py b/tests/asyncapi/kafka/v2_6_0/test_app.py index 77470ef1cd..2bd9b5a916 100644 --- a/tests/asyncapi/kafka/v2_6_0/test_app.py +++ b/tests/asyncapi/kafka/v2_6_0/test_app.py @@ -1,9 +1,6 @@ from faststream.kafka import KafkaBroker +from faststream.specification import Contact, ExternalDocs, License, Tag from faststream.specification.asyncapi import AsyncAPI -from faststream.specification.schema.contact import Contact -from faststream.specification.schema.docs import ExternalDocs -from faststream.specification.schema.license import License -from faststream.specification.schema.tag import Tag def test_base() -> None: diff --git a/tests/asyncapi/kafka/v2_6_0/test_arguments.py b/tests/asyncapi/kafka/v2_6_0/test_arguments.py index ded081b739..c6af2e02ac 100644 --- a/tests/asyncapi/kafka/v2_6_0/test_arguments.py +++ b/tests/asyncapi/kafka/v2_6_0/test_arguments.py @@ -1,4 +1,4 @@ -from faststream.kafka import KafkaBroker +from faststream.kafka import KafkaBroker, TopicPartition from faststream.specification.asyncapi import AsyncAPI from tests.asyncapi.base.v2_6_0.arguments import ArgumentsTestcase @@ -18,3 +18,37 @@ async def handle(msg) -> None: ... assert schema["channels"][key]["bindings"] == { "kafka": {"bindingVersion": "0.4.0", "topic": "test"}, } + + def test_subscriber_with_one_topic_partitions(self) -> None: + broker = self.broker_class() + + part1 = TopicPartition("topic_name", 1) + part2 = TopicPartition("topic_name", 2) + + @broker.subscriber(partitions=[part1, part2]) + async def handle(msg): ... + + schema = AsyncAPI(self.build_app(broker), schema_version="2.6.0").to_jsonable() + key = tuple(schema["channels"].keys())[0] # noqa: RUF015 + + assert schema["channels"][key]["bindings"] == { + "kafka": {"bindingVersion": "0.4.0", "topic": "topic_name"} + } + + def test_subscriber_with_multi_topics_partitions(self) -> None: + broker = self.broker_class() + + part1 = TopicPartition("topic_name1", 1) + part2 = TopicPartition("topic_name2", 2) + + @broker.subscriber(partitions=[part1, part2]) + async def handle(msg): ... + + schema = AsyncAPI(self.build_app(broker), schema_version="2.6.0").to_jsonable() + key1 = tuple(schema["channels"].keys())[0] # noqa: RUF015 + key2 = tuple(schema["channels"].keys())[1] + + assert sorted(( + schema["channels"][key1]["bindings"]["kafka"]["topic"], + schema["channels"][key2]["bindings"]["kafka"]["topic"], + )) == sorted(("topic_name1", "topic_name2")) diff --git a/tests/asyncapi/kafka/v2_6_0/test_connection.py b/tests/asyncapi/kafka/v2_6_0/test_connection.py index cc7b61114b..2107e3882b 100644 --- a/tests/asyncapi/kafka/v2_6_0/test_connection.py +++ b/tests/asyncapi/kafka/v2_6_0/test_connection.py @@ -1,6 +1,6 @@ from faststream.kafka import KafkaBroker +from faststream.specification import Tag from faststream.specification.asyncapi import AsyncAPI -from faststream.specification.schema.tag import Tag def test_base() -> None: diff --git a/tests/asyncapi/kafka/v2_6_0/test_naming.py b/tests/asyncapi/kafka/v2_6_0/test_naming.py index 44297b5ef2..bba38e11b7 100644 --- a/tests/asyncapi/kafka/v2_6_0/test_naming.py +++ b/tests/asyncapi/kafka/v2_6_0/test_naming.py @@ -29,7 +29,7 @@ async def handle() -> None: ... "test:Handle": { "servers": ["development"], "bindings": {"kafka": {"topic": "test", "bindingVersion": "0.4.0"}}, - "subscribe": { + "publish": { "message": { "$ref": "#/components/messages/test:Handle:Message" }, diff --git a/tests/asyncapi/kafka/v2_6_0/test_router.py b/tests/asyncapi/kafka/v2_6_0/test_router.py index 043ebfa453..2fd0342eb0 100644 --- a/tests/asyncapi/kafka/v2_6_0/test_router.py +++ b/tests/asyncapi/kafka/v2_6_0/test_router.py @@ -40,7 +40,7 @@ async def handle(msg) -> None: ... "bindings": { "kafka": {"topic": "test_test", "bindingVersion": "0.4.0"}, }, - "subscribe": { + "publish": { "message": { "$ref": "#/components/messages/test_test:Handle:Message", }, diff --git a/tests/asyncapi/kafka/v2_6_0/test_security.py b/tests/asyncapi/kafka/v2_6_0/test_security.py index 88661995bf..b1275e242d 100644 --- a/tests/asyncapi/kafka/v2_6_0/test_security.py +++ b/tests/asyncapi/kafka/v2_6_0/test_security.py @@ -18,13 +18,13 @@ "test_1:TestTopic": { "bindings": {"kafka": {"bindingVersion": "0.4.0", "topic": "test_1"}}, "servers": ["development"], - "subscribe": { + "publish": { "message": {"$ref": "#/components/messages/test_1:TestTopic:Message"}, }, }, "test_2:Publisher": { "bindings": {"kafka": {"bindingVersion": "0.4.0", "topic": "test_2"}}, - "publish": { + "subscribe": { "message": {"$ref": "#/components/messages/test_2:Publisher:Message"}, }, "servers": ["development"], @@ -186,7 +186,7 @@ async def test_topic(msg: str) -> str: {"oauthbearer": []}, ] sasl_oauthbearer_security_schema["components"]["securitySchemes"] = { - "oauthbearer": {"type": "oauthBearer"}, + "oauthbearer": {"type": "oauth2", "$ref": ""} } assert schema == sasl_oauthbearer_security_schema diff --git a/tests/asyncapi/kafka/v3_0_0/test_connection.py b/tests/asyncapi/kafka/v3_0_0/test_connection.py index 280cb798d1..e1fb6cfaab 100644 --- a/tests/asyncapi/kafka/v3_0_0/test_connection.py +++ b/tests/asyncapi/kafka/v3_0_0/test_connection.py @@ -1,6 +1,6 @@ from faststream.kafka import KafkaBroker +from faststream.specification import Tag from faststream.specification.asyncapi import AsyncAPI -from faststream.specification.schema.tag import Tag def test_base() -> None: diff --git a/tests/asyncapi/kafka/v3_0_0/test_security.py b/tests/asyncapi/kafka/v3_0_0/test_security.py index 95ce792aee..ddb06cce77 100644 --- a/tests/asyncapi/kafka/v3_0_0/test_security.py +++ b/tests/asyncapi/kafka/v3_0_0/test_security.py @@ -205,7 +205,7 @@ async def test_topic(msg: str) -> str: {"oauthbearer": []}, ] sasl_oauthbearer_security_schema["components"]["securitySchemes"] = { - "oauthbearer": {"type": "oauthBearer"}, + "oauthbearer": {"type": "oauth2", "$ref": ""} } assert schema == sasl_oauthbearer_security_schema diff --git a/tests/asyncapi/nats/v2_6_0/test_arguments.py b/tests/asyncapi/nats/v2_6_0/test_arguments.py index d3b9b53a34..5ad34a0001 100644 --- a/tests/asyncapi/nats/v2_6_0/test_arguments.py +++ b/tests/asyncapi/nats/v2_6_0/test_arguments.py @@ -17,4 +17,4 @@ async def handle(msg) -> None: ... assert schema["channels"][key]["bindings"] == { "nats": {"bindingVersion": "custom", "subject": "test"}, - } + }, schema["channels"][key]["bindings"] diff --git a/tests/asyncapi/nats/v2_6_0/test_connection.py b/tests/asyncapi/nats/v2_6_0/test_connection.py index 8cb4110d78..486bbb8033 100644 --- a/tests/asyncapi/nats/v2_6_0/test_connection.py +++ b/tests/asyncapi/nats/v2_6_0/test_connection.py @@ -1,6 +1,6 @@ from faststream.nats import NatsBroker +from faststream.specification import Tag from faststream.specification.asyncapi import AsyncAPI -from faststream.specification.schema.tag import Tag def test_base() -> None: diff --git a/tests/asyncapi/nats/v2_6_0/test_naming.py b/tests/asyncapi/nats/v2_6_0/test_naming.py index 676be562f9..9c0738f9de 100644 --- a/tests/asyncapi/nats/v2_6_0/test_naming.py +++ b/tests/asyncapi/nats/v2_6_0/test_naming.py @@ -31,7 +31,7 @@ async def handle() -> None: ... "bindings": { "nats": {"subject": "test", "bindingVersion": "custom"}, }, - "subscribe": { + "publish": { "message": { "$ref": "#/components/messages/test:Handle:Message" }, diff --git a/tests/asyncapi/nats/v2_6_0/test_router.py b/tests/asyncapi/nats/v2_6_0/test_router.py index b830dc69fe..7986cba82e 100644 --- a/tests/asyncapi/nats/v2_6_0/test_router.py +++ b/tests/asyncapi/nats/v2_6_0/test_router.py @@ -40,7 +40,7 @@ async def handle(msg) -> None: ... "bindings": { "nats": {"subject": "test_test", "bindingVersion": "custom"}, }, - "subscribe": { + "publish": { "message": { "$ref": "#/components/messages/test_test:Handle:Message", }, diff --git a/tests/asyncapi/nats/v3_0_0/test_connection.py b/tests/asyncapi/nats/v3_0_0/test_connection.py index f4913252ef..f88fc0fb83 100644 --- a/tests/asyncapi/nats/v3_0_0/test_connection.py +++ b/tests/asyncapi/nats/v3_0_0/test_connection.py @@ -1,6 +1,6 @@ from faststream.nats import NatsBroker +from faststream.specification import Tag from faststream.specification.asyncapi import AsyncAPI -from faststream.specification.schema.tag import Tag def test_base() -> None: diff --git a/tests/asyncapi/rabbit/v2_6_0/test_connection.py b/tests/asyncapi/rabbit/v2_6_0/test_connection.py index 6fae2ac389..15781dcf0e 100644 --- a/tests/asyncapi/rabbit/v2_6_0/test_connection.py +++ b/tests/asyncapi/rabbit/v2_6_0/test_connection.py @@ -1,6 +1,6 @@ from faststream.rabbit import RabbitBroker +from faststream.specification import Tag from faststream.specification.asyncapi import AsyncAPI -from faststream.specification.schema.tag import Tag def test_base() -> None: @@ -74,7 +74,7 @@ def test_custom() -> None: }, }, }, - "publish": { + "subscribe": { "bindings": { "amqp": { "ack": True, @@ -115,4 +115,4 @@ def test_custom() -> None: }, }, } - ) + ), schema diff --git a/tests/asyncapi/rabbit/v2_6_0/test_naming.py b/tests/asyncapi/rabbit/v2_6_0/test_naming.py index b63ab03fcb..2ee937f21e 100644 --- a/tests/asyncapi/rabbit/v2_6_0/test_naming.py +++ b/tests/asyncapi/rabbit/v2_6_0/test_naming.py @@ -72,7 +72,7 @@ async def handle() -> None: ... "exchange": {"type": "default", "vhost": "/"}, }, }, - "subscribe": { + "publish": { "bindings": { "amqp": { "cc": "test", diff --git a/tests/asyncapi/rabbit/v2_6_0/test_publisher.py b/tests/asyncapi/rabbit/v2_6_0/test_publisher.py index b9c17a9d00..e24edd3fed 100644 --- a/tests/asyncapi/rabbit/v2_6_0/test_publisher.py +++ b/tests/asyncapi/rabbit/v2_6_0/test_publisher.py @@ -29,7 +29,7 @@ async def handle(msg) -> None: ... "is": "routingKey", }, }, - "publish": { + "subscribe": { "bindings": { "amqp": { "ack": True, @@ -105,7 +105,15 @@ async def handle(msg) -> None: ... "is": "routingKey", }, }, - "publish": { + "subscribe": { + "bindings": { + "amqp": { + "ack": True, + "bindingVersion": "0.2.0", + "deliveryMode": 1, + "mandatory": True, + }, + }, "message": { "$ref": "#/components/messages/_:test-ex:Publisher:Message", }, @@ -138,7 +146,7 @@ async def handle(msg) -> None: ... "is": "routingKey", }, }, - "publish": { + "subscribe": { "bindings": { "amqp": { "ack": True, @@ -168,7 +176,7 @@ async def handle(msg) -> None: ... "is": "routingKey", }, }, - "publish": { + "subscribe": { "bindings": { "amqp": { "ack": True, diff --git a/tests/asyncapi/rabbit/v2_6_0/test_router.py b/tests/asyncapi/rabbit/v2_6_0/test_router.py index 8dca0c9b1b..8e042da398 100644 --- a/tests/asyncapi/rabbit/v2_6_0/test_router.py +++ b/tests/asyncapi/rabbit/v2_6_0/test_router.py @@ -59,7 +59,7 @@ async def handle(msg) -> None: ... "exchange": {"type": "default", "vhost": "/"}, }, }, - "subscribe": { + "publish": { "bindings": { "amqp": { "cc": "test_key", diff --git a/tests/asyncapi/rabbit/v3_0_0/test_connection.py b/tests/asyncapi/rabbit/v3_0_0/test_connection.py index 0403cef9c5..971a89afec 100644 --- a/tests/asyncapi/rabbit/v3_0_0/test_connection.py +++ b/tests/asyncapi/rabbit/v3_0_0/test_connection.py @@ -1,6 +1,6 @@ from faststream.rabbit import RabbitBroker +from faststream.specification import Tag from faststream.specification.asyncapi import AsyncAPI -from faststream.specification.schema.tag import Tag def test_base() -> None: diff --git a/tests/asyncapi/rabbit/v3_0_0/test_publisher.py b/tests/asyncapi/rabbit/v3_0_0/test_publisher.py index 1456b5b86d..a108270da5 100644 --- a/tests/asyncapi/rabbit/v3_0_0/test_publisher.py +++ b/tests/asyncapi/rabbit/v3_0_0/test_publisher.py @@ -141,13 +141,19 @@ async def handle(msg) -> None: ... assert schema["operations"] == { "_:test-ex:Publisher": { "action": "send", - "channel": { - "$ref": "#/channels/_:test-ex:Publisher", + "bindings": { + "amqp": { + "ack": True, + "bindingVersion": "0.3.0", + "deliveryMode": 1, + "mandatory": True, + } }, + "channel": {"$ref": "#/channels/_:test-ex:Publisher"}, "messages": [ - {"$ref": "#/channels/_:test-ex:Publisher/messages/Message"}, + {"$ref": "#/channels/_:test-ex:Publisher/messages/Message"} ], - }, + } } def test_reusable_exchange(self) -> None: diff --git a/tests/asyncapi/redis/v2_6_0/test_arguments.py b/tests/asyncapi/redis/v2_6_0/test_arguments.py index b65598f6f5..403cccad84 100644 --- a/tests/asyncapi/redis/v2_6_0/test_arguments.py +++ b/tests/asyncapi/redis/v2_6_0/test_arguments.py @@ -79,8 +79,8 @@ async def handle(msg) -> None: ... "redis": { "bindingVersion": "custom", "channel": "test", - "consumer_name": "consumer", - "group_name": "group", + "consumerName": "consumer", + "groupName": "group", "method": "xreadgroup", }, } diff --git a/tests/asyncapi/redis/v2_6_0/test_connection.py b/tests/asyncapi/redis/v2_6_0/test_connection.py index 221e4cd430..194371e767 100644 --- a/tests/asyncapi/redis/v2_6_0/test_connection.py +++ b/tests/asyncapi/redis/v2_6_0/test_connection.py @@ -1,6 +1,6 @@ from faststream.redis import RedisBroker +from faststream.specification import Tag from faststream.specification.asyncapi import AsyncAPI -from faststream.specification.schema.tag import Tag def test_base() -> None: diff --git a/tests/asyncapi/redis/v2_6_0/test_naming.py b/tests/asyncapi/redis/v2_6_0/test_naming.py index 5afbf50267..e2558bb9a6 100644 --- a/tests/asyncapi/redis/v2_6_0/test_naming.py +++ b/tests/asyncapi/redis/v2_6_0/test_naming.py @@ -28,7 +28,7 @@ async def handle() -> None: ... }, }, "servers": ["development"], - "subscribe": { + "publish": { "message": { "$ref": "#/components/messages/test:Handle:Message" }, diff --git a/tests/asyncapi/redis/v2_6_0/test_router.py b/tests/asyncapi/redis/v2_6_0/test_router.py index 6b785d61bf..7d37538dbc 100644 --- a/tests/asyncapi/redis/v2_6_0/test_router.py +++ b/tests/asyncapi/redis/v2_6_0/test_router.py @@ -35,7 +35,7 @@ async def handle(msg) -> None: ... }, }, "servers": ["development"], - "subscribe": { + "publish": { "message": { "$ref": "#/components/messages/test_test:Handle:Message", }, diff --git a/tests/asyncapi/redis/v3_0_0/test_arguments.py b/tests/asyncapi/redis/v3_0_0/test_arguments.py index b9a3136274..0def5e4f41 100644 --- a/tests/asyncapi/redis/v3_0_0/test_arguments.py +++ b/tests/asyncapi/redis/v3_0_0/test_arguments.py @@ -79,8 +79,8 @@ async def handle(msg) -> None: ... "redis": { "bindingVersion": "custom", "channel": "test", - "consumer_name": "consumer", - "group_name": "group", + "consumerName": "consumer", + "groupName": "group", "method": "xreadgroup", }, } diff --git a/tests/asyncapi/redis/v3_0_0/test_connection.py b/tests/asyncapi/redis/v3_0_0/test_connection.py index 51d7224c50..968e67b464 100644 --- a/tests/asyncapi/redis/v3_0_0/test_connection.py +++ b/tests/asyncapi/redis/v3_0_0/test_connection.py @@ -1,6 +1,6 @@ from faststream.redis import RedisBroker +from faststream.specification import Tag from faststream.specification.asyncapi import AsyncAPI -from faststream.specification.schema.tag import Tag def test_base() -> None: diff --git a/tests/brokers/base/basic.py b/tests/brokers/base/basic.py index 23d6e3a1ec..28f9dbfa78 100644 --- a/tests/brokers/base/basic.py +++ b/tests/brokers/base/basic.py @@ -2,6 +2,7 @@ from typing import Any from faststream._internal.broker.broker import BrokerUsecase +from faststream._internal.broker.router import BrokerRouter class BaseTestcaseConfig: @@ -31,3 +32,7 @@ def get_subscriber_params( dict[str, Any], ]: return args, kwargs + + @abstractmethod + def get_router(self, **kwargs: Any) -> BrokerRouter: + raise NotImplementedError diff --git a/tests/brokers/base/middlewares.py b/tests/brokers/base/middlewares.py index 47e07c6398..9df2f84d19 100644 --- a/tests/brokers/base/middlewares.py +++ b/tests/brokers/base/middlewares.py @@ -1,5 +1,5 @@ import asyncio -from unittest.mock import Mock, call +from unittest.mock import MagicMock, call import pytest @@ -12,12 +12,249 @@ from .basic import BaseTestcaseConfig +@pytest.mark.asyncio() +class MiddlewaresOrderTestcase(BaseTestcaseConfig): + async def test_broker_middleware_order(self, queue: str, mock: MagicMock): + class InnerMiddleware(BaseMiddleware): + async def __aenter__(self): + mock.enter_inner() + mock.enter("inner") + + async def __aexit__(self, *args): + mock.exit_inner() + mock.exit("inner") + + async def consume_scope(self, call_next, msg): + mock.consume_inner() + mock.sub("inner") + return await call_next(msg) + + async def publish_scope(self, call_next, cmd): + mock.publish_inner() + mock.pub("inner") + return await call_next(cmd) + + class OuterMiddleware(BaseMiddleware): + async def __aenter__(self): + mock.enter_outer() + mock.enter("outer") + + async def __aexit__(self, *args): + mock.exit_outer() + mock.exit("outer") + + async def consume_scope(self, call_next, msg): + mock.consume_outer() + mock.sub("outer") + return await call_next(msg) + + async def publish_scope(self, call_next, cmd): + mock.publish_outer() + mock.pub("outer") + return await call_next(cmd) + + broker = self.get_broker(middlewares=[OuterMiddleware, InnerMiddleware]) + + args, kwargs = self.get_subscriber_params(queue) + + @broker.subscriber(*args, **kwargs) + async def handler(msg): + pass + + async with self.patch_broker(broker) as br: + await br.publish(None, queue) + + mock.consume_inner.assert_called_once() + mock.consume_outer.assert_called_once() + mock.publish_inner.assert_called_once() + mock.publish_outer.assert_called_once() + mock.enter_inner.assert_called_once() + mock.enter_outer.assert_called_once() + mock.exit_inner.assert_called_once() + mock.exit_outer.assert_called_once() + + assert [c.args[0] for c in mock.sub.call_args_list] == ["outer", "inner"] + assert [c.args[0] for c in mock.pub.call_args_list] == ["outer", "inner"] + assert [c.args[0] for c in mock.enter.call_args_list] == ["outer", "inner"] + assert [c.args[0] for c in mock.exit.call_args_list] == ["inner", "outer"] + + async def test_publisher_middleware_order(self, queue: str, mock: MagicMock): + class InnerMiddleware(BaseMiddleware): + async def publish_scope(self, call_next, cmd): + mock.publish_inner() + mock("inner") + return await call_next(cmd) + + class MiddleMiddleware(BaseMiddleware): + async def publish_scope(self, call_next, cmd): + mock.publish_middle() + mock("middle") + return await call_next(cmd) + + class OuterMiddleware(BaseMiddleware): + async def publish_scope(self, call_next, cmd): + mock.publish_outer() + mock("outer") + return await call_next(cmd) + + broker = self.get_broker(middlewares=[OuterMiddleware]) + publisher = broker.publisher( + queue, + middlewares=[ + MiddleMiddleware(None, context=None).publish_scope, + InnerMiddleware(None, context=None).publish_scope, + ], + ) + + args, kwargs = self.get_subscriber_params(queue) + + @broker.subscriber(*args, **kwargs) + async def handler(msg): + pass + + async with self.patch_broker(broker): + await publisher.publish(None, queue) + + mock.publish_inner.assert_called_once() + mock.publish_middle.assert_called_once() + mock.publish_outer.assert_called_once() + + assert [c.args[0] for c in mock.call_args_list] == ["outer", "middle", "inner"] + + async def test_publisher_with_router_middleware_order( + self, queue: str, mock: MagicMock + ): + class InnerMiddleware(BaseMiddleware): + async def publish_scope(self, call_next, cmd): + mock.publish_inner() + mock("inner") + return await call_next(cmd) + + class MiddleMiddleware(BaseMiddleware): + async def publish_scope(self, call_next, cmd): + mock.publish_middle() + mock("middle") + return await call_next(cmd) + + class OuterMiddleware(BaseMiddleware): + async def publish_scope(self, call_next, cmd): + mock.publish_outer() + mock("outer") + return await call_next(cmd) + + broker = self.get_broker(middlewares=[OuterMiddleware]) + router = self.get_router(middlewares=[MiddleMiddleware]) + router2 = self.get_router(middlewares=[InnerMiddleware]) + + publisher = router2.publisher(queue) + + args, kwargs = self.get_subscriber_params(queue) + + @router2.subscriber(*args, **kwargs) + async def handler(msg): + pass + + router.include_router(router2) + broker.include_router(router) + + async with self.patch_broker(broker): + await publisher.publish(None, queue) + + mock.publish_inner.assert_called_once() + mock.publish_middle.assert_called_once() + mock.publish_outer.assert_called_once() + + assert [c.args[0] for c in mock.call_args_list] == ["outer", "middle", "inner"] + + async def test_consume_middleware_order(self, queue: str, mock: MagicMock): + class InnerMiddleware(BaseMiddleware): + async def consume_scope(self, call_next, cmd): + mock.consume_inner() + mock("inner") + return await call_next(cmd) + + class MiddleMiddleware(BaseMiddleware): + async def consume_scope(self, call_next, cmd): + mock.consume_middle() + mock("middle") + return await call_next(cmd) + + class OuterMiddleware(BaseMiddleware): + async def consume_scope(self, call_next, cmd): + mock.consume_outer() + mock("outer") + return await call_next(cmd) + + broker = self.get_broker(middlewares=[OuterMiddleware]) + + args, kwargs = self.get_subscriber_params( + queue, + middlewares=[ + MiddleMiddleware(None, context=None).consume_scope, + InnerMiddleware(None, context=None).consume_scope, + ], + ) + + @broker.subscriber(*args, **kwargs) + async def handler(msg): + pass + + async with self.patch_broker(broker) as br: + await br.publish(None, queue) + + mock.consume_inner.assert_called_once() + mock.consume_middle.assert_called_once() + mock.consume_outer.assert_called_once() + + assert [c.args[0] for c in mock.call_args_list] == ["outer", "middle", "inner"] + + async def test_consume_with_middleware_order(self, queue: str, mock: MagicMock): + class InnerMiddleware(BaseMiddleware): + async def consume_scope(self, call_next, cmd): + mock.consume_inner() + mock("inner") + return await call_next(cmd) + + class MiddleMiddleware(BaseMiddleware): + async def consume_scope(self, call_next, cmd): + mock.consume_middle() + mock("middle") + return await call_next(cmd) + + class OuterMiddleware(BaseMiddleware): + async def consume_scope(self, call_next, cmd): + mock.consume_outer() + mock("outer") + return await call_next(cmd) + + broker = self.get_broker(middlewares=[OuterMiddleware]) + router = self.get_router(middlewares=[MiddleMiddleware]) + router2 = self.get_router(middlewares=[InnerMiddleware]) + + args, kwargs = self.get_subscriber_params(queue) + + @router2.subscriber(*args, **kwargs) + async def handler(msg): + pass + + router.include_router(router2) + broker.include_router(router) + async with self.patch_broker(broker) as br: + await br.publish(None, queue) + + mock.consume_inner.assert_called_once() + mock.consume_middle.assert_called_once() + mock.consume_outer.assert_called_once() + + assert [c.args[0] for c in mock.call_args_list] == ["outer", "middle", "inner"] + + @pytest.mark.asyncio() class LocalMiddlewareTestcase(BaseTestcaseConfig): async def test_subscriber_middleware( self, queue: str, - mock: Mock, + mock: MagicMock, ) -> None: event = asyncio.Event() @@ -56,7 +293,7 @@ async def handler(m) -> str: async def test_publisher_middleware( self, queue: str, - mock: Mock, + mock: MagicMock, ) -> None: event = asyncio.Event() @@ -97,7 +334,7 @@ async def handler(m) -> str: async def test_local_middleware_not_shared_between_subscribers( self, queue: str, - mock: Mock, + mock: MagicMock, ) -> None: event1 = asyncio.Event() event2 = asyncio.Event() @@ -147,7 +384,7 @@ async def handler(m) -> str: async def test_local_middleware_consume_not_shared_between_filters( self, queue: str, - mock: Mock, + mock: MagicMock, ) -> None: event1 = asyncio.Event() event2 = asyncio.Event() @@ -196,7 +433,7 @@ async def handler2(m) -> str: mock.end.assert_called_once() assert mock.call_count == 2 - async def test_error_traceback(self, queue: str, mock: Mock) -> None: + async def test_error_traceback(self, queue: str, mock: MagicMock) -> None: event = asyncio.Event() async def mid(call_next, msg): @@ -237,7 +474,7 @@ class MiddlewareTestcase(LocalMiddlewareTestcase): async def test_global_middleware( self, queue: str, - mock: Mock, + mock: MagicMock, ) -> None: event = asyncio.Event() @@ -278,7 +515,7 @@ async def handler(m) -> str: async def test_add_global_middleware( self, queue: str, - mock: Mock, + mock: MagicMock, ) -> None: event = asyncio.Event() @@ -333,7 +570,7 @@ async def handler2(m) -> str: async def test_patch_publish( self, queue: str, - mock: Mock, + mock: MagicMock, ) -> None: event = asyncio.Event() @@ -374,7 +611,7 @@ async def handler_resp(m) -> None: async def test_global_publisher_middleware( self, queue: str, - mock: Mock, + mock: MagicMock, ) -> None: event = asyncio.Event() @@ -422,7 +659,7 @@ class ExceptionMiddlewareTestcase(BaseTestcaseConfig): async def test_exception_middleware_default_msg( self, queue: str, - mock: Mock, + mock: MagicMock, ) -> None: event = asyncio.Event() @@ -465,7 +702,7 @@ async def subscriber2(msg=Context("message")) -> None: async def test_exception_middleware_skip_msg( self, queue: str, - mock: Mock, + mock: MagicMock, ) -> None: event = asyncio.Event() @@ -506,7 +743,7 @@ async def subscriber2(msg=Context("message")) -> None: async def test_exception_middleware_do_not_catch_skip_msg( self, queue: str, - mock: Mock, + mock: MagicMock, ) -> None: event = asyncio.Event() @@ -541,7 +778,7 @@ async def subscriber(m): async def test_exception_middleware_reraise( self, queue: str, - mock: Mock, + mock: MagicMock, ) -> None: event = asyncio.Event() @@ -582,7 +819,7 @@ async def subscriber2(msg=Context("message")) -> None: async def test_exception_middleware_different_handler( self, queue: str, - mock: Mock, + mock: MagicMock, ) -> None: event = asyncio.Event() @@ -663,7 +900,7 @@ async def value_error_handler(exc) -> str: async def test_exception_middleware_decoder_error( self, queue: str, - mock: Mock, + mock: MagicMock, ) -> None: event = asyncio.Event() diff --git a/tests/brokers/confluent/basic.py b/tests/brokers/confluent/basic.py index 2ea7a94c91..4b9e626695 100644 --- a/tests/brokers/confluent/basic.py +++ b/tests/brokers/confluent/basic.py @@ -1,10 +1,15 @@ from typing import Any -from faststream.confluent import TopicPartition -from tests.brokers.base.basic import BaseTestcaseConfig as _Base +from faststream.confluent import ( + KafkaBroker, + KafkaRouter, + TestKafkaBroker, + TopicPartition, +) +from tests.brokers.base.basic import BaseTestcaseConfig -class ConfluentTestcaseConfig(_Base): +class ConfluentTestcaseConfig(BaseTestcaseConfig): timeout: float = 10.0 def get_subscriber_params( @@ -27,3 +32,21 @@ def get_subscriber_params( "partitions": partitions, **kwargs, } + + def get_broker( + self, + apply_types: bool = False, + **kwargs: Any, + ) -> KafkaBroker: + return KafkaBroker(apply_types=apply_types, **kwargs) + + def patch_broker(self, broker: KafkaBroker, **kwargs: Any) -> KafkaBroker: + return broker + + def get_router(self, **kwargs: Any) -> KafkaRouter: + return KafkaRouter(**kwargs) + + +class ConfluentMemoryTestcaseConfig(ConfluentTestcaseConfig): + def patch_broker(self, broker: KafkaBroker, **kwargs: Any) -> KafkaBroker: + return TestKafkaBroker(broker, **kwargs) diff --git a/tests/brokers/confluent/test_consume.py b/tests/brokers/confluent/test_consume.py index 2e886bc1f6..4312be906f 100644 --- a/tests/brokers/confluent/test_consume.py +++ b/tests/brokers/confluent/test_consume.py @@ -1,11 +1,9 @@ import asyncio -from typing import Any from unittest.mock import patch import pytest from faststream import AckPolicy -from faststream.confluent import KafkaBroker from faststream.confluent.annotations import KafkaMessage from faststream.confluent.client import AsyncConfluentConsumer from faststream.exceptions import AckMessage @@ -19,9 +17,6 @@ class TestConsume(ConfluentTestcaseConfig, BrokerRealConsumeTestcase): """A class to represent a test Kafka broker.""" - def get_broker(self, apply_types: bool = False, **kwargs: Any) -> KafkaBroker: - return KafkaBroker(apply_types=apply_types, **kwargs) - @pytest.mark.asyncio() async def test_consume_batch(self, queue: str) -> None: consume_broker = self.get_broker() @@ -86,7 +81,7 @@ def subscriber(m, msg: KafkaMessage) -> None: @pytest.mark.asyncio() @pytest.mark.slow() - async def test_consume_ack( + async def test_consume_auto_ack( self, queue: str, ) -> None: @@ -97,7 +92,7 @@ async def test_consume_ack( args, kwargs = self.get_subscriber_params( queue, group_id="test", - auto_commit=False, + ack_policy=AckPolicy.REJECT_ON_ERROR, ) @consume_broker.subscriber(*args, **kwargs) @@ -141,7 +136,7 @@ async def test_consume_ack_manual( args, kwargs = self.get_subscriber_params( queue, group_id="test", - auto_commit=False, + ack_policy=AckPolicy.REJECT_ON_ERROR, ) @consume_broker.subscriber(*args, **kwargs) @@ -181,7 +176,7 @@ async def test_consume_ack_raise( args, kwargs = self.get_subscriber_params( queue, group_id="test", - auto_commit=False, + ack_policy=AckPolicy.REJECT_ON_ERROR, ) @consume_broker.subscriber(*args, **kwargs) @@ -221,7 +216,7 @@ async def test_nack( args, kwargs = self.get_subscriber_params( queue, group_id="test", - auto_commit=False, + ack_policy=AckPolicy.REJECT_ON_ERROR, ) @consume_broker.subscriber(*args, **kwargs) @@ -302,8 +297,8 @@ async def test_consume_with_no_auto_commit( args, kwargs = self.get_subscriber_params( queue, - auto_commit=False, group_id="test", + ack_policy=AckPolicy.REJECT_ON_ERROR, ) @consume_broker.subscriber(*args, **kwargs) @@ -316,8 +311,8 @@ async def subscriber_no_auto_commit(msg: KafkaMessage) -> None: args, kwargs = self.get_subscriber_params( queue, - auto_commit=True, group_id="test", + ack_policy=AckPolicy.REJECT_ON_ERROR, ) @broker2.subscriber(*args, **kwargs) diff --git a/tests/brokers/confluent/test_fastapi.py b/tests/brokers/confluent/test_fastapi.py index d47612c91e..1059e4d433 100644 --- a/tests/brokers/confluent/test_fastapi.py +++ b/tests/brokers/confluent/test_fastapi.py @@ -1,15 +1,13 @@ import asyncio -from typing import Any from unittest.mock import Mock import pytest -from faststream.confluent import KafkaBroker, KafkaRouter +from faststream.confluent import KafkaRouter from faststream.confluent.fastapi import KafkaRouter as StreamRouter -from faststream.confluent.testing import TestKafkaBroker from tests.brokers.base.fastapi import FastAPILocalTestcase, FastAPITestcase -from .basic import ConfluentTestcaseConfig +from .basic import ConfluentMemoryTestcaseConfig, ConfluentTestcaseConfig @pytest.mark.confluent() @@ -47,13 +45,10 @@ async def hello(msg: list[str]): mock.assert_called_with(["hi"]) -class TestRouterLocal(ConfluentTestcaseConfig, FastAPILocalTestcase): +class TestRouterLocal(ConfluentMemoryTestcaseConfig, FastAPILocalTestcase): router_class = StreamRouter broker_router_class = KafkaRouter - def patch_broker(self, broker: KafkaBroker, **kwargs: Any) -> TestKafkaBroker: - return TestKafkaBroker(broker, **kwargs) - async def test_batch_testclient( self, mock: Mock, diff --git a/tests/brokers/confluent/test_logger.py b/tests/brokers/confluent/test_logger.py index 01b74ef5f9..ce1c8f145f 100644 --- a/tests/brokers/confluent/test_logger.py +++ b/tests/brokers/confluent/test_logger.py @@ -2,8 +2,6 @@ import pytest -from faststream.confluent import KafkaBroker - from .basic import ConfluentTestcaseConfig @@ -14,7 +12,7 @@ class TestLogger(ConfluentTestcaseConfig): @pytest.mark.asyncio() async def test_custom_logger(self, queue: str) -> None: test_logger = logging.getLogger("test_logger") - broker = KafkaBroker(logger=test_logger) + broker = self.get_broker(logger=test_logger) args, kwargs = self.get_subscriber_params(queue) diff --git a/tests/brokers/confluent/test_middlewares.py b/tests/brokers/confluent/test_middlewares.py index 5309d8008e..b81c9a4325 100644 --- a/tests/brokers/confluent/test_middlewares.py +++ b/tests/brokers/confluent/test_middlewares.py @@ -1,23 +1,23 @@ -from typing import Any - import pytest -from faststream.confluent import KafkaBroker from tests.brokers.base.middlewares import ( ExceptionMiddlewareTestcase, MiddlewareTestcase, + MiddlewaresOrderTestcase, ) -from .basic import ConfluentTestcaseConfig +from .basic import ConfluentMemoryTestcaseConfig, ConfluentTestcaseConfig + + +class TestMiddlewaresOrder(ConfluentMemoryTestcaseConfig, MiddlewaresOrderTestcase): + pass @pytest.mark.confluent() class TestMiddlewares(ConfluentTestcaseConfig, MiddlewareTestcase): - def get_broker(self, apply_types: bool = False, **kwargs: Any) -> KafkaBroker: - return KafkaBroker(apply_types=apply_types, **kwargs) + pass @pytest.mark.confluent() class TestExceptionMiddlewares(ConfluentTestcaseConfig, ExceptionMiddlewareTestcase): - def get_broker(self, apply_types: bool = False, **kwargs: Any) -> KafkaBroker: - return KafkaBroker(apply_types=apply_types, **kwargs) + pass diff --git a/tests/brokers/confluent/test_misconfigure.py b/tests/brokers/confluent/test_misconfigure.py new file mode 100644 index 0000000000..9c533c22e5 --- /dev/null +++ b/tests/brokers/confluent/test_misconfigure.py @@ -0,0 +1,60 @@ +import pytest + +from faststream import AckPolicy +from faststream.confluent import KafkaBroker, TopicPartition +from faststream.exceptions import SetupError + + +def test_deprecated_options(queue: str) -> None: + broker = KafkaBroker() + + with pytest.warns(DeprecationWarning): + broker.subscriber(queue, group_id="test", auto_commit=False) + + with pytest.warns(DeprecationWarning): + broker.subscriber(queue, auto_commit=True) + + with pytest.warns(DeprecationWarning): + broker.subscriber(queue, group_id="test", no_ack=False) + + with pytest.warns(DeprecationWarning): + broker.subscriber(queue, group_id="test", no_ack=True) + + +def test_deprecated_conflicts_actual(queue: str) -> None: + broker = KafkaBroker() + + with pytest.raises(SetupError), pytest.warns(DeprecationWarning): + broker.subscriber(queue, auto_commit=False, ack_policy=AckPolicy.ACK) + + with pytest.raises(SetupError), pytest.warns(DeprecationWarning): + broker.subscriber(queue, no_ack=False, ack_policy=AckPolicy.ACK) + + +def test_manual_ack_policy_without_group(queue: str) -> None: + broker = KafkaBroker() + + broker.subscriber(queue, group_id="test", ack_policy=AckPolicy.DO_NOTHING) + + with pytest.raises(SetupError): + broker.subscriber(queue, ack_policy=AckPolicy.DO_NOTHING) + + +def test_manual_commit_without_group(queue: str) -> None: + broker = KafkaBroker() + + with pytest.warns(DeprecationWarning): + broker.subscriber(queue, group_id="test", auto_commit=False) + + with pytest.raises(SetupError), pytest.warns(DeprecationWarning): + broker.subscriber(queue, auto_commit=False) + + +def test_wrong_destination(queue: str) -> None: + broker = KafkaBroker() + + with pytest.raises(SetupError): + broker.subscriber() + + with pytest.raises(SetupError): + broker.subscriber(queue, partitions=[TopicPartition(queue, 1)]) diff --git a/tests/brokers/confluent/test_parser.py b/tests/brokers/confluent/test_parser.py index deb150cc9a..65aa2bff15 100644 --- a/tests/brokers/confluent/test_parser.py +++ b/tests/brokers/confluent/test_parser.py @@ -1,8 +1,5 @@ -from typing import Any - import pytest -from faststream.confluent import KafkaBroker from tests.brokers.base.parser import CustomParserTestcase from .basic import ConfluentTestcaseConfig @@ -10,5 +7,4 @@ @pytest.mark.confluent() class TestCustomParser(ConfluentTestcaseConfig, CustomParserTestcase): - def get_broker(self, apply_types: bool = False, **kwargs: Any) -> KafkaBroker: - return KafkaBroker(apply_types=apply_types, **kwargs) + pass diff --git a/tests/brokers/confluent/test_publish.py b/tests/brokers/confluent/test_publish.py index 0d6fc9f4e1..7ed7a76ac8 100644 --- a/tests/brokers/confluent/test_publish.py +++ b/tests/brokers/confluent/test_publish.py @@ -1,11 +1,10 @@ import asyncio -from typing import Any from unittest.mock import Mock import pytest from faststream import Context -from faststream.confluent import KafkaBroker, KafkaResponse +from faststream.confluent import KafkaResponse from tests.brokers.base.publish import BrokerPublishTestcase from .basic import ConfluentTestcaseConfig @@ -13,9 +12,6 @@ @pytest.mark.confluent() class TestPublish(ConfluentTestcaseConfig, BrokerPublishTestcase): - def get_broker(self, apply_types: bool = False, **kwargs: Any) -> KafkaBroker: - return KafkaBroker(apply_types=apply_types, **kwargs) - @pytest.mark.asyncio() async def test_publish_batch(self, queue: str) -> None: pub_broker = self.get_broker() diff --git a/tests/brokers/confluent/test_requests.py b/tests/brokers/confluent/test_requests.py index a0343ced57..fac1b09331 100644 --- a/tests/brokers/confluent/test_requests.py +++ b/tests/brokers/confluent/test_requests.py @@ -3,10 +3,9 @@ import pytest from faststream import BaseMiddleware -from faststream.confluent import KafkaBroker, KafkaRouter, TestKafkaBroker from tests.brokers.base.requests import RequestsTestcase -from .basic import ConfluentTestcaseConfig +from .basic import ConfluentMemoryTestcaseConfig class Mid(BaseMiddleware): @@ -19,15 +18,6 @@ async def consume_scope(self, call_next, msg): @pytest.mark.asyncio() -class TestRequestTestClient(ConfluentTestcaseConfig, RequestsTestcase): +class TestRequestTestClient(ConfluentMemoryTestcaseConfig, RequestsTestcase): def get_middleware(self, **kwargs: Any): return Mid - - def get_router(self, **kwargs: Any): - return KafkaRouter(**kwargs) - - def get_broker(self, apply_types: bool = False, **kwargs: Any) -> KafkaBroker: - return KafkaBroker(apply_types=apply_types, **kwargs) - - def patch_broker(self, broker: KafkaBroker, **kwargs: Any) -> TestKafkaBroker: - return TestKafkaBroker(broker, **kwargs) diff --git a/tests/brokers/confluent/test_router.py b/tests/brokers/confluent/test_router.py index 8d134e62c3..c26198d8d4 100644 --- a/tests/brokers/confluent/test_router.py +++ b/tests/brokers/confluent/test_router.py @@ -1,36 +1,20 @@ -from typing import Any - import pytest from faststream.confluent import ( - KafkaBroker, KafkaPublisher, KafkaRoute, - KafkaRouter, - TestKafkaBroker, ) from tests.brokers.base.router import RouterLocalTestcase, RouterTestcase -from .basic import ConfluentTestcaseConfig +from .basic import ConfluentMemoryTestcaseConfig, ConfluentTestcaseConfig @pytest.mark.confluent() class TestRouter(ConfluentTestcaseConfig, RouterTestcase): - broker_class = KafkaRouter route_class = KafkaRoute publisher_class = KafkaPublisher - def get_broker(self, apply_types: bool = False, **kwargs: Any) -> KafkaBroker: - return KafkaBroker(apply_types=apply_types, **kwargs) - -class TestRouterLocal(ConfluentTestcaseConfig, RouterLocalTestcase): - broker_class = KafkaRouter +class TestRouterLocal(ConfluentMemoryTestcaseConfig, RouterLocalTestcase): route_class = KafkaRoute publisher_class = KafkaPublisher - - def get_broker(self, apply_types: bool = False, **kwargs: Any) -> KafkaBroker: - return KafkaBroker(apply_types=apply_types, **kwargs) - - def patch_broker(self, broker: KafkaBroker, **kwargs: Any) -> TestKafkaBroker: - return TestKafkaBroker(broker, **kwargs) diff --git a/tests/brokers/confluent/test_test_client.py b/tests/brokers/confluent/test_test_client.py index 73077ec147..01fc161746 100644 --- a/tests/brokers/confluent/test_test_client.py +++ b/tests/brokers/confluent/test_test_client.py @@ -1,28 +1,20 @@ import asyncio -from typing import Any from unittest.mock import patch import pytest -from faststream import BaseMiddleware -from faststream.confluent import KafkaBroker, TestKafkaBroker +from faststream import AckPolicy, BaseMiddleware from faststream.confluent.annotations import KafkaMessage from faststream.confluent.message import FAKE_CONSUMER from faststream.confluent.testing import FakeProducer from tests.brokers.base.testclient import BrokerTestclientTestcase from tests.tools import spy_decorator -from .basic import ConfluentTestcaseConfig +from .basic import ConfluentMemoryTestcaseConfig @pytest.mark.asyncio() -class TestTestclient(ConfluentTestcaseConfig, BrokerTestclientTestcase): - def get_broker(self, apply_types: bool = False, **kwargs: Any) -> KafkaBroker: - return KafkaBroker(apply_types=apply_types, **kwargs) - - def patch_broker(self, broker: KafkaBroker, **kwargs: Any) -> TestKafkaBroker: - return TestKafkaBroker(broker, **kwargs) - +class TestTestclient(ConfluentMemoryTestcaseConfig, BrokerTestclientTestcase): async def test_message_nack_seek( self, queue: str, @@ -32,8 +24,8 @@ async def test_message_nack_seek( @broker.subscriber( queue, group_id=f"{queue}-consume", - auto_commit=False, auto_offset_reset="earliest", + ack_policy=AckPolicy.REJECT_ON_ERROR, ) async def m(msg: KafkaMessage) -> None: await msg.nack() diff --git a/tests/brokers/kafka/basic.py b/tests/brokers/kafka/basic.py new file mode 100644 index 0000000000..39c095a637 --- /dev/null +++ b/tests/brokers/kafka/basic.py @@ -0,0 +1,24 @@ +from typing import Any + +from faststream.kafka import KafkaBroker, KafkaRouter, TestKafkaBroker +from tests.brokers.base.basic import BaseTestcaseConfig + + +class KafkaTestcaseConfig(BaseTestcaseConfig): + def get_broker( + self, + apply_types: bool = False, + **kwargs: Any, + ) -> KafkaBroker: + return KafkaBroker(apply_types=apply_types, **kwargs) + + def patch_broker(self, broker: KafkaBroker, **kwargs: Any) -> KafkaBroker: + return broker + + def get_router(self, **kwargs: Any) -> KafkaRouter: + return KafkaRouter(**kwargs) + + +class KafkaMemoryTestcaseConfig(KafkaTestcaseConfig): + def patch_broker(self, broker: KafkaBroker, **kwargs: Any) -> KafkaBroker: + return TestKafkaBroker(broker, **kwargs) diff --git a/tests/brokers/kafka/test_consume.py b/tests/brokers/kafka/test_consume.py index f725f90371..9a9172fbd4 100644 --- a/tests/brokers/kafka/test_consume.py +++ b/tests/brokers/kafka/test_consume.py @@ -1,23 +1,21 @@ import asyncio -from typing import Any -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest from aiokafka import AIOKafkaConsumer from faststream import AckPolicy from faststream.exceptions import AckMessage -from faststream.kafka import KafkaBroker, TopicPartition +from faststream.kafka import TopicPartition from faststream.kafka.annotations import KafkaMessage from tests.brokers.base.consume import BrokerRealConsumeTestcase from tests.tools import spy_decorator +from .basic import KafkaTestcaseConfig -@pytest.mark.kafka() -class TestConsume(BrokerRealConsumeTestcase): - def get_broker(self, apply_types: bool = False, **kwargs: Any) -> KafkaBroker: - return KafkaBroker(apply_types=apply_types, **kwargs) +@pytest.mark.kafka() +class TestConsume(KafkaTestcaseConfig, BrokerRealConsumeTestcase): @pytest.mark.asyncio() async def test_consume_by_pattern( self, @@ -114,7 +112,7 @@ def subscriber(m, msg: KafkaMessage) -> None: @pytest.mark.asyncio() @pytest.mark.slow() - async def test_consume_ack( + async def test_consume_auto_ack( self, queue: str, ) -> None: @@ -122,7 +120,9 @@ async def test_consume_ack( consume_broker = self.get_broker(apply_types=True) - @consume_broker.subscriber(queue, group_id="test", auto_commit=False) + @consume_broker.subscriber( + queue, group_id="test", ack_policy=AckPolicy.REJECT_ON_ERROR + ) async def handler(msg: KafkaMessage) -> None: event.set() @@ -188,7 +188,9 @@ async def test_consume_ack_manual( consume_broker = self.get_broker(apply_types=True) - @consume_broker.subscriber(queue, group_id="test", auto_commit=False) + @consume_broker.subscriber( + queue, group_id="test", ack_policy=AckPolicy.REJECT_ON_ERROR + ) async def handler(msg: KafkaMessage) -> None: await msg.ack() event.set() @@ -219,7 +221,7 @@ async def handler(msg: KafkaMessage) -> None: @pytest.mark.asyncio() @pytest.mark.slow() - async def test_consume_ack_raise( + async def test_consume_ack_by_raise( self, queue: str, ) -> None: @@ -227,7 +229,9 @@ async def test_consume_ack_raise( consume_broker = self.get_broker(apply_types=True) - @consume_broker.subscriber(queue, group_id="test", auto_commit=False) + @consume_broker.subscriber( + queue, group_id="test", ack_policy=AckPolicy.REJECT_ON_ERROR + ) async def handler(msg: KafkaMessage): event.set() raise AckMessage @@ -258,7 +262,7 @@ async def handler(msg: KafkaMessage): @pytest.mark.asyncio() @pytest.mark.slow() - async def test_nack( + async def test_manual_nack( self, queue: str, ) -> None: @@ -266,7 +270,9 @@ async def test_nack( consume_broker = self.get_broker(apply_types=True) - @consume_broker.subscriber(queue, group_id="test", auto_commit=False) + @consume_broker.subscriber( + queue, group_id="test", ack_policy=AckPolicy.REJECT_ON_ERROR + ) async def handler(msg: KafkaMessage) -> None: await msg.nack() event.set() @@ -334,3 +340,42 @@ async def handler(msg: KafkaMessage) -> None: m.mock.assert_not_called() assert event.is_set() + + @pytest.mark.asyncio() + @pytest.mark.slow() + async def test_concurrent_consume(self, queue: str, mock: MagicMock) -> None: + event = asyncio.Event() + event2 = asyncio.Event() + + consume_broker = self.get_broker() + + args, kwargs = self.get_subscriber_params(queue, max_workers=2) + + @consume_broker.subscriber(*args, **kwargs) + async def handler(msg) -> None: + mock() + if event.is_set(): + event2.set() + else: + event.set() + + # probably, we should increase it + await asyncio.sleep(0.1) + + async with self.patch_broker(consume_broker) as br: + await br.start() + + for i in range(5): + await br.publish(i, queue) + + await asyncio.wait( + ( + asyncio.create_task(event.wait()), + asyncio.create_task(event2.wait()), + ), + timeout=3, + ) + + assert event.is_set() + assert event2.is_set() + assert mock.call_count == 2, mock.call_count diff --git a/tests/brokers/kafka/test_fastapi.py b/tests/brokers/kafka/test_fastapi.py index 899deaffce..3da7a8bf51 100644 --- a/tests/brokers/kafka/test_fastapi.py +++ b/tests/brokers/kafka/test_fastapi.py @@ -1,13 +1,14 @@ import asyncio -from typing import Any from unittest.mock import Mock import pytest -from faststream.kafka import KafkaBroker, KafkaRouter, TestKafkaBroker +from faststream.kafka import KafkaRouter from faststream.kafka.fastapi import KafkaRouter as StreamRouter from tests.brokers.base.fastapi import FastAPILocalTestcase, FastAPITestcase +from .basic import KafkaMemoryTestcaseConfig + @pytest.mark.kafka() class TestKafkaRouter(FastAPITestcase): @@ -42,13 +43,10 @@ async def hello(msg: list[str]): mock.assert_called_with(["hi"]) -class TestRouterLocal(FastAPILocalTestcase): +class TestRouterLocal(KafkaMemoryTestcaseConfig, FastAPILocalTestcase): router_class = StreamRouter broker_router_class = KafkaRouter - def patch_broker(self, broker: KafkaBroker, **kwargs: Any) -> TestKafkaBroker: - return TestKafkaBroker(broker, **kwargs) - async def test_batch_testclient( self, mock: Mock, diff --git a/tests/brokers/kafka/test_middlewares.py b/tests/brokers/kafka/test_middlewares.py index 3bde2314f0..58d5bf7d31 100644 --- a/tests/brokers/kafka/test_middlewares.py +++ b/tests/brokers/kafka/test_middlewares.py @@ -1,21 +1,23 @@ -from typing import Any - import pytest -from faststream.kafka import KafkaBroker from tests.brokers.base.middlewares import ( ExceptionMiddlewareTestcase, MiddlewareTestcase, + MiddlewaresOrderTestcase, ) +from .basic import KafkaMemoryTestcaseConfig, KafkaTestcaseConfig + + +class TestMiddlewaresOrder(KafkaMemoryTestcaseConfig, MiddlewaresOrderTestcase): + pass + @pytest.mark.kafka() -class TestMiddlewares(MiddlewareTestcase): - def get_broker(self, apply_types: bool = False, **kwargs: Any) -> KafkaBroker: - return KafkaBroker(apply_types=apply_types, **kwargs) +class TestMiddlewares(KafkaTestcaseConfig, MiddlewareTestcase): + pass @pytest.mark.kafka() -class TestExceptionMiddlewares(ExceptionMiddlewareTestcase): - def get_broker(self, apply_types: bool = False, **kwargs: Any) -> KafkaBroker: - return KafkaBroker(apply_types=apply_types, **kwargs) +class TestExceptionMiddlewares(KafkaTestcaseConfig, ExceptionMiddlewareTestcase): + pass diff --git a/tests/brokers/kafka/test_misconfigure.py b/tests/brokers/kafka/test_misconfigure.py new file mode 100644 index 0000000000..79bd8bdef8 --- /dev/null +++ b/tests/brokers/kafka/test_misconfigure.py @@ -0,0 +1,90 @@ +import pytest + +from faststream import AckPolicy +from faststream.exceptions import SetupError +from faststream.kafka import KafkaBroker, TopicPartition +from faststream.kafka.subscriber.specified import ( + SpecificationConcurrentDefaultSubscriber, +) + + +def test_deprecated_options(queue: str) -> None: + broker = KafkaBroker() + + with pytest.warns(DeprecationWarning): + broker.subscriber(queue, group_id="test", auto_commit=False) + + with pytest.warns(DeprecationWarning): + broker.subscriber(queue, auto_commit=True) + + with pytest.warns(DeprecationWarning): + broker.subscriber(queue, no_ack=False) + + with pytest.warns(DeprecationWarning): + broker.subscriber(queue, group_id="test", no_ack=True) + + +def test_deprecated_conflicts_actual(queue: str) -> None: + broker = KafkaBroker() + + with pytest.raises(SetupError), pytest.warns(DeprecationWarning): + broker.subscriber(queue, auto_commit=False, ack_policy=AckPolicy.ACK) + + with pytest.raises(SetupError), pytest.warns(DeprecationWarning): + broker.subscriber(queue, no_ack=False, ack_policy=AckPolicy.ACK) + + +def test_manual_ack_policy_without_group(queue: str) -> None: + broker = KafkaBroker() + + broker.subscriber(queue, group_id="test", ack_policy=AckPolicy.DO_NOTHING) + + with pytest.raises(SetupError): + broker.subscriber(queue, ack_policy=AckPolicy.DO_NOTHING) + + +def test_manual_commit_without_group(queue: str) -> None: + broker = KafkaBroker() + + with pytest.warns(DeprecationWarning): + broker.subscriber(queue, group_id="test", auto_commit=False) + + with pytest.raises(SetupError), pytest.warns(DeprecationWarning): + broker.subscriber(queue, auto_commit=False) + + +def test_max_workers_with_manual(queue: str) -> None: + broker = KafkaBroker() + + with pytest.warns(DeprecationWarning): + sub = broker.subscriber(queue, max_workers=3, auto_commit=True) + assert isinstance(sub, SpecificationConcurrentDefaultSubscriber) + + with pytest.raises(SetupError), pytest.warns(DeprecationWarning): + broker.subscriber(queue, max_workers=3, auto_commit=False) + + +def test_max_workers_with_ack_policy(queue: str) -> None: + broker = KafkaBroker() + + sub = broker.subscriber(queue, max_workers=3, ack_policy=AckPolicy.ACK_FIRST) + assert isinstance(sub, SpecificationConcurrentDefaultSubscriber) + + with pytest.raises(SetupError): + broker.subscriber(queue, max_workers=3, ack_policy=AckPolicy.REJECT_ON_ERROR) + + +def test_wrong_destination(queue: str) -> None: + broker = KafkaBroker() + + with pytest.raises(SetupError): + broker.subscriber() + + with pytest.raises(SetupError): + broker.subscriber(queue, partitions=[TopicPartition(queue, 1)]) + + with pytest.raises(SetupError): + broker.subscriber(partitions=[TopicPartition(queue, 1)], pattern=".*") + + with pytest.raises(SetupError): + broker.subscriber(queue, pattern=".*") diff --git a/tests/brokers/kafka/test_parser.py b/tests/brokers/kafka/test_parser.py index 0b91d82ff5..0e229bbd37 100644 --- a/tests/brokers/kafka/test_parser.py +++ b/tests/brokers/kafka/test_parser.py @@ -1,12 +1,10 @@ -from typing import Any - import pytest -from faststream.kafka import KafkaBroker from tests.brokers.base.parser import CustomParserTestcase +from .basic import KafkaTestcaseConfig + @pytest.mark.kafka() -class TestCustomParser(CustomParserTestcase): - def get_broker(self, apply_types: bool = False, **kwargs: Any) -> KafkaBroker: - return KafkaBroker(apply_types=apply_types, **kwargs) +class TestCustomParser(KafkaTestcaseConfig, CustomParserTestcase): + pass diff --git a/tests/brokers/kafka/test_publish.py b/tests/brokers/kafka/test_publish.py index 80cb7b017b..bb884c450a 100644 --- a/tests/brokers/kafka/test_publish.py +++ b/tests/brokers/kafka/test_publish.py @@ -1,19 +1,17 @@ import asyncio -from typing import Any from unittest.mock import Mock import pytest from faststream import Context -from faststream.kafka import KafkaBroker, KafkaResponse +from faststream.kafka import KafkaResponse from tests.brokers.base.publish import BrokerPublishTestcase +from .basic import KafkaTestcaseConfig -@pytest.mark.kafka() -class TestPublish(BrokerPublishTestcase): - def get_broker(self, apply_types: bool = False, **kwargs: Any) -> KafkaBroker: - return KafkaBroker(apply_types=apply_types, **kwargs) +@pytest.mark.kafka() +class TestPublish(KafkaTestcaseConfig, BrokerPublishTestcase): @pytest.mark.asyncio() async def test_publish_batch(self, queue: str) -> None: pub_broker = self.get_broker() diff --git a/tests/brokers/kafka/test_requests.py b/tests/brokers/kafka/test_requests.py index f84ba2a5db..c3d7a4cb4d 100644 --- a/tests/brokers/kafka/test_requests.py +++ b/tests/brokers/kafka/test_requests.py @@ -3,9 +3,10 @@ import pytest from faststream import BaseMiddleware -from faststream.kafka import KafkaBroker, KafkaRouter, TestKafkaBroker from tests.brokers.base.requests import RequestsTestcase +from .basic import KafkaMemoryTestcaseConfig + class Mid(BaseMiddleware): async def on_receive(self) -> None: @@ -17,15 +18,6 @@ async def consume_scope(self, call_next, msg): @pytest.mark.asyncio() -class TestRequestTestClient(RequestsTestcase): +class TestRequestTestClient(KafkaMemoryTestcaseConfig, RequestsTestcase): def get_middleware(self, **kwargs: Any): return Mid - - def get_router(self, **kwargs: Any): - return KafkaRouter(**kwargs) - - def get_broker(self, apply_types: bool = False, **kwargs: Any) -> KafkaBroker: - return KafkaBroker(apply_types=apply_types, **kwargs) - - def patch_broker(self, broker: KafkaBroker, **kwargs: Any) -> TestKafkaBroker: - return TestKafkaBroker(broker, **kwargs) diff --git a/tests/brokers/kafka/test_router.py b/tests/brokers/kafka/test_router.py index 8460b6cc56..e9b27f5a01 100644 --- a/tests/brokers/kafka/test_router.py +++ b/tests/brokers/kafka/test_router.py @@ -1,34 +1,20 @@ -from typing import Any - import pytest from faststream.kafka import ( - KafkaBroker, KafkaPublisher, KafkaRoute, - KafkaRouter, - TestKafkaBroker, ) from tests.brokers.base.router import RouterLocalTestcase, RouterTestcase +from .basic import KafkaMemoryTestcaseConfig, KafkaTestcaseConfig + @pytest.mark.kafka() -class TestRouter(RouterTestcase): - broker_class = KafkaRouter +class TestRouter(KafkaTestcaseConfig, RouterTestcase): route_class = KafkaRoute publisher_class = KafkaPublisher - def get_broker(self, apply_types: bool = False, **kwargs: Any) -> KafkaBroker: - return KafkaBroker(apply_types=apply_types, **kwargs) - -class TestRouterLocal(RouterLocalTestcase): - broker_class = KafkaRouter +class TestRouterLocal(KafkaMemoryTestcaseConfig, RouterLocalTestcase): route_class = KafkaRoute publisher_class = KafkaPublisher - - def get_broker(self, apply_types: bool = False, **kwargs: Any) -> KafkaBroker: - return KafkaBroker(apply_types=apply_types, **kwargs) - - def patch_broker(self, broker: KafkaBroker, **kwargs: Any) -> TestKafkaBroker: - return TestKafkaBroker(broker, **kwargs) diff --git a/tests/brokers/kafka/test_stuff.py b/tests/brokers/kafka/test_stuff.py deleted file mode 100644 index 6d066d8ebe..0000000000 --- a/tests/brokers/kafka/test_stuff.py +++ /dev/null @@ -1,10 +0,0 @@ -import pytest - -from faststream.kafka import KafkaBroker - - -def test_wrong_subscriber() -> None: - broker = KafkaBroker() - - with pytest.raises(ValueError): # noqa: PT011 - broker.subscriber("test", auto_commit=False)(lambda: None) diff --git a/tests/brokers/kafka/test_test_client.py b/tests/brokers/kafka/test_test_client.py index b490604453..6f28a38353 100644 --- a/tests/brokers/kafka/test_test_client.py +++ b/tests/brokers/kafka/test_test_client.py @@ -1,26 +1,21 @@ import asyncio -from typing import Any from unittest.mock import patch import pytest -from faststream import BaseMiddleware -from faststream.kafka import KafkaBroker, TestKafkaBroker, TopicPartition +from faststream import AckPolicy, BaseMiddleware +from faststream.kafka import TopicPartition from faststream.kafka.annotations import KafkaMessage from faststream.kafka.message import FAKE_CONSUMER from faststream.kafka.testing import FakeProducer from tests.brokers.base.testclient import BrokerTestclientTestcase from tests.tools import spy_decorator +from .basic import KafkaMemoryTestcaseConfig -@pytest.mark.asyncio() -class TestTestclient(BrokerTestclientTestcase): - def get_broker(self, apply_types: bool = False, **kwargs: Any) -> KafkaBroker: - return KafkaBroker(apply_types=apply_types, **kwargs) - - def patch_broker(self, broker: KafkaBroker, **kwargs: Any) -> TestKafkaBroker: - return TestKafkaBroker(broker, **kwargs) +@pytest.mark.asyncio() +class TestTestclient(KafkaMemoryTestcaseConfig, BrokerTestclientTestcase): async def test_partition_match( self, queue: str, @@ -77,7 +72,7 @@ async def test_message_nack_seek( ) -> None: broker = self.get_broker(apply_types=True) - @broker.subscriber(queue) + @broker.subscriber(queue, group_id=f"{queue}1", ack_policy=AckPolicy.DO_NOTHING) async def m(msg: KafkaMessage) -> None: await msg.nack() diff --git a/tests/brokers/nats/basic.py b/tests/brokers/nats/basic.py new file mode 100644 index 0000000000..bc73b67da6 --- /dev/null +++ b/tests/brokers/nats/basic.py @@ -0,0 +1,24 @@ +from typing import Any + +from faststream.nats import NatsBroker, NatsRouter, TestNatsBroker +from tests.brokers.base.basic import BaseTestcaseConfig + + +class NatsTestcaseConfig(BaseTestcaseConfig): + def get_broker( + self, + apply_types: bool = False, + **kwargs: Any, + ) -> NatsBroker: + return NatsBroker(apply_types=apply_types, **kwargs) + + def patch_broker(self, broker: NatsBroker, **kwargs: Any) -> NatsBroker: + return broker + + def get_router(self, **kwargs: Any) -> NatsRouter: + return NatsRouter(**kwargs) + + +class NatsMemoryTestcaseConfig(NatsTestcaseConfig): + def patch_broker(self, broker: NatsBroker, **kwargs: Any) -> NatsBroker: + return TestNatsBroker(broker, **kwargs) diff --git a/tests/brokers/nats/test_consume.py b/tests/brokers/nats/test_consume.py index 82ff607e4a..18fb87c818 100644 --- a/tests/brokers/nats/test_consume.py +++ b/tests/brokers/nats/test_consume.py @@ -1,23 +1,62 @@ import asyncio -from typing import Any -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, patch import pytest from nats.aio.msg import Msg -from nats.js.api import PubAck from faststream import AckPolicy from faststream.exceptions import AckMessage -from faststream.nats import ConsumerConfig, JStream, NatsBroker, PullSub +from faststream.nats import ConsumerConfig, JStream, PubAck, PullSub from faststream.nats.annotations import NatsMessage +from faststream.nats.message import NatsMessage as StreamMessage from tests.brokers.base.consume import BrokerRealConsumeTestcase from tests.tools import spy_decorator +from .basic import NatsTestcaseConfig + @pytest.mark.nats() -class TestConsume(BrokerRealConsumeTestcase): - def get_broker(self, apply_types: bool = False, **kwargs: Any) -> NatsBroker: - return NatsBroker(apply_types=apply_types, **kwargs) +class TestConsume(NatsTestcaseConfig, BrokerRealConsumeTestcase): + async def test_concurrent_subscriber( + self, + queue: str, + mock: MagicMock, + ) -> None: + event = asyncio.Event() + event2 = asyncio.Event() + + broker = self.get_broker() + + args, kwargs = self.get_subscriber_params(queue, max_workers=2) + + @broker.subscriber(*args, **kwargs) + async def handler(msg): + mock() + + if event.is_set(): + event2.set() + else: + event.set() + + await asyncio.sleep(1.0) + + async with self.patch_broker(broker) as br: + await br.start() + + for i in range(5): + await br.publish(i, queue) + + await asyncio.wait( + ( + asyncio.create_task(event.wait()), + asyncio.create_task(event2.wait()), + ), + timeout=3, + ) + + assert event.is_set() + assert event2.is_set() + assert mock.call_count == 2, mock.call_count async def test_consume_js( self, @@ -28,7 +67,9 @@ async def test_consume_js( consume_broker = self.get_broker() - @consume_broker.subscriber(queue, stream=stream) + args, kwargs = self.get_subscriber_params(queue, stream=stream) + + @consume_broker.subscriber(*args, **kwargs) def subscriber(m) -> None: event.set() @@ -53,8 +94,8 @@ def subscriber(m) -> None: async def test_consume_with_filter( self, - queue, - mock: Mock, + queue: str, + mock: MagicMock, ) -> None: event = asyncio.Event() @@ -147,23 +188,31 @@ def subscriber(m) -> None: assert event.is_set() mock.assert_called_once_with([b"hello"]) - async def test_consume_ack( + async def test_core_consume_no_ack( self, queue: str, - stream: JStream, + mock: MagicMock, ) -> None: event = asyncio.Event() consume_broker = self.get_broker(apply_types=True) - @consume_broker.subscriber(queue, stream=stream) + args, kwargs = self.get_subscriber_params( + queue, ack_policy=AckPolicy.DO_NOTHING + ) + + @consume_broker.subscriber(*args, **kwargs) async def handler(msg: NatsMessage) -> None: + mock(msg.raw_message._ackd) event.set() async with self.patch_broker(consume_broker) as br: await br.start() - with patch.object(Msg, "ack", spy_decorator(Msg.ack)) as m: + # Check, that Core Subscriber doesn't call Acknowledgement automatically + with patch.object( + StreamMessage, "ack", spy_decorator(StreamMessage.ack) + ) as m: await asyncio.wait( ( asyncio.create_task(br.publish("hello", queue)), @@ -171,19 +220,21 @@ async def handler(msg: NatsMessage) -> None: ), timeout=3, ) - m.mock.assert_called_once() + assert not m.mock.called assert event.is_set() + mock.assert_called_once_with(False) - async def test_core_consume_no_ack( + async def test_consume_ack( self, queue: str, + stream: JStream, ) -> None: event = asyncio.Event() consume_broker = self.get_broker(apply_types=True) - @consume_broker.subscriber(queue) + @consume_broker.subscriber(queue, stream=stream) async def handler(msg: NatsMessage) -> None: event.set() @@ -198,7 +249,7 @@ async def handler(msg: NatsMessage) -> None: ), timeout=3, ) - assert not m.mock.called + m.mock.assert_called_once() assert event.is_set() @@ -231,6 +282,34 @@ async def handler(msg: NatsMessage) -> None: assert event.is_set() + async def test_consume_ack_sync_manual( + self, + queue: str, + event: asyncio.Event, + stream: JStream, + ): + consume_broker = self.get_broker(apply_types=True) + + @consume_broker.subscriber(queue, stream=stream) + async def handler(msg: NatsMessage): + await msg.ack_sync() + event.set() + + async with self.patch_broker(consume_broker) as br: + await br.start() + + with patch.object(Msg, "ack_sync", spy_decorator(Msg.ack_sync)) as m: + await asyncio.wait( + ( + asyncio.create_task(br.publish("hello", queue)), + asyncio.create_task(event.wait()), + ), + timeout=3, + ) + m.mock.assert_called_once() + + assert event.is_set() + async def test_consume_ack_raise( self, queue: str, @@ -513,7 +592,7 @@ async def test_get_one_pull_timeout( self, queue: str, stream: JStream, - mock: Mock, + mock: MagicMock, ) -> None: broker = self.get_broker(apply_types=True) subscriber = broker.subscriber( @@ -567,7 +646,7 @@ async def test_get_one_batch_timeout( self, queue: str, stream: JStream, - mock: Mock, + mock: MagicMock, ) -> None: broker = self.get_broker(apply_types=True) subscriber = broker.subscriber( @@ -652,7 +731,7 @@ async def test_get_one_kv_timeout( self, queue: str, stream: JStream, - mock: Mock, + mock: MagicMock, ) -> None: broker = self.get_broker(apply_types=True) subscriber = broker.subscriber(queue, kv_watch=queue + "1") @@ -700,7 +779,7 @@ async def test_get_one_os_timeout( self, queue: str, stream: JStream, - mock: Mock, + mock: MagicMock, ) -> None: broker = self.get_broker(apply_types=True) subscriber = broker.subscriber(queue, obj_watch=True) diff --git a/tests/brokers/nats/test_fastapi.py b/tests/brokers/nats/test_fastapi.py index bbcdac2113..15fbe319f7 100644 --- a/tests/brokers/nats/test_fastapi.py +++ b/tests/brokers/nats/test_fastapi.py @@ -1,16 +1,17 @@ import asyncio -from typing import Any from unittest.mock import MagicMock import pytest -from faststream.nats import JStream, NatsBroker, NatsRouter, PullSub, TestNatsBroker +from faststream.nats import JStream, NatsRouter, PullSub from faststream.nats.fastapi import NatsRouter as StreamRouter from tests.brokers.base.fastapi import FastAPILocalTestcase, FastAPITestcase +from .basic import NatsMemoryTestcaseConfig, NatsTestcaseConfig + @pytest.mark.nats() -class TestRouter(FastAPITestcase): +class TestRouter(NatsTestcaseConfig, FastAPITestcase): router_class = StreamRouter broker_router_class = NatsRouter @@ -76,13 +77,10 @@ def subscriber(m: list[str]) -> None: mock.assert_called_once_with(["hello"]) -class TestRouterLocal(FastAPILocalTestcase): +class TestRouterLocal(NatsMemoryTestcaseConfig, FastAPILocalTestcase): router_class = StreamRouter broker_router_class = NatsRouter - def patch_broker(self, broker: NatsBroker, **kwargs: Any) -> NatsBroker: - return TestNatsBroker(broker, **kwargs) - async def test_consume_batch( self, queue: str, diff --git a/tests/brokers/nats/test_middlewares.py b/tests/brokers/nats/test_middlewares.py index d646418eba..c726d7e231 100644 --- a/tests/brokers/nats/test_middlewares.py +++ b/tests/brokers/nats/test_middlewares.py @@ -1,21 +1,23 @@ -from typing import Any - import pytest -from faststream.nats import NatsBroker from tests.brokers.base.middlewares import ( ExceptionMiddlewareTestcase, MiddlewareTestcase, + MiddlewaresOrderTestcase, ) +from .basic import NatsMemoryTestcaseConfig, NatsTestcaseConfig + + +class TestMiddlewaresOrder(NatsMemoryTestcaseConfig, MiddlewaresOrderTestcase): + pass + @pytest.mark.nats() -class TestMiddlewares(MiddlewareTestcase): - def get_broker(self, apply_types: bool = False, **kwargs: Any) -> NatsBroker: - return NatsBroker(apply_types=apply_types, **kwargs) +class TestMiddlewares(NatsTestcaseConfig, MiddlewareTestcase): + pass @pytest.mark.nats() -class TestExceptionMiddlewares(ExceptionMiddlewareTestcase): - def get_broker(self, apply_types: bool = False, **kwargs: Any) -> NatsBroker: - return NatsBroker(apply_types=apply_types, **kwargs) +class TestExceptionMiddlewares(NatsTestcaseConfig, ExceptionMiddlewareTestcase): + pass diff --git a/tests/brokers/nats/test_parser.py b/tests/brokers/nats/test_parser.py index 82caa39d3f..635cbccb65 100644 --- a/tests/brokers/nats/test_parser.py +++ b/tests/brokers/nats/test_parser.py @@ -1,12 +1,10 @@ -from typing import Any - import pytest -from faststream.nats import NatsBroker from tests.brokers.base.parser import CustomParserTestcase +from .basic import NatsTestcaseConfig + @pytest.mark.nats() -class TestCustomParser(CustomParserTestcase): - def get_broker(self, apply_types: bool = False, **kwargs: Any) -> NatsBroker: - return NatsBroker(apply_types=apply_types, **kwargs) +class TestCustomParser(NatsTestcaseConfig, CustomParserTestcase): + pass diff --git a/tests/brokers/nats/test_publish.py b/tests/brokers/nats/test_publish.py index c7367ac389..cee41ef6a4 100644 --- a/tests/brokers/nats/test_publish.py +++ b/tests/brokers/nats/test_publish.py @@ -4,17 +4,16 @@ import pytest from faststream import Context -from faststream.nats import NatsBroker, NatsResponse +from faststream.nats import NatsResponse from tests.brokers.base.publish import BrokerPublishTestcase +from .basic import NatsTestcaseConfig + @pytest.mark.nats() -class TestPublish(BrokerPublishTestcase): +class TestPublish(NatsTestcaseConfig, BrokerPublishTestcase): """Test publish method of NATS broker.""" - def get_broker(self, apply_types: bool = False, **kwargs) -> NatsBroker: - return NatsBroker(apply_types=apply_types, **kwargs) - @pytest.mark.asyncio() async def test_response( self, diff --git a/tests/brokers/nats/test_requests.py b/tests/brokers/nats/test_requests.py index 579f13113d..f440a83c6d 100644 --- a/tests/brokers/nats/test_requests.py +++ b/tests/brokers/nats/test_requests.py @@ -1,11 +1,10 @@ -from typing import Any - import pytest from faststream import BaseMiddleware -from faststream.nats import NatsBroker, NatsRouter, TestNatsBroker from tests.brokers.base.requests import RequestsTestcase +from .basic import NatsMemoryTestcaseConfig, NatsTestcaseConfig + class Mid(BaseMiddleware): async def on_receive(self) -> None: @@ -21,12 +20,6 @@ class NatsRequestsTestcase(RequestsTestcase): def get_middleware(self, **kwargs): return Mid - def get_broker(self, apply_types: bool = False, **kwargs: Any) -> NatsBroker: - return NatsBroker(apply_types=apply_types, **kwargs) - - def get_router(self, **kwargs): - return NatsRouter(**kwargs) - async def test_broker_stream_request(self, queue: str) -> None: broker = self.get_broker() @@ -78,10 +71,9 @@ async def handler(msg) -> str: @pytest.mark.nats() -class TestRealRequests(NatsRequestsTestcase): +class TestRealRequests(NatsTestcaseConfig, NatsRequestsTestcase): pass -class TestRequestTestClient(NatsRequestsTestcase): - def patch_broker(self, broker, **kwargs): - return TestNatsBroker(broker, **kwargs) +class TestRequestTestClient(NatsMemoryTestcaseConfig, NatsRequestsTestcase): + pass diff --git a/tests/brokers/nats/test_router.py b/tests/brokers/nats/test_router.py index 8af70bcfa0..4bd846abd0 100644 --- a/tests/brokers/nats/test_router.py +++ b/tests/brokers/nats/test_router.py @@ -1,29 +1,24 @@ import asyncio -from typing import Any import pytest from faststream import Path from faststream.nats import ( JStream, - NatsBroker, NatsPublisher, NatsRoute, NatsRouter, - TestNatsBroker, ) from tests.brokers.base.router import RouterLocalTestcase, RouterTestcase +from .basic import NatsMemoryTestcaseConfig, NatsTestcaseConfig + @pytest.mark.nats() -class TestRouter(RouterTestcase): - broker_class = NatsRouter +class TestRouter(NatsTestcaseConfig, RouterTestcase): route_class = NatsRoute publisher_class = NatsPublisher - def get_broker(self, apply_types: bool = False, **kwargs: Any) -> NatsBroker: - return NatsBroker(apply_types=apply_types, **kwargs) - async def test_router_path( self, event, @@ -158,16 +153,10 @@ def response(m) -> None: assert event.is_set() -class TestRouterLocal(RouterLocalTestcase): +class TestRouterLocal(NatsMemoryTestcaseConfig, RouterLocalTestcase): route_class = NatsRoute publisher_class = NatsPublisher - def get_broker(self, apply_types: bool = False, **kwargs: Any) -> NatsBroker: - return NatsBroker(apply_types=apply_types, **kwargs) - - def patch_broker(self, broker: NatsBroker, **kwargs: Any) -> NatsBroker: - return TestNatsBroker(broker, **kwargs) - async def test_include_stream( self, router: NatsRouter, diff --git a/tests/brokers/nats/test_test_client.py b/tests/brokers/nats/test_test_client.py index 7ca172e6ce..1c008ed9b2 100644 --- a/tests/brokers/nats/test_test_client.py +++ b/tests/brokers/nats/test_test_client.py @@ -1,5 +1,4 @@ import asyncio -from typing import Any import pytest @@ -7,22 +6,16 @@ from faststream.nats import ( ConsumerConfig, JStream, - NatsBroker, PullSub, - TestNatsBroker, ) from faststream.nats.testing import FakeProducer from tests.brokers.base.testclient import BrokerTestclientTestcase +from .basic import NatsMemoryTestcaseConfig -@pytest.mark.asyncio() -class TestTestclient(BrokerTestclientTestcase): - def get_broker(self, apply_types: bool = False, **kwargs: Any) -> NatsBroker: - return NatsBroker(apply_types=apply_types, **kwargs) - - def patch_broker(self, broker: NatsBroker, **kwargs: Any) -> NatsBroker: - return TestNatsBroker(broker, **kwargs) +@pytest.mark.asyncio() +class TestTestclient(NatsMemoryTestcaseConfig, BrokerTestclientTestcase): @pytest.mark.asyncio() async def test_stream_publish( self, diff --git a/tests/brokers/rabbit/basic.py b/tests/brokers/rabbit/basic.py new file mode 100644 index 0000000000..6a451530c5 --- /dev/null +++ b/tests/brokers/rabbit/basic.py @@ -0,0 +1,24 @@ +from typing import Any + +from faststream.rabbit import RabbitBroker, RabbitRouter, TestRabbitBroker +from tests.brokers.base.basic import BaseTestcaseConfig + + +class RabbitTestcaseConfig(BaseTestcaseConfig): + def get_broker( + self, + apply_types: bool = False, + **kwargs: Any, + ) -> RabbitBroker: + return RabbitBroker(apply_types=apply_types, **kwargs) + + def patch_broker(self, broker: RabbitBroker, **kwargs: Any) -> RabbitBroker: + return broker + + def get_router(self, **kwargs: Any) -> RabbitRouter: + return RabbitRouter(**kwargs) + + +class RabbitMemoryTestcaseConfig(RabbitTestcaseConfig): + def patch_broker(self, broker: RabbitBroker, **kwargs: Any) -> RabbitBroker: + return TestRabbitBroker(broker, **kwargs) diff --git a/tests/brokers/rabbit/test_consume.py b/tests/brokers/rabbit/test_consume.py index 19317948fa..45b0ad0026 100644 --- a/tests/brokers/rabbit/test_consume.py +++ b/tests/brokers/rabbit/test_consume.py @@ -1,5 +1,4 @@ import asyncio -from typing import Any from unittest.mock import patch import pytest @@ -7,17 +6,16 @@ from faststream import AckPolicy from faststream.exceptions import AckMessage, NackMessage, RejectMessage, SkipMessage -from faststream.rabbit import RabbitBroker, RabbitExchange, RabbitQueue +from faststream.rabbit import RabbitExchange, RabbitQueue from faststream.rabbit.annotations import RabbitMessage from tests.brokers.base.consume import BrokerRealConsumeTestcase from tests.tools import spy_decorator +from .basic import RabbitTestcaseConfig -@pytest.mark.rabbit() -class TestConsume(BrokerRealConsumeTestcase): - def get_broker(self, apply_types: bool = False, **kwargs: Any) -> RabbitBroker: - return RabbitBroker(apply_types=apply_types, **kwargs) +@pytest.mark.rabbit() +class TestConsume(RabbitTestcaseConfig, BrokerRealConsumeTestcase): @pytest.mark.asyncio() async def test_consume_from_exchange( self, @@ -398,7 +396,9 @@ async def test_consume_no_ack( consume_broker = self.get_broker(apply_types=True) @consume_broker.subscriber( - queue, exchange=exchange, ack_policy=AckPolicy.DO_NOTHING + queue, + exchange=exchange, + ack_policy=AckPolicy.DO_NOTHING, ) async def handler(msg: RabbitMessage) -> None: event.set() diff --git a/tests/brokers/rabbit/test_fastapi.py b/tests/brokers/rabbit/test_fastapi.py index 3bd99eae62..6ffa26392a 100644 --- a/tests/brokers/rabbit/test_fastapi.py +++ b/tests/brokers/rabbit/test_fastapi.py @@ -5,9 +5,10 @@ from faststream.rabbit import ExchangeType, RabbitExchange, RabbitQueue, RabbitRouter from faststream.rabbit.fastapi import RabbitRouter as StreamRouter -from faststream.rabbit.testing import TestRabbitBroker from tests.brokers.base.fastapi import FastAPILocalTestcase, FastAPITestcase +from .basic import RabbitMemoryTestcaseConfig + @pytest.mark.rabbit() class TestRouter(FastAPITestcase): @@ -55,13 +56,10 @@ def subscriber(msg: str, name: str) -> None: @pytest.mark.asyncio() -class TestRouterLocal(FastAPILocalTestcase): +class TestRouterLocal(RabbitMemoryTestcaseConfig, FastAPILocalTestcase): router_class = StreamRouter broker_router_class = RabbitRouter - def patch_broker(self, broker, **kwargs): - return TestRabbitBroker(broker, **kwargs) - async def test_path(self) -> None: router = self.router_class() diff --git a/tests/brokers/rabbit/test_middlewares.py b/tests/brokers/rabbit/test_middlewares.py index 9f21be5f8b..f56c836e8d 100644 --- a/tests/brokers/rabbit/test_middlewares.py +++ b/tests/brokers/rabbit/test_middlewares.py @@ -1,21 +1,23 @@ -from typing import Any - import pytest -from faststream.rabbit import RabbitBroker from tests.brokers.base.middlewares import ( ExceptionMiddlewareTestcase, MiddlewareTestcase, + MiddlewaresOrderTestcase, ) +from .basic import RabbitMemoryTestcaseConfig, RabbitTestcaseConfig + + +class TestMiddlewaresOrder(RabbitMemoryTestcaseConfig, MiddlewaresOrderTestcase): + pass + @pytest.mark.rabbit() -class TestMiddlewares(MiddlewareTestcase): - def get_broker(self, apply_types: bool = False, **kwargs: Any) -> RabbitBroker: - return RabbitBroker(apply_types=apply_types, **kwargs) +class TestMiddlewares(RabbitTestcaseConfig, MiddlewareTestcase): + pass @pytest.mark.rabbit() -class TestExceptionMiddlewares(ExceptionMiddlewareTestcase): - def get_broker(self, apply_types: bool = False, **kwargs: Any) -> RabbitBroker: - return RabbitBroker(apply_types=apply_types, **kwargs) +class TestExceptionMiddlewares(RabbitTestcaseConfig, ExceptionMiddlewareTestcase): + pass diff --git a/tests/brokers/rabbit/test_parser.py b/tests/brokers/rabbit/test_parser.py index db5b52aef1..2ecaf2967c 100644 --- a/tests/brokers/rabbit/test_parser.py +++ b/tests/brokers/rabbit/test_parser.py @@ -1,12 +1,10 @@ -from typing import Any - import pytest -from faststream.rabbit import RabbitBroker from tests.brokers.base.parser import CustomParserTestcase +from .basic import RabbitTestcaseConfig + @pytest.mark.rabbit() -class TestCustomParser(CustomParserTestcase): - def get_broker(self, apply_types: bool = False, **kwargs: Any) -> RabbitBroker: - return RabbitBroker(apply_types=apply_types, **kwargs) +class TestCustomParser(RabbitTestcaseConfig, CustomParserTestcase): + pass diff --git a/tests/brokers/rabbit/test_publish.py b/tests/brokers/rabbit/test_publish.py index f5000d9104..a8aa2508f1 100644 --- a/tests/brokers/rabbit/test_publish.py +++ b/tests/brokers/rabbit/test_publish.py @@ -1,24 +1,23 @@ import asyncio -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING from unittest.mock import Mock, patch import pytest from faststream import Context -from faststream.rabbit import RabbitBroker, RabbitResponse +from faststream.rabbit import RabbitResponse from faststream.rabbit.publisher.producer import AioPikaFastProducer from tests.brokers.base.publish import BrokerPublishTestcase from tests.tools import spy_decorator +from .basic import RabbitTestcaseConfig + if TYPE_CHECKING: from faststream.rabbit.response import RabbitPublishCommand @pytest.mark.rabbit() -class TestPublish(BrokerPublishTestcase): - def get_broker(self, apply_types: bool = False, **kwargs: Any) -> RabbitBroker: - return RabbitBroker(apply_types=apply_types, **kwargs) - +class TestPublish(RabbitTestcaseConfig, BrokerPublishTestcase): @pytest.mark.asyncio() async def test_reply_config( self, diff --git a/tests/brokers/rabbit/test_requests.py b/tests/brokers/rabbit/test_requests.py index fd5fb93ebf..8eb64a075a 100644 --- a/tests/brokers/rabbit/test_requests.py +++ b/tests/brokers/rabbit/test_requests.py @@ -1,11 +1,10 @@ -from typing import Any - import pytest from faststream import BaseMiddleware -from faststream.rabbit import RabbitBroker, RabbitRouter, TestRabbitBroker from tests.brokers.base.requests import RequestsTestcase +from .basic import RabbitMemoryTestcaseConfig, RabbitTestcaseConfig + class Mid(BaseMiddleware): async def on_receive(self) -> None: @@ -22,19 +21,12 @@ class RabbitRequestsTestcase(RequestsTestcase): def get_middleware(self, **kwargs): return Mid - def get_broker(self, apply_types: bool = False, **kwargs: Any) -> RabbitBroker: - return RabbitBroker(apply_types=apply_types, **kwargs) - - def get_router(self, **kwargs): - return RabbitRouter(**kwargs) - @pytest.mark.rabbit() -class TestRealRequests(RabbitRequestsTestcase): +class TestRealRequests(RabbitTestcaseConfig, RabbitRequestsTestcase): pass @pytest.mark.asyncio() -class TestRequestTestClient(RabbitRequestsTestcase): - def patch_broker(self, broker, **kwargs): - return TestRabbitBroker(broker, **kwargs) +class TestRequestTestClient(RabbitMemoryTestcaseConfig, RabbitRequestsTestcase): + pass diff --git a/tests/brokers/rabbit/test_router.py b/tests/brokers/rabbit/test_router.py index 0fb2b8babf..fba7752300 100644 --- a/tests/brokers/rabbit/test_router.py +++ b/tests/brokers/rabbit/test_router.py @@ -1,31 +1,26 @@ import asyncio -from typing import Any import pytest from faststream import Path from faststream.rabbit import ( ExchangeType, - RabbitBroker, RabbitExchange, RabbitPublisher, RabbitQueue, RabbitRoute, RabbitRouter, - TestRabbitBroker, ) from tests.brokers.base.router import RouterLocalTestcase, RouterTestcase +from .basic import RabbitMemoryTestcaseConfig, RabbitTestcaseConfig + @pytest.mark.rabbit() -class TestRouter(RouterTestcase): - broker_class = RabbitRouter +class TestRouter(RabbitTestcaseConfig, RouterTestcase): route_class = RabbitRoute publisher_class = RabbitPublisher - def get_broker(self, apply_types: bool = False, **kwargs: Any) -> RabbitBroker: - return RabbitBroker(apply_types=apply_types, **kwargs) - async def test_router_path( self, queue, @@ -213,13 +208,6 @@ def response(m) -> None: assert event.is_set() -class TestRouterLocal(RouterLocalTestcase): - broker_class = RabbitRouter +class TestRouterLocal(RabbitMemoryTestcaseConfig, RouterLocalTestcase): route_class = RabbitRoute publisher_class = RabbitPublisher - - def get_broker(self, apply_types: bool = False, **kwargs: Any) -> RabbitBroker: - return RabbitBroker(apply_types=apply_types, **kwargs) - - def patch_broker(self, broker: RabbitBroker, **kwargs: Any) -> RabbitBroker: - return TestRabbitBroker(broker, **kwargs) diff --git a/tests/brokers/rabbit/test_schemas.py b/tests/brokers/rabbit/test_schemas.py index ce686c6d69..2224fb976a 100644 --- a/tests/brokers/rabbit/test_schemas.py +++ b/tests/brokers/rabbit/test_schemas.py @@ -3,23 +3,19 @@ def test_same_queue() -> None: assert ( - len( - { - RabbitQueue("test"): 0, - RabbitQueue("test"): 1, - }, - ) + len({ + RabbitQueue("test"): 0, + RabbitQueue("test"): 1, + }) == 1 ) def test_different_queue_routing_key() -> None: assert ( - len( - { - RabbitQueue("test", routing_key="binding-1"): 0, - RabbitQueue("test", routing_key="binding-2"): 1, - }, - ) + len({ + RabbitQueue("test", routing_key="binding-1"): 0, + RabbitQueue("test", routing_key="binding-2"): 1, + }) == 1 ) diff --git a/tests/brokers/rabbit/test_test_client.py b/tests/brokers/rabbit/test_test_client.py index 0784da9bf3..0ddaf24b58 100644 --- a/tests/brokers/rabbit/test_test_client.py +++ b/tests/brokers/rabbit/test_test_client.py @@ -1,5 +1,4 @@ import asyncio -from typing import Any import pytest @@ -7,26 +6,18 @@ from faststream.exceptions import SubscriberNotFound from faststream.rabbit import ( ExchangeType, - RabbitBroker, RabbitExchange, RabbitQueue, - TestRabbitBroker, ) from faststream.rabbit.annotations import RabbitMessage from faststream.rabbit.testing import FakeProducer, apply_pattern from tests.brokers.base.testclient import BrokerTestclientTestcase +from .basic import RabbitMemoryTestcaseConfig -@pytest.mark.asyncio() -class TestTestclient(BrokerTestclientTestcase): - test_class = TestRabbitBroker - - def get_broker(self, apply_types: bool = False, **kwargs: Any) -> RabbitBroker: - return RabbitBroker(apply_types=apply_types, **kwargs) - - def patch_broker(self, broker: RabbitBroker, **kwargs: Any) -> RabbitBroker: - return TestRabbitBroker(broker, **kwargs) +@pytest.mark.asyncio() +class TestTestclient(RabbitMemoryTestcaseConfig, BrokerTestclientTestcase): @pytest.mark.rabbit() async def test_with_real_testclient( self, diff --git a/tests/brokers/redis/basic.py b/tests/brokers/redis/basic.py new file mode 100644 index 0000000000..11f424017c --- /dev/null +++ b/tests/brokers/redis/basic.py @@ -0,0 +1,24 @@ +from typing import Any + +from faststream.redis import RedisBroker, RedisRouter, TestRedisBroker +from tests.brokers.base.basic import BaseTestcaseConfig + + +class RedisTestcaseConfig(BaseTestcaseConfig): + def get_broker( + self, + apply_types: bool = False, + **kwargs: Any, + ) -> RedisBroker: + return RedisBroker(apply_types=apply_types, **kwargs) + + def patch_broker(self, broker: RedisBroker, **kwargs: Any) -> RedisBroker: + return broker + + def get_router(self, **kwargs: Any) -> RedisRouter: + return RedisRouter(**kwargs) + + +class RedisMemoryTestcaseConfig(RedisTestcaseConfig): + def patch_broker(self, broker: RedisBroker, **kwargs: Any) -> RedisBroker: + return TestRedisBroker(broker, **kwargs) diff --git a/tests/brokers/redis/test_consume.py b/tests/brokers/redis/test_consume.py index 7c7a1e4152..9aa3b7590b 100644 --- a/tests/brokers/redis/test_consume.py +++ b/tests/brokers/redis/test_consume.py @@ -1,21 +1,19 @@ import asyncio -from typing import Any from unittest.mock import MagicMock, patch import pytest from redis.asyncio import Redis -from faststream.redis import ListSub, PubSub, RedisBroker, RedisMessage, StreamSub +from faststream.redis import ListSub, PubSub, RedisMessage, StreamSub from tests.brokers.base.consume import BrokerRealConsumeTestcase from tests.tools import spy_decorator +from .basic import RedisTestcaseConfig + @pytest.mark.redis() @pytest.mark.asyncio() -class TestConsume(BrokerRealConsumeTestcase): - def get_broker(self, apply_types: bool = False, **kwargs: Any) -> RedisBroker: - return RedisBroker(apply_types=apply_types, **kwargs) - +class TestConsume(RedisTestcaseConfig, BrokerRealConsumeTestcase): async def test_consume_native( self, mock: MagicMock, @@ -98,13 +96,7 @@ async def handler(msg) -> None: @pytest.mark.redis() @pytest.mark.asyncio() -class TestConsumeList: - def get_broker(self, apply_types: bool = False): - return RedisBroker(apply_types=apply_types) - - def patch_broker(self, broker): - return broker - +class TestConsumeList(RedisTestcaseConfig): async def test_consume_list( self, queue: str, @@ -365,13 +357,7 @@ async def test_get_one_timeout( @pytest.mark.redis() @pytest.mark.asyncio() -class TestConsumeStream: - def get_broker(self, apply_types: bool = False): - return RedisBroker(apply_types=apply_types) - - def patch_broker(self, broker): - return broker - +class TestConsumeStream(RedisTestcaseConfig): @pytest.mark.slow() async def test_consume_stream( self, diff --git a/tests/brokers/redis/test_fastapi.py b/tests/brokers/redis/test_fastapi.py index 9c66a73230..31c1ff0f70 100644 --- a/tests/brokers/redis/test_fastapi.py +++ b/tests/brokers/redis/test_fastapi.py @@ -1,14 +1,14 @@ import asyncio -from typing import Any from unittest.mock import Mock import pytest -from faststream.redis import ListSub, RedisBroker, RedisRouter, StreamSub +from faststream.redis import ListSub, RedisRouter, StreamSub from faststream.redis.fastapi import RedisRouter as StreamRouter -from faststream.redis.testing import TestRedisBroker from tests.brokers.base.fastapi import FastAPILocalTestcase, FastAPITestcase +from .basic import RedisMemoryTestcaseConfig + @pytest.mark.redis() class TestRouter(FastAPITestcase): @@ -140,13 +140,10 @@ async def handler(msg: list[str]) -> None: mock.assert_called_once_with(["hello"]) -class TestRouterLocal(FastAPILocalTestcase): +class TestRouterLocal(RedisMemoryTestcaseConfig, FastAPILocalTestcase): router_class = StreamRouter broker_router_class = RedisRouter - def patch_broker(self, broker: RedisBroker, **kwargs: Any) -> RedisBroker: - return TestRedisBroker(broker, **kwargs) - async def test_batch_testclient( self, mock: Mock, diff --git a/tests/brokers/redis/test_middlewares.py b/tests/brokers/redis/test_middlewares.py index 8dd6cdee7e..c75e0914cd 100644 --- a/tests/brokers/redis/test_middlewares.py +++ b/tests/brokers/redis/test_middlewares.py @@ -1,21 +1,23 @@ -from typing import Any - import pytest -from faststream.redis import RedisBroker from tests.brokers.base.middlewares import ( ExceptionMiddlewareTestcase, MiddlewareTestcase, + MiddlewaresOrderTestcase, ) +from .basic import RedisMemoryTestcaseConfig, RedisTestcaseConfig + + +class TestMiddlewaresOrder(RedisMemoryTestcaseConfig, MiddlewaresOrderTestcase): + pass + @pytest.mark.redis() -class TestMiddlewares(MiddlewareTestcase): - def get_broker(self, apply_types: bool = False, **kwargs: Any) -> RedisBroker: - return RedisBroker(apply_types=apply_types, **kwargs) +class TestMiddlewares(RedisTestcaseConfig, MiddlewareTestcase): + pass @pytest.mark.redis() -class TestExceptionMiddlewares(ExceptionMiddlewareTestcase): - def get_broker(self, apply_types: bool = False, **kwargs: Any) -> RedisBroker: - return RedisBroker(apply_types=apply_types, **kwargs) +class TestExceptionMiddlewares(RedisTestcaseConfig, ExceptionMiddlewareTestcase): + pass diff --git a/tests/brokers/redis/test_parser.py b/tests/brokers/redis/test_parser.py index 8e9093bc04..cf16275b65 100644 --- a/tests/brokers/redis/test_parser.py +++ b/tests/brokers/redis/test_parser.py @@ -1,12 +1,10 @@ -from typing import Any - import pytest -from faststream.redis import RedisBroker from tests.brokers.base.parser import CustomParserTestcase +from .basic import RedisTestcaseConfig + @pytest.mark.redis() -class TestCustomParser(CustomParserTestcase): - def get_broker(self, apply_types: bool = False, **kwargs: Any) -> RedisBroker: - return RedisBroker(apply_types=apply_types, **kwargs) +class TestCustomParser(RedisTestcaseConfig, CustomParserTestcase): + pass diff --git a/tests/brokers/redis/test_publish.py b/tests/brokers/redis/test_publish.py index 8967aa0778..ba4d445d78 100644 --- a/tests/brokers/redis/test_publish.py +++ b/tests/brokers/redis/test_publish.py @@ -1,22 +1,20 @@ import asyncio -from typing import Any from unittest.mock import MagicMock, patch import pytest from redis.asyncio import Redis from faststream import Context -from faststream.redis import ListSub, RedisBroker, RedisResponse, StreamSub +from faststream.redis import ListSub, RedisResponse, StreamSub from tests.brokers.base.publish import BrokerPublishTestcase from tests.tools import spy_decorator +from .basic import RedisTestcaseConfig + @pytest.mark.redis() @pytest.mark.asyncio() -class TestPublish(BrokerPublishTestcase): - def get_broker(self, apply_types: bool = False, **kwargs: Any) -> RedisBroker: - return RedisBroker(apply_types=apply_types, **kwargs) - +class TestPublish(RedisTestcaseConfig, BrokerPublishTestcase): async def test_list_publisher( self, queue: str, diff --git a/tests/brokers/redis/test_requests.py b/tests/brokers/redis/test_requests.py index f1d4fc3c0f..03c1441fa7 100644 --- a/tests/brokers/redis/test_requests.py +++ b/tests/brokers/redis/test_requests.py @@ -3,9 +3,10 @@ import pytest from faststream import BaseMiddleware -from faststream.redis import RedisBroker, RedisRouter, TestRedisBroker from tests.brokers.base.requests import RequestsTestcase +from .basic import RedisMemoryTestcaseConfig, RedisTestcaseConfig + class Mid(BaseMiddleware): async def on_receive(self) -> None: @@ -23,18 +24,11 @@ class RedisRequestsTestcase(RequestsTestcase): def get_middleware(self, **kwargs): return Mid - def get_broker(self, **kwargs): - return RedisBroker(**kwargs) - - def get_router(self, **kwargs): - return RedisRouter(**kwargs) - @pytest.mark.redis() -class TestRealRequests(RedisRequestsTestcase): +class TestRealRequests(RedisTestcaseConfig, RedisRequestsTestcase): pass -class TestRequestTestClient(RedisRequestsTestcase): - def patch_broker(self, broker, **kwargs): - return TestRedisBroker(broker, **kwargs) +class TestRequestTestClient(RedisMemoryTestcaseConfig, RedisRequestsTestcase): + pass diff --git a/tests/brokers/redis/test_router.py b/tests/brokers/redis/test_router.py index ef53e47a37..7a56d30edf 100644 --- a/tests/brokers/redis/test_router.py +++ b/tests/brokers/redis/test_router.py @@ -1,40 +1,28 @@ import asyncio -from typing import Any import pytest from faststream import Path from faststream.redis import ( - RedisBroker, RedisPublisher, RedisRoute, RedisRouter, - TestRedisBroker, ) from tests.brokers.base.router import RouterLocalTestcase, RouterTestcase +from .basic import RedisMemoryTestcaseConfig, RedisTestcaseConfig + @pytest.mark.redis() -class TestRouter(RouterTestcase): - broker_class = RedisRouter +class TestRouter(RedisTestcaseConfig, RouterTestcase): route_class = RedisRoute publisher_class = RedisPublisher - def get_broker(self, apply_types: bool = False, **kwargs: Any) -> RedisBroker: - return RedisBroker(apply_types=apply_types, **kwargs) - -class TestRouterLocal(RouterLocalTestcase): - broker_class = RedisRouter +class TestRouterLocal(RedisMemoryTestcaseConfig, RouterLocalTestcase): route_class = RedisRoute publisher_class = RedisPublisher - def get_broker(self, apply_types: bool = False, **kwargs: Any) -> RedisBroker: - return RedisBroker(apply_types=apply_types, **kwargs) - - def patch_broker(self, broker: RedisBroker, **kwargs: Any) -> RedisBroker: - return TestRedisBroker(broker, **kwargs) - async def test_router_path( self, event, diff --git a/tests/brokers/redis/test_test_client.py b/tests/brokers/redis/test_test_client.py index c9bcbe5527..da352a0c7b 100644 --- a/tests/brokers/redis/test_test_client.py +++ b/tests/brokers/redis/test_test_client.py @@ -1,24 +1,17 @@ import asyncio -from typing import Any import pytest from faststream import BaseMiddleware -from faststream.redis import ListSub, RedisBroker, StreamSub, TestRedisBroker +from faststream.redis import ListSub, StreamSub from faststream.redis.testing import FakeProducer from tests.brokers.base.testclient import BrokerTestclientTestcase +from .basic import RedisMemoryTestcaseConfig -@pytest.mark.asyncio() -class TestTestclient(BrokerTestclientTestcase): - test_class = TestRedisBroker - - def get_broker(self, apply_types: bool = False, **kwargs: Any) -> RedisBroker: - return RedisBroker(apply_types=apply_types, **kwargs) - - def patch_broker(self, broker: RedisBroker, **kwargs: Any) -> TestRedisBroker: - return TestRedisBroker(broker, **kwargs) +@pytest.mark.asyncio() +class TestTestclient(RedisMemoryTestcaseConfig, BrokerTestclientTestcase): @pytest.mark.redis() async def test_with_real_testclient( self, @@ -43,7 +36,7 @@ def subscriber(m) -> None: assert event.is_set() - async def test_respect_middleware(self, queue) -> None: + async def test_respect_middleware(self, queue: str) -> None: routes = [] class Middleware(BaseMiddleware): @@ -66,7 +59,7 @@ async def h2(m) -> None: ... assert len(routes) == 2 @pytest.mark.redis() - async def test_real_respect_middleware(self, queue) -> None: + async def test_real_respect_middleware(self, queue: str) -> None: routes = [] class Middleware(BaseMiddleware): @@ -209,10 +202,7 @@ async def m(msg): m.mock.assert_called_once_with("hello") publisher.mock.assert_called_once_with([1, 2, 3]) - async def test_publish_to_none( - self, - queue: str, - ) -> None: + async def test_publish_to_none(self) -> None: broker = self.get_broker() async with self.patch_broker(broker) as br: diff --git a/tests/cli/test_asyncapi_docs.py b/tests/cli/test_asyncapi_docs.py index 9deb1877b9..344ba54013 100644 --- a/tests/cli/test_asyncapi_docs.py +++ b/tests/cli/test_asyncapi_docs.py @@ -1,5 +1,6 @@ import json import sys +import traceback from http.server import HTTPServer from pathlib import Path from unittest.mock import Mock @@ -94,7 +95,7 @@ def test_serve_asyncapi_json_schema( m.setattr(HTTPServer, "serve_forever", mock) r = runner.invoke(cli, SERVE_CMD + [str(schema_path)]) # noqa: RUF005 - assert r.exit_code == 0, r.exc_info + assert r.exit_code == 0, traceback.format_tb(r.exc_info[2]) mock.assert_called_once() schema_path.unlink() @@ -115,7 +116,7 @@ def test_serve_asyncapi_yaml_schema( m.setattr(HTTPServer, "serve_forever", mock) r = runner.invoke(cli, SERVE_CMD + [str(schema_path)]) # noqa: RUF005 - assert r.exit_code == 0, r.exc_info + assert r.exit_code == 0, traceback.format_tb(r.exc_info[2]) mock.assert_called_once() schema_path.unlink() diff --git a/tests/cli/test_run_asgi.py b/tests/cli/test_run_asgi.py index c644c04bb2..49825f932b 100644 --- a/tests/cli/test_run_asgi.py +++ b/tests/cli/test_run_asgi.py @@ -34,14 +34,7 @@ def test_run_as_asgi(runner: CliRunner) -> None: assert result.exit_code == 0 -@pytest.mark.parametrize( - "workers", - ( - pytest.param(1), - pytest.param(2), - pytest.param(5), - ), -) +@pytest.mark.parametrize("workers", (pytest.param(1), pytest.param(2), pytest.param(5))) def test_run_as_asgi_with_workers(runner: CliRunner, workers: int) -> None: app = AsgiFastStream(AsyncMock()) app.run = AsyncMock() diff --git a/tests/examples/kafka/test_ack.py b/tests/examples/kafka/test_ack.py index 3dfe4e8502..c2402f2f6a 100644 --- a/tests/examples/kafka/test_ack.py +++ b/tests/examples/kafka/test_ack.py @@ -4,12 +4,14 @@ from examples.kafka.ack_after_process import app, broker from faststream.kafka import TestApp, TestKafkaBroker -from faststream.kafka.message import KafkaMessage +from faststream.kafka.message import KafkaAckableMessage from tests.tools import spy_decorator @pytest.mark.asyncio() async def test_ack() -> None: - with patch.object(KafkaMessage, "ack", spy_decorator(KafkaMessage.ack)) as m: + with patch.object( + KafkaAckableMessage, "ack", spy_decorator(KafkaAckableMessage.ack) + ) as m: async with TestKafkaBroker(broker), TestApp(app): m.mock.assert_called_once() diff --git a/tests/prometheus/basic.py b/tests/prometheus/basic.py index 6a5f0e303e..e7598aa166 100644 --- a/tests/prometheus/basic.py +++ b/tests/prometheus/basic.py @@ -6,12 +6,13 @@ from prometheus_client import CollectorRegistry from faststream import Context -from faststream.exceptions import RejectMessage +from faststream.exceptions import IgnoredException, RejectMessage from faststream.message import AckStatus from faststream.prometheus import MetricsSettingsProvider from faststream.prometheus.middleware import ( PROCESSING_STATUS_BY_ACK_STATUS, PROCESSING_STATUS_BY_HANDLER_EXCEPTION_MAP, + BasePrometheusMiddleware, ) from faststream.prometheus.types import ProcessingStatus from tests.brokers.base.basic import BaseTestcaseConfig @@ -19,10 +20,10 @@ @pytest.mark.asyncio() class LocalPrometheusTestcase(BaseTestcaseConfig): - def get_broker(self, apply_types=False, **kwargs): + def get_broker(self, apply_types: bool = False, **kwargs: Any): raise NotImplementedError - def get_middleware(self, **kwargs): + def get_middleware(self, **kwargs: Any) -> BasePrometheusMiddleware: raise NotImplementedError @staticmethod @@ -51,13 +52,26 @@ def settings_provider_factory(self): Exception, id="acked status with not handler exception", ), - pytest.param(AckStatus.ACKED, None, id="acked status without exception"), - pytest.param(AckStatus.NACKED, None, id="nacked status without exception"), + pytest.param( + AckStatus.ACKED, + None, + id="acked status without exception", + ), + pytest.param( + AckStatus.NACKED, + None, + id="nacked status without exception", + ), pytest.param( AckStatus.REJECTED, None, id="rejected status without exception", ), + pytest.param( + AckStatus.ACKED, + IgnoredException, + id="acked status with ignore exception", + ), ), ) async def test_metrics( @@ -65,7 +79,7 @@ async def test_metrics( queue: str, status: AckStatus, exception_class: Optional[type[Exception]], - ): + ) -> None: event = asyncio.Event() middleware = self.get_middleware(registry=CollectorRegistry()) @@ -117,7 +131,7 @@ def assert_consume_metrics( metrics_manager: Any, message: Any, exception_class: Optional[type[Exception]], - ): + ) -> None: settings_provider = self.settings_provider_factory(message.raw_message) consume_attrs = settings_provider.get_consume_attrs_from_message(message) assert metrics_manager.add_received_message.mock_calls == [ @@ -181,7 +195,7 @@ def assert_consume_metrics( ), ] - if status == ProcessingStatus.error: + if exception_class and not issubclass(exception_class, IgnoredException): assert ( metrics_manager.add_received_processed_message_exception.mock_calls == [ @@ -192,8 +206,13 @@ def assert_consume_metrics( ), ] ) + else: + assert ( + metrics_manager.add_received_processed_message_exception.mock_calls + == [] + ) - def assert_publish_metrics(self, metrics_manager: Any): + def assert_publish_metrics(self, metrics_manager: Any) -> None: settings_provider = self.settings_provider_factory(None) assert metrics_manager.observe_published_message_duration.mock_calls == [ call( @@ -245,6 +264,9 @@ async def handle(): class LocalMetricsSettingsProviderTestcase: messaging_system: str + def get_middleware(self, **kwargs) -> BasePrometheusMiddleware: + raise NotImplementedError + @staticmethod def get_provider() -> MetricsSettingsProvider: raise NotImplementedError @@ -253,8 +275,15 @@ def test_messaging_system(self) -> None: provider = self.get_provider() assert provider.messaging_system == self.messaging_system - def test_get_consume_attrs_from_message(self, *args, **kwargs) -> None: - raise NotImplementedError + def test_one_registry_for_some_middlewares(self) -> None: + registry = CollectorRegistry() - def test_get_publish_destination_name_from_cmd(self, *args, **kwargs) -> None: - raise NotImplementedError + middleware_1 = self.get_middleware(registry=registry) + middleware_2 = self.get_middleware(registry=registry) + self.get_broker(middlewares=(middleware_1,)) + self.get_broker(middlewares=(middleware_2,)) + + assert ( + middleware_1._metrics_container.received_messages_total + is middleware_2._metrics_container.received_messages_total + ) diff --git a/tests/prometheus/confluent/basic.py b/tests/prometheus/confluent/basic.py new file mode 100644 index 0000000000..92df8e38f4 --- /dev/null +++ b/tests/prometheus/confluent/basic.py @@ -0,0 +1,32 @@ +from typing import Any + +from faststream import AckPolicy +from faststream.confluent import KafkaBroker +from faststream.confluent.prometheus import KafkaPrometheusMiddleware +from tests.brokers.confluent.basic import ConfluentTestcaseConfig + + +class KafkaPrometheusSettings(ConfluentTestcaseConfig): + messaging_system = "kafka" + + def get_broker(self, apply_types=False, **kwargs: Any) -> KafkaBroker: + return KafkaBroker(apply_types=apply_types, **kwargs) + + def get_middleware(self, **kwargs: Any) -> KafkaPrometheusMiddleware: + return KafkaPrometheusMiddleware(**kwargs) + + def get_subscriber_params( + self, + *topics: Any, + **kwargs: Any, + ) -> tuple[ + tuple[Any, ...], + dict[str, Any], + ]: + topics, kwargs = super().get_subscriber_params(*topics, **kwargs) + + return topics, { + "group_id": "test", + "ack_policy": AckPolicy.REJECT_ON_ERROR, + **kwargs, + } diff --git a/tests/prometheus/confluent/test_confluent.py b/tests/prometheus/confluent/test_confluent.py index 84714bd280..1cf625b803 100644 --- a/tests/prometheus/confluent/test_confluent.py +++ b/tests/prometheus/confluent/test_confluent.py @@ -1,4 +1,5 @@ import asyncio +from typing import Any from unittest.mock import Mock import pytest @@ -7,20 +8,15 @@ from faststream import Context from faststream.confluent import KafkaBroker from faststream.confluent.prometheus.middleware import KafkaPrometheusMiddleware -from tests.brokers.confluent.basic import ConfluentTestcaseConfig from tests.brokers.confluent.test_consume import TestConsume from tests.brokers.confluent.test_publish import TestPublish from tests.prometheus.basic import LocalPrometheusTestcase +from .basic import KafkaPrometheusSettings -@pytest.mark.confluent() -class TestPrometheus(ConfluentTestcaseConfig, LocalPrometheusTestcase): - def get_broker(self, apply_types=False, **kwargs): - return KafkaBroker(apply_types=apply_types, **kwargs) - - def get_middleware(self, **kwargs): - return KafkaPrometheusMiddleware(**kwargs) +@pytest.mark.confluent() +class TestPrometheus(KafkaPrometheusSettings, LocalPrometheusTestcase): async def test_metrics_batch( self, queue: str, @@ -62,7 +58,7 @@ async def handler(m=Context("message")): @pytest.mark.confluent() class TestPublishWithPrometheus(TestPublish): - def get_broker(self, apply_types: bool = False, **kwargs): + def get_broker(self, apply_types: bool = False, **kwargs: Any) -> KafkaBroker: return KafkaBroker( middlewares=(KafkaPrometheusMiddleware(registry=CollectorRegistry()),), apply_types=apply_types, @@ -72,7 +68,7 @@ def get_broker(self, apply_types: bool = False, **kwargs): @pytest.mark.confluent() class TestConsumeWithPrometheus(TestConsume): - def get_broker(self, apply_types: bool = False, **kwargs): + def get_broker(self, apply_types: bool = False, **kwargs: Any) -> KafkaBroker: return KafkaBroker( middlewares=(KafkaPrometheusMiddleware(registry=CollectorRegistry()),), apply_types=apply_types, diff --git a/tests/prometheus/confluent/test_provider.py b/tests/prometheus/confluent/test_provider.py index 6949a1ff26..d65c9d6a2a 100644 --- a/tests/prometheus/confluent/test_provider.py +++ b/tests/prometheus/confluent/test_provider.py @@ -11,12 +11,13 @@ from faststream.prometheus import MetricsSettingsProvider from tests.prometheus.basic import LocalMetricsSettingsProviderTestcase +from .basic import KafkaPrometheusSettings + class LocalBaseConfluentMetricsSettingsProviderTestcase( - LocalMetricsSettingsProviderTestcase + KafkaPrometheusSettings, + LocalMetricsSettingsProviderTestcase, ): - messaging_system = "kafka" - def test_get_publish_destination_name_from_cmd(self, queue: str) -> None: expected_destination_name = queue provider = self.get_provider() diff --git a/tests/prometheus/kafka/basic.py b/tests/prometheus/kafka/basic.py new file mode 100644 index 0000000000..0225f7053a --- /dev/null +++ b/tests/prometheus/kafka/basic.py @@ -0,0 +1,31 @@ +from typing import Any + +from faststream import AckPolicy +from faststream.kafka import KafkaBroker +from faststream.kafka.prometheus import KafkaPrometheusMiddleware +from tests.brokers.kafka.basic import KafkaTestcaseConfig + + +class KafkaPrometheusSettings(KafkaTestcaseConfig): + messaging_system = "kafka" + + def get_broker(self, apply_types=False, **kwargs: Any) -> KafkaBroker: + return KafkaBroker(apply_types=apply_types, **kwargs) + + def get_middleware(self, **kwargs: Any) -> KafkaPrometheusMiddleware: + return KafkaPrometheusMiddleware(**kwargs) + + def get_subscriber_params( + self, + *args: Any, + **kwargs: Any, + ) -> tuple[ + tuple[Any, ...], + dict[str, Any], + ]: + args, kwargs = super().get_subscriber_params(*args, **kwargs) + return args, { + "group_id": "test", + "ack_policy": AckPolicy.REJECT_ON_ERROR, + **kwargs, + } diff --git a/tests/prometheus/kafka/test_kafka.py b/tests/prometheus/kafka/test_kafka.py index 7ba5ba6f82..4384add588 100644 --- a/tests/prometheus/kafka/test_kafka.py +++ b/tests/prometheus/kafka/test_kafka.py @@ -11,15 +11,11 @@ from tests.brokers.kafka.test_publish import TestPublish from tests.prometheus.basic import LocalPrometheusTestcase +from .basic import KafkaPrometheusSettings -@pytest.mark.kafka() -class TestPrometheus(LocalPrometheusTestcase): - def get_broker(self, apply_types=False, **kwargs): - return KafkaBroker(apply_types=apply_types, **kwargs) - - def get_middleware(self, **kwargs): - return KafkaPrometheusMiddleware(**kwargs) +@pytest.mark.kafka() +class TestPrometheus(KafkaPrometheusSettings, LocalPrometheusTestcase): async def test_metrics_batch( self, queue: str, diff --git a/tests/prometheus/kafka/test_provider.py b/tests/prometheus/kafka/test_provider.py index 1e0c980981..2b590712b2 100644 --- a/tests/prometheus/kafka/test_provider.py +++ b/tests/prometheus/kafka/test_provider.py @@ -11,12 +11,13 @@ from faststream.prometheus import MetricsSettingsProvider from tests.prometheus.basic import LocalMetricsSettingsProviderTestcase +from .basic import KafkaPrometheusSettings + class LocalBaseKafkaMetricsSettingsProviderTestcase( - LocalMetricsSettingsProviderTestcase + KafkaPrometheusSettings, + LocalMetricsSettingsProviderTestcase, ): - messaging_system = "kafka" - def test_get_publish_destination_name_from_cmd(self, queue: str) -> None: expected_destination_name = queue provider = self.get_provider() diff --git a/tests/prometheus/nats/basic.py b/tests/prometheus/nats/basic.py new file mode 100644 index 0000000000..2bb3abbad9 --- /dev/null +++ b/tests/prometheus/nats/basic.py @@ -0,0 +1,14 @@ +from typing import Any + +from faststream.nats import NatsBroker +from faststream.nats.prometheus import NatsPrometheusMiddleware + + +class NatsPrometheusSettings: + messaging_system = "nats" + + def get_broker(self, apply_types=False, **kwargs: Any) -> NatsBroker: + return NatsBroker(apply_types=apply_types, **kwargs) + + def get_middleware(self, **kwargs: Any) -> NatsPrometheusMiddleware: + return NatsPrometheusMiddleware(**kwargs) diff --git a/tests/prometheus/nats/test_nats.py b/tests/prometheus/nats/test_nats.py index 117b696922..a3bdbed2e0 100644 --- a/tests/prometheus/nats/test_nats.py +++ b/tests/prometheus/nats/test_nats.py @@ -1,4 +1,5 @@ import asyncio +from typing import Any from unittest.mock import Mock import pytest @@ -11,6 +12,8 @@ from tests.brokers.nats.test_publish import TestPublish from tests.prometheus.basic import LocalPrometheusTestcase, LocalRPCPrometheusTestcase +from .basic import NatsPrometheusSettings + @pytest.fixture() def stream(queue): @@ -18,18 +21,14 @@ def stream(queue): @pytest.mark.nats() -class TestPrometheus(LocalPrometheusTestcase, LocalRPCPrometheusTestcase): - def get_broker(self, apply_types=False, **kwargs): - return NatsBroker(apply_types=apply_types, **kwargs) - - def get_middleware(self, **kwargs): - return NatsPrometheusMiddleware(**kwargs) - +class TestPrometheus( + NatsPrometheusSettings, LocalPrometheusTestcase, LocalRPCPrometheusTestcase +): async def test_metrics_batch( self, queue: str, stream: JStream, - ): + ) -> None: event = asyncio.Event() middleware = self.get_middleware(registry=CollectorRegistry()) @@ -69,7 +68,7 @@ async def handler(m=Context("message")): @pytest.mark.nats() class TestPublishWithPrometheus(TestPublish): - def get_broker(self, apply_types: bool = False, **kwargs): + def get_broker(self, apply_types: bool = False, **kwargs: Any) -> NatsBroker: return NatsBroker( middlewares=(NatsPrometheusMiddleware(registry=CollectorRegistry()),), apply_types=apply_types, @@ -79,7 +78,7 @@ def get_broker(self, apply_types: bool = False, **kwargs): @pytest.mark.nats() class TestConsumeWithPrometheus(TestConsume): - def get_broker(self, apply_types: bool = False, **kwargs): + def get_broker(self, apply_types: bool = False, **kwargs: Any) -> NatsBroker: return NatsBroker( middlewares=(NatsPrometheusMiddleware(registry=CollectorRegistry()),), apply_types=apply_types, diff --git a/tests/prometheus/nats/test_provider.py b/tests/prometheus/nats/test_provider.py index 817a37142d..c8ce53dc72 100644 --- a/tests/prometheus/nats/test_provider.py +++ b/tests/prometheus/nats/test_provider.py @@ -12,12 +12,12 @@ from faststream.prometheus import MetricsSettingsProvider from tests.prometheus.basic import LocalMetricsSettingsProviderTestcase +from .basic import NatsPrometheusSettings + class LocalBaseNatsMetricsSettingsProviderTestcase( - LocalMetricsSettingsProviderTestcase + NatsPrometheusSettings, LocalMetricsSettingsProviderTestcase ): - messaging_system = "nats" - def test_get_publish_destination_name_from_cmd(self, queue: str) -> None: expected_destination_name = queue command = SimpleNamespace(destination=queue) @@ -29,8 +29,7 @@ def test_get_publish_destination_name_from_cmd(self, queue: str) -> None: class TestNatsMetricsSettingsProvider(LocalBaseNatsMetricsSettingsProviderTestcase): - @staticmethod - def get_provider() -> MetricsSettingsProvider: + def get_provider(self) -> MetricsSettingsProvider: return NatsMetricsSettingsProvider() def test_get_consume_attrs_from_message(self, queue: str) -> None: diff --git a/tests/prometheus/rabbit/basic.py b/tests/prometheus/rabbit/basic.py new file mode 100644 index 0000000000..d1d192e459 --- /dev/null +++ b/tests/prometheus/rabbit/basic.py @@ -0,0 +1,15 @@ +from typing import Any + +from faststream.rabbit import RabbitBroker +from faststream.rabbit.prometheus import RabbitPrometheusMiddleware +from tests.brokers.rabbit.basic import RabbitTestcaseConfig + + +class RabbitPrometheusSettings(RabbitTestcaseConfig): + messaging_system = "rabbitmq" + + def get_broker(self, apply_types=False, **kwargs: Any) -> RabbitBroker: + return RabbitBroker(apply_types=apply_types, **kwargs) + + def get_middleware(self, **kwargs: Any) -> RabbitPrometheusMiddleware: + return RabbitPrometheusMiddleware(**kwargs) diff --git a/tests/prometheus/rabbit/test_provider.py b/tests/prometheus/rabbit/test_provider.py index 71d47a781b..9b6eb0b7bf 100644 --- a/tests/prometheus/rabbit/test_provider.py +++ b/tests/prometheus/rabbit/test_provider.py @@ -5,12 +5,19 @@ from faststream.prometheus import MetricsSettingsProvider from faststream.rabbit.prometheus.provider import RabbitMetricsSettingsProvider -from tests.prometheus.basic import LocalMetricsSettingsProviderTestcase +from tests.prometheus.basic import ( + LocalMetricsSettingsProviderTestcase, + LocalRPCPrometheusTestcase, +) +from .basic import RabbitPrometheusSettings -class TestRabbitMetricsSettingsProvider(LocalMetricsSettingsProviderTestcase): - messaging_system = "rabbitmq" +class TestRabbitMetricsSettingsProvider( + RabbitPrometheusSettings, + LocalMetricsSettingsProviderTestcase, + LocalRPCPrometheusTestcase, +): @staticmethod def get_provider() -> MetricsSettingsProvider: return RabbitMetricsSettingsProvider() diff --git a/tests/prometheus/rabbit/test_rabbit.py b/tests/prometheus/rabbit/test_rabbit.py index f64786fc4f..dff264063a 100644 --- a/tests/prometheus/rabbit/test_rabbit.py +++ b/tests/prometheus/rabbit/test_rabbit.py @@ -1,3 +1,5 @@ +from typing import Any + import pytest from prometheus_client import CollectorRegistry @@ -7,6 +9,8 @@ from tests.brokers.rabbit.test_publish import TestPublish from tests.prometheus.basic import LocalPrometheusTestcase, LocalRPCPrometheusTestcase +from .basic import RabbitPrometheusSettings + @pytest.fixture() def exchange(queue): @@ -14,17 +18,15 @@ def exchange(queue): @pytest.mark.rabbit() -class TestPrometheus(LocalPrometheusTestcase, LocalRPCPrometheusTestcase): - def get_broker(self, apply_types=False, **kwargs): - return RabbitBroker(apply_types=apply_types, **kwargs) - - def get_middleware(self, **kwargs): - return RabbitPrometheusMiddleware(**kwargs) +class TestPrometheus( + RabbitPrometheusSettings, LocalPrometheusTestcase, LocalRPCPrometheusTestcase +): + pass @pytest.mark.rabbit() class TestPublishWithPrometheus(TestPublish): - def get_broker(self, apply_types: bool = False, **kwargs): + def get_broker(self, apply_types: bool = False, **kwargs: Any) -> RabbitBroker: return RabbitBroker( middlewares=(RabbitPrometheusMiddleware(registry=CollectorRegistry()),), apply_types=apply_types, @@ -34,7 +36,7 @@ def get_broker(self, apply_types: bool = False, **kwargs): @pytest.mark.rabbit() class TestConsumeWithPrometheus(TestConsume): - def get_broker(self, apply_types: bool = False, **kwargs): + def get_broker(self, apply_types: bool = False, **kwargs: Any) -> RabbitBroker: return RabbitBroker( middlewares=(RabbitPrometheusMiddleware(registry=CollectorRegistry()),), apply_types=apply_types, diff --git a/tests/prometheus/redis/basic.py b/tests/prometheus/redis/basic.py new file mode 100644 index 0000000000..1a6ed6cfeb --- /dev/null +++ b/tests/prometheus/redis/basic.py @@ -0,0 +1,15 @@ +from typing import Any + +from faststream.redis import RedisBroker +from faststream.redis.prometheus import RedisPrometheusMiddleware +from tests.brokers.redis.basic import RedisTestcaseConfig + + +class RedisPrometheusSettings(RedisTestcaseConfig): + messaging_system = "redis" + + def get_broker(self, apply_types=False, **kwargs: Any) -> RedisBroker: + return RedisBroker(apply_types=apply_types, **kwargs) + + def get_middleware(self, **kwargs: Any) -> RedisPrometheusMiddleware: + return RedisPrometheusMiddleware(**kwargs) diff --git a/tests/prometheus/redis/test_provider.py b/tests/prometheus/redis/test_provider.py index c1b593b545..58e84cee4d 100644 --- a/tests/prometheus/redis/test_provider.py +++ b/tests/prometheus/redis/test_provider.py @@ -17,12 +17,13 @@ ) from tests.prometheus.basic import LocalMetricsSettingsProviderTestcase +from .basic import RedisPrometheusSettings + class LocalBaseRedisMetricsSettingsProviderTestcase( - LocalMetricsSettingsProviderTestcase + RedisPrometheusSettings, + LocalMetricsSettingsProviderTestcase, ): - messaging_system = "redis" - def test_get_publish_destination_name_from_cmd(self, queue: str) -> None: expected_destination_name = queue provider = self.get_provider() diff --git a/tests/prometheus/redis/test_redis.py b/tests/prometheus/redis/test_redis.py index ee7f62cfd2..4f6954116f 100644 --- a/tests/prometheus/redis/test_redis.py +++ b/tests/prometheus/redis/test_redis.py @@ -1,4 +1,5 @@ import asyncio +from typing import Any from unittest.mock import Mock import pytest @@ -11,15 +12,15 @@ from tests.brokers.redis.test_publish import TestPublish from tests.prometheus.basic import LocalPrometheusTestcase, LocalRPCPrometheusTestcase +from .basic import RedisPrometheusSettings -@pytest.mark.redis() -class TestPrometheus(LocalPrometheusTestcase, LocalRPCPrometheusTestcase): - def get_broker(self, apply_types=False, **kwargs): - return RedisBroker(apply_types=apply_types, **kwargs) - - def get_middleware(self, **kwargs): - return RedisPrometheusMiddleware(**kwargs) +@pytest.mark.redis() +class TestPrometheus( + RedisPrometheusSettings, + LocalPrometheusTestcase, + LocalRPCPrometheusTestcase, +): async def test_metrics_batch( self, queue: str, @@ -60,7 +61,7 @@ async def handler(m=Context("message")): @pytest.mark.redis() class TestPublishWithPrometheus(TestPublish): - def get_broker(self, apply_types: bool = False, **kwargs): + def get_broker(self, apply_types: bool = False, **kwargs: Any) -> RedisBroker: return RedisBroker( middlewares=(RedisPrometheusMiddleware(registry=CollectorRegistry()),), apply_types=apply_types, @@ -70,7 +71,7 @@ def get_broker(self, apply_types: bool = False, **kwargs): @pytest.mark.redis() class TestConsumeWithPrometheus(TestConsume): - def get_broker(self, apply_types: bool = False, **kwargs): + def get_broker(self, apply_types: bool = False, **kwargs: Any) -> RedisBroker: return RedisBroker( middlewares=(RedisPrometheusMiddleware(registry=CollectorRegistry()),), apply_types=apply_types,